In [109]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType, PeftType
from BidirectionalSwitchForLoraLlama import LoraLlamaBidirectionalSwitch
from peft import get_peft_model, LoraConfig, TaskType, PeftType, PeftModel
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import json
import math
from tqdm import tqdm
from DataHandling import MLMSlimPajamaDataset, dataloader_collate_function
import pandas as pd
import gc
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR

model_checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
llm = AutoModelForCausalLM.from_pretrained(model_checkpoint, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

## Data prep

Take the first 1000 sentences from the SlimPajama dataset for training

In [4]:
data_files = ["train/chunk1/*.jsonl.zst"]
slim_pajama = load_dataset("cerebras/SlimPajama-627B", data_files=data_files, streaming=True)

Resolving data files:   0%|          | 0/5912 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/5912 [00:00<?, ?it/s]

In [88]:
dataset = MLMSlimPajamaDataset(tokenizer, 100, 0.2)
dataset.create_from_dataset(slim_pajama["train"])
dataset.export_dataset("slim_pajama_100_samples_20_percent_masking.json")
# dataset.load_from_disk("slim_pajama_first_1000.json")

  1%|█▌                                                                                                                                                        | 1/100 [00:00<00:39,  2.54it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (5199 > 2048). Running this sequence through the model will result in indexing errors
 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 99/100 [00:00<00:00, 178.61it/s]





## Define MLM head

In [81]:
class MLMHead(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(2048, 32000)

    def forward(self, x):
        return self.linear(x)

## Add embedding of the [MASK] token to the embedding matrix

In [82]:
llm.resize_token_embeddings(32001)

Embedding(32001, 2048)

In [83]:
llm.model.embed_tokens

Embedding(32001, 2048)

The initial embedding of the [MASK] token needs to be learned, so requires_grad for this parameter must be set to True. Thus, preparing to receive a copy of the section of the embedding matrix via the get_mask_token_embedding_function

In [84]:
def get_mask_token_embedding(llm, tokenizer):
    for parameter in llm.parameters():
        if param.shape == torch.Size([tokenizer.vocab_size + 1, 2048]):
            return parameter[tokenizer.vocab_size].clone()

def save_mask_token_embedding(llm, tokenizer):
    for parameter in llm.parameters():
        if param.shape == torch.Size([tokenizer.vocab_size + 1, 2048]):
            mask_token_embedding_weights = parameter[tokenizer.vocab_size].clone().detach()
            with open("mask_token_embedding_weights.pt", "wb") as mask_embedding_weights_file:
                torch.save(mask_token_embedding_weights, mask_embedding_weights_file)
                return

## Define the model with LoRA weights

In [85]:
lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,  # Type of task (e.g., causal language modeling)
        r=8,  # Low-rank dimension
        lora_alpha=32,  # Scaling factor
        lora_dropout=0.1,  # Dropout probability for LoRA
        bias="none",  # Bias configuration
    )
lora_model = get_peft_model(llm, lora_config)

In [86]:
for param in lora_model.parameters():
    print(f"param.shape = {param.shape} param.requires_grad = {param.requires_grad}")

param.shape = torch.Size([32001, 2048]) param.requires_grad = False
param.shape = torch.Size([2048, 2048]) param.requires_grad = False
param.shape = torch.Size([8, 2048]) param.requires_grad = True
param.shape = torch.Size([2048, 8]) param.requires_grad = True
param.shape = torch.Size([256, 2048]) param.requires_grad = False
param.shape = torch.Size([256, 2048]) param.requires_grad = False
param.shape = torch.Size([8, 2048]) param.requires_grad = True
param.shape = torch.Size([256, 8]) param.requires_grad = True
param.shape = torch.Size([2048, 2048]) param.requires_grad = False
param.shape = torch.Size([5632, 2048]) param.requires_grad = False
param.shape = torch.Size([5632, 2048]) param.requires_grad = False
param.shape = torch.Size([2048, 5632]) param.requires_grad = False
param.shape = torch.Size([2048]) param.requires_grad = False
param.shape = torch.Size([2048]) param.requires_grad = False
param.shape = torch.Size([2048, 2048]) param.requires_grad = False
param.shape = torch.Size(

# Training

### Preparing parameters for training

Three groups:
- LoRA parameters
- MLM head parameters
- Mask token representation

In [None]:
experiment_name = input("Enter the name of the experiment")
# re-initializing everything so that the model parameters are reset
lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,  # Type of task (e.g., causal language modeling)
        r=8,  # Low-rank dimension
        lora_alpha=32,  # Scaling factor
        lora_dropout=0.1,  # Dropout probability for LoRA
        bias="none",  # Bias configuration
    )
lora_model = get_peft_model(llm, lora_config)

mask_token_representation = get_mask_token_embedding(lora_model, tokenizer)
mask_token_representation.requires_grad = True

mlm_head = MLMHead()
batch_size = 4
# checked that the maximum batch size is 4, using 8 led to a CudaOutOfMemoryError
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda x: dataloader_collate_function(x, tokenizer, 2048))
parameters = list([mask_token_representation]) + list(lora_model.parameters()) + list(mlm_head.parameters())
writer = SummaryWriter(f"runs/spaudel/{experiment_name}")
num_epochs = 10
num_batches = len(dataloader)

writer.add_text("init_lr", str(init_lr))
writer.add_text("lr_scheduler", str(lr_scheduler))
writer.add_text("num_epochs", str(num_epochs))
writer.add_text("batch_size", str(batch_size))
last_epoch_trained = 0

In [100]:
writer = SummaryWriter(f"runs/spaudel/100_samples_with_20_percent_masking_babied_lr")

In [102]:
num_epochs = 2
init_lr = 1e-6
optimizer = Adam(parameters, lr=init_lr)
lr_scheduler = ExponentialLR(optimizer, 1)
last_epoch_trained = 29

for epoch_idx in range(num_epochs):
    average_loss = 0
    writer.add_scalar("learning_rate", lr_scheduler.get_last_lr()[0], last_epoch_trained + epoch_idx)
    for iteration_idx, batch in (pbar:=tqdm(enumerate(dataloader))):
        pbar.set_description(f"Epoch : {last_epoch_trained + epoch_idx} Iteration: {iteration_idx}")
        optimizer.zero_grad()
        masked_input_ids, attention_mask, original_input_ids = batch
        masked_input_ids, attention_mask, original_input_ids = (
            masked_input_ids.to(lora_model.device),
            attention_mask.to(lora_model.device), 
            original_input_ids.to(lora_model.device)
        )
        with LoraLlamaBidirectionalSwitch(lora_model):
            model_output = lora_model(input_ids=masked_input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state = model_output.hidden_states[-1]
        mlm_head = mlm_head.to(last_hidden_state.device)
        scores = mlm_head.forward(last_hidden_state)
        dists = torch.nn.functional.softmax(scores, dim=-1)
        ce_loss_one_hot_encodings = torch.nn.functional.one_hot(original_input_ids, num_classes=32000)
        ce_loss_intermediate = (ce_loss_one_hot_encodings * torch.log(dists)).sum(dim=-1)
        mask_tokens_positions = masked_input_ids == tokenizer.vocab_size
        loss = ce_loss_intermediate[mask_tokens_positions].mean() * -1
        loss.backward()
        optimizer.step()
        loss = loss.detach().cpu().item()
        writer.add_scalar("training_loss", loss, (last_epoch_trained+epoch_idx) * num_batches + iteration_idx)
        average_loss += loss
    average_loss /= num_batches
    lr_scheduler.step(average_loss)

Epoch : 29 Iteration: 24: : 25it [01:11,  2.86s/it]
Epoch : 30 Iteration: 24: : 25it [01:10,  2.82s/it]


In [103]:
lora_model.save_pretrained("babied_lr_peft_weights")



In [106]:
mlm_head = mlm_head.to("cpu")
torch.save(mlm_head.state_dict(), "mlm_head.pt")

In [107]:
# save_mask_token_embedding(lora_model, tokenizer) not needed, since PeFT handles saving the embedding layers

In [120]:
merged_model= PeftModel.from_pretrained(llm, "babied_lr_peft_weights")
merged_model= merged_model.merge_and_unload()

In [121]:
merged_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32001, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Line

In [123]:
mlm_head_reloaded = MLMHead()
mlm_head_reloaded.load_state_dict(torch.load("mlm_head.pt"))

<All keys matched successfully>