# Practical 3A — Parameter-Efficient Fine-Tuning with LoRA

**Road to SKA: Foundation Models, Embeddings, and Latent Spaces**

This practical demonstrates running LoRA on a small vision encoder *without* heavy EO plumbing.

You will:

1. Train a **base CNN classifier** (a stand-in for a pretrained foundation encoder).
2. Freeze the base weights and adapt to a **shifted target task** using:
   - **Full fine-tuning** (update all weights)
   - **Head-only fine-tuning** (update only last layer)
   - **LoRA fine-tuning** (update only small low-rank adapter matrices)
3. Compare:
   - accuracy vs training time
   - number of trainable parameters

---

## References
- LoRA paper (Hu et al., 2021 / ICLR 2022): https://arxiv.org/abs/2106.09685  
- Microsoft LoRA repo / `loralib`: https://github.com/microsoft/LoRA  
- Hugging Face PEFT LoRA docs (concept + config): https://huggingface.co/docs/peft/en/conceptual_guides/lora  

> We implement LoRA **directly in PyTorch** here so it works anywhere (no Transformers dependency).


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cambridge-iccs/R2SKA_Advanced_Tutorial/blob/main/Session3A_LoRA_Finetuning.ipynb)

---

## Environment Setup (Colab / Local)

Run the cell below to detect your environment and set up paths. This notebook uses only PyTorch and standard libraries, which are pre-installed on Colab.

In [None]:
# Detect environment and set up paths
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    DATA_ROOT = '/content/data'
    # PyTorch is pre-installed on Colab
else:
    print("Running locally")
    DATA_ROOT = './data'

print(f"Data directory: {DATA_ROOT}")

## 1. Setup

CPU-friendly. Uses GPU if available.

Required:
- `torch`, `torchvision`, `numpy`, `matplotlib`, `scikit-learn`

Optional:
- `tqdm` (progress bars)


In [None]:
# Optional installs (uncomment if needed)
# %pip -q install torch torchvision scikit-learn tqdm

import time
import math
import random
from dataclasses import dataclass, fields
from typing import Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score, classification_report

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = lambda x, **kwargs: x


## 2. Configuration

We create a **domain shift**:

- **Base training:** MNIST with standard images  
- **Target adaptation:** MNIST with rotation + noise, and **few-shot** labels

This mimics adapting a pretrained encoder to a new domain with limited labels.


In [None]:
@dataclass
class Config:
    data_dir: str = DATA_ROOT    # Uses environment-detected path
    seed: int = 42
    batch_size: int = 128
    base_epochs: int = 5
    target_epochs: int = 5
    lr_base: float = 1e-3
    lr_target: float = 1e-3
    max_base_train: int = 20000
    max_target_train: int = 2000
    max_test: int = 5000
    lora_rank: int = 8
    lora_alpha: float = 16.0
    lora_dropout: float = 0.0
    rotate_deg: float = 45.0
    noise_std: float = 0.15

    # Formatted string representation for readable output
    def __repr__(self):
        lines = [f"{self.__class__.__name__}:"]
        for f in fields(self):
            lines.append(f"  {f.name:20s} = {getattr(self, f.name)!r}")
        return "\n".join(lines)

cfg = Config()

random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

# Device selection: CUDA > MPS (Apple Silicon) > CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Device:", device)
print(cfg)

## 3. Data: Base vs Target

- Base transform: standard MNIST  
- Target transform: rotation + gaussian noise (synthetic shift)


In [None]:
class AddGaussianNoise:
    """
    Add Gaussian noise to the data.
    
    """
    def __init__(self, std: float = 0.1):
        self.std = std
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp(x + torch.randn_like(x) * self.std, 0.0, 1.0)

# Stack transformations to operate on the data
base_tfm = transforms.Compose([transforms.ToTensor()])  # Create tensor
target_tfm = transforms.Compose([
    transforms.RandomRotation(degrees=cfg.rotate_deg),  # Rand rotate
    transforms.ToTensor(),
    AddGaussianNoise(std=cfg.noise_std),                # Add noise
])

base_train = datasets.MNIST(cfg.data_dir, train=True, download=True, transform=base_tfm)
target_train = datasets.MNIST(cfg.data_dir, train=True, download=True, transform=target_tfm)
test_ds = datasets.MNIST(cfg.data_dir, train=False, download=True, transform=base_tfm)

# Sample a subset of the data
def subset(ds, n: Optional[int], seed: int):
    if n is None or n >= len(ds):
        return ds
    rng = np.random.default_rng(seed)
    idx = rng.choice(len(ds), size=n, replace=False)
    return Subset(ds, idx)

base_train_s = subset(base_train, cfg.max_base_train, cfg.seed)
target_train_s = subset(target_train, cfg.max_target_train, cfg.seed + 1)
test_s = subset(test_ds, cfg.max_test, cfg.seed + 2)

# num_workers=0 for compatibility (set higher for multiprocessing)
base_loader = DataLoader(base_train_s, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
target_loader = DataLoader(target_train_s, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_s, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

# Visualise base vs target transforms on THE SAME images
# Get a batch of indices to use for both
sample_indices = list(range(8))

fig, axes = plt.subplots(2, 8, figsize=(10, 3))

for i, idx in enumerate(sample_indices):
    # Get the same underlying image with base transform
    img_base, label = base_train[idx]
    # Get with target transform (will have rotation + noise)
    img_target, _ = target_train[idx]

    axes[0, i].imshow(img_base[0], cmap="gray")
    axes[0, i].set_title(int(label))
    axes[0, i].axis("off")

    axes[1, i].imshow(img_target[0], cmap="gray")
    axes[1, i].set_title(int(label))
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("Base", fontsize=11)
axes[1, 0].set_ylabel("Target", fontsize=11)
plt.suptitle("Same images: Base (top) vs Target with rotation + noise (bottom)")
plt.tight_layout()
plt.show()

## 4. Base model - small CNN 

This is our “pretrained encoder + classifier head”.


In [None]:
class SmallCNN(nn.Module):
    def __init__(self, n_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128), nn.ReLU(),
            nn.Linear(128, n_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

def count_params(model: nn.Module) -> Tuple[int, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

base_model = SmallCNN().to(device)
print("Params (total, trainable):", count_params(base_model))


## 5. Pretrain the base model

Short training just to get a reasonable base model.


In [None]:
def train_epoch(model, loader, opt, loss_fn):
    model.train()
    total_loss = 0.0
    y_true, y_pred = [], []
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
        y_true.append(y.detach().cpu().numpy())
        y_pred.append(logits.argmax(dim=1).detach().cpu().numpy())
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    return total_loss / len(loader.dataset), accuracy_score(y_true, y_pred)

@torch.no_grad()
def eval_model(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        y_true.append(y.detach().cpu().numpy())
        y_pred.append(logits.argmax(dim=1).detach().cpu().numpy())
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    return accuracy_score(y_true, y_pred)

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(base_model.parameters(), lr=cfg.lr_base)

hist_base = {"train_acc": [], "test_acc": []}
for ep in range(1, cfg.base_epochs + 1):
    tl, ta = train_epoch(base_model, base_loader, opt, loss_fn)
    va = eval_model(base_model, test_loader)
    hist_base["train_acc"].append(ta)
    hist_base["test_acc"].append(va)
    print(f"[Base] epoch {ep:02d} loss={tl:.4f} train_acc={ta:.3f} test_acc={va:.3f}")

plt.figure(figsize=(6,4))
plt.plot(hist_base["train_acc"], label="train acc")
plt.plot(hist_base["test_acc"], label="test acc")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend()
plt.tight_layout()
plt.show()


## 6. LoRA implementation (Linear layers)

When adapting large pretrained models, full fine-tuning updates **all** weights. For a model with millions or billions of parameters, this means:
- Storing a complete copy of gradients and optimizer states
- Risk of catastrophic forgetting (overwriting useful pretrained features)
- Needing separate weight copies for each downstream task

**Low-Rank Adaptation (LoRA)** is based on a key empirical observation from Hu et al. (2021):

> *"The change in weights during fine-tuning has a low intrinsic rank."*

In other words, even though weight matrices are large (e.g., 4096 × 4096), the **update** $\Delta W$ needed for adaptation can be well-approximated by a much smaller, low-rank matrix.

### How LoRA Works

Instead of updating $W$ directly, LoRA **freezes** the pretrained weights and adds a parallel low-rank branch:

$$W' = W + \Delta W, \quad \text{where} \quad \Delta W = \frac{\alpha}{r} B A$$

- $W \in \mathbb{R}^{d_{out} \times d_{in}}$ — original frozen weight matrix
- $A \in \mathbb{R}^{r \times d_{in}}$ — projects input down to rank $r$
- $B \in \mathbb{R}^{d_{out} \times r}$ — projects back up to output dimension
- $r \ll \min(d_{in}, d_{out})$ — the **rank** (typically 4–64)
- $\alpha$ — scaling hyperparameter (controls magnitude of adaptation)

### Parameter Savings

For a weight matrix $W$ of shape $(d_{out}, d_{in})$:
- **Full fine-tuning**: $d_{out} \times d_{in}$ trainable parameters
- **LoRA**: $r \times (d_{in} + d_{out})$ trainable parameters

**Example**: A linear layer with shape (4096, 4096):
- Full: 16.7M parameters
- LoRA (r=8): 65.5K parameters → **255× fewer!**

### Initialization Strategy

- **A** is initialized with Kaiming/He initialization (random)
- **B** is initialized to **zero**

This means $\Delta W = BA = 0$ at the start, so the model begins exactly at the pretrained solution. Training then learns the minimal adjustment needed.

### The Scaling Factor $\alpha/r$

The ratio $\alpha/r$ controls how much the low-rank update contributes:
- Larger $\alpha$ → stronger adaptation signal
- Common practice: set $\alpha = 2r$ (so scaling = 2.0)
- This decouples the learning rate from the choice of rank

### Why LoRA Works Well

1. **Preserves pretrained knowledge**: Frozen base weights retain learned features
2. **Regularization effect**: Low-rank constraint prevents overfitting on small datasets
3. **Efficient multi-task**: Store only small adapter weights per task, share the base model
4. **No inference overhead**: Can merge $BA$ into $W$ after training: $W_{merged} = W + \frac{\alpha}{r}BA$

In [None]:
class LoRALinear(nn.Module):
    """
    Wraps a frozen nn.Linear with a trainable low-rank adapter.
    
    Forward pass computes: y = W_frozen @ x + (alpha/r) * B @ A @ x
    
    Parameters
    ----------
    linear : nn.Linear
        The original linear layer to wrap (will be frozen)
    r : int
        Rank of the low-rank adaptation matrices A and B
    alpha : float
        Scaling factor; the update is scaled by alpha/r
    dropout : float
        Dropout probability applied to input before the low-rank path
    """
    def __init__(self, linear: nn.Linear, r: int = 8, alpha: float = 16.0, dropout: float = 0.0):
        super().__init__()
        assert isinstance(linear, nn.Linear)
        
        # Store and freeze the original pretrained weights
        self.base = linear
        for p in self.base.parameters():
            p.requires_grad = False

        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.r = int(r)
        self.alpha = float(alpha)
        self.scaling = self.alpha / max(self.r, 1)  # alpha/r scaling factor
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        if self.r > 0:
            # A: projects input from d_in -> r (down-projection)
            # B: projects from r -> d_out (up-projection)
            # ΔW = B @ A has shape (d_out, d_in) but rank at most r
            self.A = nn.Parameter(torch.zeros(self.r, self.in_features))
            self.B = nn.Parameter(torch.zeros(self.out_features, self.r))
            
            # Initialize A with Kaiming uniform (like nn.Linear)
            nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
            # Initialize B to zero so ΔW = 0 at start (begin at pretrained solution)
            nn.init.zeros_(self.B)
        else:
            self.register_parameter("A", None)
            self.register_parameter("B", None)

    def forward(self, x):
        # Frozen pretrained path: y = W @ x + b
        y = self.base(x)
        
        if self.r > 0:
            # Low-rank adaptation path: Δy = (alpha/r) * (x @ A.T) @ B.T
            # Equivalent to: Δy = (alpha/r) * x @ (B @ A).T
            x_d = self.dropout(x)
            delta = (x_d @ self.A.t()) @ self.B.t()  # shape: (batch, d_out)
            y = y + self.scaling * delta
        return y


def apply_lora_to_linears(model: nn.Module, r: int, alpha: float, dropout: float):
    """
    Recursively replace all nn.Linear layers in the model with LoRALinear.
    
    This freezes the original weights and adds trainable low-rank adapters.
    """
    for name, module in list(model.named_children()):
        if isinstance(module, nn.Linear):
            setattr(model, name, LoRALinear(module, r=r, alpha=alpha, dropout=dropout))
        else:
            apply_lora_to_linears(module, r=r, alpha=alpha, dropout=dropout)


def set_trainable(model: nn.Module, trainable: bool):
    """Set requires_grad for all parameters in the model."""
    for p in model.parameters():
        p.requires_grad = trainable


def set_trainable_last_layer(model: nn.Module):
    """Freeze all weights except the final classification layer."""
    set_trainable(model, False)
    for p in model.classifier[-1].parameters():
        p.requires_grad = True


def set_trainable_lora_only(model: nn.Module):
    """Freeze everything except LoRA adapter matrices (A and B)."""
    set_trainable(model, False)
    for m in model.modules():
        if isinstance(m, LoRALinear):
            if m.A is not None: m.A.requires_grad = True
            if m.B is not None: m.B.requires_grad = True

## 7. Create adaptation variants

- **Full fine-tune**
- **Head-only**
- **LoRA-only**

All start from the same base weights.


In [None]:
import copy

def clone_model(m: nn.Module) -> nn.Module:
    return copy.deepcopy(m)

m_full = clone_model(base_model).to(device)
set_trainable(m_full, True)

m_head = clone_model(base_model).to(device)
set_trainable_last_layer(m_head)

# For LoRA: apply adapters BEFORE moving to device, so A and B tensors are also moved
m_lora = clone_model(base_model)
apply_lora_to_linears(m_lora, r=cfg.lora_rank, alpha=cfg.lora_alpha, dropout=cfg.lora_dropout)
m_lora = m_lora.to(device)  # Now move everything (including A, B) to device
set_trainable_lora_only(m_lora)

print("Full FT params (total, trainable):", count_params(m_full))
print("Head-only params (total, trainable):", count_params(m_head))
print("LoRA params (total, trainable):", count_params(m_lora))

trainable_names = [n for n,p in m_lora.named_parameters() if p.requires_grad]
print("LoRA trainable tensors:", len(trainable_names))
print("Example trainables:", trainable_names[:10])

## 8. Fine-tune on the shifted few-shot target data

We train each method on `target_loader` and evaluate on `test_loader`.


In [None]:
def fit(model: nn.Module, train_loader, test_loader, lr: float, epochs: int, label: str):
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr)
    hist = {"test_acc": [], "time_s": []}

    for ep in range(1, epochs+1):
        t0 = time.time()
        tl, ta = train_epoch(model, train_loader, opt, loss_fn)
        va = eval_model(model, test_loader)
        hist["test_acc"].append(va)
        hist["time_s"].append(time.time() - t0)
        print(f"[{label}] epoch {ep:02d} loss={tl:.4f} train_acc={ta:.3f} test_acc={va:.3f} time={hist['time_s'][-1]:.2f}s")
    return hist

hist_full = fit(m_full, target_loader, test_loader, cfg.lr_target, cfg.target_epochs, "FULL")
hist_head = fit(m_head, target_loader, test_loader, cfg.lr_target, cfg.target_epochs, "HEAD")
hist_lora = fit(m_lora, target_loader, test_loader, cfg.lr_target, cfg.target_epochs, "LoRA")

plt.figure(figsize=(6,4))
plt.plot(hist_full["test_acc"], label="Full FT")
plt.plot(hist_head["test_acc"], label="Head-only")
plt.plot(hist_lora["test_acc"], label="LoRA")
plt.xlabel("epoch")
plt.ylabel("test accuracy")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(6,4))
plt.plot(np.cumsum(hist_full["time_s"]), label="Full FT")
plt.plot(np.cumsum(hist_head["time_s"]), label="Head-only")
plt.plot(np.cumsum(hist_lora["time_s"]), label="LoRA")
plt.xlabel("epoch")
plt.ylabel("cumulative seconds")
plt.legend()
plt.tight_layout()
plt.show()


## 9. Final evaluation (classification report)

How do the adapted models behave across classes?


In [None]:
@torch.no_grad()
def predict(model, loader):
    model.eval()
    ys, ps = [], []
    for x, y in loader:
        x = x.to(device)
        logits = model(x)
        ps.append(logits.argmax(dim=1).cpu().numpy())
        ys.append(y.numpy())
    return np.concatenate(ys), np.concatenate(ps)

y_true, p_full = predict(m_full, test_loader)
_, p_head = predict(m_head, test_loader)
_, p_lora = predict(m_lora, test_loader)

print("=== Full fine-tune ===")
print(classification_report(y_true, p_full, zero_division=0))
print("=== Head-only ===")
print(classification_report(y_true, p_head, zero_division=0))
print("=== LoRA ===")
print(classification_report(y_true, p_lora, zero_division=0))


## 9. Discussion & extensions

1. Which method performs best **per trainable parameter**?
2. When does LoRA beat head-only fine-tuning?
3. How does rank `r` change the tradeoff?

**Extensions**
- Make a **shifted test set** (apply rotation+noise at test time) and evaluate there too.
- Try `r ∈ {2, 4, 8, 16}` and plot accuracy vs trainable params.
- Swap the base CNN for a pretrained model and repeat (more realistic).

---

### What models is LoRA designed for?

You may notice LoRA is *slower* than full fine-tuning here. This is expected for small models—LoRA adds extra matrix multiplications in the forward pass, and the backward-pass savings are minimal at this scale. LoRA's speed advantage only emerges on large models (millions+ parameters) where backward-pass computation dominates and memory savings enable larger batch sizes.

| Model Size | LoRA Speed Benefit |
|------------|-------------------|
| <1M params | Slower (like this tutorial) |
| 1–10M params | Break-even |
| 10–100M params | Modest speedup |
| 100M+ params | Significant speedup + memory savings |

To demonstrate LoRA's speed benefits, you'd need a larger model such as ResNet-18 (~11M params) or a small ViT (~5-10M params).

For models like BERT-base (110M params) or GPT-2 (124M params), full fine-tuning requires storing gradients + Adam optimizer states (2× params for momentum and variance) for all weights. LoRA cuts this dramatically, enabling fine-tuning on GPUs that couldn't otherwise fit the training.

LoRA was designed for foundation models (100M–100B+ params) where full fine-tuning is impractical or impossible. At the ~400K param scale used here, the value is in **parameter efficiency** and **regularization**, not raw speed. This tutorial makes that tradeoff visible—which is useful for understanding when to reach for LoRA in practice.