In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Conv1D
import math

from tests import *

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
#!pip install -r requirements.txt

In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load a small Hugging Face dataset
dataset = load_dataset("allenai/winogrande", "winogrande_xs")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [5]:
# Tokenizer and model initialization
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 doesn't have a pad token, using eos_token instead



In [6]:
# GPT2 uses custom conv1d layers which are just linear layers with a weight transpose
# Therefore we can just convert them to standard linear layers to simplify the architecture
for name, layer in model.named_modules():
    if not isinstance(layer, Conv1D):
        continue
    parent = model.get_submodule(name[:name.rfind(".")])
    has_bias = torch.any(layer.bias.data)
    linear_layer = torch.nn.Linear(*layer.weight.shape, bias=has_bias)
    with torch.no_grad():
        linear_layer.weight.copy_(layer.weight.T)
        if has_bias:
            linear_layer.bias.copy_(layer.bias)
    setattr(parent, name.split(".")[-1], linear_layer)

# GPT2 also uses a merged weight matrix for qkv which nowadays is not really done anymore
# While this is equivalent to having 3 separate weight matrices, it is more clear and easier to implement with LoRA to have 3 separate weight matrices
# Therefore we will split the weight matrix into 3 separate weight matrices
class AttentionProjections(nn.Module):
    def __init__(self, merged_weight):
        super().__init__()
        dim = merged_weight.shape[1]
        q, k, v = layer.weight.data.split(dim)
        q_bias, k_bias, v_bias = layer.bias.data.split(dim)
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        with torch.no_grad():
            self.q_proj.weight.copy_(q)
            self.k_proj.weight.copy_(k)
            self.v_proj.weight.copy_(v)
            self.q_proj.bias.copy_(q_bias)
            self.k_proj.bias.copy_(k_bias)
            self.v_proj.bias.copy_(v_bias)

    def forward(self, x):
        q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        return torch.cat([q, k, v], dim=-1)

for name, layer in model.named_modules():
    if "c_attn" in name:
        parent = model.get_submodule(name[:name.rfind(".")])
        setattr(parent, name.split(".")[-1], AttentionProjections(layer.weight.data))


In [7]:
# Preprocess dataset
def tokenize_function(examples):
    concatenated_examples = [s + " " + a for s, a in zip(examples["sentence"], examples["answer"])]
    return tokenizer(concatenated_examples, padding="max_length", truncation=True, max_length=64)

# Apply the function using map
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type="torch", columns=["input_ids"])

In [8]:
# DataLoader
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=4, shuffle=True)

In [9]:
# TASK 1: Implement a LoRA layer that adds a low-rank trainable matrix to the frozen weights.
class LoRALinear(nn.Module):
    """
    Task: Implement a LoRA layer that adds a low-rank trainable matrix to the frozen weights.
    """
    def __init__(self, base_layer, rank=4, alpha=1.0):
        super().__init__()
        self.base = base_layer

        for param in self.base.parameters():
            param.requires_grad = False

        in_dim = base_layer.in_features
        out_dim = base_layer.out_features

        self.lora_A = nn.Parameter(torch.empty((rank, in_dim)))
        self.lora_B = nn.Parameter(torch.empty((out_dim, rank)))

        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

        self.scaling = alpha / rank

    def forward(self, x):
        base_out = self.base(x)
        lora_out = (x @ self.lora_A.T) @ self.lora_B.T
        return base_out + lora_out * self.scaling



In [10]:
test_lora_layer_forward(LoRALinear)

In [11]:
# Print the model architecture
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): AttentionProjections(
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_featur

In [12]:
# TASK 2: Replace all q, k, v, o layers with LoRA

for name, module in model.named_modules():
    if isinstance(module, AttentionProjections):
        # Replace q_proj, k_proj, v_proj inside AttentionProjections
        module.q_proj = LoRALinear(module.q_proj, rank=8, alpha=16)
        module.k_proj = LoRALinear(module.k_proj, rank=8, alpha=16)
        module.v_proj = LoRALinear(module.v_proj, rank=8, alpha=16)
    elif "attn.c_proj" in name:
        parent = model.get_submodule(name[:name.rfind(".")])
        orig_layer = getattr(parent, name.split(".")[-1])
        setattr(parent, name.split(".")[-1], LoRALinear(orig_layer, rank=8, alpha=16))



In [13]:
# TEST: Check that the correct layers are LoRA layers
test_lora_layers(model)

In [14]:
for name, module in model.named_modules():
    if isinstance(module, LoRALinear):
        print("LoRA Layer:", name)


LoRA Layer: transformer.h.0.attn.c_attn.q_proj
LoRA Layer: transformer.h.0.attn.c_attn.k_proj
LoRA Layer: transformer.h.0.attn.c_attn.v_proj
LoRA Layer: transformer.h.0.attn.c_proj
LoRA Layer: transformer.h.1.attn.c_attn.q_proj
LoRA Layer: transformer.h.1.attn.c_attn.k_proj
LoRA Layer: transformer.h.1.attn.c_attn.v_proj
LoRA Layer: transformer.h.1.attn.c_proj
LoRA Layer: transformer.h.2.attn.c_attn.q_proj
LoRA Layer: transformer.h.2.attn.c_attn.k_proj
LoRA Layer: transformer.h.2.attn.c_attn.v_proj
LoRA Layer: transformer.h.2.attn.c_proj
LoRA Layer: transformer.h.3.attn.c_attn.q_proj
LoRA Layer: transformer.h.3.attn.c_attn.k_proj
LoRA Layer: transformer.h.3.attn.c_attn.v_proj
LoRA Layer: transformer.h.3.attn.c_proj
LoRA Layer: transformer.h.4.attn.c_attn.q_proj
LoRA Layer: transformer.h.4.attn.c_attn.k_proj
LoRA Layer: transformer.h.4.attn.c_attn.v_proj
LoRA Layer: transformer.h.4.attn.c_proj
LoRA Layer: transformer.h.5.attn.c_attn.q_proj
LoRA Layer: transformer.h.5.attn.c_attn.k_proj
L

In [15]:
# TASK 3: ensure gradients are only enabled for LoRA parameters

# Disable gradients for all parameters
for param in model.parameters():
    param.requires_grad = False

# Enable gradients only for LoRA A and B
for name, module in model.named_modules():
    if isinstance(module, LoRALinear):
        module.lora_A.requires_grad = True
        module.lora_B.requires_grad = True


In [16]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print("Trainierbar:", name)


Trainierbar: transformer.h.0.attn.c_attn.q_proj.lora_A
Trainierbar: transformer.h.0.attn.c_attn.q_proj.lora_B
Trainierbar: transformer.h.0.attn.c_attn.k_proj.lora_A
Trainierbar: transformer.h.0.attn.c_attn.k_proj.lora_B
Trainierbar: transformer.h.0.attn.c_attn.v_proj.lora_A
Trainierbar: transformer.h.0.attn.c_attn.v_proj.lora_B
Trainierbar: transformer.h.0.attn.c_proj.lora_A
Trainierbar: transformer.h.0.attn.c_proj.lora_B
Trainierbar: transformer.h.1.attn.c_attn.q_proj.lora_A
Trainierbar: transformer.h.1.attn.c_attn.q_proj.lora_B
Trainierbar: transformer.h.1.attn.c_attn.k_proj.lora_A
Trainierbar: transformer.h.1.attn.c_attn.k_proj.lora_B
Trainierbar: transformer.h.1.attn.c_attn.v_proj.lora_A
Trainierbar: transformer.h.1.attn.c_attn.v_proj.lora_B
Trainierbar: transformer.h.1.attn.c_proj.lora_A
Trainierbar: transformer.h.1.attn.c_proj.lora_B
Trainierbar: transformer.h.2.attn.c_attn.q_proj.lora_A
Trainierbar: transformer.h.2.attn.c_attn.q_proj.lora_B
Trainierbar: transformer.h.2.attn.c_at

In [17]:
# TEST: Check that only LoRA parameters are trainable
# Adjust the lora_param_names to the actual parameter names used in your LoRA implementation
test_only_lora_trainable(model, lora_param_names=["lora_A", "lora_B"])

In [18]:
# Simple Training Loop (Few Steps)
# You should see the loss go down
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)

model.train()
for step, batch in enumerate(train_dataloader):
    if step >= 5:  # Run for a few steps only
        break
    optimizer.zero_grad()
    input_ids = batch["input_ids"].to(device)
    outputs = model(input_ids, labels=input_ids)
    loss = outputs.loss
    print(f"Step {step}, Loss: {loss.item()}")
    loss.backward()
    optimizer.step()

Step 0, Loss: 10.741665840148926
Step 1, Loss: 10.323408126831055
Step 2, Loss: 9.493453979492188
Step 3, Loss: 9.08095645904541
Step 4, Loss: 7.785666465759277
