# Main Experiment

The Orchestrator (Run experiments & Visualize here)

In [None]:
import torch
import time
import pandas as pd
import matplotlib.pyplot as plt

from src.data_utils import get_split_cifar10
from src.models import get_resnet18
from src.trainer import train_baseline, train_constrained, evaluate
from src.decompositions import SVDProjector, QRProjector, RSVDProjector, NMFProjector

# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS_A = 15 # [cite: 26]
EPOCHS_B = 10
BATCH_SIZE = 64

# --- PHASE 1: INFRASTRUCTURE [cite: 3] ---
print("--- Phase 1: Data Setup ---")
train_A, test_A, train_B, test_B = get_split_cifar10(BATCH_SIZE)

# --- PHASE 2: BASELINE EXPERT [cite: 18] ---
print("--- Phase 2: Training Expert Model (Task A) ---")
model_expert = get_resnet18(num_classes=10).to(DEVICE)
model_expert = train_baseline(model_expert, train_A, EPOCHS_A, DEVICE)

# Checkpoint the expert [cite: 29]
torch.save(model_expert.state_dict(), "checkpoints/task_a_expert.pth")
acc_A_initial = evaluate(model_expert, test_A, DEVICE)
print(f"Task A Expert Accuracy: {acc_A_initial:.2f}%")

# --- PHASE 3 & 4: EXPERIMENT LOOP [cite: 31, 58] ---
results = []
methods = [
    ("Naive", None),
    ("SVD", SVDProjector()), 
    ("QR", QRProjector()), 
    ("RSVD", RSVDProjector()), 
    ("NMF", NMFProjector())
]

print("--- Phase 3 & 4: Running Experiments ---")
for name, projector in methods:
    print(f"\n--- Experiment: {name} ---")
    
    # Reload Expert Weights [cite: 68]
    model_current = get_resnet18(num_classes=10).to(DEVICE)
    model_current.load_state_dict(torch.load("checkpoints/task_a_expert.pth"))
    
    prep_time = 0
    if projector is not None:
        # Phase 3: Subspace Extraction [cite: 31]
        start_time = time.time()
        projector.compute_subspaces(model_current)
        prep_time = time.time() - start_time
        print(f"Subspace computed in {prep_time:.2f}s")
        
        # Phase 4: Constrained Training [cite: 58]
        model_current = train_constrained(model_current, train_B, EPOCHS_B, DEVICE, projector)
    else:
        # Naive Baseline (Just train without constraints)
        model_current = train_baseline(model_current, train_B, EPOCHS_B, DEVICE)
        
    # Phase 5: Metrics [cite: 79]
    acc_A_final = evaluate(model_current, test_A, DEVICE) # Retention
    acc_B_final = evaluate(model_current, test_B, DEVICE) # Learning
    forgetting = acc_A_initial - acc_A_final
    
    results.append({
        "Method": name,
        "Task A Acc (Retention)": acc_A_final,
        "Task B Acc (Plasticity)": acc_B_final,
        "Forgetting": forgetting,
        "Prep Time (s)": prep_time
    })

# --- PHASE 5: REPORTING [cite: 84] ---
df = pd.DataFrame(results)
print("\n--- Final Results ---")
print(df)

# Visualization [cite: 87]
df.plot(x="Method", y=["Task A Acc (Retention)", "Forgetting"], kind="bar")
plt.title("Catastrophic Forgetting Analysis")
plt.show()