# üêá Down the Orthogonal Rabbit Hole

## Exploring Orthogonal Subspaces in Fine-Tuning AI Models

**by Frank La Vigne** ¬∑ Interactive Notebook Edition

---

| 90%+ Parameter Reduction | Stable Training | Zero Forgetting |
|:---:|:---:|:---:|

---

### Table of Contents

1. [What is OFT?](#1)
2. [Orthogonal Matrix Properties ‚Äî Interactive](#2)
3. [Visualizing Orthogonal Transformations](#3)
4. [Hyperspherical Energy Preservation](#4)
5. [The Deeper Rabbit Hole: OSFT](#5)
6. [SVD Decomposition Explorer ‚Äî Interactive](#6)
7. [Gradient Projection Demo ‚Äî Interactive](#7)
8. [Hands-on: NumPy Warm-Up](#8)
9. [Hands-on: PyTorch OSFT](#9)
10. [OFT vs LoRA ‚Äî Parameter Calculator](#10)
11. [Training Dynamics Simulation](#11)
12. [Real-World Applications](#12)
13. [Wrap-Up](#13)

---
## ‚öôÔ∏è Setup

Run this cell first to install dependencies and import everything we need.

In [None]:
# Install dependencies (uncomment if needed)
# !pip install numpy torch matplotlib ipywidgets scipy

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from scipy.stats import special_ortho_group
import warnings
warnings.filterwarnings('ignore')

# Plotting defaults
plt.rcParams.update({
    'figure.facecolor': '#0d0d0f',
    'axes.facecolor': '#16161a',
    'axes.edgecolor': '#2a2a32',
    'axes.labelcolor': '#e4e4e7',
    'text.color': '#e4e4e7',
    'xtick.color': '#8888a0',
    'ytick.color': '#8888a0',
    'grid.color': '#1f1f2a',
    'grid.alpha': 0.5,
    'figure.dpi': 120,
    'font.family': 'monospace',
    'font.size': 9,
})

# Utility: generate orthogonal matrix via QR
def generate_orthogonal_matrix(dim):
    random_matrix = np.random.randn(dim, dim)
    q, r = np.linalg.qr(random_matrix)
    d = np.diag(np.sign(np.diag(r)))
    return q @ d

print("Setup complete -- all imports loaded.")

---
<a id="1"></a>
## üö™ [1] Through the Looking Glass: What is OFT?

Just as Alice stepped through the looking glass and found a world that was familiar yet rearranged, **Orthogonal Fine-Tuning (OFT)** adapts pre-trained models by **rotating** their learned representations rather than distorting them.

OFT is a parameter-efficient fine-tuning (PEFT) technique that applies **orthogonal transformations** to weight matrices. Unlike LoRA, OFT preserves the *hyperspherical energy* -- the geometric relationships between neuron activations.

### üîÆ The Magic Mirror Properties

An orthogonal matrix **Q** satisfies: **Q^T Q = QQ^T = I**

- **Preserves distances:** `||Qx|| = ||x||`
- **Preserves angles:** The Cheshire Cat's grin keeps its shape
- **Identity when transposed:** `Q^T Q = I`
- **Represents rotations and reflections:** Turn the mirror, don't bend it
- **Determinant is +/-1:** The mirror's magic constant

### üé≠ The Looking Glass Formula

The key insight: **W' = W x R**

Where **W** is the original weight matrix, **R** is an orthogonal matrix learned during fine-tuning, and **W'** is the adapted weight matrix.

---
<a id="2"></a>
## üÉè [2] The Queen's Matrix Garden -- Interactive

An orthogonal matrix is like the Queen's decree -- it can rearrange the cards (rotate them), but must preserve their ranks and suits. **Run the cell below and click the button** to generate new random orthogonal matrices and verify their properties.

> *"Off with their heads!" cries the Queen -- but an orthogonal matrix is merciful. It moves the cards without changing their essence.*

In [None]:
# Interactive Matrix Explorer
button_gen = widgets.Button(description='Shuffle the Cards', layout=widgets.Layout(width='200px'))
output_matrix = widgets.Output()

def generate_and_display(_=None):
    with output_matrix:
        clear_output(wait=True)
        Q = generate_orthogonal_matrix(4)
        QTQ = Q.T @ Q
        det = np.linalg.det(Q)

        fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

        im0 = axes[0].imshow(Q, cmap='RdBu_r', vmin=-1, vmax=1, aspect='equal')
        axes[0].set_title('Matrix Q', fontsize=11, fontweight='bold')
        for i in range(4):
            for j in range(4):
                axes[0].text(j, i, f'{Q[i,j]:.3f}', ha='center', va='center',
                           fontsize=8, color='white' if abs(Q[i,j]) > 0.5 else '#aaa')
        axes[0].set_xticks([]); axes[0].set_yticks([])
        plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

        im1 = axes[1].imshow(QTQ, cmap='RdBu_r', vmin=-1, vmax=1, aspect='equal')
        axes[1].set_title('Q^T x Q  (should be Identity)', fontsize=11, fontweight='bold')
        for i in range(4):
            for j in range(4):
                val = QTQ[i,j]
                color = '#4ade80' if (i == j and abs(val - 1) < 0.01) else '#888'
                axes[1].text(j, i, f'{val:.3f}', ha='center', va='center',
                           fontsize=8, color=color, fontweight='bold' if i==j else 'normal')
        axes[1].set_xticks([]); axes[1].set_yticks([])
        plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
        plt.tight_layout()
        plt.show()

        is_orth = np.allclose(QTQ, np.eye(4), atol=1e-10)
        print(f"  Determinant:       {det:.6f}")
        print(f"  Is Orthogonal:     {is_orth}")
        print(f"  Preserves Norm:    True")
        print(f"  Preserves Angles:  True")

button_gen.on_click(generate_and_display)
display(button_gen, output_matrix)
generate_and_display()

---
<a id="3"></a>
## üîÑ [3] Visualizing Orthogonal Transformations

Watch how orthogonal transformations preserve geometric relationships while rotating the feature space. **Drag the slider** to change the rotation angle. Notice the dashed lines connecting two reference points stay the same length!

> *Everything moves, but relationships stay true. The mirror spins, but nothing warps.*

In [None]:
# Interactive Rotation Visualization
np.random.seed(42)
n_pts = 80
orig_pts = np.random.randn(n_pts, 2) * 1.2

angle_slider = widgets.FloatSlider(value=45, min=0, max=360, step=1,
    description='Angle:', layout=widgets.Layout(width='500px'),
    style={'description_width': '60px'})
output_rot = widgets.Output()

def draw_rotation(change=None):
    angle = angle_slider.value
    with output_rot:
        clear_output(wait=True)
        rad = np.radians(angle)
        R = np.array([[np.cos(rad), -np.sin(rad)],
                      [np.sin(rad),  np.cos(rad)]])
        transformed = orig_pts @ R.T

        fig, axes = plt.subplots(1, 3, figsize=(14, 4.2))
        for ax in axes:
            ax.set_xlim(-4.5, 4.5); ax.set_ylim(-4.5, 4.5)
            ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
            ax.axhline(0, color='#2a2a3a', linewidth=0.8)
            ax.axvline(0, color='#2a2a3a', linewidth=0.8)

        axes[0].scatter(orig_pts[:,0], orig_pts[:,1], alpha=0.5, s=12, c='#60a5fa')
        axes[0].plot(orig_pts[:2,0], orig_pts[:2,1], '--', color='#60a5fa', alpha=0.5)
        axes[0].set_title('Original', fontsize=10)

        axes[1].scatter(transformed[:,0], transformed[:,1], alpha=0.7, s=12, c='#fb923c')
        axes[1].plot(transformed[:2,0], transformed[:2,1], '--', color='#fb923c', alpha=0.7)
        axes[1].set_title(f'Rotated {angle:.0f} degrees', fontsize=10)

        axes[2].scatter(orig_pts[:,0], orig_pts[:,1], alpha=0.3, s=10, c='#60a5fa', label='Original')
        axes[2].scatter(transformed[:,0], transformed[:,1], alpha=0.6, s=10, c='#fb923c', label='Rotated')
        axes[2].plot(orig_pts[:2,0], orig_pts[:2,1], '--', color='#60a5fa', alpha=0.4)
        axes[2].plot(transformed[:2,0], transformed[:2,1], '--', color='#fb923c', alpha=0.6)
        axes[2].legend(fontsize=8, loc='upper right')
        axes[2].set_title('Overlay -- Structure Preserved', fontsize=10)
        plt.tight_layout()
        plt.show()

        d_orig = np.linalg.norm(orig_pts[0] - orig_pts[1])
        d_trans = np.linalg.norm(transformed[0] - transformed[1])
        print(f"  Distance (original):    {d_orig:.6f}")
        print(f"  Distance (transformed): {d_trans:.6f}")
        print(f"  Preserved: {np.isclose(d_orig, d_trans)}")

angle_slider.observe(draw_rotation, names='value')
display(angle_slider, output_rot)
draw_rotation()

---
<a id="4"></a>
## üé© [4] Hyperspherical Energy Preservation

### The Mad Hatter's Tea Party Problem

**Traditional fine-tuning** is like the Hatter shouting "Move down!" -- everyone shifts chaotically. **OFT** is like rotating the entire table. Everyone maintains their relative positions.

**Run the cell** to compare how different transformations affect hyperspherical energy. Click regenerate for fresh random features.

> *Down here in Wonderland, we don't break what already works -- we just rotate it to see it from a new angle.*

In [None]:
# Hyperspherical Energy Comparison
button_energy = widgets.Button(description='Regenerate', layout=widgets.Layout(width='160px'))
output_energy = widgets.Output()

def compute_hyperspherical_energy(features):
    norms = np.linalg.norm(features, axis=-1, keepdims=True)
    normalized = features / (norms + 1e-8)
    sims = normalized @ normalized.T
    n = features.shape[0]
    mask = 1 - np.eye(n)
    energy = np.sum(np.abs(sims) * mask) / (n * (n - 1))
    return energy

def draw_energy(_=None):
    with output_energy:
        clear_output(wait=True)
        n_samp, n_feat = 50, 8
        orig = np.random.randn(n_samp, n_feat)

        Q = generate_orthogonal_matrix(n_feat)
        oft = orig @ Q.T

        M = np.random.randn(n_feat, n_feat) * 0.3 + np.eye(n_feat)
        rand = orig @ M.T

        A = np.random.randn(4, n_feat) * 0.1
        B = np.random.randn(n_feat, 4) * 0.1
        lora = orig + orig @ A.T @ B.T

        energies = {
            'Original': compute_hyperspherical_energy(orig),
            'OFT': compute_hyperspherical_energy(oft),
            'Random': compute_hyperspherical_energy(rand),
            'LoRA': compute_hyperspherical_energy(lora),
        }
        colors_map = {'Original':'#6b7280', 'OFT':'#60a5fa', 'Random':'#f87171', 'LoRA':'#fbbf24'}

        fig, ax = plt.subplots(figsize=(9, 3.8))
        bars = ax.bar(energies.keys(), energies.values(),
                      color=[colors_map[k] for k in energies], width=0.5, edgecolor='none')
        for bar, (name, val) in zip(bars, energies.items()):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.003,
                    f'{val:.4f}', ha='center', va='bottom', fontsize=9, color='#e4e4e7')
        ax.set_ylabel('Hyperspherical Energy')
        ax.set_title('Energy Comparison -- OFT Preserves Best', fontsize=11, fontweight='bold')
        ax.grid(axis='y', alpha=0.2)
        plt.tight_layout()
        plt.show()

        orig_e = energies['Original']
        for name, e in energies.items():
            delta = abs(e - orig_e)
            marker = 'OK' if delta < 0.01 else '~' if delta < 0.05 else 'X'
            print(f"  [{marker}] {name:10s}: {e:.4f}  (delta = {delta:.4f})")

button_energy.on_click(draw_energy)
display(button_energy, output_energy)
draw_energy()

---
<a id="5"></a>
## üóùÔ∏è [5] The Deeper Rabbit Hole: OSFT

OSFT is all about teaching your model new tricks without it forgetting the old ones. If you've ever fine-tuned a model and watched it suddenly get dumber at stuff it used to know -- that's **catastrophic forgetting**. OSFT is how we fight back.

### The Problem: Eating the Wrong Mushroom

Standard fine-tuning is like Alice eating random mushrooms -- she might grow taller (learn new tasks) but forget how to get back to normal size.

### The Solution: The Caterpillar's Wisdom

*"One side makes you taller, the other makes you shorter."* OSFT identifies **safe directions** where updates won't harm critical knowledge.

**Key Ideas:**
- Break down weight matrices with **SVD** (like putting on X-Ray specs for your model)
- Spot which directions in parameter space are pulling their weight vs. idle
- Keep updates out of the "critical" directions, funnel them into unused space
- End result: new learning without trashing old knowledge

> *"Who are YOU?" said the Caterpillar. OSFT answers: "I'm the same model, just viewing from a different angle -- my core identity (critical subspace) intact."*

---
<a id="6"></a>
## üîç [6] The Cheshire Cat's Grin: SVD Decomposition -- Interactive

SVD is like the Cheshire Cat revealing which parts of itself are essential (the grin -- critical directions) and which can vanish (the body -- safe for modification).

**Drag the rank slider** to see which singular values are protected (red=critical) vs. available for updates (green=safe).

> *The larger the singular value, the more "grin-like" -- essential to the model's identity. Smaller values can fade without losing who the cat really is.*

In [None]:
# Interactive SVD Explorer
rank_slider = widgets.IntSlider(value=2, min=1, max=6, step=1,
    description='Rank cutoff:', layout=widgets.Layout(width='400px'),
    style={'description_width': '100px'})
button_svd = widgets.Button(description='Regenerate', layout=widgets.Layout(width='150px'))
output_svd = widgets.Output()

svd_svals = None

def regenerate_svd(_=None):
    global svd_svals
    base = np.array([3.5, 2.1, 1.3, 0.7, 0.25, 0.05])
    svd_svals = np.sort(base + np.random.uniform(-0.3, 0.3, size=6))[::-1]
    draw_svd()

def draw_svd(change=None):
    if svd_svals is None:
        return
    cutoff = rank_slider.value
    with output_svd:
        clear_output(wait=True)
        n = len(svd_svals)
        colors = ['#f87171' if i < cutoff else '#4ade80' for i in range(n)]

        fig, ax = plt.subplots(figsize=(9, 3.8))
        bars = ax.bar(range(n), svd_svals, color=colors, width=0.6, edgecolor='none')
        for i, (bar, s) in enumerate(zip(bars, svd_svals)):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.08,
                    f'{s:.3f}', ha='center', va='bottom', fontsize=9, color='#e4e4e7')

        ax.axvline(x=cutoff - 0.5, color='#e4e4e7', linestyle='--', linewidth=1.2, alpha=0.6)
        ax.text(cutoff - 0.5, max(svd_svals) * 1.05, '<-- critical | safe -->',
                ha='center', fontsize=8, color='#8888a0')
        ax.set_xticks(range(n))
        ax.set_xticklabels([f's{i+1}' for i in range(n)])
        ax.set_ylabel('Singular Value')
        ax.set_title('SVD: Critical vs Safe Directions', fontsize=11, fontweight='bold')
        ax.grid(axis='y', alpha=0.2)

        crit_patch = mpatches.Patch(color='#f87171', label=f'Critical (top {cutoff})')
        safe_patch = mpatches.Patch(color='#4ade80', label=f'Safe ({n-cutoff} remaining)')
        ax.legend(handles=[crit_patch, safe_patch], fontsize=8, loc='upper right')
        plt.tight_layout()
        plt.show()

        for i, s in enumerate(svd_svals):
            tag = "[CRITICAL]" if i < cutoff else "[SAFE]    "
            print(f"  s{i+1} = {s:.4f}  {tag}")

rank_slider.observe(draw_svd, names='value')
button_svd.on_click(regenerate_svd)
display(widgets.HBox([rank_slider, button_svd]), output_svd)
regenerate_svd()

---
<a id="7"></a>
## üß≠ [7] Finding Your Way: Gradient Projection -- Interactive

*"Would you tell me, please, which way I ought to go from here?" asked Alice.*

The **critical direction** (red dashed) must not be disturbed. The **gray arrow** is where vanilla training would step. OSFT projects that gradient to produce the **safe update** (green) -- orthogonal to the critical direction.

**Drag the sliders** to change the gradient direction and watch the projection update live.

In [None]:
# Interactive Gradient Projection Demo
gx_slider = widgets.FloatSlider(value=0.5, min=-2, max=2, step=0.1,
    description='Gradient X:', layout=widgets.Layout(width='400px'),
    style={'description_width': '90px'})
gy_slider = widgets.FloatSlider(value=0.8, min=-2, max=2, step=0.1,
    description='Gradient Y:', layout=widgets.Layout(width='400px'),
    style={'description_width': '90px'})
output_proj = widgets.Output()

crit_raw = np.array([1.0, 0.5])
crit_n = crit_raw / np.linalg.norm(crit_raw)

def draw_projection(change=None):
    gx, gy = gx_slider.value, gy_slider.value
    grad = np.array([gx, gy])
    dot = np.dot(grad, crit_n)
    proj = dot * crit_n
    safe = grad - proj

    with output_proj:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(7, 6))
        lim = 3
        ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim)
        ax.set_aspect('equal'); ax.grid(True, alpha=0.2)
        ax.axhline(0, color='#2a2a3a', linewidth=0.8)
        ax.axvline(0, color='#2a2a3a', linewidth=0.8)

        t = lim * 1.5
        ax.plot([-crit_n[0]*t, crit_n[0]*t], [-crit_n[1]*t, crit_n[1]*t],
                '--', color='#f87171', linewidth=1.5, alpha=0.6, label='Critical Direction')

        if np.linalg.norm(grad) > 0.05:
            ax.annotate('', xy=grad, xytext=(0,0),
                arrowprops=dict(arrowstyle='->', color='#8888a0', lw=2.2))
            ax.text(grad[0]+0.1, grad[1]+0.1, 'Original', fontsize=8, color='#8888a0')
        if np.linalg.norm(proj) > 0.05:
            ax.annotate('', xy=proj, xytext=(0,0),
                arrowprops=dict(arrowstyle='->', color='#fbbf24', lw=1.5))
            ax.text(proj[0]+0.1, proj[1]-0.15, 'Projection', fontsize=8, color='#fbbf24')
        if np.linalg.norm(safe) > 0.05:
            ax.annotate('', xy=safe, xytext=(0,0),
                arrowprops=dict(arrowstyle='->', color='#4ade80', lw=2.8))
            ax.text(safe[0]+0.1, safe[1]+0.1, 'OSFT Update', fontsize=8, color='#4ade80',
                    fontweight='bold')
        if np.linalg.norm(proj) > 0.05:
            ax.plot([grad[0], safe[0]], [grad[1], safe[1]], ':', color='#fbbf24', alpha=0.4)

        ax.set_title('Gradient Projection -- OSFT in Action', fontsize=11, fontweight='bold')
        ax.legend(fontsize=8, loc='upper left')
        plt.tight_layout()
        plt.show()

        orig_norm = np.linalg.norm(grad)
        safe_norm = np.linalg.norm(safe)
        reduction = ((orig_norm - safe_norm) / orig_norm * 100) if orig_norm > 0 else 0
        print(f"  Original Gradient:         [{gx:.2f}, {gy:.2f}]")
        print(f"  Projection onto Critical:  [{proj[0]:.2f}, {proj[1]:.2f}]")
        print(f"  OSFT Update (safe):        [{safe[0]:.2f}, {safe[1]:.2f}]")
        print(f"  Magnitude Reduction:       {reduction:.1f}%")
        print(f"")
        print(f"  The green arrow = where OSFT directs the update,")
        print(f"  safely away from critical learned features.")

gx_slider.observe(draw_projection, names='value')
gy_slider.observe(draw_projection, names='value')
display(gx_slider, gy_slider, output_proj)
draw_projection()

---
<a id="8"></a>
## üß™ [8] Hands-On: NumPy Warm-Up

Before touching PyTorch, let's warm up. We take a toy weight matrix, run SVD to split it into important vs. not-so-important directions, and project a gradient update into the "safe" zone.

In [None]:
# A toy weight matrix (e.g. from a linear layer)
W = np.array([[2.0, 0.5, 0.0],
              [0.0, 1.5, 0.1],
              [0.0, 0.0, 0.2]])
print("Original weight matrix W:")
print(W)

# Perform SVD decomposition
U, S, Vt = np.linalg.svd(W)
print("\nSingular values:", S)

# Define high-rank vs low-rank subspaces
rank_cutoff = 1  # keep top-1 singular vector as important
U_high = U[:, :rank_cutoff]
V_high = Vt[:rank_cutoff, :].T

# Any gradient update
grad = np.array([[0.1, -0.2, 0.05],
                 [0.05, 0.1, -0.1],
                 [-0.2, 0.0, 0.2]])

# Project gradient onto low-rank subspace (orthogonal to U_high, V_high)
proj = grad - U_high @ (U_high.T @ grad @ V_high) @ V_high.T

print("\nOriginal gradient:")
print(grad)
print("\nProjected gradient (OSFT):")
print(proj)
print("\nThe projected gradient steers clear of the critical directions.")
print("That's OSFT on training wheels.")

### Training Loop Pseudocode

```python
for each training step:
    for each layer l in model:
        W = layer.weight
        U, S, Vt = svd(W)
        r = retention_ratio(layer)  # based on importance
        U_high = U[:, :r]
        V_high = Vt[:r, :].T

        grad = compute_gradient(layer)
        grad_proj = grad - U_high @ (U_high.T @ grad @ V_high) @ V_high.T

        apply_update(layer, grad_proj)
```

---
<a id="9"></a>
## ‚ö° [9] Hands-On: PyTorch OSFT

Time to level up. Here's a small PyTorch model with OSFT-style gradient projection applied. The model takes a step -- but only in the directions we allow.

> *That's how OSFT threads the needle: new learning without wrecking the old foundation.*

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class SmallNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 3, bias=False)
    def forward(self, x):
        return self.fc(x)

model = SmallNet()
x = torch.randn(5, 3)
y = torch.randn(5, 3)

opt = optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.MSELoss()
output = model(x)
loss = loss_fn(output, y)
loss.backward()

# Inspect original gradient
grad = model.fc.weight.grad.detach().numpy()
print("Original gradient:")
print(grad)

# SVD on weights
W = model.fc.weight.detach().numpy()
U, S, Vt = np.linalg.svd(W)
print(f"\nSingular values of W: {S}")

# Keep top-1 singular vector as "critical"
rank_cutoff = 1
U_high = U[:, :rank_cutoff]
V_high = Vt[:rank_cutoff, :].T

# Project gradient into safe subspace
grad_proj = grad - U_high @ (U_high.T @ grad @ V_high) @ V_high.T
model.fc.weight.grad = torch.from_numpy(grad_proj).float()
opt.step()

print("\nProjected gradient (OSFT):")
print(grad_proj)
print("\nUpdated weights (after OSFT projection):")
print(model.fc.weight.data.numpy())
print("\nThe model stepped -- but only in the directions we allowed.")

---
<a id="10"></a>
## üßÆ [10] OFT vs LoRA -- Parameter Efficiency Calculator

**Full Fine-Tuning** rewrites every parameter. **OFT** only needs `rank^2` parameters. **LoRA** needs `rank x (input + output)`. Adjust the values below to see the savings.

In [None]:
# Interactive Parameter Calculator
in_w = widgets.IntText(value=512, description='Input features:', style={'description_width': '120px'})
out_w = widgets.IntText(value=256, description='Output features:', style={'description_width': '120px'})
rank_w = widgets.IntText(value=16, description='Rank:', style={'description_width': '120px'})
output_calc = widgets.Output()

def calc_params(change=None):
    i, o, r = in_w.value, out_w.value, rank_w.value
    total = i * o
    oft = r * r
    lora = r * (i + o)

    with output_calc:
        clear_output(wait=True)
        methods = [
            ('Full Fine-Tuning', total, '#f87171'),
            ('LoRA',             lora,  '#fbbf24'),
            ('OFT',              oft,   '#60a5fa'),
        ]

        fig, ax = plt.subplots(figsize=(9, 2.8))
        names = [m[0] for m in methods]
        counts = [m[1] for m in methods]
        bar_colors = [m[2] for m in methods]

        bars = ax.barh(names, counts, color=bar_colors, height=0.5, edgecolor='none')
        ax.set_xscale('log')
        ax.set_xlabel('Parameters (log scale)')
        ax.set_title('Parameter Count Comparison', fontsize=11, fontweight='bold')
        ax.grid(axis='x', alpha=0.2)

        for bar, count in zip(bars, counts):
            pct = count / total * 100
            ax.text(count * 1.5, bar.get_y() + bar.get_height()/2,
                    f'{count:,}  ({pct:.2f}%)', va='center', fontsize=9, color='#e4e4e7')

        plt.tight_layout()
        plt.show()

        print(f"  Layer: {i} x {o} = {total:,} total parameters\n")
        for name, count, _ in methods:
            pct = count / total * 100
            print(f"  {name:20s}: {count:>10,}  ({pct:>7.2f}%)")
        print(f"\n  OFT Parameter Reduction: {(1 - oft/total)*100:.2f}%")

for w in [in_w, out_w, rank_w]:
    w.observe(calc_params, names='value')
display(widgets.HBox([in_w, out_w, rank_w]), output_calc)
calc_params()

### Side-by-Side Comparison

| | OFT | LoRA | OSFT |
|---|---|---|---|
| **Approach** | Orthogonal rotations | Low-rank additive updates | Orthogonal + subspace protection |
| **Preserves geometry** | Yes | No | Yes |
| **Catastrophic forgetting** | Reduced | Possible | Eliminated |
| **Best for** | Domain adaptation, few-shot | General LLM fine-tuning | Continual learning, multi-task |
| **Parameters** | rank^2 | rank x (in + out) | rank^2 + SVD overhead |

---
<a id="11"></a>
## üìà [11] Training Dynamics Comparison

Watch how **Full Fine-Tuning** (red) starts strong but suffers catastrophic forgetting, while **OSFT** (green) maintains steady, superior performance throughout.

In [None]:
# Training Dynamics Simulation
button_train = widgets.Button(description='Run Simulation', layout=widgets.Layout(width='170px'),
                              button_style='success')
button_reset = widgets.Button(description='Reset', layout=widgets.Layout(width='100px'))
output_train = widgets.Output()

def run_training(_=None):
    with output_train:
        clear_output(wait=True)
        max_steps = 100
        steps = np.arange(max_steps)

        def gen(s, method):
            base = 0.5 + (s / max_steps) * 0.4
            if method == 'oft':  return base + np.sin(s/10) * 0.05
            if method == 'lora': return base + np.sin(s/8) * 0.06 - 0.05
            if method == 'osft': return base + np.sin(s/12) * 0.03 + 0.05
            if method == 'full': return base if s < 60 else base - (s-60)/100
            return base

        methods_cfg = {
            'OFT':     ('#60a5fa', 'oft'),
            'LoRA':    ('#fbbf24', 'lora'),
            'OSFT':    ('#4ade80', 'osft'),
            'Full FT': ('#f87171', 'full'),
        }

        fig, ax = plt.subplots(figsize=(10, 4.5))
        for name, (color, key) in methods_cfg.items():
            data = [gen(s, key) for s in steps]
            ax.plot(steps, data, color=color, linewidth=2.2, label=name)

        ax.set_xlabel('Training Steps')
        ax.set_ylabel('Performance')
        ax.set_title('Training Dynamics -- Fine-Tuning Methods Compared', fontsize=11, fontweight='bold')
        ax.legend(fontsize=9, loc='lower right')
        ax.grid(True, alpha=0.2)
        ax.set_ylim(0.3, 1.05)

        ax.annotate('<-- catastrophic\n    forgetting', xy=(75, gen(75, 'full')),
                    xytext=(82, 0.55), fontsize=8, color='#f87171',
                    arrowprops=dict(arrowstyle='->', color='#f87171', lw=1.2))
        plt.tight_layout()
        plt.show()

        print("  Final performance at step 100:")
        for name, (_, key) in methods_cfg.items():
            val = gen(99, key)
            print(f"     {name:10s}: {val:.3f}")

def reset_training(_=None):
    with output_train:
        clear_output(wait=True)
        print("  Ready -- click Run Simulation")

button_train.on_click(run_training)
button_reset.on_click(reset_training)
display(widgets.HBox([button_train, button_reset]), output_train)
print("  Ready -- click Run Simulation")

---
<a id="12"></a>
## üåç [12] Real-World Applications

| Use Case | Technique | Description |
|---|---|---|
| **Domain Adaptation** | OFT | Fine-tune to related domains without losing general knowledge |
| **Few-Shot Learning** | OFT | Adapt with limited data while maintaining robustness |
| **Continual Learning** | OSFT | Learn new tasks sequentially without forgetting |
| **Enterprise Chatbots** | OSFT | Add new product knowledge without erasing old FAQs |
| **Medical AI** | OSFT | Stay current with research without forgetting fundamentals |
| **Legal Models** | OSFT | Incorporate new regulations while maintaining legal knowledge |

### Benchmark Results
- **Text classification sequences:** Keeps performance steady across 5, 10, 15+ tasks
- **TRACE benchmark:** Boosted LLaMA-2-7B's accuracy by ~7 points over O-LoRA

---
<a id="13"></a>
## üé™ [13] Lessons from Wonderland: Wrap-Up

### The Looking Glass Principle (OFT)
Like spinning a mirror rather than cracking it, orthogonal transformations rotate the feature space without distortion. Alice stays Alice, just viewed from a new angle.

### The Cheshire Grin Strategy (OSFT)
Keep the grin (critical directions), let the body fade (safe subspace). OSFT identifies what's essential and protects it -- no catastrophic forgetting.

### The Mad Hatter's Efficiency
90%+ parameter reduction! Fewer parameters to tune, but the tea party keeps its charm.

### The Queen's Stability
Orthogonal constraints keep gradients well-behaved -- no explosion, no vanishing. Off with gradient chaos!

Think of it like renovating a house. **OSFT adds a new room without tearing down the walls that are already holding up the place.**

---

> *"Curiouser and curiouser!" cried Alice. And indeed -- the deeper you go down this orthogonal rabbit hole, the more elegant the mathematics becomes.*

---

*Created by **Frank La Vigne** -- Interactive Notebook Edition*

*"It's no use going back to yesterday, because I was a different person then." -- But with OSFT, your model remembers who it was!*