In [1]:
# loading the 0.0005% of the dataset
# from datasets import load_dataset
# dataset = load_dataset("cerebras/SlimPajama-627B", split="train[:0.0005%]")

In [2]:
import zstandard as zstd 
import json 

# opening using zstandard 
with open("example_train_0.jsonl.zst", "rb") as original:
    doc = zstd.ZstdDecompressor()
    with open("slimpajama_sample.jsonl", "wb") as out:
        doc.copy_stream(original, out)


In [3]:
dataset = []
with open("slimpajama_sample.jsonl", "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        data = json.loads(line)
        dataset.append(data["text"])

In [4]:
# Using GPT2 Tokenizer since rho-1 was trained using Tinyllama as both a base model and a reference model 
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

In [5]:
import torch
tokenizer.pad_token = tokenizer.eos_token
encoded_dataset = tokenizer(
    dataset,                         
    padding=True,
    truncation=True,
    max_length=128,                
    return_tensors="pt",           
    return_attention_mask=True, 
)
torch.save(encoded_dataset, "tokenizer_dataset.pt")

In [6]:
# This is the line for converting ids to tokens to see them 
# tokens = tokenizer.convert_ids_to_tokens(encoded_dataset["input_ids"][0])
# print(tokens)

In [7]:
from torch.utils.data import DataLoader, TensorDataset
import torch.serialization
from transformers.tokenization_utils_base import BatchEncoding

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.serialization.add_safe_globals([BatchEncoding])
encoded_dataset = torch.load("tokenizer_dataset.pt")
input_ids = encoded_dataset["input_ids"]
attention_mask = encoded_dataset["attention_mask"]

train_dataset = TensorDataset(input_ids, attention_mask)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)


In [8]:
import torch
from transformers import GPT2LMHeadModel
from torch.optim import AdamW

# Training a reference model 
reference_model = GPT2LMHeadModel.from_pretrained("distilgpt2")
reference_model.resize_token_embeddings(len(tokenizer))
reference_model.to(device)
reference_model.train()

# Optimizer setup: RHO-1 paper: LR 5e-5
optimizer = AdamW(reference_model.parameters(), lr=5e-5)

num_epochs = 1  # We use 3 for the Rho 
for epoch in range(num_epochs):
    total_loss = 0
    for i, batch in enumerate(train_dataloader):
        batch_input_ids, batch_attention_mask = [b for b in batch]

        outputs = reference_model(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            labels=batch_input_ids
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataloader)
reference_model.save_pretrained("reference_model")
tokenizer.save_pretrained("reference_model")

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


 batch 1 completed and  Total Loss is : 3.9176
 batch 2 completed and  Total Loss is : 8.2188
 batch 3 completed and  Total Loss is : 12.1910
 batch 4 completed and  Total Loss is : 16.8535
 batch 5 completed and  Total Loss is : 22.7946
 batch 6 completed and  Total Loss is : 28.0647
 batch 7 completed and  Total Loss is : 33.0705
 batch 8 completed and  Total Loss is : 36.9488
 batch 9 completed and  Total Loss is : 40.1713
 batch 10 completed and  Total Loss is : 44.2876
 batch 11 completed and  Total Loss is : 48.3998
 batch 12 completed and  Total Loss is : 52.1979
 batch 13 completed and  Total Loss is : 56.4609
 batch 14 completed and  Total Loss is : 61.3824
 batch 15 completed and  Total Loss is : 65.5958
 batch 16 completed and  Total Loss is : 69.6157
 batch 17 completed and  Total Loss is : 73.8085
 batch 18 completed and  Total Loss is : 77.6510
 batch 19 completed and  Total Loss is : 81.2258
 batch 20 completed and  Total Loss is : 85.8323
 batch 21 completed and  Total 

('reference_model/tokenizer_config.json',
 'reference_model/special_tokens_map.json',
 'reference_model/vocab.json',
 'reference_model/merges.txt',
 'reference_model/added_tokens.json')