# 🔄 Orthogonal Fine-Tuning (OFT) Tutorial

## Through the Looking Glass: An Introduction

Just as Alice stepped through the looking glass and found a world that was familiar yet rearranged, Orthogonal Fine-Tuning (OFT) adapts pre-trained models by **rotating** their learned representations rather than distorting them.

OFT is a parameter-efficient fine-tuning (PEFT) technique that applies **orthogonal transformations** to weight matrices—think of it as spinning the looking glass to see your model's knowledge from a new perspective, without warping the reflection. Unlike other PEFT methods like LoRA, OFT preserves the hyperspherical energy – the geometric relationships between neuron activations – ensuring stable and efficient adaptation.

## 1. 🔮 The Magic Mirror Properties: Understanding Orthogonal Matrices

An orthogonal matrix Q is like a magic mirror with special properties. It satisfies: **Q<sup>T</sup>Q = QQ<sup>T</sup> = I**

### Key properties of the looking glass:
- **Preserves distances:** Alice stays the same size (||Qx|| = ||x||)
- **Preserves angles:** The Cheshire Cat's grin keeps its shape
- **Identity when transposed:** The mirror reflects perfectly (Q<sup>T</sup>Q = I)
- **Represents rotations and reflections:** Turn the mirror, don't bend it
- **Determinant is ±1:** The mirror's magic constant

In [10]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.stats import special_ortho_group

# Generate a random orthogonal matrix
def generate_orthogonal_matrix(dim):
    """Generate a random orthogonal matrix using QR decomposition"""
    random_matrix = np.random.randn(dim, dim)
    q, r = np.linalg.qr(random_matrix)
    return q

# Verify orthogonality
Q = generate_orthogonal_matrix(4)
print("Random orthogonal matrix Q:")
print(Q)
print("\nQ^T @ Q (should be identity):")
print(np.round(Q.T @ Q, 4))
print("\nDeterminant (should be ±1):", np.linalg.det(Q))

Random orthogonal matrix Q:
[[-0.60507124  0.67209757  0.38633691  0.18143165]
 [-0.24292475 -0.64775666  0.7198619   0.05654843]
 [-0.74974867 -0.30378398 -0.50238365 -0.30529147]
 [-0.11293046 -0.19081788 -0.28311358  0.9331034 ]]

Q^T @ Q (should be identity):
[[ 1.  0.  0.  0.]
 [ 0.  1. -0.  0.]
 [ 0. -0.  1. -0.]
 [ 0.  0. -0.  1.]]

Determinant (should be ±1): -0.9999999999999996


## 2. 🪞 Visualizing Orthogonal Transformations

Let's visualize how orthogonal transformations preserve geometric relationships while rotating the feature space.

Watch how the points rotate together, like Wonderland itself spinning—everything moves, but relationships stay true. The mirror spins, but nothing warps!

In [None]:
# Create sample 2D data
np.random.seed(42)
n_points = 100
original_data = np.random.randn(n_points, 2)

# Create a 2D rotation matrix (orthogonal)
theta = np.pi / 4  # 45 degrees
rotation_matrix = np.array([
    [np.cos(theta), -np.sin(theta)],
    [np.sin(theta), np.cos(theta)]
])

# Apply orthogonal transformation
transformed_data = original_data @ rotation_matrix.T

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original data
axes[0].scatter(original_data[:, 0], original_data[:, 1], alpha=0.6)
axes[0].set_title('Original Data')
axes[0].set_xlim(-4, 4)
axes[0].set_ylim(-4, 4)
axes[0].grid(True, alpha=0.3)
axes[0].set_aspect('equal')

# Transformed data
axes[1].scatter(transformed_data[:, 0], transformed_data[:, 1], alpha=0.6, color='orange')
axes[1].set_title('After Orthogonal Transformation')
axes[1].set_xlim(-4, 4)
axes[1].set_ylim(-4, 4)
axes[1].grid(True, alpha=0.3)
axes[1].set_aspect('equal')

# Overlay comparison
axes[2].scatter(original_data[:, 0], original_data[:, 1], alpha=0.6, label='Original')
axes[2].scatter(transformed_data[:, 0], transformed_data[:, 1], alpha=0.6, color='orange', label='Transformed')
axes[2].set_title('Comparison: Rotation Preserves Structure')
axes[2].set_xlim(-4, 4)
axes[2].set_ylim(-4, 4)
axes[2].grid(True, alpha=0.3)
axes[2].set_aspect('equal')
axes[2].legend()

plt.tight_layout()
plt.show()

# Verify preservation of distances
original_distances = np.linalg.norm(original_data[0] - original_data[1])
transformed_distances = np.linalg.norm(transformed_data[0] - transformed_data[1])
print(f"Distance between first two points:")
print(f"  Original: {original_distances:.4f}")
print(f"  Transformed: {transformed_distances:.4f}")
print(f"  Preserved: {np.isclose(original_distances, transformed_distances)}")

## 3. 🎭 The Looking Glass Formula: OFT Core Implementation

The key insight of OFT is to parameterize weight updates as:

**W' = W × R**

Where:
- **W** is the original pre-trained weight matrix (the scene in the mirror)
- **R** is an orthogonal matrix learned during fine-tuning (the rotation of the looking glass)
- **W'** is the adapted weight matrix (the new view through the rotated mirror)

The orthogonal mirror **R** rotates (not reshapes) your model's knowledge, preserving all the geometric relationships that were learned during pre-training.

In [None]:
class OFTLayer(nn.Module):
    """
    Orthogonal Fine-Tuning layer implementation.
    Applies an orthogonal transformation to the weight matrix.
    """
    def __init__(self, in_features, out_features, rank=16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = min(rank, min(in_features, out_features))
        
        # Pre-trained weights (frozen)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.weight.requires_grad = False
        
        # Orthogonal transformation parameters
        # We use Cayley parameterization for stable orthogonal updates
        self.cayley_map = nn.Parameter(torch.zeros(self.rank, self.rank))
        
    def get_orthogonal_matrix(self):
        """
        Compute orthogonal matrix using Cayley transform:
        R = (I - A)(I + A)^(-1)
        where A is skew-symmetric
        """
        # Make skew-symmetric
        A = self.cayley_map - self.cayley_map.t()
        I = torch.eye(self.rank, device=A.device)
        
        # Cayley transform
        R = torch.linalg.solve(I + A, I - A)
        return R
    
    def forward(self, x):
        # Get orthogonal transformation
        R = self.get_orthogonal_matrix()
        
        # For simplicity, apply R to a subspace of the weight matrix
        # In practice, this would be more sophisticated
        W_adapted = self.weight.clone()
        W_adapted[:self.rank, :self.rank] = self.weight[:self.rank, :self.rank] @ R
        
        return x @ W_adapted.t()

# Example usage
oft_layer = OFTLayer(128, 64, rank=16)
input_tensor = torch.randn(32, 128)  # batch_size=32, in_features=128
output = oft_layer(input_tensor)
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output.shape}")
print(f"Trainable parameters: {sum(p.numel() for p in oft_layer.parameters() if p.requires_grad)}")
print(f"Total parameters: {sum(p.numel() for p in oft_layer.parameters())}")

## 4. 📊 Comparing OFT with LoRA

Let's implement a simple comparison between OFT and LoRA to understand their differences.

While both methods achieve parameter efficiency, they work in fundamentally different ways:
- **OFT** rotates the feature space (multiplicative, preserves geometry)
- **LoRA** adds low-rank updates (additive, more flexible but can distort relationships)

In [None]:
class LoRALayer(nn.Module):
    """
    Low-Rank Adaptation (LoRA) layer for comparison.
    """
    def __init__(self, in_features, out_features, rank=16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        
        # Pre-trained weights (frozen)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.weight.requires_grad = False
        
        # LoRA parameters
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
    def forward(self, x):
        # Original transformation + low-rank update
        return x @ self.weight.t() + x @ self.lora_A.t() @ self.lora_B.t()

# Compare parameter efficiency
in_features, out_features = 512, 256
rank = 16

oft_layer = OFTLayer(in_features, out_features, rank)
lora_layer = LoRALayer(in_features, out_features, rank)

oft_params = sum(p.numel() for p in oft_layer.parameters() if p.requires_grad)
lora_params = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
total_params = in_features * out_features

print("Parameter Efficiency Comparison:")
print(f"Original layer parameters: {total_params:,}")
print(f"OFT trainable parameters: {oft_params:,} ({oft_params/total_params*100:.2f}%)")
print(f"LoRA trainable parameters: {lora_params:,} ({lora_params/total_params*100:.2f}%)")

## 5. 🎩 The Mad Hatter's Tea Party: Hyperspherical Energy Preservation

A key advantage of OFT is preserving **hyperspherical energy**, which maintains the angular relationships between features.

### The Mad Hatter's Tea Party Problem

Imagine you're at the Mad Hatter's tea party, and everyone's seat represents a learned feature in your model:

- **Traditional fine-tuning** is like the Hatter shouting "Move down!" - everyone shifts chaotically, and suddenly the March Hare is sitting where the Dormouse should be. The entire seating arrangement (your model's learned relationships) gets scrambled.

- **OFT is different.** It's like rotating the entire table instead of moving individual seats. Everyone maintains their relative positions—if Alice was between the Hatter and the Hare before, she still is after the rotation. The *relationships* stay intact.

By preserving angular relationships (hyperspherical energy), OFT maintains the semantic structure learned during pre-training. Features that were similar before fine-tuning remain similar after. This leads to more stable training and better generalization!

💫 *Down here in Wonderland, we don't break what already works - we just rotate it to see it from a new angle.*

In [None]:
def compute_hyperspherical_energy(features):
    """
    Compute hyperspherical energy as the sum of pairwise angular similarities.
    Lower energy indicates more uniform distribution on the hypersphere.
    """
    # Normalize features to unit sphere
    normalized = features / (torch.norm(features, dim=-1, keepdim=True) + 1e-8)
    
    # Compute pairwise cosine similarities
    similarities = torch.matmul(normalized, normalized.t())
    
    # Hyperspherical energy (Riesz s-energy with s=1)
    n = features.shape[0]
    mask = 1 - torch.eye(n)
    energy = -torch.sum(torch.log(1 - similarities + 1e-8) * mask) / (n * (n - 1))
    
    return energy

# Generate sample features
torch.manual_seed(42)
n_samples = 50
n_features = 64

# Original features
original_features = torch.randn(n_samples, n_features)

# Apply different transformations
# 1. Orthogonal transformation (OFT-style)
Q = torch.tensor(generate_orthogonal_matrix(n_features), dtype=torch.float32)
oft_features = original_features @ Q.t()

# 2. Random transformation (non-orthogonal)
random_matrix = torch.randn(n_features, n_features) * 0.5 + torch.eye(n_features)
random_features = original_features @ random_matrix.t()

# 3. Low-rank update (LoRA-style)
A = torch.randn(16, n_features) * 0.1
B = torch.randn(n_features, 16) * 0.1
lora_features = original_features + original_features @ A.t() @ B.t()

# Compute energies
original_energy = compute_hyperspherical_energy(original_features)
oft_energy = compute_hyperspherical_energy(oft_features)
random_energy = compute_hyperspherical_energy(random_features)
lora_energy = compute_hyperspherical_energy(lora_features)

print("Hyperspherical Energy Comparison:")
print(f"Original features: {original_energy:.4f}")
print(f"After OFT (orthogonal): {oft_energy:.4f} (change: {abs(oft_energy - original_energy):.4f})")
print(f"After random transform: {random_energy:.4f} (change: {abs(random_energy - original_energy):.4f})")
print(f"After LoRA update: {lora_energy:.4f} (change: {abs(lora_energy - original_energy):.4f})")
print("\nOFT preserves hyperspherical energy best!")

## 6. ⚙️ Practical OFT Module for Transformer Models

Let's implement a more practical OFT module that can be integrated into transformer-based models.

This production-ready module uses a **block-diagonal structure** for efficiency—like having multiple smaller looking glasses instead of one giant mirror. Each block can rotate independently, making the computation more tractable while maintaining the orthogonal properties.

In [None]:
class OFTModule(nn.Module):
    """
    Production-ready OFT module with block-diagonal structure for efficiency.
    """
    def __init__(self, original_module, rank=16, num_blocks=4, alpha=1.0):
        super().__init__()
        self.original_module = original_module
        self.rank = rank
        self.num_blocks = num_blocks
        self.alpha = alpha  # scaling for the adaptation

        if not isinstance(original_module, nn.Linear):
            raise ValueError("OFT currently supports only Linear layers")
        self.in_features = original_module.in_features
        self.out_features = original_module.out_features

        # Choose block size
        self.block_size = min(rank, min(self.in_features, self.out_features) // num_blocks)

        # Trainable skew-symmetric seeds for each block (square)
        self.blocks = nn.ParameterList([
            nn.Parameter(torch.zeros(self.block_size, self.block_size))
            for _ in range(num_blocks)
        ])

        # Freeze base weights/bias
        for p in self.original_module.parameters():
            p.requires_grad = False

    def get_block_diagonal_orthogonal(self):
        """
        Differentiable block-diagonal orthogonal matrix via Cayley transform.
        """
        device = self.blocks[0].device
        I = torch.eye(self.block_size, device=device)

        per_block_R = []
        for B in self.blocks:
            A = B - B.t()                       # skew-symmetric
            # (I - A)(I + A)^{-1} is also common; either is fine if consistent.
            block_R = torch.linalg.solve(I + A, I - A)
            per_block_R.append(block_R)

        # Differentiable assembly (no in-place writes)
        R = torch.block_diag(*per_block_R)      # shape: (k, k)
        return R

    def forward(self, x):
        # Original output (constant w.r.t. OFT params)
        original_output = self.original_module(x)

        # Build rotation
        R = self.get_block_diagonal_orthogonal()
        k = R.shape[0]

        # Base weight (constant path); DO NOT detach here—let autograd trace through the parts that depend on R
        W_base = self.original_module.weight    # [out_features, in_features]

        # Top-left block adapted (differentiable wrt R)
        W_tl = W_base[:k, :k]                   # view (read-only)
        W_tl_adapted = W_tl @ R                 # depends on R (grad flows)

        # Reconstruct full adapted weight without in-place writes or aliasing
        # Top row: [TL_adapted | TR_const]
        top_right = W_base[:k, k:]
        top = torch.cat([W_tl_adapted, top_right], dim=1)

        # Bottom rows unchanged
        bottom = W_base[k:, :]

        adapted_weight = torch.cat([top, bottom], dim=0)

        # Compute adapted output with differentiable weight
        adapted_output = nn.functional.linear(x, adapted_weight, self.original_module.bias)

        # Blend
        return (1 - self.alpha) * original_output + self.alpha * adapted_output


## 7. 🧪 Training Example with OFT

Let's demonstrate how to train a model using OFT for a simple classification task.

In this example, we'll create a simple classifier and compare a regular model with an OFT-adapted version. Notice how OFT achieves dramatic parameter reduction (often 90%+) while maintaining the model's learning capacity!

In [None]:
# Create a simple dataset
from torch.utils.data import TensorDataset, DataLoader

# Generate synthetic classification data
torch.manual_seed(42)
n_samples = 1000
n_features = 64
n_classes = 10

X = torch.randn(n_samples, n_features)
y = torch.randint(0, n_classes, (n_samples,))

# Split into train and test
train_size = int(0.8 * n_samples)
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# Create dataloaders
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Define a simple model with OFT adaptation
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, use_oft=False):
        super().__init__()
        self.use_oft = use_oft
        
        # Create layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        
        # Apply OFT if requested
        if use_oft:
            self.fc1 = OFTModule(self.fc1, rank=8, num_blocks=2)
            self.fc2 = OFTModule(self.fc2, rank=8, num_blocks=2)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create two models: one with OFT, one without
model_oft = SimpleClassifier(n_features, 128, n_classes, use_oft=True)
model_regular = SimpleClassifier(n_features, 128, n_classes, use_oft=False)

# Count parameters
oft_trainable = sum(p.numel() for p in model_oft.parameters() if p.requires_grad)
regular_trainable = sum(p.numel() for p in model_regular.parameters() if p.requires_grad)

print("Model Comparison:")
print(f"Regular model trainable parameters: {regular_trainable:,}")
print(f"OFT model trainable parameters: {oft_trainable:,}")
print(f"Parameter reduction: {(1 - oft_trainable/regular_trainable)*100:.1f}%")

In [None]:
# Training function
def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    test_accuracies = []

    
    for epoch in range(epochs):
        # Training
        model.train()
        epoch_loss = 0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        train_losses.append(epoch_loss / len(train_loader))
        
        # Testing
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                outputs = model(X_batch)
                _, predicted = torch.max(outputs, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
        
        accuracy = correct / total
        test_accuracies.append(accuracy)
        
        if (epoch + 1) % 2 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {train_losses[-1]:.4f}, Test Acc: {accuracy:.4f}")
    
    return train_losses, test_accuracies

# Train OFT model
print("Training OFT Model:")
oft_losses, oft_accs = train_model(model_oft, train_loader, test_loader, epochs=10)

# For comparison, we could also train the regular model
# Note: In practice, OFT is used for fine-tuning pre-trained models

## 8. 🌹 The Queen's Decree: Advantages and Use Cases of OFT

### Key Advantages

1. **Stability** 👑 - "All ways are MY ways!" Orthogonal constraints keep gradients well-behaved—no explosion, no vanishing. The Queen's rules (mathematical constraints) ensure orderly training.

2. **Parameter Efficiency** 🎩 - "Why is a raven like a writing desk?" Because both OFT achieves 90%+ parameter reduction! Fewer parameters to tune, but the tea party keeps its charm.

3. **Geometric Preservation** 🪞 - The looking glass principle: rotations preserve the learned feature relationships, maintaining the model's core understanding.

4. **Better Generalization** 💫 - Less prone to overfitting on small datasets because we preserve the robust representations learned during pre-training.

### Best Use Cases

- **Domain Adaptation** 🎯 - When fine-tuning to a related but different domain without losing general knowledge
- **Few-Shot Learning** 🔬 - When training data is limited and you need to maintain robustness
- **Multi-Task Learning** 📚 - Preserving shared representations across tasks
- **Continual Learning** 🤖 - Reducing catastrophic forgetting (more on this with OSFT!)

In [None]:
# Demonstration: Stability comparison
def analyze_gradient_flow(model, input_data, target):
    """
    Analyze gradient magnitudes through the network.
    """
    model.train()
    criterion = nn.CrossEntropyLoss()
    
    # Forward pass
    output = model(input_data)
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    
    # Collect gradient norms
    gradient_norms = []
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            grad_norm = param.grad.norm().item()
            gradient_norms.append(grad_norm)
            print(f"{name}: gradient norm = {grad_norm:.6f}")
    
    return gradient_norms

# Create test data
test_input = torch.randn(32, n_features)
test_target = torch.randint(0, n_classes, (32,))

print("Gradient Flow Analysis for OFT Model:")
print("="*50)
grad_norms = analyze_gradient_flow(model_oft, test_input, test_target)
print(f"\nAverage gradient norm: {np.mean(grad_norms):.6f}")
print(f"Gradient variance: {np.var(grad_norms):.6f}")
print("\nNote: OFT maintains stable gradients due to orthogonal constraints!")

## 🎪 Summary: Lessons from Wonderland

Orthogonal Fine-Tuning (OFT) represents a significant advancement in parameter-efficient fine-tuning:

### 🪞 The Looking Glass Principle
Like spinning a mirror rather than cracking it, orthogonal transformations rotate the feature space without distortion. Alice stays Alice, just viewed from a new angle—geometric relationships preserved perfectly.

### 🎩 The Mad Hatter's Efficiency  
Achieve 90%+ parameter reduction while maintaining model capacity! More efficient than full fine-tuning, just as effective. The orthogonal constraint ensures that fine-tuning rotates the feature space rather than arbitrarily distorting it.

### 🌹 The Queen's Stability
Orthogonal constraints ensure orderly, stable training compared to unconstrained methods. Off with gradient chaos!

### 💫 Better for What Matters
OFT is particularly well-suited for tasks requiring preservation of learned features—domain adaptation, few-shot learning, and scenarios where you can't afford to forget what the model already knows.

🐇 *"Curiouser and curiouser!" cried Alice. And indeed—the deeper you go down this orthogonal rabbit hole, the more elegant the mathematics becomes.*