# Tier 0 — Baseline Training + Spectral Introspection (PACS, ResNet-18)

In [None]:
# ===== Colab / Drive setup =====
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/SoMA_PACS"
import os
os.makedirs(PROJECT_ROOT, exist_ok=True)


Mounted at /content/drive


In [None]:
!pip -q install datasets huggingface_hub pillow


In [None]:
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm


In [None]:
import shutil, os

PACS_ROOT = "/content/drive/MyDrive/DG_PACS/PACS"
assert "PACS" in PACS_ROOT
if not os.path.exists(PACS_ROOT):
    os.makedirs(PACS_ROOT, exist_ok=True)
    print("Created PACS folder.")
else:
    print("PACS folder exists. Skipping deletion.")



PACS folder exists. Skipping deletion.


In [None]:
from torchvision import datasets
import os

for d in ["photo","art_painting","cartoon","sketch"]:
    ds = datasets.ImageFolder(os.path.join(PACS_ROOT, d))
    print(d, ds.class_to_idx)
    assert len(ds.class_to_idx) == 7, f"{d} does not have 7 classes!"



photo {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
art_painting {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
cartoon {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
sketch {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}


In [None]:
from torchvision import datasets
import os

def get_class_to_idx(domain):
    ds = datasets.ImageFolder(root=os.path.join(PACS_ROOT, domain))
    return ds.class_to_idx

mappings = {d: get_class_to_idx(d) for d in ["photo","art_painting","cartoon","sketch"]}
for d, m in mappings.items():
    print(d, m)

# Compare
base = mappings["photo"]
for d in mappings:
    print(d, "matches photo:", mappings[d] == base)


photo {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
art_painting {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
cartoon {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
sketch {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
photo matches photo: True
art_painting matches photo: True
cartoon matches photo: True
sketch matches photo: True


In [None]:
# ===== Reproducibility =====
import random, numpy as np, torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ===== Imports =====
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm


In [None]:
# ===== PACS paths (edit if needed) =====
# PACS_ROOT = "/content/drive/MyDrive/datasets/PACS"

SOURCE_DOMAINS = ["photo", "art_painting", "cartoon"]
TARGET_DOMAIN = "sketch"

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2


In [None]:
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])


In [None]:
def make_loader(domain, tf, shuffle):
    ds = datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=tf
    )
    return DataLoader(ds, batch_size=BATCH_SIZE,
                      shuffle=shuffle, num_workers=NUM_WORKERS)

train_loaders = [make_loader(d, train_tf, True) for d in SOURCE_DOMAINS]
test_loader = make_loader(TARGET_DOMAIN, test_tf, False)


In [None]:
class ResNet18_FeatureHook(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet18(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(512, num_classes)

        self._features = {}

        def make_hook(name):
            def hook(module, inp, out):
                self._features[name] = out
            return hook

        # register hooks: BN output before ReLU
        for lname in ["layer2", "layer3", "layer4"]:
            layer = getattr(self.backbone, lname)
            for i, block in enumerate(layer):
                block.bn2.register_forward_hook(
                    make_hook(f"{lname}.{i}.bn2")
                )

    def forward(self, x):
        self._features = {}
        return self.backbone(x)


# Training loop (Tier-0 backbone)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet18_FeatureHook(num_classes=7).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 192MB/s]


In [None]:
def train_epoch(loaders):
    model.train()
    total_loss, total_correct, total = 0, 0, 0

    for loader in loaders:
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            total_correct += (logits.argmax(1) == y).sum().item()
            total += x.size(0)

    return total_loss / total, total_correct / total


In [None]:
@torch.no_grad()
def eval_epoch(loader):
    model.eval()
    total_correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_correct / total


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
EPOCHS = 15
for ep in range(EPOCHS):
    tr_loss, tr_acc = train_epoch(train_loaders)
    te_acc = eval_epoch(test_loader)

    print(f"[{ep:02d}] loss={tr_loss:.3f} | src_acc={tr_acc:.3f} | tgt_acc={te_acc:.3f}")


[00] loss=0.386 | src_acc=0.874 | tgt_acc=0.661
[01] loss=0.213 | src_acc=0.929 | tgt_acc=0.632
[02] loss=0.144 | src_acc=0.953 | tgt_acc=0.680
[03] loss=0.123 | src_acc=0.958 | tgt_acc=0.652
[04] loss=0.085 | src_acc=0.972 | tgt_acc=0.725
[05] loss=0.058 | src_acc=0.981 | tgt_acc=0.718
[06] loss=0.105 | src_acc=0.967 | tgt_acc=0.733
[07] loss=0.063 | src_acc=0.982 | tgt_acc=0.708
[08] loss=0.061 | src_acc=0.981 | tgt_acc=0.688
[09] loss=0.074 | src_acc=0.976 | tgt_acc=0.754
[10] loss=0.031 | src_acc=0.990 | tgt_acc=0.722
[11] loss=0.028 | src_acc=0.991 | tgt_acc=0.732
[12] loss=0.063 | src_acc=0.979 | tgt_acc=0.699
[13] loss=0.044 | src_acc=0.986 | tgt_acc=0.708
[14] loss=0.038 | src_acc=0.987 | tgt_acc=0.736


Save trained backbone

In [None]:
BACKBONE_PATH = f"{PROJECT_ROOT}/resnet18_pacs_base.pt"

torch.save(model.state_dict(), BACKBONE_PATH)
print("Saved PACS-trained backbone.")


Saved PACS-trained backbone.


# TIER 0 — SPECTRAL INTROSPECTION

In [None]:
import numpy.linalg as LA

def conv_svd(conv):
    # conv.weight: [C_out, C_in, k, k]
    W = conv.weight.detach().cpu().numpy()
    Cout = W.shape[0]
    Wmat = W.reshape(Cout, -1)
    U, S, Vt = LA.svd(Wmat, full_matrices=False)
    return U, S


In [None]:
spectra = {}

for lname in ["layer2", "layer3", "layer4"]:
    for i, block in enumerate(getattr(model.backbone, lname)):
        conv = block.conv2
        U, S = conv_svd(conv)
        spectra[f"{lname}.{i}"] = S


In [None]:
import pickle

with open(f"{PROJECT_ROOT}/tier0_spectra.pkl", "wb") as f:
    pickle.dump(spectra, f)


# Tier-0 Feature Statistics (Energy by Spectral Rank)

In [None]:
@torch.no_grad()
def collect_feature_energy(loader):
    model.eval()
    energy = {}

    for x, _ in loader:
        x = x.to(device)
        _ = model(x)

        for k, feat in model._features.items():
            # GAP → channel vector
            h = feat.mean(dim=[2,3])  # [B, C]
            e = (h**2).mean(0).cpu().numpy()
            energy.setdefault(k, []).append(e)

    for k in energy:
        energy[k] = np.mean(energy[k], axis=0)

    return energy


In [None]:
src_energy = collect_feature_energy(train_loaders[0])
tgt_energy = collect_feature_energy(test_loader)

with open(f"{PROJECT_ROOT}/tier0_feature_energy.pkl", "wb") as f:
    pickle.dump({"src": src_energy, "tgt": tgt_energy}, f)


Saved to Drive

- tier0_spectra.pkl → singular values per conv layer

- tier0_feature_energy.pkl → channel-wise activation energy (src vs tgt)

- trained baseline checkpoint (implicitly in model state)



Scientific baseline

- no spectral intervention

- no leakage

- fixed feature definition

- BN explicitly visible

# =============================================================================
# TIER 1 - SCENARIO 1: Probing the PRETRAINED ImageNet Model

This tests whether the pretrained model (before any PACS finetuning) already has spectral structure that separates domain from class information.


In [None]:
# -----------------------------------------------------------------------------
# Cell 0: Re-define the Model Class with ALL Layer Hooks (Run this first!)
# -----------------------------------------------------------------------------
# This ensures hooks are registered for ALL layers including layer1

class ResNet18_FeatureHook(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet18(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(512, num_classes)

        self._features = {}

        def make_hook(name):
            def hook(module, inp, out):
                self._features[name] = out
            return hook

        # Register hooks for ALL layers (layer1, layer2, layer3, layer4)
        for lname in ["layer1", "layer2", "layer3", "layer4"]:
            layer = getattr(self.backbone, lname)
            for i, block in enumerate(layer):
                block.bn2.register_forward_hook(
                    make_hook(f"{lname}.{i}.bn2")
                )

    def forward(self, x):
        self._features = {}  # Clear before each forward pass
        return self.backbone(x)

print("ResNet18_FeatureHook class defined with hooks for layer1, layer2, layer3, layer4")

ResNet18_FeatureHook class defined with hooks for layer1, layer2, layer3, layer4


In [None]:
# -----------------------------------------------------------------------------
# Cell 1: Load Fresh Pretrained Model and Verify Hooks
# -----------------------------------------------------------------------------

print("Loading fresh pretrained ResNet-18 (ImageNet weights only)...")

pretrained_model = ResNet18_FeatureHook(num_classes=7).to(device)

# Reset fc layer (doesn't affect feature extraction)
nn.init.xavier_uniform_(pretrained_model.backbone.fc.weight)
nn.init.zeros_(pretrained_model.backbone.fc.bias)

pretrained_model.eval()

# VERIFY: Run a test forward pass and check that hooks are working
print("\nVerifying hooks are registered correctly...")
test_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    _ = pretrained_model(test_input)

print(f"Keys in _features after forward pass: {sorted(pretrained_model._features.keys())}")

# Check we have all expected keys
expected_keys = [f"{l}.{i}.bn2" for l in ["layer1", "layer2", "layer3", "layer4"] for i in [0, 1]]
missing_keys = [k for k in expected_keys if k not in pretrained_model._features]

if missing_keys:
    print(f"\n⚠️  WARNING: Missing hooks for: {missing_keys}")
    print("Please re-run the cell that defines ResNet18_FeatureHook class!")
else:
    print("\n✓ All hooks registered correctly!")
    print("Pretrained model ready. This model has NEVER seen PACS data.")

Loading fresh pretrained ResNet-18 (ImageNet weights only)...

Verifying hooks are registered correctly...
Keys in _features after forward pass: ['layer1.0.bn2', 'layer1.1.bn2', 'layer2.0.bn2', 'layer2.1.bn2', 'layer3.0.bn2', 'layer3.1.bn2', 'layer4.0.bn2', 'layer4.1.bn2']

✓ All hooks registered correctly!
Pretrained model ready. This model has NEVER seen PACS data.


In [None]:
# -----------------------------------------------------------------------------
# Cell 2: Feature Extraction with Per-Domain Tracking (with error handling)
# -----------------------------------------------------------------------------

@torch.no_grad()
def extract_gap_features_v2(model, domains, layer_name):
    """
    Extract globally-averaged-pooled features with domain tracking.

    Args:
        model: The model to extract features from
        domains: List of domain names to process
        layer_name: Which layer to extract from (e.g., "layer4.1")

    Returns:
        X: (N, C) feature matrix
        y_class: (N,) class labels
        y_domain: (N,) domain labels (0-indexed)
        domain_names: List mapping domain index to name
    """
    DOMAIN_TO_ID = {
        "photo": 0,
        "art_painting": 1,
        "cartoon": 2,
        "sketch": 3,
    }

    # The key format in _features
    feature_key = f"{layer_name}.bn2"

    X, y_class, y_domain = [], [], []
    model.eval()

    for domain in domains:
        loader = make_loader(domain, test_tf, shuffle=False)
        d_id = DOMAIN_TO_ID[domain]

        for x, y in loader:
            x = x.to(device)
            _ = model(x)

            # Check if the key exists
            if feature_key not in model._features:
                available_keys = list(model._features.keys())
                raise KeyError(
                    f"Key '{feature_key}' not found in model._features.\n"
                    f"Available keys: {sorted(available_keys)}\n"
                    f"Please re-run the cell that defines ResNet18_FeatureHook class."
                )

            feat = model._features[feature_key]
            h = feat.mean(dim=[2, 3]).cpu().numpy()  # Global Average Pooling

            X.append(h)
            y_class.append(y.numpy())
            y_domain.append(np.full(h.shape[0], d_id))

    return (
        np.concatenate(X),
        np.concatenate(y_class),
        np.concatenate(y_domain),
        domains,
    )

In [None]:
# -----------------------------------------------------------------------------
# Cell 3: Probe Training with Per-Domain Accuracy Breakdown
# -----------------------------------------------------------------------------
# This function trains a linear probe and computes accuracy both overall
# and broken down by domain.

def train_probe_with_breakdown(dataset, num_classes, y_domain_full, domain_names,
                                seed=0, epochs=50, train_frac=0.7):
    """
    Train a linear probe and return both overall and per-domain accuracy.

    The per-domain breakdown shows how well the probe classifies samples
    FROM each domain. For a domain probe, this tells us which domains are
    easiest/hardest to identify.

    Args:
        dataset: ProjectedDataset with X and y
        num_classes: Number of output classes
        y_domain_full: Full array of domain labels (before train/val split)
        domain_names: List of domain name strings
        seed: Random seed
        epochs: Training epochs
        train_frac: Fraction for training (rest is validation)

    Returns:
        overall_acc: Float, overall validation accuracy
        per_domain_acc: Dict mapping domain name -> accuracy for that domain's samples
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    n = len(dataset)
    n_train = int(train_frac * n)

    # Create shuffled indices
    indices = np.random.permutation(n)
    train_indices = indices[:n_train]
    val_indices = indices[n_train:]

    train_ds = torch.utils.data.Subset(dataset, train_indices.tolist())
    val_ds = torch.utils.data.Subset(dataset, val_indices.tolist())

    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)

    # Initialize and train probe
    probe = LinearProbe(dataset.X.shape[1], num_classes).to(device)
    opt = torch.optim.Adam(probe.parameters(), lr=1e-2)
    loss_fn = nn.CrossEntropyLoss()

    probe.train()
    for _ in range(epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss_fn(probe(x), y).backward()
            opt.step()

    # Evaluate on validation set
    probe.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            pred = probe(x).argmax(1).cpu().numpy()
            all_preds.append(pred)
            all_labels.append(y.numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Get domain labels for validation samples
    val_domain_labels = y_domain_full[val_indices]

    # Overall accuracy
    overall_acc = (all_preds == all_labels).mean()

    # Per-domain accuracy
    # For each domain, how accurate is the probe on samples FROM that domain?
    per_domain_acc = {}
    for d_idx, d_name in enumerate(domain_names):
        mask = val_domain_labels == d_idx
        if mask.sum() > 0:
            per_domain_acc[d_name] = (all_preds[mask] == all_labels[mask]).mean()
        else:
            per_domain_acc[d_name] = np.nan

    return overall_acc, per_domain_acc

In [None]:
# -----------------------------------------------------------------------------
# Cell 4: Main Probing Function with Full Reporting
# -----------------------------------------------------------------------------
# This runs the complete probe analysis for a given model, reporting both
# combined accuracy and per-domain breakdown.

def run_full_probe_analysis(model, layer_name, conv, domains, ranks, seeds,
                            probe_type="domain", exclude_photo=False):
    """
    Run complete probe analysis with per-domain breakdown.

    Args:
        model: Model to analyze
        layer_name: Layer name (e.g., "layer4.1")
        conv: The conv layer to get SVD from
        domains: List of domains to include
        ranks: List of ranks to test
        seeds: List of random seeds for averaging
        probe_type: "domain" or "class"
        exclude_photo: If True, exclude photo from analysis

    Returns:
        results: Nested dict with all results
    """
    # Optionally exclude photo
    if exclude_photo:
        domains = [d for d in domains if d != "photo"]

    num_classes = len(domains) if probe_type == "domain" else 7

    # Get SVD of conv layer
    U, S = get_conv_svd(conv)

    # Extract features
    X, y_class, y_domain, domain_names = extract_gap_features_v2(model, domains, layer_name)

    # Choose labels based on probe type
    y_labels = y_domain if probe_type == "domain" else y_class

    results = {}

    for r in ranks:
        results[r] = {}

        for kind in ["major", "minor", "random"]:
            overall_accs = []
            per_domain_accs = {d: [] for d in domains}

            for seed in seeds:
                # Get subspace basis
                Usub = get_subspace(U, kind, r, seed)

                # Create projected dataset
                ds = ProjectedDataset(X, y_labels, Usub)

                # Train probe with breakdown
                overall_acc, per_domain = train_probe_with_breakdown(
                    ds, num_classes=num_classes,
                    y_domain_full=y_domain,
                    domain_names=domains,
                    seed=seed
                )

                overall_accs.append(overall_acc)
                for d in domains:
                    if d in per_domain and not np.isnan(per_domain[d]):
                        per_domain_accs[d].append(per_domain[d])

            # Store results
            results[r][kind] = {
                'overall_mean': np.mean(overall_accs),
                'overall_std': np.std(overall_accs),
                'per_domain': {
                    d: {
                        'mean': np.mean(per_domain_accs[d]) if per_domain_accs[d] else np.nan,
                        'std': np.std(per_domain_accs[d]) if per_domain_accs[d] else np.nan
                    }
                    for d in domains
                }
            }

    return results


def print_probe_results(results, layer_name, probe_type, domains):
    """Pretty-print probe results with per-domain breakdown."""

    print(f"\n{'='*70}")
    print(f"{probe_type.upper()} PROBE @ {layer_name}")
    print(f"{'='*70}")

    for r, r_results in results.items():
        print(f"\n--- Rank {r} ---")
        print(f"{'Subspace':<10} {'Overall':<18} | Per-Domain Accuracy")
        print("-" * 70)

        for kind in ["major", "minor", "random"]:
            data = r_results[kind]
            overall_str = f"{data['overall_mean']:.3f} ± {data['overall_std']:.3f}"

            # Per-domain string
            per_domain_parts = []
            for d in domains:
                d_short = d[:4]  # Abbreviate domain names
                d_acc = data['per_domain'][d]['mean']
                if not np.isnan(d_acc):
                    per_domain_parts.append(f"{d_short}:{d_acc:.2f}")
            per_domain_str = ", ".join(per_domain_parts)

            print(f"{kind:<10} {overall_str:<18} | {per_domain_str}")

        # Print gap analysis
        major_overall = r_results['major']['overall_mean']
        minor_overall = r_results['minor']['overall_mean']
        gap = minor_overall - major_overall
        gap_sign = "+" if gap > 0 else ""
        print(f"\n  Gap (minor - major): {gap_sign}{gap:.3f}")
        if probe_type == "domain":
            if gap > 0.02:
                print("  → Minor has MORE domain info (supports SoMA hypothesis)")
            elif gap < -0.02:
                print("  → Major has MORE domain info (contradicts SoMA hypothesis)")
            else:
                print("  → No clear difference")

In [None]:
# -----------------------------------------------------------------------------
# DIAGNOSTIC: Check what's in pretrained_model._features
# -----------------------------------------------------------------------------

print("Testing pretrained_model hooks...")
pretrained_model.eval()

# Run a single forward pass
test_batch = torch.randn(2, 3, 224, 224).to(device)
with torch.no_grad():
    _ = pretrained_model(test_batch)

print(f"\nNumber of keys: {len(pretrained_model._features)}")
print(f"Keys: {sorted(pretrained_model._features.keys())}")

# Check shapes
for key, feat in pretrained_model._features.items():
    print(f"  {key}: {feat.shape}")
# ```

# **Expected output:**
# ```
# Testing pretrained_model hooks...

# Number of keys: 8
# Keys: ['layer1.0.bn2', 'layer1.1.bn2', 'layer2.0.bn2', 'layer2.1.bn2', 'layer3.0.bn2', 'layer3.1.bn2', 'layer4.0.bn2', 'layer4.1.bn2']
#   layer1.0.bn2: torch.Size([2, 64, 56, 56])
#   layer1.1.bn2: torch.Size([2, 64, 56, 56])
#   layer2.0.bn2: torch.Size([2, 128, 28, 28])
  # ...

Testing pretrained_model hooks...

Number of keys: 8
Keys: ['layer1.0.bn2', 'layer1.1.bn2', 'layer2.0.bn2', 'layer2.1.bn2', 'layer3.0.bn2', 'layer3.1.bn2', 'layer4.0.bn2', 'layer4.1.bn2']
  layer1.0.bn2: torch.Size([2, 64, 56, 56])
  layer1.1.bn2: torch.Size([2, 64, 56, 56])
  layer2.0.bn2: torch.Size([2, 128, 28, 28])
  layer2.1.bn2: torch.Size([2, 128, 28, 28])
  layer3.0.bn2: torch.Size([2, 256, 14, 14])
  layer3.1.bn2: torch.Size([2, 256, 14, 14])
  layer4.0.bn2: torch.Size([2, 512, 7, 7])
  layer4.1.bn2: torch.Size([2, 512, 7, 7])


In [None]:
# -----------------------------------------------------------------------------
# Cell 5: Run Scenario 1 - Domain Probe on PRETRAINED Model (All 4 Domains)
# -----------------------------------------------------------------------------
# This probes the pretrained model on ALL conv blocks to see if domain info
# is naturally concentrated in the minor subspace BEFORE any finetuning.

print("=" * 70)
print("SCENARIO 1: Probing PRETRAINED ImageNet Model")
print("Testing if pretrained spectral structure already separates domain info")
print("=" * 70)

# Configuration
domains_all = ["photo", "art_painting", "cartoon", "sketch"]
ranks = [4, 8, 16]
seeds = [0, 1, 2]

# Get conv blocks from pretrained model (ALL layers)
pretrained_conv_blocks = {}
for lname in ["layer1", "layer2", "layer3", "layer4"]:
    layer = getattr(pretrained_model.backbone, lname)
    for i, block in enumerate(layer):
        key = f"{lname}.{i}"
        pretrained_conv_blocks[key] = block.conv2

# Store all results
scenario1_domain_results = {}

print("\n===== DOMAIN PROBES (ALL CONV BLOCKS) - PRETRAINED MODEL =====")

# Loop over ALL conv blocks
for layer_name, conv in pretrained_conv_blocks.items():

    results = run_full_probe_analysis(
        model=pretrained_model,
        layer_name=layer_name,
        conv=conv,
        domains=domains_all,
        ranks=ranks,
        seeds=seeds,
        probe_type="domain",
        exclude_photo=False
    )

    scenario1_domain_results[layer_name] = results
    print_probe_results(results, layer_name, "domain", domains_all)

SCENARIO 1: Probing PRETRAINED ImageNet Model
Testing if pretrained spectral structure already separates domain info

===== DOMAIN PROBES (ALL CONV BLOCKS) - PRETRAINED MODEL =====

DOMAIN PROBE @ layer1.0

--- Rank 4 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.746 ± 0.009      | phot:0.33, art_:0.70, cart:0.66, sket:1.00
minor      0.696 ± 0.012      | phot:0.41, art_:0.73, cart:0.34, sket:1.00
random     0.655 ± 0.023      | phot:0.23, art_:0.62, cart:0.41, sket:1.00

  Gap (minor - major): -0.051
  → Major has MORE domain info (contradicts SoMA hypothesis)

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.790 ± 0.002      | phot:0.49, art_:0.67, cart:0.76, sket:0.99
minor      0.777 ± 0.013      | phot:0.65, art_:0.69, cart:0.56, sket:1.00
random     0.777 ± 0.002      | phot:0.63, art_:0

In [None]:
# -----------------------------------------------------------------------------
# Cell 7: Run Scenario 1 - Class Probe on PRETRAINED Model (All Conv Blocks)
# -----------------------------------------------------------------------------

print("\n" + "=" * 70)
print("SCENARIO 1: Class Probe on PRETRAINED Model - ALL CONV BLOCKS")
print("=" * 70)

scenario1_class_results = {}

for layer_name, conv in pretrained_conv_blocks.items():

    results = run_full_probe_analysis(
        model=pretrained_model,
        layer_name=layer_name,
        conv=conv,
        domains=domains_all,
        ranks=ranks,
        seeds=seeds,
        probe_type="class",
        exclude_photo=False
    )

    scenario1_class_results[layer_name] = results
    print_probe_results(results, layer_name, "class", domains_all)


SCENARIO 1: Class Probe on PRETRAINED Model - ALL CONV BLOCKS

CLASS PROBE @ layer1.0

--- Rank 4 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.245 ± 0.007      | phot:0.30, art_:0.29, cart:0.27, sket:0.19
minor      0.266 ± 0.007      | phot:0.25, art_:0.28, cart:0.28, sket:0.26
random     0.256 ± 0.017      | phot:0.29, art_:0.27, cart:0.29, sket:0.22

  Gap (minor - major): +0.021

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.294 ± 0.013      | phot:0.34, art_:0.31, cart:0.31, sket:0.26
minor      0.299 ± 0.011      | phot:0.29, art_:0.29, cart:0.31, sket:0.30
random     0.319 ± 0.021      | phot:0.36, art_:0.30, cart:0.37, sket:0.29

  Gap (minor - major): +0.006

--- Rank 16 ---
Subspace   Overall            | Per-Domain Accuracy
------------------------------------------------------

# Tier-2: ΔW-Based Domain-Sensitive Subspace Discovery
Working with pretrained base, and just finetuning that in one direction only

In [None]:
import copy
import numpy as np
import torch
import torch.nn as nn
from torchvision import models

device = "cuda" if torch.cuda.is_available() else "cpu"

# =============================================================================
# Load PRETRAINED model (ImageNet only, never seen PACS)
# =============================================================================

def load_pretrained_model():
    """Load fresh ImageNet-pretrained ResNet-18 (NOT finetuned on PACS)"""
    model = ResNet18_FeatureHook(num_classes=7).to(device)
    # The ResNet18_FeatureHook already loads ImageNet weights in __init__
    # We just need to ensure fc is properly initialized (it doesn't matter for ΔW)
    nn.init.xavier_uniform_(model.backbone.fc.weight)
    nn.init.zeros_(model.backbone.fc.bias)
    return model


def load_finetuned_model():
    """Load the model finetuned on 3 source domains"""
    model = ResNet18_FeatureHook(num_classes=7).to(device)
    state = torch.load(BACKBONE_PATH, map_location=device)
    model.load_state_dict(state)
    return model


In [None]:
# =============================================================================
# Train single domain starting from PRETRAINED (not finetuned)
# =============================================================================

def freeze_bn_stats(model):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()                  # freeze running mean/var
            m.weight.requires_grad = True
            m.bias.requires_grad = True

from collections import defaultdict
def make_balanced_loader(domain, tf, batch_size):
    ds = datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=tf
    )

    # group indices by class
    class_to_indices = defaultdict(list)
    for idx, (_, y) in enumerate(ds.samples):
        class_to_indices[y].append(idx)

    # sample equal counts per class
    min_count = min(len(v) for v in class_to_indices.values())
    balanced_indices = []
    for v in class_to_indices.values():
        balanced_indices.extend(v[:min_count])

    sampler = torch.utils.data.SubsetRandomSampler(balanced_indices)

    return DataLoader(
        ds,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=NUM_WORKERS
    )

def train_single_domain_from_pretrained(domain, epochs=3, lr=1e-4):
    """
    Start from pretrained ImageNet model and train on single domain.
    This captures: "How does pretrained → domain d adaptation occur?"
    """
    model = load_pretrained_model()  # PRETRAINED, not finetuned!
    freeze_bn_stats(model)

    loader = make_balanced_loader(domain, train_tf, BATCH_SIZE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    return model


In [None]:
def build_deltaW_channel_subspace(dW_list):
    """
    dW_list: list of ΔW matrices for one layer
             each ΔW has shape [C_out, D]

    Returns:
        U_domain: [C_out, C_out] eigenvectors (descending order)
        S_domain: eigenvalues
    """
    C_out = dW_list[0].shape[0]
    C = np.zeros((C_out, C_out))

    for dW in dW_list:
        C += dW @ dW.T   # channel-channel covariance

    # Eigen-decomposition (symmetric PSD matrix)
    eigvals, eigvecs = np.linalg.eigh(C)

    # Sort descending
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]

    return eigvecs, eigvals


def compute_alignment_metrics(U, V):
    """
    U, V: [C_out, r] orthonormal basis matrices

    Returns:
        mean_angle_deg
        overlap (mean cos^2)
        projection_energy
    """
    # Cross-subspace matrix
    M = U.T @ V   # [r, r]

    # Singular values = cos(theta_i)
    sv = np.linalg.svd(M, compute_uv=False)
    sv = np.clip(sv, -1.0, 1.0)

    # Principal angles (degrees)
    angles = np.degrees(np.arccos(sv))
    mean_angle = angles.mean()

    # Overlap (mean squared cosine)
    overlap = np.mean(sv ** 2)

    # Projection energy (equivalent to overlap)
    proj_energy = np.linalg.norm(M, ord="fro") ** 2 / V.shape[1]

    return mean_angle, overlap, proj_energy


# =============================================================================
# Compute ΔW from pretrained model
# =============================================================================

def extract_delta_W_from_pretrained(pretrained_model, adapted_model, conv_path):
    """
    Compute ΔW = W_adapted - W_pretrained
    """
    base_conv = getattr(
        getattr(pretrained_model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    adapt_conv = getattr(
        getattr(adapted_model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    W0 = base_conv.weight.detach().cpu().numpy()
    W1 = adapt_conv.weight.detach().cpu().numpy()

    dW = W1 - W0
    C_out = dW.shape[0]
    return dW.reshape(C_out, -1)



In [None]:


# =============================================================================
# Run the corrected Tier 2 analysis
# =============================================================================

conv_layers = {
    "layer2.1": ("layer2", 1, "conv2"),
    "layer3.1": ("layer3", 1, "conv2"),
    "layer4.1": ("layer4", 1, "conv2"),
}

# Load pretrained model as the BASE for ΔW computation
pretrained_model = load_pretrained_model()

# Store ΔW for each domain
deltaWs_from_pretrained = {k: [] for k in conv_layers}

print("Computing ΔW from PRETRAINED model to each domain...")
for domain in SOURCE_DOMAINS:
    print(f"  Training on domain: {domain}")
    adapted = train_single_domain_from_pretrained(domain, epochs=3)

    for lname, path in conv_layers.items():
        dW = extract_delta_W_from_pretrained(pretrained_model, adapted, path)
        deltaWs_from_pretrained[lname].append(dW)

# Build domain-sensitive subspace from these ΔWs
delta_subspaces_pretrained = {}
for lname, dW_list in deltaWs_from_pretrained.items():
    U_domain, S_domain = build_deltaW_channel_subspace(dW_list)
    delta_subspaces_pretrained[lname] = (U_domain, S_domain)

print("Constructed ΔW subspaces from pretrained base.")


Computing ΔW from PRETRAINED model to each domain...
  Training on domain: photo
  Training on domain: art_painting
  Training on domain: cartoon
Constructed ΔW subspaces from pretrained base.


In [None]:
# =============================================================================
# Get SVD subspaces from PRETRAINED model (what SoMA would actually use)
# =============================================================================

def get_soma_subspace_from_model(model, layer_path, kind, r):
    """
    Get major/minor/random subspace from a specific model's weights.
    """
    conv = getattr(
        getattr(model.backbone, layer_path[0])[layer_path[1]],
        layer_path[2]
    )

    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)

    U, S, _ = np.linalg.svd(Wmat, full_matrices=False)

    if kind == "minor":
        return U[:, -r:]
    if kind == "major":
        return U[:, :r]
    if kind == "random":
        Q, _ = np.linalg.qr(np.random.randn(C_out, r))
        return Q

    raise ValueError(f"Unknown kind: {kind}")

In [None]:
# =============================================================================
# Compute alignment between ΔW subspace and SVD subspaces of PRETRAINED model
# =============================================================================

print("\n" + "="*70)
print("APPROACH A: ΔW from Pretrained, SVD from Pretrained")
print("="*70)

ranks = [4, 8, 16]
n_random = 10

alignment_results_pretrained = {}

for lname, (U_domain, S_domain) in delta_subspaces_pretrained.items():
    print(f"\n=== Alignment @ {lname} ===")
    alignment_results_pretrained[lname] = {}

    for r in ranks:
        print(f"\nRank {r}")
        V = U_domain[:, :r]  # Top-r ΔW directions

        alignment_results_pretrained[lname][r] = {}

        for kind in ["minor", "major", "random"]:
            angles, overlaps = [], []

            for seed in range(n_random):
                np.random.seed(seed)
                # SVD from PRETRAINED model (what SoMA actually uses)
                U = get_soma_subspace_from_model(pretrained_model, conv_layers[lname], kind, r)

                mean_angle, overlap, _ = compute_alignment_metrics(U, V)
                angles.append(mean_angle)
                overlaps.append(overlap)

            alignment_results_pretrained[lname][r][kind] = {
                "angle": (np.mean(angles), np.std(angles)),
                "overlap": (np.mean(overlaps), np.std(overlaps)),
            }

            print(f"  {kind:>6} | angle={np.mean(angles):.1f}° | overlap={np.mean(overlaps):.3f}")


APPROACH A: ΔW from Pretrained, SVD from Pretrained

=== Alignment @ layer2.1 ===

Rank 4
   minor | angle=83.4° | overlap=0.018
   major | angle=81.7° | overlap=0.036
  random | angle=81.8° | overlap=0.030

Rank 8
   minor | angle=81.0° | overlap=0.032
   major | angle=78.3° | overlap=0.062
  random | angle=77.4° | overlap=0.066

Rank 16
   minor | angle=76.9° | overlap=0.070
   major | angle=69.4° | overlap=0.157
  random | angle=71.9° | overlap=0.128

=== Alignment @ layer3.1 ===

Rank 4
   minor | angle=85.8° | overlap=0.007
   major | angle=81.0° | overlap=0.037
  random | angle=84.1° | overlap=0.016

Rank 8
   minor | angle=83.9° | overlap=0.016
   major | angle=76.0° | overlap=0.086
  random | angle=81.1° | overlap=0.033

Rank 16
   minor | angle=80.8° | overlap=0.037
   major | angle=72.3° | overlap=0.125
  random | angle=77.6° | overlap=0.062

=== Alignment @ layer4.1 ===

Rank 4
   minor | angle=85.8° | overlap=0.008
   major | angle=86.5° | overlap=0.005
  random | angle=85

In [None]:
# =============================================================================
# VALIDATION: Does held-out domain (Sketch) align with ΔW subspace from 3 sources?
# =============================================================================

HELD_OUT_DOMAIN = "sketch"

print("\n" + "="*70)
print("VALIDATION: Does Sketch's ΔW align with subspace from Photo/Art/Cartoon?")
print("="*70)

# Train on Sketch starting from pretrained
print(f"\nTraining on held-out domain: {HELD_OUT_DOMAIN}")
adapted_sketch = train_single_domain_from_pretrained(HELD_OUT_DOMAIN, epochs=3)

# Compute ΔW for Sketch
deltaW_sketch = {}
for lname, path in conv_layers.items():
    dW = extract_delta_W_from_pretrained(pretrained_model, adapted_sketch, path)
    deltaW_sketch[lname] = dW

# For each layer, check if Sketch's ΔW aligns with the subspace derived from 3 domains
print("\nAlignment of Sketch's ΔW with 3-domain derived subspace:")

for lname in conv_layers.keys():
    print(f"\n=== {lname} ===")

    # Get the ΔW subspace from 3 source domains
    U_3domain, S_3domain = delta_subspaces_pretrained[lname]

    # Get Sketch's ΔW
    dW_sketch = deltaW_sketch[lname]

    # Compute how much of Sketch's ΔW variance is captured by 3-domain subspace
    for r in [4, 8, 16]:
        U_sub = U_3domain[:, :r]  # Top-r directions from 3 domains

        # Project Sketch's ΔW onto this subspace
        dW_projected = U_sub @ (U_sub.T @ dW_sketch)

        # Compute fraction of variance captured
        total_variance = np.linalg.norm(dW_sketch, 'fro')**2
        projected_variance = np.linalg.norm(dW_projected, 'fro')**2
        fraction_captured = projected_variance / total_variance

        print(f"  Rank {r}: {fraction_captured:.1%} of Sketch's ΔW variance captured")

# Also compute principal angles between Sketch's ΔW and 3-domain subspace
print("\nPrincipal angles between Sketch's ΔW direction and 3-domain subspace:")

for lname in conv_layers.keys():
    print(f"\n=== {lname} ===")

    U_3domain, _ = delta_subspaces_pretrained[lname]
    dW_sketch = deltaW_sketch[lname]

    # Get top direction of Sketch's ΔW
    U_sketch, S_sketch, _ = np.linalg.svd(dW_sketch, full_matrices=False)

    for r in [4, 8, 16]:
        V_3domain = U_3domain[:, :r]
        V_sketch = U_sketch[:, :r]

        mean_angle, overlap, _ = compute_alignment_metrics(V_sketch, V_3domain)
        print(f"  Rank {r}: angle={mean_angle:.1f}°, overlap={overlap:.3f}")


VALIDATION: Does Sketch's ΔW align with subspace from Photo/Art/Cartoon?

Training on held-out domain: sketch

Alignment of Sketch's ΔW with 3-domain derived subspace:

=== layer2.1 ===
  Rank 4: 7.2% of Sketch's ΔW variance captured
  Rank 8: 12.7% of Sketch's ΔW variance captured
  Rank 16: 23.3% of Sketch's ΔW variance captured

=== layer3.1 ===
  Rank 4: 2.9% of Sketch's ΔW variance captured
  Rank 8: 6.3% of Sketch's ΔW variance captured
  Rank 16: 11.7% of Sketch's ΔW variance captured

=== layer4.1 ===
  Rank 4: 0.5% of Sketch's ΔW variance captured
  Rank 8: 1.2% of Sketch's ΔW variance captured
  Rank 16: 2.3% of Sketch's ΔW variance captured

Principal angles between Sketch's ΔW direction and 3-domain subspace:

=== layer2.1 ===
  Rank 4: angle=68.4°, overlap=0.158
  Rank 8: angle=67.6°, overlap=0.183
  Rank 16: angle=60.2°, overlap=0.295

=== layer3.1 ===
  Rank 4: angle=80.8°, overlap=0.034
  Rank 8: angle=75.1°, overlap=0.089
  Rank 16: angle=69.0°, overlap=0.167

=== lay

In [None]:
# -----------------------------------------------------------------------------
# Get Major/Minor subspaces from pretrained model
# -----------------------------------------------------------------------------

def get_svd_subspaces(model, conv_path):
    """Get SVD of pretrained weights"""
    conv = getattr(
        getattr(model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)

    U, S, Vh = np.linalg.svd(Wmat, full_matrices=False)
    return U, S, Vh

pretrained_svd = {}
for lname, path in conv_layers.items():
    U, S, Vh = get_svd_subspaces(pretrained_model, path)
    pretrained_svd[lname] = {'U': U, 'S': S, 'Vh': Vh}
    print(f"  {lname}: U shape = {U.shape}, rank = {len(S)}")


def compute_alignment_metrics(U, V):
    """
    Compute alignment between two subspaces.
    U, V: [d, r] orthonormal bases
    Returns: mean_angle (degrees), overlap (mean cos²θ)
    """
    M = U.T @ V
    sv = np.linalg.svd(M, compute_uv=False)
    sv = np.clip(sv, -1.0, 1.0)

    angles = np.degrees(np.arccos(sv))
    mean_angle = angles.mean()
    overlap = np.mean(sv ** 2)

    return mean_angle, overlap

def get_deltaW_principal_directions(dW, r):
    """Get top-r principal directions of ΔW"""
    U_dw, S_dw, _ = np.linalg.svd(dW, full_matrices=False)
    return U_dw[:, :r]

print("\n" + "="*70)
print("SANITY CHECK: Sketch's ΔW Alignment with Pretrained SVD Subspaces")
print("="*70)

ranks = [4, 8, 16]

for lname in conv_layers.keys():
    print(f"\n=== {lname} ===")

    U_pretrained = pretrained_svd[lname]['U']
    dW_sketch = deltaW_sketch[lname]
    d = U_pretrained.shape[0]  # Output dimension

    for r in ranks:
        # Get Sketch's ΔW principal directions
        V_sketch = get_deltaW_principal_directions(dW_sketch, r)

        # Major: top-r of pretrained
        U_major = U_pretrained[:, :r]

        # Minor: bottom-r of pretrained
        U_minor = U_pretrained[:, -r:]

        # Middle: middle-r of pretrained
        mid_start = d // 2 - r // 2
        U_middle = U_pretrained[:, mid_start:mid_start + r]

        # Random
        np.random.seed(42)
        Q_random, _ = np.linalg.qr(np.random.randn(d, r))

        # Compute alignments
        angle_major, overlap_major = compute_alignment_metrics(U_major, V_sketch)
        angle_minor, overlap_minor = compute_alignment_metrics(U_minor, V_sketch)
        angle_middle, overlap_middle = compute_alignment_metrics(U_middle, V_sketch)
        angle_random, overlap_random = compute_alignment_metrics(Q_random, V_sketch)

        print(f"\n  Rank {r}:")
        print(f"    Major:  angle={angle_major:.1f}°, overlap={overlap_major:.4f}")
        print(f"    Minor:  angle={angle_minor:.1f}°, overlap={overlap_minor:.4f}")
        print(f"    Middle: angle={angle_middle:.1f}°, overlap={overlap_middle:.4f}")
        print(f"    Random: angle={angle_random:.1f}°, overlap={overlap_random:.4f}")

  layer2.1: U shape = (128, 128), rank = 128
  layer3.1: U shape = (256, 256), rank = 256
  layer4.1: U shape = (512, 512), rank = 512

SANITY CHECK: Sketch's ΔW Alignment with Pretrained SVD Subspaces

=== layer2.1 ===

  Rank 4:
    Major:  angle=78.6°, overlap=0.0509
    Minor:  angle=83.2°, overlap=0.0190
    Middle: angle=81.1°, overlap=0.0348
    Random: angle=82.1°, overlap=0.0273

  Rank 8:
    Major:  angle=77.1°, overlap=0.0703
    Minor:  angle=79.4°, overlap=0.0452
    Middle: angle=78.3°, overlap=0.0546
    Random: angle=78.1°, overlap=0.0647

  Rank 16:
    Major:  angle=70.0°, overlap=0.1474
    Minor:  angle=75.4°, overlap=0.0879
    Middle: angle=72.5°, overlap=0.1214
    Random: angle=71.9°, overlap=0.1309

=== layer3.1 ===

  Rank 4:
    Major:  angle=84.3°, overlap=0.0152
    Minor:  angle=84.7°, overlap=0.0174
    Middle: angle=85.3°, overlap=0.0102
    Random: angle=83.5°, overlap=0.0157

  Rank 8:
    Major:  angle=79.1°, overlap=0.0494
    Minor:  angle=82.0°, o

# EXPERIMENT 3: Subspace Comparison for Domain Generalization


In [None]:
# =============================================================================
# EXPERIMENT 3: Subspace Comparison for Domain Generalization
# =============================================================================
#
# This experiment definitively tests whether the MINOR subspace is special,
# or whether ANY low-rank constraint provides similar benefits.
#
# Variants:
#   - SoMA-Minor:  Constrain updates to bottom-r singular vectors
#   - SoMA-Major:  Constrain updates to top-r singular vectors
#   - SoMA-Middle: Constrain updates to middle-r singular vectors
#   - SoMA-Random: Constrain updates to random orthonormal basis
#   - Full-FT:     No constraint (baseline)
#
# Protocol:
#   - Train on 3 source domains (Photo, Art, Cartoon)
#   - Evaluate on held-out target domain (Sketch)
#   - Use rank=8 for all constrained variants
#
# =============================================================================

import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from tqdm import tqdm
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------

RANK = 8  # Adapter rank for all constrained variants
EPOCHS = 10
LR = 1e-4
BATCH_SIZE = 32
NUM_WORKERS = 2

SOURCE_DOMAINS = ["photo", "art_painting", "cartoon"]
TARGET_DOMAIN = "sketch"

# Layers to apply adapters (following SoMA paper - typically mid-to-late layers)
ADAPTER_LAYERS = [
    ("layer2", 0, "conv2"),
    ("layer2", 1, "conv2"),
    ("layer3", 0, "conv2"),
    ("layer3", 1, "conv2"),
    ("layer4", 0, "conv2"),
    ("layer4", 1, "conv2"),
]

# Random seeds for reproducibility
SEEDS = [42, 123, 456]


# -----------------------------------------------------------------------------
# Data Loading
# -----------------------------------------------------------------------------

IMG_SIZE = 224

train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def get_domain_dataset(domain, transform):
    return datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=transform
    )


def get_source_loader(batch_size=BATCH_SIZE):
    """Combined loader for all source domains"""
    source_datasets = [get_domain_dataset(d, train_tf) for d in SOURCE_DOMAINS]
    combined = ConcatDataset(source_datasets)
    return DataLoader(combined, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)


def get_target_loader(batch_size=BATCH_SIZE):
    """Loader for target domain (evaluation only)"""
    target_dataset = get_domain_dataset(TARGET_DOMAIN, test_tf)
    return DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS)


# -----------------------------------------------------------------------------
# Subspace-Constrained Adapter (LoRA-style)
# -----------------------------------------------------------------------------

class SubspaceAdapter(nn.Module):
    def __init__(self, conv_layer, subspace_basis, scale=1.0):
        super().__init__()

        self.conv = conv_layer
        self.scale = scale

        C_out, C_in, k1, k2 = conv_layer.weight.shape
        r = subspace_basis.shape[1]

        self.C_out = C_out
        self.C_in = C_in
        self.k = k1
        self.r = r
        self.stride = conv_layer.stride
        self.padding = conv_layer.padding

        # A is trainable (down-projection)
        self.A = nn.Parameter(torch.randn(r, C_in * k1 * k2) * 0.01)

        # B is FROZEN (enforces subspace constraint)
        self.register_buffer('B', torch.from_numpy(subspace_basis.astype(np.float32)))

    def forward(self, x):
        # Original convolution output
        out = self.conv(x)

        # Adapter path: unfold input, apply low-rank transform
        # This is equivalent to a rank-r convolution

        # Unfold input to patches: (N, C_in*k*k, H_out*W_out)
        x_unf = F.unfold(x, kernel_size=self.k, stride=self.stride, padding=self.padding)

        # Apply A: (N, r, H_out*W_out)
        adapter_out = self.A @ x_unf

        # Apply B: (N, C_out, H_out*W_out)
        adapter_out = self.B @ adapter_out

        # Reshape to match conv output: (N, C_out, H_out, W_out)
        adapter_out = adapter_out.view(out.shape)

        return out + self.scale * adapter_out


# -----------------------------------------------------------------------------
# Model with Subspace Adapters
# -----------------------------------------------------------------------------

class ResNet18WithAdapters(nn.Module):
    """
    ResNet-18 with subspace-constrained adapters on specified layers.
    """

    def __init__(self, num_classes, adapter_layers, subspace_type, rank, random_seed=42):
        """
        Args:
            num_classes: Number of output classes
            adapter_layers: List of (layer_name, block_idx, conv_name) tuples
            subspace_type: 'minor', 'major', 'middle', 'random', or 'none'
            rank: Rank of adapters
            random_seed: Seed for random subspace generation
        """
        super().__init__()

        from torchvision import models

        # Load pretrained backbone
        self.backbone = models.resnet18(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(512, num_classes)

        self.subspace_type = subspace_type
        self.rank = rank
        self.random_seed = random_seed
        self.adapter_layers = adapter_layers
        self.adapters = nn.ModuleDict()

        # Store the random bases used (for reporting)
        self.random_bases = {}

        if subspace_type != 'none':
            self._add_adapters()

    def _get_conv_layer(self, layer_name, block_idx, conv_name):
        """Get reference to a specific conv layer"""
        layer = getattr(self.backbone, layer_name)
        block = layer[block_idx]
        return getattr(block, conv_name)

    def _set_conv_layer(self, layer_name, block_idx, conv_name, new_module):
        """Replace a specific conv layer with adapter"""
        layer = getattr(self.backbone, layer_name)
        block = layer[block_idx]
        setattr(block, conv_name, new_module)

    def _compute_subspace_basis(self, conv_layer, layer_key):
        """Compute the subspace basis for a given conv layer"""
        W = conv_layer.weight.detach().cpu().numpy()
        C_out = W.shape[0]
        W_mat = W.reshape(C_out, -1)

        # Compute SVD
        U, S, Vh = np.linalg.svd(W_mat, full_matrices=False)

        r = self.rank
        d = C_out

        if self.subspace_type == 'minor':
            # Bottom-r singular vectors
            basis = U[:, -r:]
        elif self.subspace_type == 'major':
            # Top-r singular vectors
            basis = U[:, :r]
        elif self.subspace_type == 'middle':
            # Middle-r singular vectors
            mid_start = d // 2 - r // 2
            basis = U[:, mid_start:mid_start + r]
        elif self.subspace_type == 'random':
            # Random orthonormal basis (consistent across runs with same seed)
            np.random.seed(self.random_seed + hash(layer_key) % 10000)
            random_matrix = np.random.randn(d, r)
            basis, _ = np.linalg.qr(random_matrix)
            # Store for reporting
            self.random_bases[layer_key] = basis.copy()
        else:
            raise ValueError(f"Unknown subspace type: {self.subspace_type}")

        return basis

    def _add_adapters(self):
        """Add adapters to specified layers"""
        for layer_name, block_idx, conv_name in self.adapter_layers:
            # layer_key = f"{layer_name}.{block_idx}.{conv_name}"
            layer_key = f"{layer_name}_{block_idx}_{conv_name}"

            # Get original conv layer
            conv = self._get_conv_layer(layer_name, block_idx, conv_name)

            # Compute subspace basis
            basis = self._compute_subspace_basis(conv, layer_key)

            # Create adapter
            adapter = SubspaceAdapter(conv, basis, scale=1.0)

            # Replace conv with adapter
            self._set_conv_layer(layer_name, block_idx, conv_name, adapter)

            # Store reference
            self.adapters[layer_key] = adapter

    def freeze_backbone(self):
      for param in self.backbone.parameters():
          param.requires_grad = False

      # Only unfreeze A (B is a buffer, not a parameter)
      for adapter in self.adapters.values():
          adapter.A.requires_grad = True

      # Unfreeze classifier
      for param in self.backbone.fc.parameters():
          param.requires_grad = True

      # BatchNorm affine params
      for m in self.backbone.modules():
          if isinstance(m, nn.BatchNorm2d):
              m.eval()
              m.weight.requires_grad = True
              m.bias.requires_grad = True

    def forward(self, x):
        return self.backbone(x)

    def get_trainable_params(self):
        """Return number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# -----------------------------------------------------------------------------
# Training and Evaluation Functions
# -----------------------------------------------------------------------------

def train_epoch(model, loader, optimizer, criterion):
    """Train for one epoch"""
    model.train()

    # Keep BatchNorm in eval mode
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

    total_loss = 0
    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader):
    """Evaluate on a dataset"""
    model.eval()

    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return correct / total


# -----------------------------------------------------------------------------
# Full Finetuning Baseline
# -----------------------------------------------------------------------------

class ResNet18FullFinetune(nn.Module):
    """Standard ResNet-18 for full finetuning baseline"""

    def __init__(self, num_classes):
        super().__init__()
        from torchvision import models
        self.backbone = models.resnet18(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        return self.backbone(x)

    def freeze_backbone(self):
        """For full finetuning, we don't freeze anything except BN stats"""
        for m in self.backbone.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = True
                m.bias.requires_grad = True

    def get_trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# -----------------------------------------------------------------------------
# Run Single Experiment
# -----------------------------------------------------------------------------

def run_single_experiment(subspace_type, seed):
    """
    Run a single experiment with given subspace type and seed.

    Args:
        subspace_type: 'minor', 'major', 'middle', 'random', or 'full'
        seed: Random seed

    Returns:
        dict with results
    """
    # Set seeds
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Create model
    if subspace_type == 'full':
        model = ResNet18FullFinetune(num_classes=7).to(device)
    else:
        model = ResNet18WithAdapters(
            num_classes=7,
            adapter_layers=ADAPTER_LAYERS,
            subspace_type=subspace_type,
            rank=RANK,
            random_seed=seed
        ).to(device)

    # Freeze backbone (except adapters and classifier)
    model.freeze_backbone()

    # Get trainable params count
    trainable_params = model.get_trainable_params()

    # Create data loaders
    source_loader = get_source_loader()
    target_loader = get_target_loader()

    # Optimizer and loss
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR
    )
    criterion = nn.CrossEntropyLoss()

    # Training loop
    best_target_acc = 0
    history = {'train_loss': [], 'train_acc': [], 'target_acc': []}

    for epoch in range(EPOCHS):
        train_loss, train_acc = train_epoch(model, source_loader, optimizer, criterion)
        target_acc = evaluate(model, target_loader)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['target_acc'].append(target_acc)

        if target_acc > best_target_acc:
            best_target_acc = target_acc

    # Final evaluation
    final_target_acc = evaluate(model, target_loader)

    # Also evaluate on source domains individually
    source_accs = {}
    for domain in SOURCE_DOMAINS:
        loader = DataLoader(
            get_domain_dataset(domain, test_tf),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )
        source_accs[domain] = evaluate(model, loader)

    return {
        'subspace_type': subspace_type,
        'seed': seed,
        'trainable_params': trainable_params,
        'final_target_acc': final_target_acc,
        'best_target_acc': best_target_acc,
        'source_accs': source_accs,
        'history': history,
    }


# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['minor', 'major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Results Analysis and Reporting
# -----------------------------------------------------------------------------

def analyze_results(results):
    """Analyze and print summary of results"""
    import pandas as pd

    # Convert to DataFrame
    rows = []
    for r in results:
        rows.append({
            'variant': r['subspace_type'],
            'seed': r['seed'],
            'trainable_params': r['trainable_params'],
            'target_acc': r['final_target_acc'],
            'best_target_acc': r['best_target_acc'],
            'source_photo': r['source_accs'].get('photo', 0),
            'source_art': r['source_accs'].get('art_painting', 0),
            'source_cartoon': r['source_accs'].get('cartoon', 0),
        })

    df = pd.DataFrame(rows)

    # Summary statistics
    print("\n" + "=" * 70)
    print("EXPERIMENT 3 RESULTS SUMMARY")
    print("=" * 70)

    print("\n--- Target Domain Accuracy (Sketch) ---")
    summary = df.groupby('variant')['target_acc'].agg(['mean', 'std']).round(4)
    summary = summary.sort_values('mean', ascending=False)
    print(summary)

    print("\n--- Best Target Accuracy (over training) ---")
    summary_best = df.groupby('variant')['best_target_acc'].agg(['mean', 'std']).round(4)
    summary_best = summary_best.sort_values('mean', ascending=False)
    print(summary_best)

    print("\n--- Source Domain Accuracies ---")
    for variant in df['variant'].unique():
        v_df = df[df['variant'] == variant]
        print(f"\n{variant.upper()}:")
        print(f"  Photo:   {v_df['source_photo'].mean():.4f} ± {v_df['source_photo'].std():.4f}")
        print(f"  Art:     {v_df['source_art'].mean():.4f} ± {v_df['source_art'].std():.4f}")
        print(f"  Cartoon: {v_df['source_cartoon'].mean():.4f} ± {v_df['source_cartoon'].std():.4f}")

    print("\n--- Trainable Parameters ---")
    params = df.groupby('variant')['trainable_params'].first()
    print(params)

    # Statistical comparison
    print("\n" + "=" * 70)
    print("KEY COMPARISONS")
    print("=" * 70)

    minor_acc = df[df['variant'] == 'minor']['target_acc'].values
    major_acc = df[df['variant'] == 'major']['target_acc'].values
    middle_acc = df[df['variant'] == 'middle']['target_acc'].values
    random_acc = df[df['variant'] == 'random']['target_acc'].values
    full_acc = df[df['variant'] == 'full']['target_acc'].values

    print(f"\nMinor vs Major: {minor_acc.mean():.4f} vs {major_acc.mean():.4f} (diff: {minor_acc.mean() - major_acc.mean():+.4f})")
    print(f"Minor vs Middle: {minor_acc.mean():.4f} vs {middle_acc.mean():.4f} (diff: {minor_acc.mean() - middle_acc.mean():+.4f})")
    print(f"Minor vs Random: {minor_acc.mean():.4f} vs {random_acc.mean():.4f} (diff: {minor_acc.mean() - random_acc.mean():+.4f})")
    print(f"Minor vs Full: {minor_acc.mean():.4f} vs {full_acc.mean():.4f} (diff: {minor_acc.mean() - full_acc.mean():+.4f})")

    # Interpretation
    print("\n" + "=" * 70)
    print("INTERPRETATION")
    print("=" * 70)

    constrained_accs = np.concatenate([minor_acc, major_acc, middle_acc, random_acc])
    constrained_mean = constrained_accs.mean()
    constrained_std = constrained_accs.std()

    print(f"\nAll constrained methods: {constrained_mean:.4f} ± {constrained_std:.4f}")
    print(f"Full finetuning:         {full_acc.mean():.4f} ± {full_acc.std():.4f}")

    # Check if minor is special
    minor_vs_others = minor_acc.mean() - np.mean([major_acc.mean(), middle_acc.mean(), random_acc.mean()])
    print(f"\nMinor vs average of other constrained: {minor_vs_others:+.4f}")

    if abs(minor_vs_others) < 0.02:
        print("\n→ Minor is NOT special. All subspace constraints perform similarly.")
        print("→ SoMA's benefit likely comes from regularization, not subspace selection.")
    elif minor_vs_others > 0.02:
        print("\n→ Minor outperforms other subspaces. SoMA's hypothesis may have merit.")
    else:
        print("\n→ Minor underperforms other subspaces. SoMA's hypothesis is contradicted.")

    return df


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()

    # Analyze results
    df = analyze_results(results)

    # Save results
    df.to_csv(f"{PROJECT_ROOT}/experiment3_results.csv", index=False)
    print(f"\nResults saved to {PROJECT_ROOT}/experiment3_results.csv")

Using device: cuda
EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [42, 123, 456]
  Adapter layers: 6 layers

Running variant: MINOR

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.5500
    Best target acc:  0.5862
    Source accs: phot=0.990, art_=0.965, cart=0.956

  Seed 123...
    Trainable params: 142,215
    Final target acc: 0.5330
    Best target acc:  0.5948
    Source accs: phot=0.995, art_=0.965, cart=0.958

  Seed 456...
    Trainable params: 142,215
    Final target acc: 0.3691
    Best target acc:  0.4910
    Source accs: phot=0.993, art_=0.962, cart=0.954

Running variant: MAJOR

  Seed 42...


KeyboardInterrupt: 

In [None]:
# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""
    SEEDS = [42]

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()


EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [42]
  Adapter layers: 6 layers

Running variant: MAJOR

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.5668
    Best target acc:  0.5750
    Source accs: phot=0.991, art_=0.966, cart=0.971

Running variant: MIDDLE

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.6139
    Best target acc:  0.6139
    Source accs: phot=0.992, art_=0.970, cart=0.962

Running variant: RANDOM

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.6124
    Best target acc:  0.6124
    Source accs: phot=0.990, art_=0.967, cart=0.965

Running variant: FULL

  Seed 42...
    Trainable params: 11,180,103
    Final target acc: 0.6760
    Best target acc:  0.7597
    Source accs: phot=0.984, art_=0.968, cart=0.997


In [None]:
# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""
    SEEDS = [456]

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()


EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [456]
  Adapter layers: 6 layers

Running variant: MAJOR

  Seed 456...
    Trainable params: 142,215
    Final target acc: 0.5574
    Best target acc:  0.6363
    Source accs: phot=0.994, art_=0.974, cart=0.971

Running variant: MIDDLE

  Seed 456...
    Trainable params: 142,215
    Final target acc: 0.4248
    Best target acc:  0.5383
    Source accs: phot=0.994, art_=0.966, cart=0.960

Running variant: RANDOM

  Seed 456...
    Trainable params: 142,215
    Final target acc: 0.4782
    Best target acc:  0.5566
    Source accs: phot=0.989, art_=0.964, cart=0.962

Running variant: FULL

  Seed 456...
    Trainable params: 11,180,103
    Final target acc: 0.7333
    Best target acc:  0.7585
    Source accs: phot=0.982, art_=0.970, cart=0.982


In [None]:
# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""
    SEEDS = [123]

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()


EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [123]
  Adapter layers: 6 layers

Running variant: MAJOR

  Seed 123...
    Trainable params: 142,215
    Final target acc: 0.6266
    Best target acc:  0.6266
    Source accs: phot=0.993, art_=0.969, cart=0.968

Running variant: MIDDLE

  Seed 123...
    Trainable params: 142,215
    Final target acc: 0.5052
    Best target acc:  0.5169
    Source accs: phot=0.992, art_=0.969, cart=0.958

Running variant: RANDOM

  Seed 123...
