# Practical 3 (Path 1) — Parameter-Efficient Fine-Tuning with LoRA (Vision, Minimal Dependencies)

**Road to SKA: Foundation Models, Embeddings, and Latent Spaces**

This practical demonstrates **Path 1**: LoRA on a small vision encoder *without* heavy EO plumbing.

You will:

1. Train a **base CNN classifier** (a stand-in for a pretrained foundation encoder).
2. Freeze the base weights and adapt to a **shifted target task** using:
   - **Full fine-tuning** (update all weights)
   - **Head-only fine-tuning** (update only last layer)
   - **LoRA fine-tuning** (update only small low-rank adapter matrices)
3. Compare:
   - accuracy vs training time
   - number of trainable parameters

---

## References
- LoRA paper (Hu et al., 2021 / ICLR 2022): https://arxiv.org/abs/2106.09685  
- Microsoft LoRA repo / `loralib`: https://github.com/microsoft/LoRA  
- Hugging Face PEFT LoRA docs (concept + config): https://huggingface.co/docs/peft/en/conceptual_guides/lora  

> We implement LoRA **directly in PyTorch** here so it works anywhere (no Transformers dependency).


## 0. Setup

CPU-friendly. Uses GPU if available.

Required:
- `torch`, `torchvision`, `numpy`, `matplotlib`, `scikit-learn`

Optional:
- `tqdm` (progress bars)


In [None]:
# Optional installs (uncomment if needed)
# %pip -q install torch torchvision scikit-learn tqdm

import time
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score, classification_report

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = lambda x, **kwargs: x


## 1. Configuration

We create a **domain shift**:

- **Base training:** MNIST with standard images  
- **Target adaptation:** MNIST with rotation + noise, and **few-shot** labels

This mimics adapting a pretrained encoder to a new domain with limited labels.


In [None]:
@dataclass
class Config:
    data_dir: str = "./data"
    seed: int = 42
    batch_size: int = 128
    base_epochs: int = 5
    target_epochs: int = 5
    lr_base: float = 1e-3
    lr_target: float = 1e-3
    max_base_train: int = 20000
    max_target_train: int = 2000
    max_test: int = 5000
    lora_rank: int = 8
    lora_alpha: float = 16.0
    lora_dropout: float = 0.0
    rotate_deg: float = 25.0
    noise_std: float = 0.15

cfg = Config()

random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

# Device selection: CUDA > MPS (Apple Silicon) > CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Device:", device)
print(cfg)

## 2. Data: Base vs Target

- Base transform: standard MNIST  
- Target transform: rotation + gaussian noise (synthetic shift)


In [None]:
class AddGaussianNoise:
    def __init__(self, std: float = 0.1):
        self.std = std
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp(x + torch.randn_like(x) * self.std, 0.0, 1.0)

base_tfm = transforms.Compose([transforms.ToTensor()])
target_tfm = transforms.Compose([
    transforms.RandomRotation(degrees=cfg.rotate_deg),
    transforms.ToTensor(),
    AddGaussianNoise(std=cfg.noise_std),
])
test_tfm = transforms.Compose([transforms.ToTensor()])

base_train = datasets.MNIST(cfg.data_dir, train=True, download=True, transform=base_tfm)
target_train = datasets.MNIST(cfg.data_dir, train=True, download=True, transform=target_tfm)
test_ds = datasets.MNIST(cfg.data_dir, train=False, download=True, transform=test_tfm)

def subset(ds, n: Optional[int], seed: int):
    if n is None or n >= len(ds):
        return ds
    rng = np.random.default_rng(seed)
    idx = rng.choice(len(ds), size=n, replace=False)
    return Subset(ds, idx)

base_train_s = subset(base_train, cfg.max_base_train, cfg.seed)
target_train_s = subset(target_train, cfg.max_target_train, cfg.seed + 1)
test_s = subset(test_ds, cfg.max_test, cfg.seed + 2)

# num_workers=0 for compatibility (set higher if you have multiprocessing issues resolved)
base_loader = DataLoader(base_train_s, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
target_loader = DataLoader(target_train_s, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_s, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

# Visualise base vs target samples
xb, yb = next(iter(base_loader))
xt, yt = next(iter(target_loader))

fig, axes = plt.subplots(2, 8, figsize=(10, 3))
for i, ax in enumerate(axes[0]):
    ax.imshow(xb[i,0], cmap="gray"); ax.set_title(int(yb[i])); ax.axis("off")
for i, ax in enumerate(axes[1]):
    ax.imshow(xt[i,0], cmap="gray"); ax.set_title(int(yt[i])); ax.axis("off")
axes[0,0].set_ylabel("Base")
axes[1,0].set_ylabel("Target")
plt.tight_layout()
plt.show()

## 3. Base model (small CNN)

This is our “pretrained encoder + classifier head”.


In [None]:
class SmallCNN(nn.Module):
    def __init__(self, n_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128), nn.ReLU(),
            nn.Linear(128, n_classes),
        )

    def forward(self, x):
        return self.classifier(self.features(x))

def count_params(model: nn.Module) -> Tuple[int, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

base_model = SmallCNN().to(device)
print("Params (total, trainable):", count_params(base_model))


## 4. Pretrain the base model

Short training just to get a reasonable base model.


In [None]:
def train_epoch(model, loader, opt, loss_fn):
    model.train()
    total_loss = 0.0
    y_true, y_pred = [], []
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
        y_true.append(y.detach().cpu().numpy())
        y_pred.append(logits.argmax(dim=1).detach().cpu().numpy())
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    return total_loss / len(loader.dataset), accuracy_score(y_true, y_pred)

@torch.no_grad()
def eval_model(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        y_true.append(y.detach().cpu().numpy())
        y_pred.append(logits.argmax(dim=1).detach().cpu().numpy())
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    return accuracy_score(y_true, y_pred)

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(base_model.parameters(), lr=cfg.lr_base)

hist_base = {"train_acc": [], "test_acc": []}
for ep in range(1, cfg.base_epochs + 1):
    tl, ta = train_epoch(base_model, base_loader, opt, loss_fn)
    va = eval_model(base_model, test_loader)
    hist_base["train_acc"].append(ta)
    hist_base["test_acc"].append(va)
    print(f"[Base] epoch {ep:02d} loss={tl:.4f} train_acc={ta:.3f} test_acc={va:.3f}")

plt.figure(figsize=(6,4))
plt.plot(hist_base["train_acc"], label="train acc")
plt.plot(hist_base["test_acc"], label="test acc")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend()
plt.tight_layout()
plt.show()


## 5. LoRA implementation (Linear layers)

LoRA adds a low-rank update to a frozen linear weight:

\[
W' = W + \Delta W,\quad \Delta W = \frac{\alpha}{r}BA
\]

We implement LoRA for `nn.Linear` and then swap all linear layers in the model.


In [None]:
class LoRALinear(nn.Module):
    """Wrap a frozen nn.Linear with a trainable low-rank adapter."""
    def __init__(self, linear: nn.Linear, r: int = 8, alpha: float = 16.0, dropout: float = 0.0):
        super().__init__()
        assert isinstance(linear, nn.Linear)
        self.base = linear
        for p in self.base.parameters():
            p.requires_grad = False

        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.r = int(r)
        self.alpha = float(alpha)
        self.scaling = self.alpha / max(self.r, 1)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        if self.r > 0:
            self.A = nn.Parameter(torch.zeros(self.r, self.in_features))
            self.B = nn.Parameter(torch.zeros(self.out_features, self.r))
            nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
            nn.init.zeros_(self.B)  # start from base model
        else:
            self.register_parameter("A", None)
            self.register_parameter("B", None)

    def forward(self, x):
        y = self.base(x)
        if self.r > 0:
            x_d = self.dropout(x)
            delta = (x_d @ self.A.t()) @ self.B.t()
            y = y + self.scaling * delta
        return y

def apply_lora_to_linears(model: nn.Module, r: int, alpha: float, dropout: float):
    """Recursively replace all nn.Linear layers with LoRALinear."""
    for name, module in list(model.named_children()):
        if isinstance(module, nn.Linear):
            setattr(model, name, LoRALinear(module, r=r, alpha=alpha, dropout=dropout))
        else:
            apply_lora_to_linears(module, r=r, alpha=alpha, dropout=dropout)

def set_trainable(model: nn.Module, trainable: bool):
    for p in model.parameters():
        p.requires_grad = trainable

def set_trainable_last_layer(model: nn.Module):
    set_trainable(model, False)
    for p in model.classifier[-1].parameters():
        p.requires_grad = True

def set_trainable_lora_only(model: nn.Module):
    set_trainable(model, False)
    for m in model.modules():
        if isinstance(m, LoRALinear):
            if m.A is not None: m.A.requires_grad = True
            if m.B is not None: m.B.requires_grad = True


## 6. Create adaptation variants

- **Full fine-tune**
- **Head-only**
- **LoRA-only**

All start from the same base weights.


In [None]:
import copy

def clone_model(m: nn.Module) -> nn.Module:
    return copy.deepcopy(m)

m_full = clone_model(base_model).to(device)
set_trainable(m_full, True)

m_head = clone_model(base_model).to(device)
set_trainable_last_layer(m_head)

m_lora = clone_model(base_model).to(device)
apply_lora_to_linears(m_lora, r=cfg.lora_rank, alpha=cfg.lora_alpha, dropout=cfg.lora_dropout)
set_trainable_lora_only(m_lora)

print("Full FT params (total, trainable):", count_params(m_full))
print("Head-only params (total, trainable):", count_params(m_head))
print("LoRA params (total, trainable):", count_params(m_lora))

trainable_names = [n for n,p in m_lora.named_parameters() if p.requires_grad]
print("LoRA trainable tensors:", len(trainable_names))
print("Example trainables:", trainable_names[:10])


## 7. Fine-tune on the shifted few-shot target data

We train each method on `target_loader` and evaluate on `test_loader`.


In [None]:
def fit(model: nn.Module, train_loader, test_loader, lr: float, epochs: int, label: str):
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr)
    hist = {"test_acc": [], "time_s": []}

    for ep in range(1, epochs+1):
        t0 = time.time()
        tl, ta = train_epoch(model, train_loader, opt, loss_fn)
        va = eval_model(model, test_loader)
        hist["test_acc"].append(va)
        hist["time_s"].append(time.time() - t0)
        print(f"[{label}] epoch {ep:02d} loss={tl:.4f} train_acc={ta:.3f} test_acc={va:.3f} time={hist['time_s'][-1]:.2f}s")
    return hist

hist_full = fit(m_full, target_loader, test_loader, cfg.lr_target, cfg.target_epochs, "FULL")
hist_head = fit(m_head, target_loader, test_loader, cfg.lr_target, cfg.target_epochs, "HEAD")
hist_lora = fit(m_lora, target_loader, test_loader, cfg.lr_target, cfg.target_epochs, "LoRA")

plt.figure(figsize=(6,4))
plt.plot(hist_full["test_acc"], label="Full FT")
plt.plot(hist_head["test_acc"], label="Head-only")
plt.plot(hist_lora["test_acc"], label="LoRA")
plt.xlabel("epoch")
plt.ylabel("test accuracy")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(6,4))
plt.plot(np.cumsum(hist_full["time_s"]), label="Full FT")
plt.plot(np.cumsum(hist_head["time_s"]), label="Head-only")
plt.plot(np.cumsum(hist_lora["time_s"]), label="LoRA")
plt.xlabel("epoch")
plt.ylabel("cumulative seconds")
plt.legend()
plt.tight_layout()
plt.show()


## 8. Final evaluation (classification report)

How do the adapted models behave across classes?


In [None]:
@torch.no_grad()
def predict(model, loader):
    model.eval()
    ys, ps = [], []
    for x, y in loader:
        x = x.to(device)
        logits = model(x)
        ps.append(logits.argmax(dim=1).cpu().numpy())
        ys.append(y.numpy())
    return np.concatenate(ys), np.concatenate(ps)

y_true, p_full = predict(m_full, test_loader)
_, p_head = predict(m_head, test_loader)
_, p_lora = predict(m_lora, test_loader)

print("=== Full fine-tune ===")
print(classification_report(y_true, p_full, zero_division=0))
print("=== Head-only ===")
print(classification_report(y_true, p_head, zero_division=0))
print("=== LoRA ===")
print(classification_report(y_true, p_lora, zero_division=0))


## 9. Discussion & extensions

1. Which method performs best **per trainable parameter**?
2. When does LoRA beat head-only fine-tuning?
3. How does rank `r` change the tradeoff?

**Extensions**
- Make a **shifted test set** (apply rotation+noise at test time) and evaluate there too.
- Try `r ∈ {2, 4, 8, 16}` and plot accuracy vs trainable params.
- Swap the base CNN for a pretrained model and repeat (more realistic).
