# Week 10: Low-Rank Adaptation (LoRA)

Implementing LoRA from scratch to fine-tune large models efficiently.

## Learning Objectives
1. Understand reparameterization
2. Implement a LoRA Linear Layer
3. Fine-tune a small BERT-like model

In [None]:
import torch
import torch.nn as nn
import math

## 1. LoRA Layer Implementation

Key concept: W = W_0 + BA
Where B is (d_out, r) and A is (r, d_in), r << d.

In [None]:
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=16):
        super().__init__()
        # Frozen pretrained weights
        self.linear = nn.Linear(in_features, out_features)
        self.linear.weight.requires_grad = False
        self.linear.bias.requires_grad = False
        
        # LoRA weights
        self.lora_rank = rank
        self.lora_alpha = alpha
        self.scaling = alpha / rank
        
        # A: Gaussian init, B: Zero init
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * (1/math.sqrt(rank)))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
    def forward(self, x):
        # Original path
        h_frozen = self.linear(x)
        
        # LoRA path: (x @ A.T) @ B.T * scaling
        h_lora = (x @ self.lora_A.T) @ self.lora_B.T * self.scaling
        
        return h_frozen + h_lora

In [None]:
# Test LoRA
in_dim = 768
out_dim = 768
rank = 8

layer = LoRALinear(in_dim, out_dim, rank=rank)
x = torch.randn(1, 10, in_dim)

output = layer(x)
print(f"Output shape: {output.shape}")
print(f"Trainable parameters: {sum(p.numel() for p in layer.parameters() if p.requires_grad)}")
print(f"Total parameters: {sum(p.numel() for p in layer.parameters())}")

## 2. Replacing Layers in a Model

Function to recursively replace Linear layers with LoRALinear.

In [None]:
def apply_lora(model, rank=4, alpha=16):
    """Replace all linear layers with LoRA layers."""
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            # Create LoRA layer
            lora_layer = LoRALinear(
                module.in_features, 
                module.out_features, 
                rank=rank, 
                alpha=alpha
            )
            # Copy weights to frozen linear
            lora_layer.linear.weight.data = module.weight.data.clone()
            if module.bias is not None:
                lora_layer.linear.bias.data = module.bias.data.clone()
            
            # Replace
            setattr(model, name, lora_layer)
        else:
            # Recurse
            apply_lora(module, rank, alpha)

In [None]:
# Example Model
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    
    def forward(self, x):
        return self.net(x)

model = SimpleMLP()
print("Before LoRA:", model)

apply_lora(model, rank=8)
print("\nAfter LoRA:", model)