# Lab 3.1.2: DoRA Comparison - Solutions

This notebook contains complete solutions for all exercises in the DoRA lab.

## Exercise 1: Understanding Weight Decomposition

**Task**: Implement magnitude and direction decomposition from scratch.

In [None]:
import torch
import torch.nn as nn
import numpy as np

def decompose_weight(W: torch.Tensor) -> tuple:
    """
    Decompose weight matrix into magnitude and direction.
    
    W = m * (W / ||W||_c) = m * d
    
    where:
    - m = magnitude (learned scalar per column)
    - d = direction (unit-normalized columns)
    """
    # Compute column-wise L2 norms (magnitude)
    magnitude = torch.norm(W, p=2, dim=0, keepdim=True)
    
    # Compute direction (normalized weight)
    direction = W / (magnitude + 1e-8)  # Add epsilon for numerical stability
    
    return magnitude, direction

# Test
W = torch.randn(768, 768)
m, d = decompose_weight(W)

# Verify: W ≈ m * d
W_reconstructed = m * d
print(f"Reconstruction error: {torch.mean(torch.abs(W - W_reconstructed)):.8f}")

# Verify: d has unit norm columns
column_norms = torch.norm(d, p=2, dim=0)
print(f"Direction column norms (should be ~1): {column_norms[:5]}")

## Exercise 2: Full DoRA Layer Implementation

In [None]:
class DoRALinear(nn.Module):
    """
    Complete DoRA implementation with proper initialization.
    """
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int = 8,
        alpha: float = 16.0,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.scaling = alpha / rank
        
        # Frozen base weight (simulated - in practice this comes from pretrained model)
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02, requires_grad=False)
        
        # Compute initial magnitude from base weight
        with torch.no_grad():
            initial_magnitude = torch.norm(self.weight, p=2, dim=0, keepdim=True)
        
        # Trainable magnitude parameter (initialized from base weight)
        self.magnitude = nn.Parameter(initial_magnitude.clone())
        
        # LoRA components for direction update
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute direction update from LoRA
        lora_delta = (self.lora_B @ self.lora_A) * self.scaling
        
        # Updated weight direction
        W_updated = self.weight + lora_delta
        
        # Normalize to get direction (column-wise)
        direction = W_updated / (torch.norm(W_updated, p=2, dim=0, keepdim=True) + 1e-8)
        
        # Apply learned magnitude to direction
        W_final = self.magnitude * direction
        
        # Apply dropout and linear transformation
        x = self.dropout(x)
        return x @ W_final.T
    
    def get_trainable_params(self):
        """Count trainable parameters."""
        return sum(p.numel() for p in [self.magnitude, self.lora_A, self.lora_B])

# Test
dora = DoRALinear(768, 768, rank=8)
x = torch.randn(4, 768)
out = dora(x)
print(f"Input: {x.shape}, Output: {out.shape}")
print(f"Trainable params: {dora.get_trainable_params():,}")

## Exercise 3: Gradient Analysis

In [None]:
def compare_gradient_flow():
    """
    Compare gradient magnitudes in LoRA vs DoRA.
    """
    # Standard LoRA
    class LoRALinear(nn.Module):
        def __init__(self, in_f, out_f, rank=8):
            super().__init__()
            self.weight = nn.Parameter(torch.randn(out_f, in_f) * 0.02, requires_grad=False)
            self.lora_A = nn.Parameter(torch.randn(rank, in_f) * 0.01)
            self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
            
        def forward(self, x):
            base = x @ self.weight.T
            lora = x @ self.lora_A.T @ self.lora_B.T
            return base + lora
    
    # Create both
    lora = LoRALinear(768, 768, rank=8)
    dora = DoRALinear(768, 768, rank=8)
    
    # Forward and backward
    x = torch.randn(4, 768)
    target = torch.randn(4, 768)
    
    # LoRA gradients
    lora_out = lora(x)
    lora_loss = ((lora_out - target) ** 2).mean()
    lora_loss.backward()
    
    lora_A_grad = lora.lora_A.grad.norm().item()
    lora_B_grad = lora.lora_B.grad.norm().item()
    
    # DoRA gradients
    dora_out = dora(x)
    dora_loss = ((dora_out - target) ** 2).mean()
    dora_loss.backward()
    
    dora_A_grad = dora.lora_A.grad.norm().item()
    dora_B_grad = dora.lora_B.grad.norm().item()
    dora_m_grad = dora.magnitude.grad.norm().item()
    
    print("Gradient Norms:")
    print(f"  LoRA  A: {lora_A_grad:.6f}, B: {lora_B_grad:.6f}")
    print(f"  DoRA  A: {dora_A_grad:.6f}, B: {dora_B_grad:.6f}, M: {dora_m_grad:.6f}")
    print(f"\nDoRA magnitude gradient provides additional learning signal!")

compare_gradient_flow()

## Exercise 4: Training Comparison on GLUE Task

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import copy

def train_and_compare(base_model_id="google/bert_uncased_L-2_H-128_A-2", epochs=3):
    """
    Train LoRA vs DoRA on SST-2 sentiment classification.
    """
    # Load tiny BERT for quick demo
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    
    # LoRA config
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["query", "value"],
        use_dora=False  # Standard LoRA
    )
    
    # DoRA config (same but with use_dora=True)
    dora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["query", "value"],
        use_dora=True  # Enable DoRA
    )
    
    # Create models
    base_model = AutoModelForSequenceClassification.from_pretrained(
        base_model_id, num_labels=2
    )
    
    lora_model = get_peft_model(copy.deepcopy(base_model), lora_config)
    dora_model = get_peft_model(copy.deepcopy(base_model), dora_config)
    
    print("LoRA trainable params:", lora_model.print_trainable_parameters())
    print("DoRA trainable params:", dora_model.print_trainable_parameters())
    
    # Load SST-2 subset
    dataset = load_dataset("glue", "sst2", split="train[:500]")
    
    def tokenize(examples):
        return tokenizer(
            examples["sentence"],
            padding="max_length",
            truncation=True,
            max_length=128
        )
    
    dataset = dataset.map(tokenize, batched=True)
    dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    
    # Simple training loop
    from torch.utils.data import DataLoader
    from torch.optim import AdamW
    
    def train_model(model, name):
        model.train()
        optimizer = AdamW(model.parameters(), lr=5e-5)
        loader = DataLoader(dataset, batch_size=16, shuffle=True)
        
        losses = []
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in loader:
                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["label"]
                )
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                epoch_loss += loss.item()
            
            avg_loss = epoch_loss / len(loader)
            losses.append(avg_loss)
            print(f"{name} Epoch {epoch+1}: Loss = {avg_loss:.4f}")
        
        return losses
    
    print("\n=== Training LoRA ===")
    lora_losses = train_model(lora_model, "LoRA")
    
    print("\n=== Training DoRA ===")
    dora_losses = train_model(dora_model, "DoRA")
    
    return lora_losses, dora_losses

# Run comparison
# lora_losses, dora_losses = train_and_compare()

## Exercise 5: Visualize Weight Changes

In [None]:
import matplotlib.pyplot as plt

def visualize_weight_decomposition():
    """
    Visualize how DoRA separates magnitude and direction updates.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    # Create sample weights
    torch.manual_seed(42)
    W_original = torch.randn(64, 64) * 0.5
    
    # Simulate LoRA update
    lora_delta = torch.randn(64, 64) * 0.1
    W_lora = W_original + lora_delta
    
    # Simulate DoRA update
    m_original = torch.norm(W_original, p=2, dim=0, keepdim=True)
    m_updated = m_original * (1 + torch.randn(1, 64) * 0.1)  # Magnitude change
    d_original = W_original / (m_original + 1e-8)
    d_updated = d_original + torch.randn(64, 64) * 0.05  # Direction change
    d_updated = d_updated / (torch.norm(d_updated, p=2, dim=0, keepdim=True) + 1e-8)
    W_dora = m_updated * d_updated
    
    # Plot
    axes[0, 0].imshow(W_original.numpy(), cmap='coolwarm', aspect='auto')
    axes[0, 0].set_title('Original W')
    
    axes[0, 1].imshow(W_lora.numpy(), cmap='coolwarm', aspect='auto')
    axes[0, 1].set_title('W after LoRA')
    
    axes[0, 2].imshow(W_dora.numpy(), cmap='coolwarm', aspect='auto')
    axes[0, 2].set_title('W after DoRA')
    
    # Show magnitude changes
    axes[1, 0].bar(range(64), m_original.squeeze().numpy(), alpha=0.7, label='Original')
    axes[1, 0].set_title('Column Magnitudes (Original)')
    axes[1, 0].set_xlabel('Column')
    
    # LoRA magnitude change
    m_lora = torch.norm(W_lora, p=2, dim=0)
    axes[1, 1].bar(range(64), m_lora.numpy(), alpha=0.7, color='orange')
    axes[1, 1].set_title('Magnitudes after LoRA (uncontrolled)')
    axes[1, 1].set_xlabel('Column')
    
    # DoRA magnitude change
    axes[1, 2].bar(range(64), m_updated.squeeze().numpy(), alpha=0.7, color='green')
    axes[1, 2].set_title('Magnitudes after DoRA (explicitly learned)')
    axes[1, 2].set_xlabel('Column')
    
    plt.tight_layout()
    plt.savefig('dora_visualization.png', dpi=150)
    plt.show()
    
    print("\nKey insight: DoRA gives explicit control over magnitude!")

visualize_weight_decomposition()

## Key Insights from Solutions

1. **Weight Decomposition**: W = m × d where m is magnitude and d is direction
2. **DoRA Advantage**: Separate learning of magnitude provides additional gradient signal
3. **Training Stability**: DoRA often shows smoother loss curves
4. **Performance**: ~3.7 point improvement on commonsense reasoning tasks