# OSFT (Orthogonal Subspace Fine-Tuning) Tutorial Notebook

## by Frank La Vigne

Alright, here’s the deal. OSFT is all about teaching your model new tricks without it forgetting the old ones. If you’ve ever fine-tuned a model and watched it suddenly get dumber at stuff it used to know — that’s catastrophic forgetting. OSFT is how we fight back.

## Key Ideas
- Break down weight matrices with SVD (like putting on X-Ray specs for your model).
- Spot which directions in parameter space are pulling their weight vs. which ones are idle.
- Keep updates out of the “critical” directions and funnel them into the unused space.
- End result: your model learns new stuff without trashing the old knowledge.

# 1. Install Required Libraries

In [None]:
! pip install -r requirements.txt

# 2. Warm-Up with NumPy: Seeing the Subspace

Before we touch PyTorch, let’s warm up with NumPy. We’ll:
1. Take a toy weight matrix.
2. Run SVD to split it into important vs. not-so-important directions.
3. Project a gradient update into the “safe” zone.

In [None]:
import numpy as np

# A toy weight matrix (e.g. from a linear layer)
W = np.array([[2.0, 0.5, 0.0],
              [0.0, 1.5, 0.1],
              [0.0, 0.0, 0.2]])
print("Original weight matrix W:\n", W)

# Perform SVD decomposition
U, S, Vt = np.linalg.svd(W)
print("Singular values:", S)

# Define high-rank vs low-rank subspaces
rank_cutoff = 1  # keep top-1 singular vector as important
U_high = U[:, :rank_cutoff]
V_high = Vt[:rank_cutoff, :].T

# Any gradient update
grad = np.array([[0.1, -0.2, 0.05],
                 [0.05, 0.1, -0.1],
                 [-0.2, 0.0, 0.2]])

# Project gradient onto low-rank subspace (orthogonal to U_high, V_high)
proj = grad - U_high @ (U_high.T @ grad @ V_high) @ V_high.T

print("Original gradient:\n", grad)
print("Projected gradient (OSFT):\n", proj)

Look at that — the projected gradient steers clear of the critical directions. That’s OSFT on training wheels.

# 3. How the Training Loop Looks (Pseudo-code)

Here’s the play-by-play in pseudocode:

In [None]:
'''
for each training step:
    for each layer l in model:
        W = layer.weight
        U, S, Vt = svd(W)
        r = retention_ratio(layer)  # based on importance
        U_high = U[:, :r]
        V_high = Vt[:r, :].T

        grad = compute_gradient(layer)
        grad_proj = grad - U_high @ (U_high.T @ grad @ V_high) @ V_high.T

        apply_update(layer, grad_proj)
'''

# 4. Hands-on PyTorch Demo

Time to level up. Let’s try this on a small PyTorch model. We’ll train it on dummy data, then apply OSFT-style projection to the gradients.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a small model
class SmallNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 3, bias=False)
    def forward(self, x):
        return self.fc(x)

model = SmallNet()

# Dummy data
x = torch.randn(5, 3)
y = torch.randn(5, 3)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.1)

# Forward + backward
loss_fn = nn.MSELoss()
output = model(x)
loss = loss_fn(output, y)
loss.backward()

# Inspect original gradient
grad = model.fc.weight.grad.detach().numpy()
print("Original gradient:\n", grad)

# Perform SVD on weights
W = model.fc.weight.detach().numpy()
U, S, Vt = np.linalg.svd(W)

# Keep top-1 singular vector as "critical"
rank_cutoff = 1
U_high = U[:, :rank_cutoff]
V_high = Vt[:rank_cutoff, :].T

# Project PyTorch gradient into low-rank subspace
grad_proj = grad - U_high @ (U_high.T @ grad @ V_high) @ V_high.T

# Replace gradient in model with projected version
model.fc.weight.grad = torch.from_numpy(grad_proj).float()

# Apply update
opt.step()

print("Updated weights (after OSFT-style projection):\n", model.fc.weight.data)

There it is. The model took a step — but only in the directions we allowed. That’s how OSFT threads the needle: new learning without wrecking the old foundation.

# 5. Why This Matters
- **Text classification sequences:** Keeps performance steady across 5, 10, 15+ tasks.
- **TRACE benchmark:** Boosted LLaMA-2-7B’s accuracy by ~7 points over O-LoRA.
- **Enterprise bots:** Add new product knowledge without erasing FAQs from last year.
- **Medical/legal models:** Stay current with new research, but don’t forget the basics.

# 6. Wrap-Up
OSFT in a nutshell:
- SVD shows us the model’s “critical” vs. “safe” directions.
- We project updates into the safe zone.
- The model grows new skills while keeping the old ones sharp.

Think of it like renovating a house. OSFT adds a new room without tearing down the walls that are already holding up the place.