In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Conv1D

from tests import *

In [17]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
# Load a small Hugging Face dataset
dataset = load_dataset("allenai/winogrande", "winogrande_xs")

In [None]:
# Tokenizer and model initialization
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 doesn't have a pad token, using eos_token instead

In [20]:
# GPT2 uses custom conv1d layers which are just linear layers with a weight transpose
# Therefore we can just convert them to standard linear layers to simplify the architecture
for name, layer in model.named_modules():
    if not isinstance(layer, Conv1D):
        continue
    parent = model.get_submodule(name[:name.rfind(".")])
    has_bias = torch.any(layer.bias.data)
    linear_layer = torch.nn.Linear(*layer.weight.shape, bias=has_bias)
    with torch.no_grad():
        linear_layer.weight.copy_(layer.weight.T)
        if has_bias:
            linear_layer.bias.copy_(layer.bias)
    setattr(parent, name.split(".")[-1], linear_layer)

# GPT2 also uses a merged weight matrix for qkv which nowadays is not really done anymore
# While this is equivalent to having 3 separate weight matrices, it is more clear and easier to implement with LoRA to have 3 separate weight matrices
# Therefore we will split the weight matrix into 3 separate weight matrices
class AttentionProjections(nn.Module):
    def __init__(self, merged_weight):
        super().__init__()
        dim = merged_weight.shape[1]
        q, k, v = layer.weight.data.split(dim)
        q_bias, k_bias, v_bias = layer.bias.data.split(dim)
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        with torch.no_grad():
            self.q_proj.weight.copy_(q)
            self.k_proj.weight.copy_(k)
            self.v_proj.weight.copy_(v)
            self.q_proj.bias.copy_(q_bias)
            self.k_proj.bias.copy_(k_bias)
            self.v_proj.bias.copy_(v_bias)

    def forward(self, x):
        q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        return torch.cat([q, k, v], dim=-1)

for name, layer in model.named_modules():
    if "c_attn" in name:
        parent = model.get_submodule(name[:name.rfind(".")])
        setattr(parent, name.split(".")[-1], AttentionProjections(layer.weight.data))


In [21]:
# Preprocess dataset
def tokenize_function(examples):
    concatenated_examples = [s + " " + a for s, a in zip(examples["sentence"], examples["answer"])]
    return tokenizer(concatenated_examples, padding="max_length", truncation=True, max_length=64)

# Apply the function using map
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type="torch", columns=["input_ids"])

In [22]:
# DataLoader
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=4, shuffle=True)

In [None]:
# TASK 1: Implement a LoRA layer that adds a low-rank trainable matrix to the frozen weights.
class LoRALinear(nn.Module):
    """
    Task: Implement a LoRA layer that adds a low-rank trainable matrix to the frozen weights.
    """
    def __init__(self, base_layer, rank=4):
        super().__init__()
        # TODO

    def forward(self, x):
        # TODO

In [9]:
test_lora_layer_forward(LoRALinear)

In [None]:
# Print the model architecture
print(model)

In [11]:
# TASK 2: Replace all q, k, v, o layers with LoRA
# TODO

In [12]:
# TEST: Check that the correct layers are LoRA layers
test_lora_layers(model)

In [13]:
# TASK 3: ensure gradients are only enabled for LoRA parameters
# TODO

In [None]:
# TEST: Check that only LoRA parameters are trainable
# Adjust the lora_param_names to the actual parameter names used in your LoRA implementation
test_only_lora_trainable(model, lora_param_names=["lora_A", "lora_B"])

In [None]:
# Simple Training Loop (Few Steps)
# You should see the loss go down
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)

model.train()
for step, batch in enumerate(train_dataloader):
    if step >= 5:  # Run for a few steps only
        break
    optimizer.zero_grad()
    input_ids = batch["input_ids"].to(device)
    outputs = model(input_ids, labels=input_ids)
    loss = outputs.loss
    print(f"Step {step}, Loss: {loss.item()}")
    loss.backward()
    optimizer.step()