<a href="https://colab.research.google.com/github/CrisMcode111/DI_Bootcamp/blob/main/w6_d4_Exercises_XP_Student.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exercises XP: LoRA Implementation Lab
Replace each `TODO` before running the next section.

## What you'll learn

- The fundamentals of LoRA (Low-Rank Adaptation) and why it helps churn out efficient fine-tunes.
- How to implement LoRA matrices `A` and `B`, plus how to wrap existing `nn.Linear` layers.
- Differences between standard linear layers, LoRA-enhanced layers, and merged-weight alternatives.
- How to freeze base parameters so that only the LoRA adapters receive updates.

## What you will create

- A reusable `LoRALayer` module and two linear wrappers (`LinearWithLoRA`, `LinearWithLoRAMerged`).
- A 3-layer MLP that can be swapped between standard and LoRA-enhanced variants.
- A minimal MNIST training loop plus accuracy helpers to compare frozen vs. fully-trainable adapters.
- A workflow to freeze baseline weights and fine-tune only the LoRA layers.

> **Learning point**  
> Keep the student and teacher notebooks open side by side. Follow the numbered exercises, run setup only once, and watch tensor shapes as you add LoRA adapters.

# Part 0: Environment Setup

Install the CPU-friendly PyTorch stack plus torchvision for MNIST. Reuse caches across reruns to save time.

In [None]:
%pip install --quiet torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [None]:
import copy
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

BASE_SEED = 123
torch.manual_seed(BASE_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cpu


# Exercise 1: Implement `LoRALayer`

Create the low-rank matrices `A` and `B`, scale them with `alpha`, and test the module on a toy tensor.

In [None]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha
        self.rank = rank # Added to store rank as an instance attribute

    def forward(self, x):
        # Low-rank update Δy = x @ A @ B, scaled by α/r (LoRA convention)
        scale = self.alpha / self.rank
        return (x @ self.A @ self.B) * scale

# Hyperparameters for the sandbox test
random_seed = 123
torch.manual_seed(random_seed)


in_dim = 8
out_dim = 5
rank = 2
alpha = 4.0

layer = LoRALayer(in_dim, out_dim, rank, alpha)

batch = 3
x = torch.randn(batch, in_dim)

print("Input x:\n", x)
print("\nLayer:\n", layer)
y = layer(x)
print("\nLoRA update output (Δy):\n", y)
print("Output shape:", y.shape)  # should be (batch, out_dim)

Input x:
 tensor([[-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010],
        [ 0.9624,  0.2492, -0.4845, -2.0929, -0.8199, -0.4210, -0.9620,  1.2825],
        [-0.3430, -0.6821, -0.9887, -1.7018, -0.7498, -1.1285,  0.4135,  0.2892]])

Layer:
 LoRALayer()

LoRA update output (Δy):
 tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], grad_fn=<MulBackward0>)
Output shape: torch.Size([3, 5])


# Exercise 2: Wrap `nn.Linear` with LoRA

Combine a frozen linear projection plus a trainable `LoRALayer`. Confirm the adapter outputs add on top of the base logits.

In [None]:
import torch
import torch.nn as nn

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features,
            linear.out_features,
            rank,
            alpha,
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

base_linear = nn.Linear(in_dim, out_dim)
layer_lora_1 = LinearWithLoRA(base_linear, rank, alpha)
print("LinearWithLoRA output:", layer_lora_1(x))

LinearWithLoRA output: tensor([[-0.1224,  0.2353,  0.2788, -0.9573,  0.7254],
        [ 0.6466,  0.4186,  0.2505,  0.9226,  0.0839],
        [ 0.1024,  0.6865,  0.7498,  0.1414, -0.4729]], grad_fn=<AddBackward0>)


# Exercise 3: Swap a simple network layer with LoRA

Start from a single-layer perceptron, then replace its linear block with `LinearWithLoRA`. The outputs should match before training because the LoRA adapters start at zero.

In [None]:
import torch
import torch.nn as nn

class SingleLayerNet(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.layer = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.layer(x)

num_features = 8
num_classes  = 5
sample_input = torch.randn(4, num_features)

single_net = SingleLayerNet(num_features=num_features, num_classes=num_classes)

with torch.no_grad():
    baseline_output = single_net(sample_input)

single_net.layer = LinearWithLoRA(single_net.layer, rank=rank, alpha=alpha)

with torch.no_grad():
    lora_output = single_net(sample_input)

print("Outputs match before training?", torch.allclose(baseline_output, lora_output, atol=1e-6))

Outputs match before training? True


# Exercise 4: Merged-weight LoRA layer

Fuse the LoRA matrices with the frozen weights to create a drop-in linear layer that behaves exactly like `LinearWithLoRA`.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features,
            linear.out_features,
            rank,
            alpha,
        )

    def forward(self, x):
        lora_matrix = self.lora.A @ self.lora.B
        scale = self.lora.alpha / self.lora.rank
        combined_weight = self.linear.weight + scale * lora_matrix.T
        return F.linear(x, combined_weight, self.linear.bias)


layer_lora_2 = LinearWithLoRAMerged(base_linear, rank, alpha)
print("Merged LoRA output:", layer_lora_2(x))


Merged LoRA output: tensor([[-0.1224,  0.2353,  0.2788, -0.9573,  0.7254],
        [ 0.6466,  0.4186,  0.2505,  0.9226,  0.0839],
        [ 0.1024,  0.6865,  0.7498,  0.1414, -0.4729]],
       grad_fn=<AddmmBackward0>)


# Exercise 5: Build an MLP and prepare MNIST

Stack three linear layers with ReLU activations, then set up the MNIST loaders plus optimizer/state for pretraining.

In [19]:
class MultilayerPerceptron(nn.Module):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)   # <-- asta lipsea (flatten: (B,1,28,28)->(B,784))
        x = self.layers(x)
        return x


In [22]:
# Architecture
num_features = 784
num_hidden_1 = 256
num_hidden_2 = 128
num_classes = 10

# Settings
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-3
num_epochs = 5

model = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes,
)

criterion = nn.CrossEntropyLoss()
optimizer_pretrained = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.to(DEVICE)
print(DEVICE)
print(model)
print(optimizer_pretrained)

cpu
MultilayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)


## Loading dataset

In [23]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


BATCH_SIZE = 64

train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)

test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break


Image batch dimensions: torch.Size([64, 1, 28, 28])
Image label dimensions: torch.Size([64])


## Define evaluation

In [24]:
def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples = 0, 0
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.view(features.size(0), -1).to(device)
            targets = targets.to(device)
            logits = model(features)
            _, predicted_labels = torch.max(logits, 1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum().item()
    return correct_pred / num_examples


## Training

In [25]:
import time


def train(num_epochs, model, optimizer, train_loader, device):
    start_time = time.time()
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            features = features.view(features.size(0), -1).to(device)
            targets = targets.to(device)

            logits = model(features)
            loss = criterion(logits, targets)
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

            if not batch_idx % 400:
                print('Epoch: %03d/%03d|Batch %03d/%03d| Loss: %.4f' %
                      (epoch+1, num_epochs, batch_idx, len(train_loader), loss))

        with torch.set_grad_enabled(False):
            print('Epoch: %03d/%03d training accuracy: %.2f%%' %
                  (epoch+1, num_epochs, compute_accuracy(model, train_loader, device)*100))

    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))



In [26]:
train(num_epochs, model, optimizer_pretrained, train_loader, DEVICE)
print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')

Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print('Epoch: %03d/%03d|Batch %03d/%03d| Loss: %.4f' %


Epoch: 001/005|Batch 000/938| Loss: 2.3046
Epoch: 001/005|Batch 400/938| Loss: 0.3282
Epoch: 001/005|Batch 800/938| Loss: 0.1142
Epoch: 001/005 training accuracy: 96.60%
Epoch: 002/005|Batch 000/938| Loss: 0.1085
Epoch: 002/005|Batch 400/938| Loss: 0.0925
Epoch: 002/005|Batch 800/938| Loss: 0.1300
Epoch: 002/005 training accuracy: 97.84%
Epoch: 003/005|Batch 000/938| Loss: 0.0439
Epoch: 003/005|Batch 400/938| Loss: 0.0800
Epoch: 003/005|Batch 800/938| Loss: 0.0912
Epoch: 003/005 training accuracy: 98.16%
Epoch: 004/005|Batch 000/938| Loss: 0.1841
Epoch: 004/005|Batch 400/938| Loss: 0.0838
Epoch: 004/005|Batch 800/938| Loss: 0.0130
Epoch: 004/005 training accuracy: 98.78%
Epoch: 005/005|Batch 000/938| Loss: 0.0452
Epoch: 005/005|Batch 400/938| Loss: 0.1017
Epoch: 005/005|Batch 800/938| Loss: 0.0533
Epoch: 005/005 training accuracy: 99.00%
Time elapsed: 1.95 min
Total Training Time: 1.95 min
Test accuracy: 0.98%


# Replacing Linear with LoRA Layers

In [30]:
import torch.nn.functional as F

class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features,
            linear.out_features,
            rank,
            alpha,
        )

    def forward(self, x):
        lora_matrix = self.lora.A @ self.lora.B
        scale = self.lora.alpha / self.lora.rank
        combined_weight = self.linear.weight + scale * lora_matrix.T
        return F.linear(x, combined_weight, self.linear.bias)


In [32]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank, dtype=torch.float32))
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha
        self.rank = rank

    def forward(self, x):
        return (x @ self.A @ self.B) * (self.alpha / self.rank)


In [33]:
import copy


model_lora = copy.deepcopy(model)

model_lora.layers[0] = LinearWithLoRAMerged(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2] = LinearWithLoRAMerged(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4] = LinearWithLoRAMerged(model_lora.layers[4], rank=4, alpha=8)

model_lora.to(DEVICE)
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
print(model_lora)

print(f'Test accuracy orig model:{compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model:{compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')


MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRAMerged(
      (linear): Linear(in_features=784, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRAMerged(
      (linear): Linear(in_features=256, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRAMerged(
      (linear): Linear(in_features=128, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)
Test accuracy orig model:0.98%
Test accuracy LoRA model:0.98%


## Freezing the Original Linear Layers

In [34]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            freeze_linear_layers(child)

freeze_linear_layers(model_lora)
for name, param in model_lora.named_parameters():
    print(f'{name}:{param.requires_grad}')

layers.0.linear.weight:False
layers.0.linear.bias:False
layers.0.lora.A:True
layers.0.lora.B:True
layers.2.linear.weight:False
layers.2.linear.bias:False
layers.2.lora.A:True
layers.2.lora.B:True
layers.4.linear.weight:False
layers.4.linear.bias:False
layers.4.lora.A:True
layers.4.lora.B:True


In [35]:
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)
print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

print(f'Test accuracy orig model:{compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model:{compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

Epoch: 001/005|Batch 000/938| Loss: 0.0116
Epoch: 001/005|Batch 400/938| Loss: 0.0158
Epoch: 001/005|Batch 800/938| Loss: 0.0367
Epoch: 001/005 training accuracy: 99.51%
Epoch: 002/005|Batch 000/938| Loss: 0.0020
Epoch: 002/005|Batch 400/938| Loss: 0.0158
Epoch: 002/005|Batch 800/938| Loss: 0.0095
Epoch: 002/005 training accuracy: 99.62%
Epoch: 003/005|Batch 000/938| Loss: 0.0403
Epoch: 003/005|Batch 400/938| Loss: 0.0034
Epoch: 003/005|Batch 800/938| Loss: 0.0053
Epoch: 003/005 training accuracy: 99.66%
Epoch: 004/005|Batch 000/938| Loss: 0.0050
Epoch: 004/005|Batch 400/938| Loss: 0.0063
Epoch: 004/005|Batch 800/938| Loss: 0.0278
Epoch: 004/005 training accuracy: 99.65%
Epoch: 005/005|Batch 000/938| Loss: 0.0091
Epoch: 005/005|Batch 400/938| Loss: 0.0086
Epoch: 005/005|Batch 800/938| Loss: 0.0017
Epoch: 005/005 training accuracy: 99.69%
Time elapsed: 1.98 min
Total Training Time: 1.98 min
Test accuracy LoRA finetune: 0.98%
Test accuracy orig model:0.98%
Test accuracy LoRA model:0.98%
