<a href="https://colab.research.google.com/github/MahdiTheGreat/Intro-to-language-modeling/blob/main/LoRA_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pickle
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np

In [None]:
!wget https://www.cse.chalmers.se/~richajo/diverse/l7/books.data
!wget https://www.cse.chalmers.se/~richajo/diverse/l7/s7_pretrained.model

--2024-11-25 10:00:52--  https://www.cse.chalmers.se/~richajo/diverse/l7/books.data
Resolving www.cse.chalmers.se (www.cse.chalmers.se)... 129.16.222.93
Connecting to www.cse.chalmers.se (www.cse.chalmers.se)|129.16.222.93|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6192712 (5.9M)
Saving to: ‘books.data.1’


2024-11-25 10:00:54 (6.36 MB/s) - ‘books.data.1’ saved [6192712/6192712]

--2024-11-25 10:00:54--  https://www.cse.chalmers.se/~richajo/diverse/l7/s7_pretrained.model
Resolving www.cse.chalmers.se (www.cse.chalmers.se)... 129.16.222.93
Connecting to www.cse.chalmers.se (www.cse.chalmers.se)|129.16.222.93|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1580288 (1.5M)
Saving to: ‘s7_pretrained.model.1’


2024-11-25 10:00:55 (2.12 MB/s) - ‘s7_pretrained.model.1’ saved [1580288/1580288]



In [None]:
with open('books.data', 'rb') as f:
    books_X, books_Y = pickle.load(f)

print('X shape:', books_X.shape)
print('Y length:', len(books_Y))

split_ix = 1500
books_X_tr = books_X[:split_ix]
books_Y_tr = books_Y[:split_ix]
books_X_te = books_X[split_ix:]
books_Y_te = books_Y[split_ix:]

X shape: (2000, 768)
Y length: 2000


In [None]:
pretrained = torch.load('s7_pretrained.model')
pretrained

  pretrained = torch.load('s7_pretrained.model')


Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=1, bias=True)
)

In [None]:
def batcher(batch):
    X = torch.as_tensor([x for x, _ in batch])
    Y = 1.0*torch.as_tensor([y for _, y in batch])
    return X, Y

In [None]:
def eval_model(model):
    dl = DataLoader(list(zip(books_X_te, books_Y_te)), batch_size=32, shuffle=False, collate_fn=batcher)
    n_corr = 0
    for Xb, Yb in dl:
        with torch.no_grad():
            model_out = model(Xb)
        preds = model_out[:, 0] > 0
        gold = Yb > 0
        n_corr += sum(preds == gold).item()
    return n_corr / len(books_Y_te)

In [None]:
eval_model(pretrained)

0.794

# Basic fine-tuning

We create a new model where we copy the weights from the pre-trained model.

In [None]:
torch.manual_seed(0)

finetuned = nn.Sequential(
    nn.Linear(in_features=768, out_features=512),
    nn.ReLU(),
    nn.Linear(in_features=512, out_features=1)
)

# pretrained = torch.load('s7_pretrained.model')

finetuned[0].weight.data = pretrained[0].weight.data.clone()
finetuned[0].bias.data = pretrained[0].bias.data.clone()
finetuned[2].weight.data = pretrained[2].weight.data.clone()
finetuned[2].bias.data = pretrained[2].bias.data.clone()

In [None]:
eval_model(finetuned)

0.794

In [None]:
def train(model, n_epochs=10):
    dl = DataLoader(list(zip(books_X_tr, books_Y_tr)), batch_size=32, shuffle=True, collate_fn=batcher)

    # NOTE!
    params = [ p for p in model.parameters() if p.requires_grad_ ]

    optimizer = torch.optim.Adam(params, lr=1e-3)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    for epoch in range(n_epochs):
        total_loss = 0
        for Xb, Yb in dl:
            model_out = model(Xb)[:, 0]
            loss = loss_fn(model_out, Yb)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        mean_loss = total_loss / len(dl)
        acc = eval_model(model)
        print(f'loss = {mean_loss:.4f}, acc = {acc:.4f}')

Your task:
- Complete `count_trainable_parameters` below.
- Count the total number of trainable parameters in the model you fine-tuned.
- Use the function `train` to fine-tune the cloned model.

In [None]:
def count_trainable_parameters(model):
  # TODO
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Implementing LoRA

In [None]:
class LinearBlockWithLoRA(nn.Module):
    def __init__(self, W, r):
        """
        Initializes the LinearBlockWithLoRA.

        Args:
            W (torch.Tensor): Pre-trained weight matrix.
            r (int): Rank of the low-rank approximation.
        """
        super().__init__()

        # Store the pre-trained weight matrix
        self.W = W  # Frozen pre-trained weights

        # Get the dimensions of the pre-trained weight matrix
        out_dim, in_dim  = W.weight.shape

        # Initialize the low-rank matrices A and B
        self.A = nn.Linear(in_features=in_dim, out_features=r, bias=False)  # Low-rank adaptation A
        self.B = nn.Linear(in_features=r, out_features=out_dim, bias=False)   # Low-rank adaptation B

    def forward(self, X):
        """
        Forward pass for the LinearBlockWithLoRA.

        Args:
            X (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after applying W and LoRA.
        """
        # Compute the output with the pre-trained weight matrix
        W_out = self.W(X)  # Using frozen weights

        # Compute the low-rank adaptation
        a_out = self.A(X) # (batch_size x in_dim) @ (in_dim x r) @ (r x out_dim)
        b_out = self.B(a_out) # (batch_size x in_dim) @ (in_dim x r) @ (r x out_dim

        # Add scaled adaptation to the pre-trained weights' output

        return W_out + b_out


Your task:
- Complete `LinearBlockWithLoRA` above
- Set up a model using this new block to replace the first linear layer. Initialize parameters from the pre-trained model. (Don't forget to switch off gradient computation for `W`.)
- Count the parameters in the new model.
- Train the new model.

In [None]:

torch.manual_seed(0)

lora_model = nn.Sequential(
    LinearBlockWithLoRA(pretrained[0], r=8),
    nn.ReLU(),
    nn.Linear(in_features=512, out_features=1)
)

lora_model[0].W.weight.data = pretrained[0].weight.data.clone()
lora_model[0].W.bias.data = pretrained[0].bias.data.clone()
lora_model[2].weight.data = pretrained[2].weight.data.clone()
lora_model[2].bias.data = pretrained[2].bias.data.clone()

lora_model[0].W.requires_grad = False

train(lora_model, n_epochs=10)

print(count_trainable_parameters(lora_model))

loss = 0.3777, acc = 0.8060
loss = 0.3460, acc = 0.8100
loss = 0.3201, acc = 0.8180
loss = 0.3003, acc = 0.8120
loss = 0.2878, acc = 0.8220
loss = 0.2665, acc = 0.8220
loss = 0.2608, acc = 0.8280
loss = 0.2496, acc = 0.8180
loss = 0.2358, acc = 0.8140
loss = 0.2304, acc = 0.8220
404481
