In [1]:
# from datasets import load_dataset
dataset = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True)
target_samples = list(dataset.take(20_000))

In [2]:
# # saving the dataset to disk for east retrival 
from datasets import Dataset
dataset = Dataset.from_list(target_samples)
dataset.save_to_disk("tiny_slimpajama")

In [3]:
# load the dataset from the disk
from datasets import load_from_disk
dataset = load_from_disk("tiny_slimpajama")

In [4]:
# Tokenize the dataset 
from transformers import GPT2Tokenizer, GPT2LMHeadModel 
import torch 

tokenizer = GPT2Tokenizer.from_pretrained("reference_model")
reference_model = GPT2LMHeadModel.from_pretrained("reference_model")
tokenizer.pad_token = tokenizer.eos_token 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reference_model.to(device)
reference_model.eval() 

data = tokenizer(
    dataset["text"], 
    padding = True,
    truncation = True,
    max_length = 128,
    return_tensors = "pt",
    return_attention_mask = True,
)
torch.save(data, "tokens_target_model.pt")


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [5]:
# Printing Tokens For Testing 
# tokens = tokenizer.convert_ids_to_tokens(data["input_ids"][0])
# print(tokens)

In [6]:
from torch.utils.data import DataLoader, TensorDataset 
import torch.serialization
from transformers.tokenization_utils_base import BatchEncoding

torch.serialization.add_safe_globals([BatchEncoding])
data = torch.load("tokens_target_model.pt")
big_dataset = TensorDataset(data["input_ids"], data["attention_mask"])
dataloader_big_dataset = DataLoader(big_dataset, batch_size = 4, shuffle=False)

In [13]:
# Use the reference model to calculate the LRM for the new dataset 

import torch.nn.functional as F 
reference_losses = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i, batch in enumerate(dataloader_big_dataset):
    input_ids, attention_mask = [b.to(device) for b in batch]

    with torch.no_grad(): # not track the gradients and update weights 
        outputs = reference_model(input_ids = input_ids, attention_mask=attention_mask)
        
        logits = outputs.logits # [batch_size, sequence_length, vocab_size]
        
        cut_logits = logits[:, :-1, :]
        cut_labels = input_ids[:, 1:]
        cut_attention = attention_mask[:, 1:]

        probs = F.softmax(cut_logits, dim=-1)
        true_token_probs = probs.gather(2, cut_labels.unsqueeze(-1)).squeeze(-1)
        reference_loss = -torch.log(true_token_probs + 1e-9) # We add small value to avoild log(0)
        reference_loss = reference_loss* cut_attention 
        reference_losses.append(reference_loss)

torch.save(torch.cat(reference_losses, dim=0), "reference_loss_final.pt")

In [14]:
# Initializing the Target Model 
from transformers import GPT2Tokenizer, GPT2LMHeadModel 
from torch.optim import AdamW

target_model = GPT2LMHeadModel.from_pretrained("distilgpt2")
target_model.resize_token_embeddings(len(tokenizer))
target_model.to(device)
target_model.eval() 
optimizer = AdamW(target_model.parameters(), lr=5e-5)

In [15]:
target_losses = []

for batch in dataloader_big_dataset:
    input_ids, attention_mask = [b.to(device) for b in batch]

    with torch.no_grad(): 
        outputs = target_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        shift_attention = attention_mask[:, 1:].contiguous()

        probs = F.softmax(shift_logits, dim=-1)
        true_token_probs = probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
        target_loss = -torch.log(true_token_probs + 1e-9) * shift_attention

        target_losses.append(target_loss)

torch.save(torch.cat(reference_losses, dim=0), "target_loss_final.pt")

In [16]:
ref_los = torch.load("reference_loss_final.pt")   
target_los = torch.load("target_loss_final.pt")  

excess_loss = target_los - ref_los  

k = 0.4  # top 40%
flat_excess = excess_loss[ref_los > 0].flatten()  
threshold = torch.quantile(flat_excess, 1 - k)

mask = (excess_loss >= threshold) & (ref_los > 0)  # top-k mask

pad = torch.zeros(mask.size(0), 1).bool() # Pad BOS token as the first token doesn't count
topk_mask_full = torch.cat([pad, mask], dim=1) 

# Checking evaluate slm loss
masked_loss = target_los * mask.float()
slm_loss = masked_loss.sum() / mask.float().sum()
print(f" slm loss over top {int(k*100)}% tokens is {slm_loss:.4f}")
torch.save(topk_mask_full.float(), "topk_mask.pt")


 slm loss over top 40% tokens is 3.7647


In [17]:
# Training the target model on the hard token 

topk_mask_full = torch.load("topk_mask.pt").float()
start_idx = 0 

for epoch in range(1):
    for batch_idx, batch in enumerate(dataloader_big_dataset):
        input_ids, attention_mask = [b.to(device) for b in batch]
        batch_size = input_ids.shape[0]

        mask_batch = topk_mask_full[start_idx:start_idx + batch_size, :].to(device)
        start_idx += batch_size

        outputs = target_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        shift_logits     = logits[:, :-1, :]
        shift_labels     = input_ids[:, 1:]
        shift_mask       = mask_batch[:, :-1]
        shift_attention  = attention_mask[:, 1:]


        probs = F.softmax(shift_logits, dim=-1)
        true_token_probs = probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
        token_loss = -torch.log(true_token_probs + 1e-9)

        selected_loss = token_loss * shift_mask
        slm_loss = selected_loss.sum() / shift_mask.sum()

        slm_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

target_model.save_pretrained("slm_target_model")
tokenizer.save_pretrained("slm_target_model")


('slm_target_model/tokenizer_config.json',
 'slm_target_model/special_tokens_map.json',
 'slm_target_model/vocab.json',
 'slm_target_model/merges.txt',
 'slm_target_model/added_tokens.json')