# LoRA: Low-Rank Adaptation from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/lora_fine_tuning.ipynb)

This notebook implements LoRA (Low-Rank Adaptation) from scratch — the parameter-efficient fine-tuning method that lets you adapt massive models by training only a tiny fraction of parameters.

We cover:
1. Why full fine-tuning is expensive
2. Low-rank matrix decomposition intuition
3. LoRA implementation from raw operations
4. Training a LoRA-adapted model
5. Adapter merging for zero-overhead inference
6. Rank analysis and parameter savings

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import matplotlib.pyplot as plt
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 0. Mathematical Foundations

### The Fine-Tuning Problem

A pretrained model has weight matrix $W_0 \in \mathbb{R}^{d \times k}$. Full fine-tuning learns:

$$W = W_0 + \Delta W$$

where $\Delta W \in \mathbb{R}^{d \times k}$ has the same size as $W_0$. For a 7B model, $\Delta W$ has **billions** of parameters.

### LoRA Insight

Aghajanyan et al. (2021) showed that pretrained models have **low intrinsic dimensionality** — the weight updates during fine-tuning lie in a low-rank subspace.

LoRA (Hu et al., 2021) constrains $\Delta W$ to be low-rank:

$$\Delta W = B \cdot A$$

where $B \in \mathbb{R}^{d \times r}$ and $A \in \mathbb{R}^{r \times k}$, with $r \ll \min(d, k)$.

The adapted forward pass becomes:

$$h = W_0 x + \Delta W x = W_0 x + B A x$$

### Parameter Savings

For $W_0 \in \mathbb{R}^{4096 \times 4096}$:
- Full fine-tuning: $4096 \times 4096 = 16.8M$ parameters
- LoRA ($r=16$): $(4096 \times 16) + (16 \times 4096) = 131K$ parameters
- **Reduction: 128x fewer parameters!**

### Initialization

- $A$ is initialized with small random values (Gaussian)
- $B$ is initialized to **zero** → $\Delta W = BA = 0$ at start
- This ensures the adapted model starts exactly at the pretrained model

### Scaling Factor

LoRA uses a scaling factor $\alpha / r$:

$$h = W_0 x + \frac{\alpha}{r} B A x$$

This makes it easier to tune $\alpha$ independently of $r$.

## 1. Low-Rank Approximation Intuition

In [None]:
# Demonstrate: weight update matrices are approximately low-rank
torch.manual_seed(42)
d, k = 64, 64

# Simulate: pretrained weights and fine-tuned weights
W_pretrained = torch.randn(d, k, device=device)

# Simulate fine-tuning: the change is approximately low-rank
# (real fine-tuning updates tend to lie in a low-dimensional subspace)
true_rank = 4
delta_B = torch.randn(d, true_rank, device=device) * 0.1
delta_A = torch.randn(true_rank, k, device=device) * 0.1
delta_W = delta_B @ delta_A  # rank-4 update
delta_W += torch.randn(d, k, device=device) * 0.001  # small noise

# Compute SVD to see the rank structure
U, S, Vh = torch.linalg.svd(delta_W)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Singular values — sharp drop shows low-rank structure
axes[0].bar(range(len(S)), S.cpu().numpy())
axes[0].set_xlabel('Singular value index')
axes[0].set_ylabel('Magnitude')
axes[0].set_title('Singular Values of ΔW\n(Sharp drop → low-rank structure)')
axes[0].axvline(x=true_rank - 0.5, color='red', linestyle='--', label=f'True rank = {true_rank}')
axes[0].legend()

# Reconstruction error vs rank
errors = []
for r in range(1, min(d, k) + 1):
    approx = U[:, :r] @ torch.diag(S[:r]) @ Vh[:r, :]
    err = (delta_W - approx).norm().item() / delta_W.norm().item()
    errors.append(err)

axes[1].plot(range(1, len(errors) + 1), errors, linewidth=2)
axes[1].set_xlabel('Rank of approximation')
axes[1].set_ylabel('Relative error')
axes[1].set_title('Reconstruction Error vs Rank\n(rank 4 captures almost everything)')
axes[1].axvline(x=true_rank, color='red', linestyle='--', label=f'r = {true_rank}')
axes[1].set_yscale('log')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Cumulative energy
energy = (S ** 2).cumsum(0) / (S ** 2).sum()
axes[2].plot(range(1, len(energy) + 1), energy.cpu().numpy(), linewidth=2)
axes[2].set_xlabel('Number of singular values')
axes[2].set_ylabel('Fraction of variance explained')
axes[2].set_title('Cumulative Variance Explained')
axes[2].axhline(y=0.99, color='gray', linestyle='--', alpha=0.5, label='99%')
axes[2].axvline(x=true_rank, color='red', linestyle='--', label=f'r = {true_rank}')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 2. LoRA Implementation

In [None]:
def init_lora(d_in, d_out, rank, alpha=1.0, device=None):
    """Initialize LoRA adapter matrices.
    
    Args:
        d_in: input dimension
        d_out: output dimension  
        rank: LoRA rank (r)
        alpha: scaling factor
    
    Returns:
        dict with A, B matrices and scaling factor
    """
    lora = {
        'A': torch.randn(rank, d_in, device=device) * (1 / math.sqrt(d_in)),  # Kaiming-like
        'B': torch.zeros(d_out, rank, device=device),  # Zero init → ΔW = 0 at start
        'scaling': alpha / rank,
    }
    # Mark as trainable (requires_grad)
    lora['A'].requires_grad_(True)
    lora['B'].requires_grad_(True)
    return lora

def lora_forward(x, W_frozen, lora):
    """Forward pass through a LoRA-adapted linear layer.
    
    h = W_frozen @ x + (scaling * B @ A) @ x
    """
    # Original frozen path
    h = x @ W_frozen.T
    
    # LoRA adapter path
    # Efficient: compute A@x first (small), then B@(Ax)
    lora_out = x @ lora['A'].T  # (..., rank)  — project to low rank
    lora_out = lora_out @ lora['B'].T  # (..., d_out) — project back up
    
    return h + lora['scaling'] * lora_out

# Test
torch.manual_seed(42)
d_in, d_out = 64, 64
rank = 4

# Pretrained (frozen) weight
W_frozen = torch.randn(d_out, d_in, device=device) * 0.1
W_frozen.requires_grad_(False)  # Frozen!

# LoRA adapter
lora = init_lora(d_in, d_out, rank, alpha=8.0, device=device)

# Forward pass
x = torch.randn(2, 8, d_in, device=device)
out = lora_forward(x, W_frozen, lora)

print(f'Input shape: {x.shape}')
print(f'Output shape: {out.shape}')
print(f'\nParameter count:')
print(f'  Frozen W: {d_in * d_out} (NOT trained)')
print(f'  LoRA A:   {rank * d_in}')
print(f'  LoRA B:   {d_out * rank}')
print(f'  Total trainable: {rank * d_in + d_out * rank}')
print(f'  Reduction: {d_in * d_out / (rank * d_in + d_out * rank):.1f}x fewer trainable params')
print(f'\nAt initialization (B=0), ΔW is zero:')
delta_W = lora['scaling'] * (lora['B'] @ lora['A'])
print(f'  ||ΔW|| = {delta_W.norm().item():.10f}')

## 3. Training a LoRA-Adapted Model

Let's train a small transformer layer with LoRA on a simple pattern-learning task.

In [None]:
# Create a simple task: learn to shift a pattern
torch.manual_seed(42)
d_model = 32
seq_len = 8
n_samples = 200
rank = 4

# Pretrained model: a simple linear layer (frozen)
W_pretrained = torch.randn(d_model, d_model, device=device) * 0.1
W_pretrained.requires_grad_(False)

# Task data: input → target (a nonlinear transformation)
X_data = torch.randn(n_samples, seq_len, d_model, device=device)

# Target: a specific low-rank transformation of input
W_target_B = torch.randn(d_model, 4, device=device) * 0.3
W_target_A = torch.randn(4, d_model, device=device) * 0.3
Y_data = X_data @ W_pretrained.T + X_data @ W_target_A.T @ W_target_B.T
Y_data = Y_data.detach()

print(f'Task: learn a rank-4 update to a frozen {d_model}x{d_model} weight matrix')
print(f'Training data: {n_samples} samples, seq_len={seq_len}, d_model={d_model}')

In [None]:
# Train with LoRA
lora_adapter = init_lora(d_model, d_model, rank=rank, alpha=8.0, device=device)
lr = 0.01
n_epochs = 100

losses_lora = []

for epoch in range(n_epochs):
    # Forward
    Y_pred = lora_forward(X_data, W_pretrained, lora_adapter)
    loss = ((Y_pred - Y_data) ** 2).mean()
    losses_lora.append(loss.item())
    
    # Backward
    loss.backward()
    
    # Update only LoRA parameters (W_pretrained is frozen)
    with torch.no_grad():
        lora_adapter['A'] -= lr * lora_adapter['A'].grad
        lora_adapter['B'] -= lr * lora_adapter['B'].grad
        lora_adapter['A'].grad.zero_()
        lora_adapter['B'].grad.zero_()
    
    if epoch % 20 == 0:
        print(f'Epoch {epoch:3d}: loss = {loss.item():.6f}')

print(f'\nFinal loss: {losses_lora[-1]:.6f}')

In [None]:
# Compare: full fine-tuning with same data
W_full = W_pretrained.clone().detach().requires_grad_(True)
losses_full = []

for epoch in range(n_epochs):
    Y_pred = X_data @ W_full.T
    loss = ((Y_pred - Y_data) ** 2).mean()
    losses_full.append(loss.item())
    
    loss.backward()
    with torch.no_grad():
        W_full -= lr * W_full.grad
        W_full.grad.zero_()

print(f'Full fine-tuning final loss: {losses_full[-1]:.6f}')
print(f'LoRA (rank={rank}) final loss: {losses_lora[-1]:.6f}')

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(losses_full, linewidth=2, label=f'Full fine-tuning ({d_model*d_model} params)')
ax.plot(losses_lora, linewidth=2, label=f'LoRA rank={rank} ({rank*d_model*2} params)')
ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('LoRA vs Full Fine-Tuning')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
plt.tight_layout()
plt.show()

print(f'\nLoRA uses {d_model*d_model / (rank*d_model*2):.1f}x fewer trainable parameters')
print(f'and achieves comparable loss on this low-rank task!')

## 4. Adapter Merging — Zero-Cost Inference

After training, LoRA adapters can be **merged** into the frozen weights for zero additional inference cost.

In [None]:
def merge_lora(W_frozen, lora):
    """Merge LoRA adapter into frozen weights.
    
    W_merged = W_frozen + scaling * B @ A
    
    After merging, inference uses a single matrix multiply — no overhead.
    """
    delta_W = lora['scaling'] * (lora['B'] @ lora['A'])
    return W_frozen + delta_W

def unmerge_lora(W_merged, W_frozen, lora):
    """Reverse: extract LoRA from merged weights (to swap adapters)."""
    return W_merged - lora['scaling'] * (lora['B'] @ lora['A'])

# Merge and verify
W_merged = merge_lora(W_pretrained, lora_adapter)

# Forward with separate LoRA
x_test = torch.randn(1, 4, d_model, device=device)
out_separate = lora_forward(x_test, W_pretrained, lora_adapter)

# Forward with merged weights (single matmul — no overhead!)
out_merged = x_test @ W_merged.T

print('Verification: separate LoRA vs merged weights')
print(f'  Max difference: {(out_separate - out_merged).abs().max().item():.2e}')
print(f'  Are equal: {torch.allclose(out_separate, out_merged, atol=1e-5)}')

print(f'\nInference cost:')
print(f'  With LoRA:  2 matmuls (frozen W + adapter B@A@x)')
print(f'  Merged:     1 matmul (same cost as original model!)')

In [None]:
# Visualize the learned ΔW and its rank structure
delta_W_learned = lora_adapter['scaling'] * (lora_adapter['B'] @ lora_adapter['A'])
U, S, Vh = torch.linalg.svd(delta_W_learned.detach())

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# The learned ΔW
im = axes[0].imshow(delta_W_learned.detach().cpu().numpy(), cmap='RdBu', aspect='auto')
axes[0].set_title(f'Learned ΔW = B @ A\n(rank = {rank})')
axes[0].set_xlabel('Input dim')
axes[0].set_ylabel('Output dim')
plt.colorbar(im, ax=axes[0])

# Singular values show the rank
axes[1].bar(range(min(20, len(S))), S[:20].cpu().numpy())
axes[1].set_xlabel('Singular value index')
axes[1].set_ylabel('Magnitude')
axes[1].set_title(f'Singular Values of Learned ΔW\n(exactly rank {rank})')
axes[1].axvline(x=rank - 0.5, color='red', linestyle='--', label=f'Rank boundary (r={rank})')
axes[1].legend()

# Comparison with full fine-tuning update
delta_full = (W_full - W_pretrained).detach()
U_f, S_f, _ = torch.linalg.svd(delta_full)
axes[2].bar(range(min(20, len(S_f))), S_f[:20].cpu().numpy(), alpha=0.6, label='Full FT ΔW')
axes[2].bar(range(min(20, len(S))), S[:20].cpu().numpy(), alpha=0.6, label=f'LoRA ΔW (r={rank})')
axes[2].set_xlabel('Singular value index')
axes[2].set_ylabel('Magnitude')
axes[2].set_title('ΔW Rank: Full vs LoRA')
axes[2].legend()

plt.tight_layout()
plt.show()

## 5. Effect of Rank

How does the choice of rank $r$ affect LoRA performance?

In [None]:
# Train LoRA with different ranks
ranks = [1, 2, 4, 8, 16, 32]
final_losses = []
param_counts = []

for r in ranks:
    torch.manual_seed(42)
    lora_r = init_lora(d_model, d_model, rank=r, alpha=8.0, device=device)
    
    for epoch in range(n_epochs):
        Y_pred = lora_forward(X_data, W_pretrained, lora_r)
        loss = ((Y_pred - Y_data) ** 2).mean()
        loss.backward()
        with torch.no_grad():
            lora_r['A'] -= lr * lora_r['A'].grad
            lora_r['B'] -= lr * lora_r['B'].grad
            lora_r['A'].grad.zero_()
            lora_r['B'].grad.zero_()
    
    final_losses.append(loss.item())
    param_counts.append(r * d_model * 2)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(ranks, final_losses, 'o-', linewidth=2, markersize=8)
axes[0].set_xlabel('LoRA Rank (r)')
axes[0].set_ylabel('Final Loss')
axes[0].set_title('Loss vs LoRA Rank')
axes[0].set_yscale('log')
axes[0].grid(True, alpha=0.3)
axes[0].axvline(x=4, color='red', linestyle='--', alpha=0.5, label='True rank of target = 4')
axes[0].legend()

axes[1].plot(param_counts, final_losses, 'o-', linewidth=2, markersize=8)
axes[1].set_xlabel('Trainable Parameters')
axes[1].set_ylabel('Final Loss')
axes[1].set_title('Loss vs Parameter Count')
axes[1].set_yscale('log')
axes[1].grid(True, alpha=0.3)
full_params = d_model * d_model
axes[1].axvline(x=full_params, color='gray', linestyle='--', alpha=0.5, label=f'Full FT ({full_params} params)')
axes[1].legend()

plt.tight_layout()
plt.show()

print('Rank vs Final Loss:')
for r, l, p in zip(ranks, final_losses, param_counts):
    print(f'  r={r:2d}: loss={l:.6f}, params={p:5d} ({100*p/full_params:.1f}% of full)')

## 6. Multi-Adapter Switching

A key advantage of LoRA: you can train multiple adapters for different tasks and swap them efficiently.

In [None]:
# Simulate multiple task-specific adapters on the same base model
torch.manual_seed(42)

# Base model (shared, frozen)
W_base = torch.randn(d_model, d_model, device=device) * 0.1

# Task A adapter: translate-like (shift pattern)
lora_task_A = init_lora(d_model, d_model, rank=4, alpha=8.0, device=device)
with torch.no_grad():
    lora_task_A['B'][:, 0] = torch.randn(d_model, device=device) * 0.5
    lora_task_A['A'][0, :] = torch.randn(d_model, device=device) * 0.5

# Task B adapter: scale-like (different transformation)
lora_task_B = init_lora(d_model, d_model, rank=4, alpha=8.0, device=device)
with torch.no_grad():
    lora_task_B['B'][:, 0] = torch.randn(d_model, device=device) * 0.3
    lora_task_B['A'][0, :] = -torch.randn(d_model, device=device) * 0.3

x_test = torch.randn(1, 4, d_model, device=device)

# Switch adapters — just change which A,B you use!
out_base = x_test @ W_base.T
out_A = lora_forward(x_test, W_base, lora_task_A)
out_B = lora_forward(x_test, W_base, lora_task_B)

print('Multi-adapter switching:')
print(f'  Base model output norm:  {out_base.norm().item():.4f}')
print(f'  + Task A adapter norm:   {out_A.norm().item():.4f}')
print(f'  + Task B adapter norm:   {out_B.norm().item():.4f}')
print(f'\nAdapter storage per task: {4 * d_model * 2 * 4 / 1024:.1f} KB (rank=4, d_model={d_model}, float32)')
print(f'Base model storage:       {d_model * d_model * 4 / 1024:.1f} KB')
print(f'\n→ Store one base model + tiny adapters for each task!')

In [None]:
# Real-world parameter savings
print('=' * 70)
print('PARAMETER SAVINGS: LoRA at Scale')
print('=' * 70)

model_configs = [
    ('7B (LLaMA)', 4096, 32),
    ('13B', 5120, 40),
    ('70B', 8192, 80),
]

lora_ranks = [4, 8, 16, 64]

for model_name, d, n_layers in model_configs:
    print(f'\n{model_name} (d_model={d}, {n_layers} layers):')
    # Each layer has W_Q, W_K, W_V, W_O — LoRA typically on Q and V
    n_adapted_matrices = 2 * n_layers  # Q and V per layer
    full_params = n_adapted_matrices * d * d
    print(f'  Full fine-tuning (Q,V only): {full_params/1e6:.1f}M params')
    
    for r in lora_ranks:
        lora_params = n_adapted_matrices * (r * d + d * r)
        reduction = full_params / lora_params
        print(f'  LoRA r={r:2d}: {lora_params/1e6:>8.2f}M params ({reduction:>6.0f}x reduction)')

print('\n' + '=' * 70)

## Summary

In this notebook we implemented from scratch:

1. **Low-rank approximation** — weight updates during fine-tuning lie in a low-dimensional subspace

2. **LoRA adapters** — decompose $\Delta W = BA$ where $B \in \mathbb{R}^{d \times r}$, $A \in \mathbb{R}^{r \times d}$, $r \ll d$

3. **Training** — only A and B are updated; base model stays frozen

4. **Adapter merging** — $W_{\text{merged}} = W_0 + \frac{\alpha}{r}BA$ gives zero-overhead inference

5. **Multi-task switching** — store one base model + tiny adapters per task

**Key insight:** Pretrained models live in a low-dimensional subspace. LoRA exploits this to reduce trainable parameters by 100-1000x while maintaining quality. This makes fine-tuning accessible: a 65B model can be adapted on a single GPU with QLoRA.