# Tier 0 — Baseline Training + Spectral Introspection (PACS, ResNet-50)

In [None]:
!fusermount -u /content/drive 2>/dev/null || true
!umount /content/drive 2>/dev/null || true


In [None]:
!rm -rf /content/drive
!mkdir -p /content/drive


In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [2]:
PROJECT_ROOT = "/content/drive/MyDrive/SoMA_PACS_R50"
PACS_ROOT = "/content/drive/MyDrive/DG_PACS/PACS"

In [None]:
!pip -q install datasets huggingface_hub pillow


In [None]:
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm


In [6]:
import shutil, os

PACS_ROOT = "/content/drive/MyDrive/DG_PACS/PACS"
assert "PACS" in PACS_ROOT
if not os.path.exists(PACS_ROOT):
    os.makedirs(PACS_ROOT, exist_ok=True)
    print("Created PACS folder.")
else:
    print("PACS folder exists. Skipping deletion.")



Created PACS folder.


In [7]:
from datasets import load_dataset
from tqdm.auto import tqdm
import os

ds = load_dataset("flwrlabs/pacs")

DOMAIN_MAP = {
    "photo": "photo",
    "art_painting": "art_painting",
    "cartoon": "cartoon",
    "sketch": "sketch",
}

def safe(s):
    return str(s).strip().lower().replace(" ", "_")

# Per-domain counters to ensure unique filenames
counters = {d: { } for d in DOMAIN_MAP.values()}

for ex in tqdm(ds["train"], desc="Exporting PACS"):
    img = ex["image"]

    domain = ex["domain"]
    label  = ex["label"]

    if not isinstance(domain, str):
        domain = ds["train"].features["domain"].int2str(domain)
    if not isinstance(label, str):
        label = ds["train"].features["label"].int2str(label)

    domain = DOMAIN_MAP[safe(domain)]
    label  = safe(label)

    # initialize counter
    counters[domain].setdefault(label, 0)
    idx = counters[domain][label]
    counters[domain][label] += 1

    out_dir = os.path.join(PACS_ROOT, domain, label)
    os.makedirs(out_dir, exist_ok=True)

    out_path = os.path.join(out_dir, f"{label}_{idx:05d}.jpg")
    img.save(out_path, quality=95)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

Exporting PACS:   0%|          | 0/9991 [00:00<?, ?it/s]

In [8]:
from torchvision import datasets
import os

for d in ["photo","art_painting","cartoon","sketch"]:
    ds = datasets.ImageFolder(os.path.join(PACS_ROOT, d))
    print(d, ds.class_to_idx)
    assert len(ds.class_to_idx) == 7, f"{d} does not have 7 classes!"



photo {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
art_painting {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
cartoon {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
sketch {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}


In [9]:
from torchvision import datasets
import os

def get_class_to_idx(domain):
    ds = datasets.ImageFolder(root=os.path.join(PACS_ROOT, domain))
    return ds.class_to_idx

mappings = {d: get_class_to_idx(d) for d in ["photo","art_painting","cartoon","sketch"]}
for d, m in mappings.items():
    print(d, m)

# Compare
base = mappings["photo"]
for d in mappings:
    print(d, "matches photo:", mappings[d] == base)


photo {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
art_painting {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
cartoon {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
sketch {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
photo matches photo: True
art_painting matches photo: True
cartoon matches photo: True
sketch matches photo: True


In [None]:
# ===== Reproducibility =====
import random, numpy as np, torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ===== Imports =====
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm


In [None]:
# ===== PACS paths (edit if needed) =====
# PACS_ROOT = "/content/drive/MyDrive/datasets/PACS"

SOURCE_DOMAINS = ["photo", "art_painting", "cartoon"]
TARGET_DOMAIN = "sketch"

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2


In [None]:
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])


In [None]:
def make_loader(domain, tf, shuffle):
    ds = datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=tf
    )
    return DataLoader(ds, batch_size=BATCH_SIZE,
                      shuffle=shuffle, num_workers=NUM_WORKERS)

train_loaders = [make_loader(d, train_tf, True) for d in SOURCE_DOMAINS]
test_loader = make_loader(TARGET_DOMAIN, test_tf, False)


In [None]:
class ResNet50_FeatureHook(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(2048, num_classes)

        self._features = {}

        def make_hook(name):
            def hook(module, inp, out):
                self._features[name] = out
            return hook

        for lname in ["layer1", "layer2", "layer3", "layer4"]:
            layer = getattr(self.backbone, lname)
            for i, block in enumerate(layer):
                block.bn2.register_forward_hook(
                    make_hook(f"{lname}.{i}.bn2")
                )

    def forward(self, x):
        self._features = {}
        return self.backbone(x)


# Training loop (Tier-0 backbone)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet50_FeatureHook(num_classes=7).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 234MB/s]


In [None]:
def train_epoch(loaders):
    model.train()
    total_loss, total_correct, total = 0, 0, 0

    for loader in loaders:
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            total_correct += (logits.argmax(1) == y).sum().item()
            total += x.size(0)

    return total_loss / total, total_correct / total


In [None]:
@torch.no_grad()
def eval_epoch(loader):
    model.eval()
    total_correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_correct / total


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
EPOCHS = 15
for ep in range(EPOCHS):
    tr_loss, tr_acc = train_epoch(train_loaders)
    te_acc = eval_epoch(test_loader)

    print(f"[{ep:02d}] loss={tr_loss:.3f} | src_acc={tr_acc:.3f} | tgt_acc={te_acc:.3f}")


[00] loss=0.512 | src_acc=0.840 | tgt_acc=0.559
[01] loss=0.346 | src_acc=0.890 | tgt_acc=0.593
[02] loss=0.224 | src_acc=0.926 | tgt_acc=0.628
[03] loss=0.232 | src_acc=0.921 | tgt_acc=0.680
[04] loss=0.181 | src_acc=0.941 | tgt_acc=0.667
[05] loss=0.107 | src_acc=0.965 | tgt_acc=0.670
[06] loss=0.116 | src_acc=0.963 | tgt_acc=0.739
[07] loss=0.110 | src_acc=0.964 | tgt_acc=0.693
[08] loss=0.159 | src_acc=0.950 | tgt_acc=0.694
[09] loss=0.071 | src_acc=0.977 | tgt_acc=0.658
[10] loss=0.079 | src_acc=0.976 | tgt_acc=0.594
[11] loss=0.059 | src_acc=0.983 | tgt_acc=0.599
[12] loss=0.066 | src_acc=0.980 | tgt_acc=0.723
[13] loss=0.058 | src_acc=0.980 | tgt_acc=0.695
[14] loss=0.081 | src_acc=0.972 | tgt_acc=0.685


Save trained backbone

In [None]:
BACKBONE_PATH = f"{PROJECT_ROOT}/resnet50_pacs_base.pt"

torch.save(model.state_dict(), BACKBONE_PATH)
print("Saved PACS-trained backbone.")


Saved PACS-trained backbone.


In [None]:
BACKBONE_PATH = f"{PROJECT_ROOT}/resnet50_pacs_base.pt"

# TIER 0 — SPECTRAL INTROSPECTION

In [None]:
import numpy.linalg as LA

def conv_svd(conv):
    # conv.weight: [C_out, C_in, k, k]
    W = conv.weight.detach().cpu().numpy()
    Cout = W.shape[0]
    Wmat = W.reshape(Cout, -1)
    U, S, Vt = LA.svd(Wmat, full_matrices=False)
    return U, S


In [None]:
spectra = {}

for lname in ["layer2", "layer3", "layer4"]:
    for i, block in enumerate(getattr(model.backbone, lname)):
        conv = block.conv2
        U, S = conv_svd(conv)
        spectra[f"{lname}.{i}"] = S


In [None]:
import pickle

with open(f"{PROJECT_ROOT}/tier0_spectra.pkl", "wb") as f:
    pickle.dump(spectra, f)


# Tier-0 Feature Statistics (Energy by Spectral Rank)

In [None]:
@torch.no_grad()
def collect_feature_energy(loader):
    model.eval()
    energy = {}

    for x, _ in loader:
        x = x.to(device)
        _ = model(x)

        for k, feat in model._features.items():
            # GAP → channel vector
            h = feat.mean(dim=[2,3])  # [B, C]
            e = (h**2).mean(0).cpu().numpy()
            energy.setdefault(k, []).append(e)

    for k in energy:
        energy[k] = np.mean(energy[k], axis=0)

    return energy


In [None]:
src_energy = collect_feature_energy(train_loaders[0])
tgt_energy = collect_feature_energy(test_loader)

with open(f"{PROJECT_ROOT}/tier0_feature_energy.pkl", "wb") as f:
    pickle.dump({"src": src_energy, "tgt": tgt_energy}, f)


Saved to Drive

- tier0_spectra.pkl → singular values per conv layer

- tier0_feature_energy.pkl → channel-wise activation energy (src vs tgt)

- trained baseline checkpoint (implicitly in model state)



Scientific baseline

- no spectral intervention

- no leakage

- fixed feature definition

- BN explicitly visible

# =============================================================================
# TIER 1 - SCENARIO 1: Probing the PRETRAINED ImageNet Model

This tests whether the pretrained model (before any PACS finetuning) already has spectral structure that separates domain from class information.


In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset

class ProjectedDataset(Dataset):
    """
    Stores (X_proj, y) where X_proj = X @ Usub.

    X    : (N, d) numpy array or torch tensor
    y    : (N,) labels (list/np/torch)
    Usub : (d, r) numpy array or torch tensor (orthonormal columns recommended)
    """
    def __init__(self, X, y, Usub):
        # X
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X)
        self.X = X.float().cpu()

        # y
        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y)
        elif isinstance(y, (list, tuple)):
            y = torch.tensor(y)
        self.y = y.long().cpu()

        # Usub
        if isinstance(Usub, np.ndarray):
            Usub = torch.from_numpy(Usub)
        Usub = Usub.float().cpu()

        # shape checks
        if self.X.ndim != 2:
            raise ValueError(f"X must be 2D (N,d). Got {self.X.shape}")
        if Usub.ndim != 2:
            raise ValueError(f"Usub must be 2D (d,r). Got {Usub.shape}")
        if self.X.shape[1] != Usub.shape[0]:
            raise ValueError(f"Dim mismatch: X has d={self.X.shape[1]} but Usub has d={Usub.shape[0]}")

        # Project: (N,d) @ (d,r) -> (N,r)
        self.X = self.X @ Usub

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [None]:
import torch
import torch.nn as nn

class LinearProbe(nn.Module):
    """
    Minimal linear classifier for probing.
    Input: feature_dim
    Output: num_classes logits
    """
    def __init__(self, feature_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


In [None]:
import numpy as np
import numpy.linalg as LA

def get_conv_svd(conv):
    """
    Returns (U, S) for conv.weight reshaped to [C_out, C_in*k*k].
    Matches the exact SVD convention used elsewhere in your notebook.
    """
    W = conv.weight.detach().cpu().numpy()     # [C_out, C_in, k, k]
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)                # [C_out, C_in*k*k]
    U, S, _ = LA.svd(Wmat, full_matrices=False)
    return U, S

# (optional) keep backward compatibility with earlier cells
conv_svd = get_conv_svd


In [None]:
# -----------------------------------------------------------------------------
# Cell 0: Re-define the Model Class with ALL Layer Hooks (Run this first!)
# -----------------------------------------------------------------------------
# This ensures hooks are registered for ALL layers including layer1

class ResNet50_FeatureHook(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(2048, num_classes)  # 512 → 2048

        self._features = {}

        def make_hook(name):
            def hook(module, inp, out):
                self._features[name] = out
            return hook

        # Register hooks for ALL layers (layer1, layer2, layer3, layer4)
        for lname in ["layer1", "layer2", "layer3", "layer4"]:
            layer = getattr(self.backbone, lname)
            for i, block in enumerate(layer):
                block.bn2.register_forward_hook(
                    make_hook(f"{lname}.{i}.bn2")
                )

    def forward(self, x):
        self._features = {}  # Clear before each forward pass
        return self.backbone(x)

print("ResNet50_FeatureHook class defined with hooks for layer1, layer2, layer3, layer4")

ResNet50_FeatureHook class defined with hooks for layer1, layer2, layer3, layer4


In [None]:
# -----------------------------------------------------------------------------
# Cell 1: Load Fresh Pretrained Model and Verify Hooks
# -----------------------------------------------------------------------------

print("Loading fresh pretrained ResNet-50 (ImageNet weights only)...")

pretrained_model = ResNet50_FeatureHook(num_classes=7).to(device)

# Reset fc layer (doesn't affect feature extraction)
nn.init.xavier_uniform_(pretrained_model.backbone.fc.weight)
nn.init.zeros_(pretrained_model.backbone.fc.bias)

pretrained_model.eval()

# VERIFY: Run a test forward pass and check that hooks are working
print("\nVerifying hooks are registered correctly...")
test_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    _ = pretrained_model(test_input)

print(f"Keys in _features after forward pass: {sorted(pretrained_model._features.keys())}")

# Check we have all expected keys
# ResNet-50 block counts: [3, 4, 6, 3]
expected_keys = []
for l, count in [("layer1", 3), ("layer2", 4), ("layer3", 6), ("layer4", 3)]:
    for i in range(count):
        expected_keys.append(f"{l}.{i}.bn2")

missing_keys = [k for k in expected_keys if k not in pretrained_model._features]

if missing_keys:
    print(f"\n⚠️  WARNING: Missing hooks for: {missing_keys}")
    print("Please re-run the cell that defines ResNet50_FeatureHook class!")
else:
    print("\n✓ All hooks registered correctly!")
    print("Pretrained model ready. This model has NEVER seen PACS data.")

Loading fresh pretrained ResNet-50 (ImageNet weights only)...

Verifying hooks are registered correctly...
Keys in _features after forward pass: ['layer1.0.bn2', 'layer1.1.bn2', 'layer1.2.bn2', 'layer2.0.bn2', 'layer2.1.bn2', 'layer2.2.bn2', 'layer2.3.bn2', 'layer3.0.bn2', 'layer3.1.bn2', 'layer3.2.bn2', 'layer3.3.bn2', 'layer3.4.bn2', 'layer3.5.bn2', 'layer4.0.bn2', 'layer4.1.bn2', 'layer4.2.bn2']

✓ All hooks registered correctly!
Pretrained model ready. This model has NEVER seen PACS data.


In [None]:
# -----------------------------------------------------------------------------
# Cell 2: Feature Extraction with Per-Domain Tracking (with error handling)
# -----------------------------------------------------------------------------

@torch.no_grad()
def extract_gap_features_v2(model, domains, layer_name):
    """
    Extract globally-averaged-pooled features with domain tracking.

    Args:
        model: The model to extract features from
        domains: List of domain names to process
        layer_name: Which layer to extract from (e.g., "layer4.1")

    Returns:
        X: (N, C) feature matrix
        y_class: (N,) class labels
        y_domain: (N,) domain labels (0-indexed)
        domain_names: List mapping domain index to name
    """
    DOMAIN_TO_ID = {
        "photo": 0,
        "art_painting": 1,
        "cartoon": 2,
        "sketch": 3,
    }

    # The key format in _features
    feature_key = f"{layer_name}.bn2"

    X, y_class, y_domain = [], [], []
    model.eval()

    for domain in domains:
        loader = make_loader(domain, test_tf, shuffle=False)
        d_id = DOMAIN_TO_ID[domain]

        for x, y in loader:
            x = x.to(device)
            _ = model(x)

            # Check if the key exists
            if feature_key not in model._features:
                available_keys = list(model._features.keys())
                raise KeyError(
                    f"Key '{feature_key}' not found in model._features.\n"
                    f"Available keys: {sorted(available_keys)}\n"
                    f"Please re-run the cell that defines ResNet50_FeatureHook class."
                )

            feat = model._features[feature_key]
            h = feat.mean(dim=[2, 3]).cpu().numpy()  # Global Average Pooling

            X.append(h)
            y_class.append(y.numpy())
            y_domain.append(np.full(h.shape[0], d_id))

    return (
        np.concatenate(X),
        np.concatenate(y_class),
        np.concatenate(y_domain),
        domains,
    )

In [None]:
# -----------------------------------------------------------------------------
# Cell 3: Probe Training with Per-Domain Accuracy Breakdown
# -----------------------------------------------------------------------------
# This function trains a linear probe and computes accuracy both overall
# and broken down by domain.

def train_probe_with_breakdown(dataset, num_classes, y_domain_full, domain_names,
                                seed=0, epochs=50, train_frac=0.7):
    """
    Train a linear probe and return both overall and per-domain accuracy.

    The per-domain breakdown shows how well the probe classifies samples
    FROM each domain. For a domain probe, this tells us which domains are
    easiest/hardest to identify.

    Args:
        dataset: ProjectedDataset with X and y
        num_classes: Number of output classes
        y_domain_full: Full array of domain labels (before train/val split)
        domain_names: List of domain name strings
        seed: Random seed
        epochs: Training epochs
        train_frac: Fraction for training (rest is validation)

    Returns:
        overall_acc: Float, overall validation accuracy
        per_domain_acc: Dict mapping domain name -> accuracy for that domain's samples
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    n = len(dataset)
    n_train = int(train_frac * n)

    # Create shuffled indices
    indices = np.random.permutation(n)
    train_indices = indices[:n_train]
    val_indices = indices[n_train:]

    train_ds = torch.utils.data.Subset(dataset, train_indices.tolist())
    val_ds = torch.utils.data.Subset(dataset, val_indices.tolist())

    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)

    # Initialize and train probe
    probe = LinearProbe(dataset.X.shape[1], num_classes).to(device)
    opt = torch.optim.Adam(probe.parameters(), lr=1e-2)
    loss_fn = nn.CrossEntropyLoss()

    probe.train()
    for _ in range(epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss_fn(probe(x), y).backward()
            opt.step()

    # Evaluate on validation set
    probe.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            pred = probe(x).argmax(1).cpu().numpy()
            all_preds.append(pred)
            all_labels.append(y.numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Get domain labels for validation samples
    val_domain_labels = y_domain_full[val_indices]

    # Overall accuracy
    overall_acc = (all_preds == all_labels).mean()

    # Per-domain accuracy
    # For each domain, how accurate is the probe on samples FROM that domain?
    per_domain_acc = {}
    for d_idx, d_name in enumerate(domain_names):
        mask = val_domain_labels == d_idx
        if mask.sum() > 0:
            per_domain_acc[d_name] = (all_preds[mask] == all_labels[mask]).mean()
        else:
            per_domain_acc[d_name] = np.nan

    return overall_acc, per_domain_acc

In [None]:
import numpy as np

def get_subspace(U, kind: str, r: int, seed: int = 0):
    """
    Build an orthonormal basis Usub (d x r) from an SVD basis U (d x k).

    kind:
      - "major"  : top singular directions (first r columns)
      - "minor"  : bottom singular directions (last r columns)
      - "middle" : centered chunk of r columns
      - "random" : random r-dim subspace (orthonormal), *within span(U)*

    Returns:
      Usub: np.ndarray of shape (d, r) with orthonormal columns
    """
    if U is None:
        raise ValueError("U is None")

    U = np.asarray(U)
    if U.ndim != 2:
        raise ValueError(f"U must be 2D, got shape {U.shape}")

    d, k = U.shape
    if r <= 0:
        raise ValueError(f"r must be > 0, got {r}")
    if r > k:
        raise ValueError(f"r={r} exceeds available basis columns k={k} (U shape {U.shape})")

    kind = kind.lower().strip()

    if kind in ["major", "top"]:
        return U[:, :r]

    if kind in ["minor", "bottom"]:
        return U[:, k - r : k]

    if kind in ["middle", "mid"]:
        start = (k - r) // 2
        return U[:, start : start + r]

    if kind in ["random", "rand"]:
        # Random r-dim subspace inside span(U): U @ Q where Q is k x r with orthonormal cols
        rng = np.random.RandomState(seed)
        A = rng.randn(k, r)
        Q, _ = np.linalg.qr(A)          # k x r, orthonormal columns
        return U @ Q                    # d x r, still orthonormal (since U has orthonormal cols)

    raise ValueError(f"Unknown kind='{kind}'. Expected one of: major, minor, middle, random.")


In [None]:
# -----------------------------------------------------------------------------
# Cell 4: Main Probing Function with Full Reporting
# -----------------------------------------------------------------------------
# This runs the complete probe analysis for a given model, reporting both
# combined accuracy and per-domain breakdown.

def run_full_probe_analysis(model, layer_name, conv, domains, ranks, seeds,
                            probe_type="domain", exclude_photo=False):
    """
    Run complete probe analysis with per-domain breakdown.

    Args:
        model: Model to analyze
        layer_name: Layer name (e.g., "layer4.1")
        conv: The conv layer to get SVD from
        domains: List of domains to include
        ranks: List of ranks to test
        seeds: List of random seeds for averaging
        probe_type: "domain" or "class"
        exclude_photo: If True, exclude photo from analysis

    Returns:
        results: Nested dict with all results
    """
    # Optionally exclude photo
    if exclude_photo:
        domains = [d for d in domains if d != "photo"]

    num_classes = len(domains) if probe_type == "domain" else 7

    # Get SVD of conv layer
    U, S = get_conv_svd(conv)

    # Extract features
    X, y_class, y_domain, domain_names = extract_gap_features_v2(model, domains, layer_name)

    # Choose labels based on probe type
    y_labels = y_domain if probe_type == "domain" else y_class

    results = {}

    for r in ranks:
        results[r] = {}

        for kind in ["major", "minor", "random"]:
            overall_accs = []
            per_domain_accs = {d: [] for d in domains}

            for seed in seeds:
                # Get subspace basis
                Usub = get_subspace(U, kind, r, seed)

                # Create projected dataset
                ds = ProjectedDataset(X, y_labels, Usub)

                # Train probe with breakdown
                overall_acc, per_domain = train_probe_with_breakdown(
                    ds, num_classes=num_classes,
                    y_domain_full=y_domain,
                    domain_names=domains,
                    seed=seed
                )

                overall_accs.append(overall_acc)
                for d in domains:
                    if d in per_domain and not np.isnan(per_domain[d]):
                        per_domain_accs[d].append(per_domain[d])

            # Store results
            results[r][kind] = {
                'overall_mean': np.mean(overall_accs),
                'overall_std': np.std(overall_accs),
                'per_domain': {
                    d: {
                        'mean': np.mean(per_domain_accs[d]) if per_domain_accs[d] else np.nan,
                        'std': np.std(per_domain_accs[d]) if per_domain_accs[d] else np.nan
                    }
                    for d in domains
                }
            }

    return results


def print_probe_results(results, layer_name, probe_type, domains):
    """Pretty-print probe results with per-domain breakdown."""

    print(f"\n{'='*70}")
    print(f"{probe_type.upper()} PROBE @ {layer_name}")
    print(f"{'='*70}")

    for r, r_results in results.items():
        print(f"\n--- Rank {r} ---")
        print(f"{'Subspace':<10} {'Overall':<18} | Per-Domain Accuracy")
        print("-" * 70)

        for kind in ["major", "minor", "random"]:
            data = r_results[kind]
            overall_str = f"{data['overall_mean']:.3f} ± {data['overall_std']:.3f}"

            # Per-domain string
            per_domain_parts = []
            for d in domains:
                d_short = d[:4]  # Abbreviate domain names
                d_acc = data['per_domain'][d]['mean']
                if not np.isnan(d_acc):
                    per_domain_parts.append(f"{d_short}:{d_acc:.2f}")
            per_domain_str = ", ".join(per_domain_parts)

            print(f"{kind:<10} {overall_str:<18} | {per_domain_str}")

        # Print gap analysis
        major_overall = r_results['major']['overall_mean']
        minor_overall = r_results['minor']['overall_mean']
        gap = minor_overall - major_overall
        gap_sign = "+" if gap > 0 else ""
        print(f"\n  Gap (minor - major): {gap_sign}{gap:.3f}")
        if probe_type == "domain":
            if gap > 0.02:
                print("  → Minor has MORE domain info (supports SoMA hypothesis)")
            elif gap < -0.02:
                print("  → Major has MORE domain info (contradicts SoMA hypothesis)")
            else:
                print("  → No clear difference")

In [None]:
# -----------------------------------------------------------------------------
# DIAGNOSTIC: Check what's in pretrained_model._features
# -----------------------------------------------------------------------------

print("Testing pretrained_model hooks...")
pretrained_model.eval()

# Run a single forward pass
test_batch = torch.randn(2, 3, 224, 224).to(device)
with torch.no_grad():
    _ = pretrained_model(test_batch)

print(f"\nNumber of keys: {len(pretrained_model._features)}")
print(f"Keys: {sorted(pretrained_model._features.keys())}")

# Check shapes
for key, feat in pretrained_model._features.items():
    print(f"  {key}: {feat.shape}")
# ```

# **Expected output:**
# ```
# Testing pretrained_model hooks...

# Number of keys: 8
# Keys: ['layer1.0.bn2', 'layer1.1.bn2', 'layer2.0.bn2', 'layer2.1.bn2', 'layer3.0.bn2', 'layer3.1.bn2', 'layer4.0.bn2', 'layer4.1.bn2']
#   layer1.0.bn2: torch.Size([2, 64, 56, 56])
#   layer1.1.bn2: torch.Size([2, 64, 56, 56])
#   layer2.0.bn2: torch.Size([2, 128, 28, 28])
  # ...

Testing pretrained_model hooks...

Number of keys: 16
Keys: ['layer1.0.bn2', 'layer1.1.bn2', 'layer1.2.bn2', 'layer2.0.bn2', 'layer2.1.bn2', 'layer2.2.bn2', 'layer2.3.bn2', 'layer3.0.bn2', 'layer3.1.bn2', 'layer3.2.bn2', 'layer3.3.bn2', 'layer3.4.bn2', 'layer3.5.bn2', 'layer4.0.bn2', 'layer4.1.bn2', 'layer4.2.bn2']
  layer1.0.bn2: torch.Size([2, 64, 56, 56])
  layer1.1.bn2: torch.Size([2, 64, 56, 56])
  layer1.2.bn2: torch.Size([2, 64, 56, 56])
  layer2.0.bn2: torch.Size([2, 128, 28, 28])
  layer2.1.bn2: torch.Size([2, 128, 28, 28])
  layer2.2.bn2: torch.Size([2, 128, 28, 28])
  layer2.3.bn2: torch.Size([2, 128, 28, 28])
  layer3.0.bn2: torch.Size([2, 256, 14, 14])
  layer3.1.bn2: torch.Size([2, 256, 14, 14])
  layer3.2.bn2: torch.Size([2, 256, 14, 14])
  layer3.3.bn2: torch.Size([2, 256, 14, 14])
  layer3.4.bn2: torch.Size([2, 256, 14, 14])
  layer3.5.bn2: torch.Size([2, 256, 14, 14])
  layer4.0.bn2: torch.Size([2, 512, 7, 7])
  layer4.1.bn2: torch.Size([2, 512, 7, 7])
  layer4.2.bn2:

In [None]:
# -----------------------------------------------------------------------------
# Cell 5: Run Scenario 1 - Domain Probe on PRETRAINED Model (All 4 Domains)
# -----------------------------------------------------------------------------
# This probes the pretrained model on ALL conv blocks to see if domain info
# is naturally concentrated in the minor subspace BEFORE any finetuning.

print("=" * 70)
print("SCENARIO 1: Probing PRETRAINED ImageNet Model")
print("Testing if pretrained spectral structure already separates domain info")
print("=" * 70)

# Configuration
domains_all = ["photo", "art_painting", "cartoon", "sketch"]
ranks = [4, 8, 16]
seeds = [0, 1, 2]

# Get conv blocks from pretrained model (ALL layers)
# ResNet-50 has [3, 4, 6, 3] blocks per layer
pretrained_conv_blocks = {}
block_counts = {"layer1": 3, "layer2": 4, "layer3": 6, "layer4": 3}
for lname, count in block_counts.items():
    layer = getattr(pretrained_model.backbone, lname)
    for i in range(count):
        key = f"{lname}.{i}"
        pretrained_conv_blocks[key] = layer[i].conv2

# Store all results
scenario1_domain_results = {}

print("\n===== DOMAIN PROBES (ALL CONV BLOCKS) - PRETRAINED MODEL =====")

# Loop over ALL conv blocks
for layer_name, conv in pretrained_conv_blocks.items():

    results = run_full_probe_analysis(
        model=pretrained_model,
        layer_name=layer_name,
        conv=conv,
        domains=domains_all,
        ranks=ranks,
        seeds=seeds,
        probe_type="domain",
        exclude_photo=False
    )

    scenario1_domain_results[layer_name] = results
    print_probe_results(results, layer_name, "domain", domains_all)

SCENARIO 1: Probing PRETRAINED ImageNet Model
Testing if pretrained spectral structure already separates domain info

===== DOMAIN PROBES (ALL CONV BLOCKS) - PRETRAINED MODEL =====

DOMAIN PROBE @ layer1.0

--- Rank 4 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.620 ± 0.002      | phot:0.36, art_:0.59, cart:0.19, sket:1.00
minor      0.396 ± 0.008      | phot:0.00, art_:0.00, cart:0.00, sket:1.00
random     0.700 ± 0.017      | phot:0.24, art_:0.76, cart:0.46, sket:1.00

  Gap (minor - major): -0.224
  → Major has MORE domain info (contradicts SoMA hypothesis)

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.770 ± 0.001      | phot:0.49, art_:0.66, cart:0.69, sket:1.00
minor      0.592 ± 0.009      | phot:0.00, art_:0.91, cart:0.04, sket:1.00
random     0.742 ± 0.027      | phot:0.32, art_:0

In [None]:
import os, json, pickle, gzip
from datetime import datetime
import torch, numpy as np

# --- 1) Mount Drive (Colab) ---
try:
    from google.colab import drive
    drive.mount("/content/drive")
    DRIVE_ROOT = "/content/drive/MyDrive"
except Exception as e:
    # If you're not on Colab, set DRIVE_ROOT manually to your Drive mount path
    print("Not in Colab or drive mount failed:", e)
    DRIVE_ROOT = "/content/drive/MyDrive"  # change if needed

# --- 2) Create a timestamped output folder ---
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = os.path.join(DRIVE_ROOT, "DG_PACS", "Scenario1_PretrainedDomainProbe", ts)
os.makedirs(out_dir, exist_ok=True)
print("Saving to:", out_dir)

# --- 3) Pack metadata/config (so you know exactly what produced the results) ---
meta = {
    "scenario": "Scenario 1 - Domain Probe on PRETRAINED ImageNet Model",
    "timestamp": ts,
    "domains_all": globals().get("domains_all", None),
    "ranks": globals().get("ranks", None),
    "seeds": globals().get("seeds", None),
    "pretrained_conv_block_keys": list(globals().get("pretrained_conv_blocks", {}).keys()),
    "python": __import__("sys").version,
    "torch": torch.__version__,
    "numpy": np.__version__,
}
with open(os.path.join(out_dir, "meta.json"), "w") as f:
    json.dump(meta, f, indent=2)
print("✅ Saved meta.json")

# --- 4) Save the raw results dict (best: pickle, optionally gzipped) ---
results_obj = globals().get("scenario1_domain_results", None)
if results_obj is None:
    raise RuntimeError("scenario1_domain_results not found in globals(). Did Cell 5 finish successfully?")

pkl_path = os.path.join(out_dir, "scenario1_domain_results.pkl")
with open(pkl_path, "wb") as f:
    pickle.dump(results_obj, f, protocol=pickle.HIGHEST_PROTOCOL)
print("✅ Saved raw results:", pkl_path)

# Optional: gzip-compressed copy (often much smaller)
gz_path = os.path.join(out_dir, "scenario1_domain_results.pkl.gz")
with gzip.open(gz_path, "wb") as f:
    pickle.dump(results_obj, f, protocol=pickle.HIGHEST_PROTOCOL)
print("✅ Saved compressed raw results:", gz_path)

# Optional: torch save as well (sometimes convenient)
pt_path = os.path.join(out_dir, "scenario1_domain_results.pt")
torch.save(results_obj, pt_path)
print("✅ Saved torch results:", pt_path)

# --- 5) Save a human-readable dump (so you can inspect without loading pickle structures) ---
dump_path = os.path.join(out_dir, "scenario1_domain_results_readable.txt")
with open(dump_path, "w") as f:
    f.write("SCENARIO 1 RESULTS (READABLE DUMP)\n")
    f.write(json.dumps(meta, indent=2))
    f.write("\n\n")

    for layer_name, layer_res in results_obj.items():
        f.write("="*80 + "\n")
        f.write(f"LAYER: {layer_name}\n")
        f.write("-"*80 + "\n")
        # Best-effort readable print of whatever structure you return
        f.write(repr(layer_res))
        f.write("\n\n")

print("✅ Saved readable dump:", dump_path)

print("\nAll done. You can reload later with:")
print("  import pickle; obj = pickle.load(open('.../scenario1_domain_results.pkl','rb'))")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Saving to: /content/drive/MyDrive/DG_PACS/Scenario1_PretrainedDomainProbe/20251228_134614
✅ Saved meta.json
✅ Saved raw results: /content/drive/MyDrive/DG_PACS/Scenario1_PretrainedDomainProbe/20251228_134614/scenario1_domain_results.pkl
✅ Saved compressed raw results: /content/drive/MyDrive/DG_PACS/Scenario1_PretrainedDomainProbe/20251228_134614/scenario1_domain_results.pkl.gz
✅ Saved torch results: /content/drive/MyDrive/DG_PACS/Scenario1_PretrainedDomainProbe/20251228_134614/scenario1_domain_results.pt
✅ Saved readable dump: /content/drive/MyDrive/DG_PACS/Scenario1_PretrainedDomainProbe/20251228_134614/scenario1_domain_results_readable.txt

All done. You can reload later with:
  import pickle; obj = pickle.load(open('.../scenario1_domain_results.pkl','rb'))


In [None]:
X, y_class, y_domain, domain_names = extract_gap_features_v2(pretrained_model, domains_all, "layer1.0")
print("y_class unique:", np.unique(y_class)[:20], "count:", len(np.unique(y_class)))
print("y_class min/max:", y_class.min(), y_class.max())


y_class unique: [0 1 2 3 4 5 6] count: 7
y_class min/max: 0 6


In [None]:
# -----------------------------------------------------------------------------
# Cell 7: Run Scenario 1 - Class Probe on PRETRAINED Model (All Conv Blocks)
# -----------------------------------------------------------------------------

print("\n" + "=" * 70)
print("SCENARIO 1: Class Probe on PRETRAINED Model - ALL CONV BLOCKS")
print("=" * 70)

scenario1_class_results = {}

for layer_name, conv in pretrained_conv_blocks.items():

    results = run_full_probe_analysis(
        model=pretrained_model,
        layer_name=layer_name,
        conv=conv,
        domains=domains_all,
        ranks=ranks,
        seeds=seeds,
        probe_type="class",
        exclude_photo=False
    )

    scenario1_class_results[layer_name] = results
    print_probe_results(results, layer_name, "class", domains_all)


SCENARIO 1: Class Probe on PRETRAINED Model - ALL CONV BLOCKS

CLASS PROBE @ layer1.0

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.271 ± 0.006      | phot:0.31, art_:0.29, cart:0.26, sket:0.25
minor      0.229 ± 0.013      | phot:0.26, art_:0.24, cart:0.23, sket:0.21
random     0.279 ± 0.011      | phot:0.29, art_:0.31, cart:0.29, sket:0.26

  Gap (minor - major): -0.042

CLASS PROBE @ layer1.1

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.267 ± 0.004      | phot:0.27, art_:0.25, cart:0.25, sket:0.28
minor      0.291 ± 0.008      | phot:0.27, art_:0.25, cart:0.28, sket:0.33
random     0.279 ± 0.010      | phot:0.27, art_:0.25, cart:0.24, sket:0.32

  Gap (minor - major): +0.025

CLASS PROBE @ layer1.2

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
-------

# TIER 1 - SCENARIO 2: Probing the FINETUNED Model

This tests whether finetuning on PACS source domains changes the spectral structure.
We probe the model that has been trained on Photo, Art, and Cartoon to see if
domain/class information now concentrates differently in major vs minor subspaces.

In [None]:
# -----------------------------------------------------------------------------
# Cell 8: Load Finetuned Model for Scenario 2
# -----------------------------------------------------------------------------

print("=" * 70)
print("SCENARIO 2: Probing FINETUNED Model (trained on 3 source domains)")
print("Testing if finetuning changes spectral structure")
print("=" * 70)

# Load the finetuned model
print("\nLoading finetuned model...")
finetuned_model = ResNet50_FeatureHook(num_classes=7).to(device)
finetuned_model.load_state_dict(torch.load(BACKBONE_PATH, map_location=device))
finetuned_model.eval()

# Verify hooks work
print("Verifying hooks...")
test_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    _ = finetuned_model(test_input)

print(f"Feature keys: {len(finetuned_model._features)} layers hooked")

# Get conv blocks from FINETUNED model
finetuned_conv_blocks = {}
block_counts = {"layer1": 3, "layer2": 4, "layer3": 6, "layer4": 3}
for lname, count in block_counts.items():
    layer = getattr(finetuned_model.backbone, lname)
    for i in range(count):
        key = f"{lname}.{i}"
        finetuned_conv_blocks[key] = layer[i].conv2

print(f"Conv blocks to analyze: {len(finetuned_conv_blocks)}")
print("Finetuned model ready for probing.")

SCENARIO 2: Probing FINETUNED Model (trained on 3 source domains)
Testing if finetuning changes spectral structure

Loading finetuned model...
Verifying hooks...
Feature keys: 16 layers hooked
Conv blocks to analyze: 16
Finetuned model ready for probing.


In [None]:
# -----------------------------------------------------------------------------
# Cell 9: Run Scenario 2 - Domain Probe on FINETUNED Model (All Conv Blocks)
# -----------------------------------------------------------------------------

print("\n" + "=" * 70)
print("SCENARIO 2: Domain Probe on FINETUNED Model - ALL CONV BLOCKS")
print("=" * 70)

# Configuration (same as Scenario 1)
domains_all = ["photo", "art_painting", "cartoon", "sketch"]
ranks = [4, 8, 16]
seeds = [0, 1, 2]

scenario2_domain_results = {}

for layer_name, conv in finetuned_conv_blocks.items():

    results = run_full_probe_analysis(
        model=finetuned_model,
        layer_name=layer_name,
        conv=conv,
        domains=domains_all,
        ranks=ranks,
        seeds=seeds,
        probe_type="domain",
        exclude_photo=False
    )

    scenario2_domain_results[layer_name] = results
    print_probe_results(results, layer_name, "domain", domains_all)


SCENARIO 2: Domain Probe on FINETUNED Model - ALL CONV BLOCKS

DOMAIN PROBE @ layer1.0

--- Rank 4 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.653 ± 0.009      | phot:0.33, art_:0.62, cart:0.32, sket:1.00
minor      0.396 ± 0.008      | phot:0.00, art_:0.00, cart:0.00, sket:1.00
random     0.676 ± 0.017      | phot:0.19, art_:0.70, cart:0.45, sket:1.00

  Gap (minor - major): -0.257
  → Major has MORE domain info (contradicts SoMA hypothesis)

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.755 ± 0.005      | phot:0.39, art_:0.60, cart:0.74, sket:1.00
minor      0.515 ± 0.013      | phot:0.00, art_:0.58, cart:0.00, sket:1.00
random     0.702 ± 0.013      | phot:0.29, art_:0.64, cart:0.55, sket:1.00

  Gap (minor - major): -0.239
  → Major has MORE domain info (contradicts SoMA hypothesis)


In [None]:
# -----------------------------------------------------------------------------
# Cell 10: Run Scenario 2 - Class Probe on FINETUNED Model (All Conv Blocks)
# -----------------------------------------------------------------------------

print("\n" + "=" * 70)
print("SCENARIO 2: Class Probe on FINETUNED Model - ALL CONV BLOCKS")
print("=" * 70)

scenario2_class_results = {}

for layer_name, conv in finetuned_conv_blocks.items():

    results = run_full_probe_analysis(
        model=finetuned_model,
        layer_name=layer_name,
        conv=conv,
        domains=domains_all,
        ranks=ranks,
        seeds=seeds,
        probe_type="class",
        exclude_photo=False
    )

    scenario2_class_results[layer_name] = results
    print_probe_results(results, layer_name, "class", domains_all)


SCENARIO 2: Class Probe on FINETUNED Model - ALL CONV BLOCKS

CLASS PROBE @ layer1.0

--- Rank 4 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.276 ± 0.004      | phot:0.31, art_:0.31, cart:0.22, sket:0.28
minor      0.189 ± 0.014      | phot:0.14, art_:0.22, cart:0.17, sket:0.20
random     0.248 ± 0.009      | phot:0.25, art_:0.28, cart:0.24, sket:0.24

  Gap (minor - major): -0.086

--- Rank 8 ---
Subspace   Overall            | Per-Domain Accuracy
----------------------------------------------------------------------
major      0.292 ± 0.005      | phot:0.32, art_:0.31, cart:0.26, sket:0.29
minor      0.220 ± 0.011      | phot:0.22, art_:0.26, cart:0.21, sket:0.20
random     0.265 ± 0.019      | phot:0.30, art_:0.31, cart:0.27, sket:0.23

  Gap (minor - major): -0.072

--- Rank 16 ---
Subspace   Overall            | Per-Domain Accuracy
-------------------------------------------------------

In [None]:
# -----------------------------------------------------------------------------
# Cell 11: Compare Scenario 1 vs Scenario 2 Results
# -----------------------------------------------------------------------------

print("\n" + "=" * 70)
print("TIER 1 SUMMARY: Comparing PRETRAINED vs FINETUNED")
print("=" * 70)

print("\n--- DOMAIN PROBE GAPS (Minor - Major) at Rank 8 ---")
print(f"{'Layer':<12} {'Pretrained':<15} {'Finetuned':<15} {'Change':<15}")
print("-" * 57)

for layer_name in finetuned_conv_blocks.keys():
    if layer_name in scenario1_domain_results and layer_name in scenario2_domain_results:
        pre_results = scenario1_domain_results[layer_name]
        fine_results = scenario2_domain_results[layer_name]

        if 8 in pre_results and 8 in fine_results:
            pre_gap = pre_results[8]['minor']['overall_mean'] - pre_results[8]['major']['overall_mean']
            fine_gap = fine_results[8]['minor']['overall_mean'] - fine_results[8]['major']['overall_mean']
            change = fine_gap - pre_gap
            print(f"{layer_name:<12} {pre_gap:+.3f}          {fine_gap:+.3f}          {change:+.3f}")

print("\n--- CLASS PROBE GAPS (Major - Minor) at Rank 8 ---")
print(f"{'Layer':<12} {'Pretrained':<15} {'Finetuned':<15} {'Change':<15}")
print("-" * 57)

for layer_name in finetuned_conv_blocks.keys():
    if layer_name in scenario1_class_results and layer_name in scenario2_class_results:
        pre_results = scenario1_class_results[layer_name]
        fine_results = scenario2_class_results[layer_name]

        if 8 in pre_results and 8 in fine_results:
            pre_gap = pre_results[8]['major']['overall_mean'] - pre_results[8]['minor']['overall_mean']
            fine_gap = fine_results[8]['major']['overall_mean'] - fine_results[8]['minor']['overall_mean']
            change = fine_gap - pre_gap
            print(f"{layer_name:<12} {pre_gap:+.3f}          {fine_gap:+.3f}          {change:+.3f}")

print("\n" + "=" * 70)
print("INTERPRETATION")
print("=" * 70)
print("""
If domain probe gaps remain NEGATIVE after finetuning:
  → Major subspace STILL contains more domain info
  → Finetuning did NOT shift domain info to minor subspace
  → SoMA's hypothesis is contradicted

If class probe gaps remain POSITIVE after finetuning:
  → Major subspace STILL contains more class info
  → Both domain AND class info concentrate in major
  → Spectral separation is FALSE
""")


TIER 1 SUMMARY: Comparing PRETRAINED vs FINETUNED

--- DOMAIN PROBE GAPS (Minor - Major) at Rank 8 ---
Layer        Pretrained      Finetuned       Change         
---------------------------------------------------------
layer1.0     -0.179          -0.239          -0.061
layer1.1     -0.031          -0.054          -0.023
layer1.2     +0.076          -0.024          -0.100
layer2.0     -0.071          -0.080          -0.009
layer2.1     +0.077          -0.123          -0.199
layer2.2     -0.023          -0.090          -0.067
layer2.3     +0.132          -0.113          -0.245
layer3.0     -0.322          -0.309          +0.013
layer3.1     -0.052          -0.110          -0.058
layer3.2     -0.052          -0.241          -0.189
layer3.3     -0.173          -0.146          +0.027
layer3.4     -0.156          -0.167          -0.010
layer3.5     -0.134          -0.183          -0.048
layer4.0     -0.120          -0.104          +0.016
layer4.1     -0.139          -0.164          -0.0

In [None]:
# -----------------------------------------------------------------------------
# Cell 12: Save All Tier 1 Results
# -----------------------------------------------------------------------------

import pickle

tier1_all_results = {
    'scenario1': {
        'domain': scenario1_domain_results,
        'class': scenario1_class_results,
    },
    'scenario2': {
        'domain': scenario2_domain_results,
        'class': scenario2_class_results,
    },
    'metadata': {
        'model': 'ResNet-50',
        'dataset': 'PACS',
        'domains': domains_all,
        'ranks': ranks,
        'seeds': seeds,
    }
}

tier1_save_path = f"{PROJECT_ROOT}/tier1_complete_results.pkl"
with open(tier1_save_path, 'wb') as f:
    pickle.dump(tier1_all_results, f)

print(f"Tier 1 results saved to: {tier1_save_path}")

Tier 1 results saved to: /content/drive/MyDrive/SoMA_PACS_R50/tier1_complete_results.pkl


# Tier-2: ΔW-Based Domain-Sensitive Subspace Discovery
Working with pretrained base, and just finetuning that in one direction only

In [None]:
import copy
import numpy as np
import torch
import torch.nn as nn
from torchvision import models

device = "cuda" if torch.cuda.is_available() else "cpu"

# =============================================================================
# Load PRETRAINED model (ImageNet only, never seen PACS)
# =============================================================================

def load_pretrained_model():
    """Load fresh ImageNet-pretrained ResNet-50 (NOT finetuned on PACS)"""
    model = ResNet50_FeatureHook(num_classes=7).to(device)
    # The ResNet50_FeatureHook already loads ImageNet weights in __init__
    # We just need to ensure fc is properly initialized (it doesn't matter for ΔW)
    nn.init.xavier_uniform_(model.backbone.fc.weight)
    nn.init.zeros_(model.backbone.fc.bias)
    return model


def load_finetuned_model():
    """Load the model finetuned on 3 source domains"""
    model = ResNet50_FeatureHook(num_classes=7).to(device)
    state = torch.load(BACKBONE_PATH, map_location=device)
    model.load_state_dict(state)
    return model


In [None]:
# =============================================================================
# Train single domain starting from PRETRAINED (not finetuned)
# =============================================================================

def freeze_bn_stats(model):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()                  # freeze running mean/var
            m.weight.requires_grad = True
            m.bias.requires_grad = True

from collections import defaultdict
def make_balanced_loader(domain, tf, batch_size):
    ds = datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=tf
    )

    # group indices by class
    class_to_indices = defaultdict(list)
    for idx, (_, y) in enumerate(ds.samples):
        class_to_indices[y].append(idx)

    # sample equal counts per class
    min_count = min(len(v) for v in class_to_indices.values())
    balanced_indices = []
    for v in class_to_indices.values():
        balanced_indices.extend(v[:min_count])

    sampler = torch.utils.data.SubsetRandomSampler(balanced_indices)

    return DataLoader(
        ds,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=NUM_WORKERS
    )

def train_single_domain_from_pretrained(domain, epochs=3, lr=1e-4):
    """
    Start from pretrained ImageNet model and train on single domain.
    This captures: "How does pretrained → domain d adaptation occur?"
    """
    model = load_pretrained_model()  # PRETRAINED, not finetuned!
    freeze_bn_stats(model)

    loader = make_balanced_loader(domain, train_tf, BATCH_SIZE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    return model


In [None]:
def build_deltaW_channel_subspace(dW_list):
    """
    dW_list: list of ΔW matrices for one layer
             each ΔW has shape [C_out, D]

    Returns:
        U_domain: [C_out, C_out] eigenvectors (descending order)
        S_domain: eigenvalues
    """
    C_out = dW_list[0].shape[0]
    C = np.zeros((C_out, C_out))

    for dW in dW_list:
        C += dW @ dW.T   # channel-channel covariance

    # Eigen-decomposition (symmetric PSD matrix)
    eigvals, eigvecs = np.linalg.eigh(C)

    # Sort descending
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]

    return eigvecs, eigvals


def compute_alignment_metrics(U, V):
    """
    U, V: [C_out, r] orthonormal basis matrices

    Returns:
        mean_angle_deg
        overlap (mean cos^2)
        projection_energy
    """
    # Cross-subspace matrix
    M = U.T @ V   # [r, r]

    # Singular values = cos(theta_i)
    sv = np.linalg.svd(M, compute_uv=False)
    sv = np.clip(sv, -1.0, 1.0)

    # Principal angles (degrees)
    angles = np.degrees(np.arccos(sv))
    mean_angle = angles.mean()

    # Overlap (mean squared cosine)
    overlap = np.mean(sv ** 2)

    # Projection energy (equivalent to overlap)
    proj_energy = np.linalg.norm(M, ord="fro") ** 2 / V.shape[1]

    return mean_angle, overlap, proj_energy


# =============================================================================
# Compute ΔW from pretrained model
# =============================================================================

def extract_delta_W_from_pretrained(pretrained_model, adapted_model, conv_path):
    """
    Compute ΔW = W_adapted - W_pretrained
    """
    base_conv = getattr(
        getattr(pretrained_model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    adapt_conv = getattr(
        getattr(adapted_model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    W0 = base_conv.weight.detach().cpu().numpy()
    W1 = adapt_conv.weight.detach().cpu().numpy()

    dW = W1 - W0
    C_out = dW.shape[0]
    return dW.reshape(C_out, -1)



In [None]:

# =============================================================================
# Run the corrected Tier 2 analysis
# =============================================================================

conv_layers = {
    "layer2.3": ("layer2", 3, "conv2"),  # Last block of layer2 (has 4 blocks: 0,1,2,3)
    "layer3.5": ("layer3", 5, "conv2"),  # Last block of layer3 (has 6 blocks: 0-5)
    "layer4.2": ("layer4", 2, "conv2"),  # Last block of layer4 (has 3 blocks: 0,1,2)
}

# Load pretrained model as the BASE for ΔW computation
pretrained_model = load_pretrained_model()

# Store ΔW for each domain
deltaWs_from_pretrained = {k: [] for k in conv_layers}

print("Computing ΔW from PRETRAINED model to each domain...")
for domain in SOURCE_DOMAINS:
    print(f"  Training on domain: {domain}")
    adapted = train_single_domain_from_pretrained(domain, epochs=3)

    for lname, path in conv_layers.items():
        dW = extract_delta_W_from_pretrained(pretrained_model, adapted, path)
        deltaWs_from_pretrained[lname].append(dW)

# Build domain-sensitive subspace from these ΔWs
delta_subspaces_pretrained = {}
for lname, dW_list in deltaWs_from_pretrained.items():
    U_domain, S_domain = build_deltaW_channel_subspace(dW_list)
    delta_subspaces_pretrained[lname] = (U_domain, S_domain)

print("Constructed ΔW subspaces from pretrained base.")


Computing ΔW from PRETRAINED model to each domain...
  Training on domain: photo
  Training on domain: art_painting
  Training on domain: cartoon
Constructed ΔW subspaces from pretrained base.


In [None]:
# =============================================================================
# Get SVD subspaces from PRETRAINED model (what SoMA would actually use)
# =============================================================================

def get_soma_subspace_from_model(model, layer_path, kind, r):
    """
    Get major/minor/random subspace from a specific model's weights.
    """
    conv = getattr(
        getattr(model.backbone, layer_path[0])[layer_path[1]],
        layer_path[2]
    )

    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)

    U, S, _ = np.linalg.svd(Wmat, full_matrices=False)

    if kind == "minor":
        return U[:, -r:]
    if kind == "major":
        return U[:, :r]
    if kind == "random":
        Q, _ = np.linalg.qr(np.random.randn(C_out, r))
        return Q

    raise ValueError(f"Unknown kind: {kind}")

In [None]:
# =============================================================================
# Compute alignment between ΔW subspace and SVD subspaces of PRETRAINED model
# =============================================================================

print("\n" + "="*70)
print("APPROACH A: ΔW from Pretrained, SVD from Pretrained")
print("="*70)

ranks = [4, 8, 16]
n_random = 10

alignment_results_pretrained = {}

for lname, (U_domain, S_domain) in delta_subspaces_pretrained.items():
    print(f"\n=== Alignment @ {lname} ===")
    alignment_results_pretrained[lname] = {}

    for r in ranks:
        print(f"\nRank {r}")
        V = U_domain[:, :r]  # Top-r ΔW directions

        alignment_results_pretrained[lname][r] = {}

        for kind in ["minor", "major", "random"]:
            angles, overlaps = [], []

            for seed in range(n_random):
                np.random.seed(seed)
                # SVD from PRETRAINED model (what SoMA actually uses)
                U = get_soma_subspace_from_model(pretrained_model, conv_layers[lname], kind, r)

                mean_angle, overlap, _ = compute_alignment_metrics(U, V)
                angles.append(mean_angle)
                overlaps.append(overlap)

            alignment_results_pretrained[lname][r][kind] = {
                "angle": (np.mean(angles), np.std(angles)),
                "overlap": (np.mean(overlaps), np.std(overlaps)),
            }

            print(f"  {kind:>6} | angle={np.mean(angles):.1f}° | overlap={np.mean(overlaps):.3f}")


APPROACH A: ΔW from Pretrained, SVD from Pretrained

=== Alignment @ layer2.3 ===

Rank 4
   minor | angle=85.1° | overlap=0.011
   major | angle=70.2° | overlap=0.187
  random | angle=81.6° | overlap=0.030

Rank 8
   minor | angle=83.6° | overlap=0.018
   major | angle=70.7° | overlap=0.171
  random | angle=77.4° | overlap=0.067

Rank 16
   minor | angle=80.6° | overlap=0.039
   major | angle=64.6° | overlap=0.240
  random | angle=72.4° | overlap=0.122

=== Alignment @ layer3.5 ===

Rank 4
   minor | angle=86.8° | overlap=0.004
   major | angle=69.1° | overlap=0.189
  random | angle=83.7° | overlap=0.018

Rank 8
   minor | angle=85.1° | overlap=0.010
   major | angle=67.6° | overlap=0.189
  random | angle=81.3° | overlap=0.033

Rank 16
   minor | angle=82.1° | overlap=0.027
   major | angle=63.8° | overlap=0.239
  random | angle=77.6° | overlap=0.063

=== Alignment @ layer4.2 ===

Rank 4
   minor | angle=88.1° | overlap=0.002
   major | angle=84.8° | overlap=0.013
  random | angle=86

In [None]:
# =============================================================================
# APPROACH B: ΔW from FINETUNED Base
# =============================================================================
# This tests whether domain adaptation from a finetuned model aligns differently.
# Key insight: If major alignment persists, it's due to gradient flow dynamics,
# not something specific to pretrained weights.

def train_single_domain_from_finetuned(domain, epochs=2, lr=1e-4):
    """
    Start from finetuned (3-domain) model and specialize on single domain.
    This captures: "How does multi-domain → single domain specialization occur?"

    Uses fewer epochs (2) since we're specializing, not training from scratch.
    """
    model = load_finetuned_model()  # Load 3-domain finetuned model
    freeze_bn_stats(model)

    loader = make_balanced_loader(domain, train_tf, BATCH_SIZE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    return model

print("Approach B training function defined.")

Approach B training function defined.


In [None]:
# =============================================================================
# Compute ΔW from FINETUNED model for each source domain
# =============================================================================

# Load finetuned model as the BASE for Approach B
finetuned_base = load_finetuned_model()

# Store ΔW for each domain
deltaWs_from_finetuned = {k: [] for k in conv_layers}

print("Computing ΔW from FINETUNED model to each domain...")
for domain in SOURCE_DOMAINS:
    print(f"  Specializing on domain: {domain}")
    specialized = train_single_domain_from_finetuned(domain, epochs=2)

    for lname, path in conv_layers.items():
        # Reuse the same extraction function - it just computes W1 - W0
        dW = extract_delta_W_from_pretrained(finetuned_base, specialized, path)
        deltaWs_from_finetuned[lname].append(dW)

# Build domain-sensitive subspace from these ΔWs
delta_subspaces_finetuned = {}
for lname, dW_list in deltaWs_from_finetuned.items():
    U_domain, S_domain = build_deltaW_channel_subspace(dW_list)
    delta_subspaces_finetuned[lname] = (U_domain, S_domain)

print("Constructed ΔW subspaces from finetuned base.")

Computing ΔW from FINETUNED model to each domain...
  Specializing on domain: photo
  Specializing on domain: art_painting
  Specializing on domain: cartoon
Constructed ΔW subspaces from finetuned base.


In [None]:
# =============================================================================
# Get SVD subspaces from FINETUNED model weights
# =============================================================================

def get_soma_subspace_from_finetuned(layer_path, kind, r):
    """
    Get major/minor/random subspace from FINETUNED model's weights.
    """
    model = load_finetuned_model()

    conv = getattr(
        getattr(model.backbone, layer_path[0])[layer_path[1]],
        layer_path[2]
    )

    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)

    U, S, _ = np.linalg.svd(Wmat, full_matrices=False)

    if kind == "minor":
        return U[:, -r:]
    if kind == "major":
        return U[:, :r]
    if kind == "random":
        np.random.seed(42)
        Q, _ = np.linalg.qr(np.random.randn(C_out, r))
        return Q

    raise ValueError(f"Unknown kind: {kind}")

print("Finetuned SVD extraction function defined.")

Finetuned SVD extraction function defined.


In [None]:
# =============================================================================
# Compute alignment between ΔW subspace and SVD subspaces of FINETUNED model
# =============================================================================

print("\n" + "="*70)
print("APPROACH B: ΔW from Finetuned, SVD from Finetuned")
print("="*70)

ranks = [4, 8, 16]
n_random = 10

alignment_results_finetuned = {}

for lname, (U_domain, S_domain) in delta_subspaces_finetuned.items():
    print(f"\n=== Alignment @ {lname} ===")
    alignment_results_finetuned[lname] = {}

    for r in ranks:
        print(f"\nRank {r}")
        alignment_results_finetuned[lname][r] = {}

        # Get top-r directions from ΔW subspace
        U_delta_r = U_domain[:, :r]

        for kind in ["minor", "major"]:
            U_svd = get_soma_subspace_from_finetuned(conv_layers[lname], kind, r)
            angle, overlap, _ = compute_alignment_metrics(U_delta_r, U_svd)

            alignment_results_finetuned[lname][r][kind] = {
                'angle': angle,
                'overlap': overlap
            }
            print(f"   {kind:>6} | angle={angle:.1f}° | overlap={overlap:.3f}")

        # Random baseline (average over multiple random bases)
        random_angles = []
        random_overlaps = []
        for seed in range(n_random):
            np.random.seed(seed)
            C_out = U_domain.shape[0]
            Q, _ = np.linalg.qr(np.random.randn(C_out, r))
            angle, overlap, _ = compute_alignment_metrics(U_delta_r, Q)
            random_angles.append(angle)
            random_overlaps.append(overlap)

        alignment_results_finetuned[lname][r]['random'] = {
            'angle': np.mean(random_angles),
            'overlap': np.mean(random_overlaps)
        }
        print(f"  random | angle={np.mean(random_angles):.1f}° | overlap={np.mean(random_overlaps):.3f}")


APPROACH B: ΔW from Finetuned, SVD from Finetuned

=== Alignment @ layer2.3 ===

Rank 4
    minor | angle=87.2° | overlap=0.003
    major | angle=64.8° | overlap=0.245
  random | angle=81.4° | overlap=0.031

Rank 8
    minor | angle=83.4° | overlap=0.020
    major | angle=62.4° | overlap=0.266
  random | angle=77.2° | overlap=0.069

Rank 16
    minor | angle=80.5° | overlap=0.040
    major | angle=59.1° | overlap=0.317
  random | angle=72.2° | overlap=0.126

=== Alignment @ layer3.5 ===

Rank 4
    minor | angle=87.5° | overlap=0.004
    major | angle=64.7° | overlap=0.242
  random | angle=84.4° | overlap=0.015

Rank 8
    minor | angle=86.5° | overlap=0.006
    major | angle=59.5° | overlap=0.297
  random | angle=81.6° | overlap=0.030

Rank 16
    minor | angle=84.3° | overlap=0.015
    major | angle=51.5° | overlap=0.412
  random | angle=77.5° | overlap=0.063

=== Alignment @ layer4.2 ===

Rank 4
    minor | angle=88.6° | overlap=0.001
    major | angle=46.4° | overlap=0.487
  rando

In [None]:
print(lname, alignment_results_pretrained[lname][8]['minor']['angle'])


layer2.3 (np.float64(83.56280060547576), np.float64(1.4210854715202004e-14))


In [None]:
# =============================================================================
# Compare Approach A vs Approach B Results
# =============================================================================

def angle_deg(x):
    """alignment_results[*][*][*]['angle'] is stored as (angle_deg, residual)."""
    if isinstance(x, (tuple, list)):
        return float(x[0])
    return float(x)

print("\n" + "="*70)
print("COMPARISON: Approach A (Pretrained Base) vs Approach B (Finetuned Base)")
print("="*70)

print("\n--- Alignment Angles at Rank 8 ---")
print(f"{'Layer':<12} {'Base':<12} {'Minor':<12} {'Major':<12} {'Random':<12}")
print("-" * 60)

for lname in conv_layers.keys():
    if lname in alignment_results_pretrained and lname in alignment_results_finetuned:
        # Approach A
        pre_minor  = angle_deg(alignment_results_pretrained[lname][8]['minor']['angle'])
        pre_major  = angle_deg(alignment_results_pretrained[lname][8]['major']['angle'])
        pre_random = angle_deg(alignment_results_pretrained[lname][8]['random']['angle'])

        # Approach B
        fine_minor  = angle_deg(alignment_results_finetuned[lname][8]['minor']['angle'])
        fine_major  = angle_deg(alignment_results_finetuned[lname][8]['major']['angle'])
        fine_random = angle_deg(alignment_results_finetuned[lname][8]['random']['angle'])

        print(f"{lname:<12} {'Pretrained':<12} {pre_minor:<12.1f} {pre_major:<12.1f} {pre_random:<12.1f}")
        print(f"{'':<12} {'Finetuned':<12} {fine_minor:<12.1f} {fine_major:<12.1f} {fine_random:<12.1f}")
        print()

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)
print("""
Key Question: Does the base model (pretrained vs finetuned) affect alignment?

If BOTH approaches show:
  - All subspaces ~85° (near orthogonal)
  - No major/minor differentiation

Then: The lack of alignment is FUNDAMENTAL, not an artifact of pretrained weights.
      Domain adaptation doesn't preferentially use any spectral subspace.

If Approach B shows DIFFERENT pattern:
  - Major shows better alignment than minor

Then: Finetuning creates structure that aligns with gradient flow.
      But this is an artifact of training history, not intrinsic to weights.
""")



COMPARISON: Approach A (Pretrained Base) vs Approach B (Finetuned Base)

--- Alignment Angles at Rank 8 ---
Layer        Base         Minor        Major        Random      
------------------------------------------------------------
layer2.3     Pretrained   83.6         70.7         77.4        
             Finetuned    83.4         62.4         77.2        

layer3.5     Pretrained   85.1         67.6         81.3        
             Finetuned    86.5         59.5         81.6        

layer4.2     Pretrained   85.6         82.7         84.2        
             Finetuned    88.1         36.8         83.7        


INTERPRETATION

Key Question: Does the base model (pretrained vs finetuned) affect alignment?

If BOTH approaches show:
  - All subspaces ~85° (near orthogonal)
  - No major/minor differentiation

Then: The lack of alignment is FUNDAMENTAL, not an artifact of pretrained weights.
      Domain adaptation doesn't preferentially use any spectral subspace.

If Approach B sh

In [None]:
# =============================================================================
# Save Tier 2 Complete Results
# =============================================================================
import pickle


tier2_all_results = {
    'approach_a': {
        'alignment': alignment_results_pretrained,
        'delta_subspaces': {k: (U.tolist(), S.tolist()) for k, (U, S) in delta_subspaces_pretrained.items()},
    },
    'approach_b': {
        'alignment': alignment_results_finetuned,
        'delta_subspaces': {k: (U.tolist(), S.tolist()) for k, (U, S) in delta_subspaces_finetuned.items()},
    },
    'metadata': {
        'model': 'ResNet-50',
        'dataset': 'PACS',
        'source_domains': SOURCE_DOMAINS,
        'conv_layers': list(conv_layers.keys()),
        'ranks': ranks,
    }
}

tier2_save_path = f"{PROJECT_ROOT}/tier2_complete_results.pkl"
with open(tier2_save_path, 'wb') as f:
    pickle.dump(tier2_all_results, f)

print(f"Tier 2 results saved to: {tier2_save_path}")

Tier 2 results saved to: /content/drive/MyDrive/SoMA_PACS_R50/tier2_complete_results.pkl


In [None]:
# =============================================================================
# VALIDATION: Does held-out domain (Sketch) align with ΔW subspace from 3 sources?
# =============================================================================

HELD_OUT_DOMAIN = "sketch"

print("\n" + "="*70)
print("VALIDATION: Does Sketch's ΔW align with subspace from Photo/Art/Cartoon?")
print("="*70)

# Train on Sketch starting from pretrained
print(f"\nTraining on held-out domain: {HELD_OUT_DOMAIN}")
adapted_sketch = train_single_domain_from_pretrained(HELD_OUT_DOMAIN, epochs=3)

# Compute ΔW for Sketch
deltaW_sketch = {}
for lname, path in conv_layers.items():
    dW = extract_delta_W_from_pretrained(pretrained_model, adapted_sketch, path)
    deltaW_sketch[lname] = dW

# For each layer, check if Sketch's ΔW aligns with the subspace derived from 3 domains
print("\nAlignment of Sketch's ΔW with 3-domain derived subspace:")

for lname in conv_layers.keys():
    print(f"\n=== {lname} ===")

    # Get the ΔW subspace from 3 source domains
    U_3domain, S_3domain = delta_subspaces_pretrained[lname]

    # Get Sketch's ΔW
    dW_sketch = deltaW_sketch[lname]

    # Compute how much of Sketch's ΔW variance is captured by 3-domain subspace
    for r in [4, 8, 16]:
        U_sub = U_3domain[:, :r]  # Top-r directions from 3 domains

        # Project Sketch's ΔW onto this subspace
        dW_projected = U_sub @ (U_sub.T @ dW_sketch)

        # Compute fraction of variance captured
        total_variance = np.linalg.norm(dW_sketch, 'fro')**2
        projected_variance = np.linalg.norm(dW_projected, 'fro')**2
        fraction_captured = projected_variance / total_variance

        print(f"  Rank {r}: {fraction_captured:.1%} of Sketch's ΔW variance captured")

# Also compute principal angles between Sketch's ΔW and 3-domain subspace
print("\nPrincipal angles between Sketch's ΔW direction and 3-domain subspace:")

for lname in conv_layers.keys():
    print(f"\n=== {lname} ===")

    U_3domain, _ = delta_subspaces_pretrained[lname]
    dW_sketch = deltaW_sketch[lname]

    # Get top direction of Sketch's ΔW
    U_sketch, S_sketch, _ = np.linalg.svd(dW_sketch, full_matrices=False)

    for r in [4, 8, 16]:
        V_3domain = U_3domain[:, :r]
        V_sketch = U_sketch[:, :r]

        mean_angle, overlap, _ = compute_alignment_metrics(V_sketch, V_3domain)
        print(f"  Rank {r}: angle={mean_angle:.1f}°, overlap={overlap:.3f}")


VALIDATION: Does Sketch's ΔW align with subspace from Photo/Art/Cartoon?

Training on held-out domain: sketch

Alignment of Sketch's ΔW with 3-domain derived subspace:

=== layer2.3 ===
  Rank 4: 9.5% of Sketch's ΔW variance captured
  Rank 8: 15.7% of Sketch's ΔW variance captured
  Rank 16: 25.6% of Sketch's ΔW variance captured

=== layer3.5 ===
  Rank 4: 4.8% of Sketch's ΔW variance captured
  Rank 8: 8.0% of Sketch's ΔW variance captured
  Rank 16: 13.2% of Sketch's ΔW variance captured

=== layer4.2 ===
  Rank 4: 1.0% of Sketch's ΔW variance captured
  Rank 8: 1.7% of Sketch's ΔW variance captured
  Rank 16: 3.5% of Sketch's ΔW variance captured

Principal angles between Sketch's ΔW direction and 3-domain subspace:

=== layer2.3 ===
  Rank 4: angle=63.8°, overlap=0.225
  Rank 8: angle=59.4°, overlap=0.308
  Rank 16: angle=55.4°, overlap=0.367

=== layer3.5 ===
  Rank 4: angle=69.7°, overlap=0.181
  Rank 8: angle=70.3°, overlap=0.171
  Rank 16: angle=64.2°, overlap=0.236

=== lay

In [None]:
# -----------------------------------------------------------------------------
# Get Major/Minor subspaces from pretrained model
# -----------------------------------------------------------------------------

def get_svd_subspaces(model, conv_path):
    """Get SVD of pretrained weights"""
    conv = getattr(
        getattr(model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)

    U, S, Vh = np.linalg.svd(Wmat, full_matrices=False)
    return U, S, Vh

pretrained_svd = {}
for lname, path in conv_layers.items():
    U, S, Vh = get_svd_subspaces(pretrained_model, path)
    pretrained_svd[lname] = {'U': U, 'S': S, 'Vh': Vh}
    print(f"  {lname}: U shape = {U.shape}, rank = {len(S)}")


def compute_alignment_metrics(U, V):
    """
    Compute alignment between two subspaces.
    U, V: [d, r] orthonormal bases
    Returns: mean_angle (degrees), overlap (mean cos²θ)
    """
    M = U.T @ V
    sv = np.linalg.svd(M, compute_uv=False)
    sv = np.clip(sv, -1.0, 1.0)

    angles = np.degrees(np.arccos(sv))
    mean_angle = angles.mean()
    overlap = np.mean(sv ** 2)

    return mean_angle, overlap

def get_deltaW_principal_directions(dW, r):
    """Get top-r principal directions of ΔW"""
    U_dw, S_dw, _ = np.linalg.svd(dW, full_matrices=False)
    return U_dw[:, :r]

print("\n" + "="*70)
print("SANITY CHECK: Sketch's ΔW Alignment with Pretrained SVD Subspaces")
print("="*70)

ranks = [4, 8, 16]

for lname in conv_layers.keys():
    print(f"\n=== {lname} ===")

    U_pretrained = pretrained_svd[lname]['U']
    dW_sketch = deltaW_sketch[lname]
    d = U_pretrained.shape[0]  # Output dimension

    for r in ranks:
        # Get Sketch's ΔW principal directions
        V_sketch = get_deltaW_principal_directions(dW_sketch, r)

        # Major: top-r of pretrained
        U_major = U_pretrained[:, :r]

        # Minor: bottom-r of pretrained
        U_minor = U_pretrained[:, -r:]

        # Middle: middle-r of pretrained
        mid_start = d // 2 - r // 2
        U_middle = U_pretrained[:, mid_start:mid_start + r]

        # Random
        np.random.seed(42)
        Q_random, _ = np.linalg.qr(np.random.randn(d, r))

        # Compute alignments
        angle_major, overlap_major = compute_alignment_metrics(U_major, V_sketch)
        angle_minor, overlap_minor = compute_alignment_metrics(U_minor, V_sketch)
        angle_middle, overlap_middle = compute_alignment_metrics(U_middle, V_sketch)
        angle_random, overlap_random = compute_alignment_metrics(Q_random, V_sketch)

        print(f"\n  Rank {r}:")
        print(f"    Major:  angle={angle_major:.1f}°, overlap={overlap_major:.4f}")
        print(f"    Minor:  angle={angle_minor:.1f}°, overlap={overlap_minor:.4f}")
        print(f"    Middle: angle={angle_middle:.1f}°, overlap={overlap_middle:.4f}")
        print(f"    Random: angle={angle_random:.1f}°, overlap={overlap_random:.4f}")

  layer2.3: U shape = (128, 128), rank = 128
  layer3.5: U shape = (256, 256), rank = 256
  layer4.2: U shape = (512, 512), rank = 512

SANITY CHECK: Sketch's ΔW Alignment with Pretrained SVD Subspaces

=== layer2.3 ===

  Rank 4:
    Major:  angle=77.6°, overlap=0.0656
    Minor:  angle=83.5°, overlap=0.0164
    Middle: angle=83.4°, overlap=0.0194
    Random: angle=81.0°, overlap=0.0327

  Rank 8:
    Major:  angle=71.2°, overlap=0.1448
    Minor:  angle=81.1°, overlap=0.0326
    Middle: angle=78.7°, overlap=0.0573
    Random: angle=78.3°, overlap=0.0591

  Rank 16:
    Major:  angle=65.7°, overlap=0.2154
    Minor:  angle=76.4°, overlap=0.0807
    Middle: angle=72.5°, overlap=0.1177
    Random: angle=71.9°, overlap=0.1256

=== layer3.5 ===

  Rank 4:
    Major:  angle=75.7°, overlap=0.1000
    Minor:  angle=83.8°, overlap=0.0166
    Middle: angle=85.0°, overlap=0.0123
    Random: angle=86.3°, overlap=0.0057

  Rank 8:
    Major:  angle=72.8°, overlap=0.1210
    Minor:  angle=83.3°, o

In [None]:
# =============================================================================
# EXPERIMENT 3: Subspace Comparison for Domain Generalization
# =============================================================================
#
# This experiment definitively tests whether the MINOR subspace is special,
# or whether ANY low-rank constraint provides similar benefits.
#
# Variants:
#   - SoMA-Minor:  Constrain updates to bottom-r singular vectors
#   - SoMA-Major:  Constrain updates to top-r singular vectors
#   - SoMA-Middle: Constrain updates to middle-r singular vectors
#   - SoMA-Random: Constrain updates to random orthonormal basis
#   - Full-FT:     No constraint (baseline)
#
# Protocol:
#   - Train on 3 source domains (Photo, Art, Cartoon)
#   - Evaluate on held-out target domain (Sketch)
#   - Use rank=8 for all constrained variants
#
# =============================================================================

import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms
from tqdm import tqdm
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------

RANK = 8  # Adapter rank for all constrained variants
EPOCHS = 10
LR = 1e-4
BATCH_SIZE = 32
NUM_WORKERS = 2

SOURCE_DOMAINS = ["photo", "art_painting", "cartoon"]
TARGET_DOMAIN = "sketch"

# Layers to apply adapters (following SoMA paper - typically mid-to-late layers)
# ADAPTER_LAYERS = [
#     ("layer2", 1, "conv2"),
#     ("layer2", 3, "conv2"),  # layer2 has 4 blocks (0-3)
#     ("layer3", 2, "conv2"),
#     ("layer3", 5, "conv2"),  # layer3 has 6 blocks (0-5)
#     ("layer4", 1, "conv2"),
#     ("layer4", 2, "conv2"),  # layer4 has 3 blocks (0-2)
# ]
# -----------------------------------------------------------------------------
# Budget-matched adapter placement (fair vs ResNet18)
# ResNet18 has 2 blocks per stage in layer2-4 => 6 blocks total.
# So for ResNet50 we adapt 2 blocks per stage in layer2-4 => 6 blocks total.
# We choose: (one "early-ish" block) + (last block) per stage.
# This rule yields: layer2 [1,3], layer3 [2,5], layer4 [1,2] for ResNet50.
# -----------------------------------------------------------------------------

def pick_two_blocks(n_blocks: int):
    return [n_blocks // 3, n_blocks - 1]  # deterministic, no tuning

def make_adapter_layers(model, stages=("layer2", "layer3", "layer4"), conv_name="conv2"):
    layers = []
    for stage_name in stages:
        stage = getattr(model.backbone, stage_name)
        n = len(stage)
        for idx in pick_two_blocks(n):
            layers.append((stage_name, idx, conv_name))
    return layers

# Build from the actual model so it works for ResNet50 without hardcoding
tmp = ResNet50_FeatureHook(num_classes=7)  # or whatever your backbone wrapper class is
ADAPTER_LAYERS = make_adapter_layers(tmp, stages=("layer2","layer3","layer4"), conv_name="conv2")

print("ADAPTER_LAYERS:", ADAPTER_LAYERS)


# Random seeds for reproducibility
SEEDS = [42]


# -----------------------------------------------------------------------------
# Data Loading
# -----------------------------------------------------------------------------

IMG_SIZE = 224

train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def get_domain_dataset(domain, transform):
    return datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=transform
    )


def get_source_loader(batch_size=BATCH_SIZE):
    """Combined loader for all source domains"""
    source_datasets = [get_domain_dataset(d, train_tf) for d in SOURCE_DOMAINS]
    combined = ConcatDataset(source_datasets)
    return DataLoader(combined, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)


def get_target_loader(batch_size=BATCH_SIZE):
    """Loader for target domain (evaluation only)"""
    target_dataset = get_domain_dataset(TARGET_DOMAIN, test_tf)
    return DataLoader(target_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS)


# -----------------------------------------------------------------------------
# Subspace-Constrained Adapter (LoRA-style)
# -----------------------------------------------------------------------------

class SubspaceAdapter(nn.Module):
    def __init__(self, conv_layer, subspace_basis, scale=1.0):
        super().__init__()

        self.conv = conv_layer
        self.scale = scale

        C_out, C_in, k1, k2 = conv_layer.weight.shape
        r = subspace_basis.shape[1]

        self.C_out = C_out
        self.C_in = C_in
        self.k = k1
        self.r = r
        self.stride = conv_layer.stride
        self.padding = conv_layer.padding

        # A is trainable (down-projection)
        self.A = nn.Parameter(torch.randn(r, C_in * k1 * k2) * 0.01)

        # B is FROZEN (enforces subspace constraint)
        self.register_buffer('B', torch.from_numpy(subspace_basis.astype(np.float32)))

    def forward(self, x):
        # Original convolution output
        out = self.conv(x)

        # Adapter path: unfold input, apply low-rank transform
        # This is equivalent to a rank-r convolution

        # Unfold input to patches: (N, C_in*k*k, H_out*W_out)
        x_unf = F.unfold(x, kernel_size=self.k, stride=self.stride, padding=self.padding)

        # Apply A: (N, r, H_out*W_out)
        adapter_out = self.A @ x_unf

        # Apply B: (N, C_out, H_out*W_out)
        adapter_out = self.B @ adapter_out

        # Reshape to match conv output: (N, C_out, H_out, W_out)
        adapter_out = adapter_out.view(out.shape)

        return out + self.scale * adapter_out


# -----------------------------------------------------------------------------
# Model with Subspace Adapters
# -----------------------------------------------------------------------------

class ResNet50WithAdapters(nn.Module):
    def __init__(self, num_classes, adapter_layers, subspace_type, rank, random_seed=42):
        super().__init__()
        from torchvision import models
        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(2048, num_classes)  # 512 → 2048

        self.subspace_type = subspace_type
        self.rank = rank
        self.random_seed = random_seed
        self.adapter_layers = adapter_layers
        self.adapters = nn.ModuleDict()

        # Store the random bases used (for reporting)
        self.random_bases = {}

        if subspace_type != 'none':
            self._add_adapters()

    def _get_conv_layer(self, layer_name, block_idx, conv_name):
        """Get reference to a specific conv layer"""
        layer = getattr(self.backbone, layer_name)
        block = layer[block_idx]
        return getattr(block, conv_name)

    def _set_conv_layer(self, layer_name, block_idx, conv_name, new_module):
        """Replace a specific conv layer with adapter"""
        layer = getattr(self.backbone, layer_name)
        block = layer[block_idx]
        setattr(block, conv_name, new_module)

    def _compute_subspace_basis(self, conv_layer, layer_key):
        """Compute the subspace basis for a given conv layer"""
        W = conv_layer.weight.detach().cpu().numpy()
        C_out = W.shape[0]
        W_mat = W.reshape(C_out, -1)

        # Compute SVD
        U, S, Vh = np.linalg.svd(W_mat, full_matrices=False)

        r = self.rank
        d = C_out

        if self.subspace_type == 'minor':
            # Bottom-r singular vectors
            basis = U[:, -r:]
        elif self.subspace_type == 'major':
            # Top-r singular vectors
            basis = U[:, :r]
        elif self.subspace_type == 'middle':
            # Middle-r singular vectors
            mid_start = d // 2 - r // 2
            basis = U[:, mid_start:mid_start + r]
        elif self.subspace_type == 'random':
            # Random orthonormal basis (consistent across runs with same seed)
            np.random.seed(self.random_seed + hash(layer_key) % 10000)
            random_matrix = np.random.randn(d, r)
            basis, _ = np.linalg.qr(random_matrix)
            # Store for reporting
            self.random_bases[layer_key] = basis.copy()
        else:
            raise ValueError(f"Unknown subspace type: {self.subspace_type}")

        return basis

    def _add_adapters(self):
        """Add adapters to specified layers"""
        for layer_name, block_idx, conv_name in self.adapter_layers:
            # layer_key = f"{layer_name}.{block_idx}.{conv_name}"
            layer_key = f"{layer_name}_{block_idx}_{conv_name}"

            # Get original conv layer
            conv = self._get_conv_layer(layer_name, block_idx, conv_name)

            # Compute subspace basis
            basis = self._compute_subspace_basis(conv, layer_key)

            # Create adapter
            adapter = SubspaceAdapter(conv, basis, scale=1.0)

            # Replace conv with adapter
            self._set_conv_layer(layer_name, block_idx, conv_name, adapter)

            # Store reference
            self.adapters[layer_key] = adapter

    def freeze_backbone(self):
      for param in self.backbone.parameters():
          param.requires_grad = False

      # Only unfreeze A (B is a buffer, not a parameter)
      for adapter in self.adapters.values():
          adapter.A.requires_grad = True

      # Unfreeze classifier
      for param in self.backbone.fc.parameters():
          param.requires_grad = True

      # BatchNorm affine params
      for m in self.backbone.modules():
          if isinstance(m, nn.BatchNorm2d):
              m.eval()
              m.weight.requires_grad = True
              m.bias.requires_grad = True

    def forward(self, x):
        return self.backbone(x)

    def get_trainable_params(self):
        """Return number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# -----------------------------------------------------------------------------
# Training and Evaluation Functions
# -----------------------------------------------------------------------------

def train_epoch(model, loader, optimizer, criterion):
    """Train for one epoch"""
    model.train()

    # Keep BatchNorm in eval mode
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

    total_loss = 0
    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader):
    """Evaluate on a dataset"""
    model.eval()

    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)

    return correct / total

from torchvision import models
# -----------------------------------------------------------------------------
# Full Finetuning Baseline
# -----------------------------------------------------------------------------

class ResNet50FullFinetune(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(2048, num_classes)  # 512 → 2048

    def forward(self, x):
        return self.backbone(x)

    def freeze_backbone(self):
        """For full finetuning, we don't freeze anything except BN stats"""
        for m in self.backbone.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = True
                m.bias.requires_grad = True

    def get_trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# -----------------------------------------------------------------------------
# Run Single Experiment
# -----------------------------------------------------------------------------

def run_single_experiment(subspace_type, seed):
    """
    Run a single experiment with given subspace type and seed.

    Args:
        subspace_type: 'minor', 'major', 'middle', 'random', or 'full'
        seed: Random seed

    Returns:
        dict with results
    """
    # Set seeds
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Create model
    if subspace_type == 'full':
        model = ResNet50FullFinetune(num_classes=7).to(device)
    else:
        model = ResNet50WithAdapters(
            num_classes=7,
            adapter_layers=ADAPTER_LAYERS,
            subspace_type=subspace_type,
            rank=RANK,
            random_seed=seed
        ).to(device)

    # Freeze backbone (except adapters and classifier)
    model.freeze_backbone()

    # Get trainable params count
    trainable_params = model.get_trainable_params()

    # Create data loaders
    source_loader = get_source_loader()
    target_loader = get_target_loader()

    # Optimizer and loss
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR
    )
    criterion = nn.CrossEntropyLoss()

    # Training loop
    best_target_acc = 0
    history = {'train_loss': [], 'train_acc': [], 'target_acc': []}

    for epoch in range(EPOCHS):
        train_loss, train_acc = train_epoch(model, source_loader, optimizer, criterion)
        target_acc = evaluate(model, target_loader)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['target_acc'].append(target_acc)

        if target_acc > best_target_acc:
            best_target_acc = target_acc

    # Final evaluation
    final_target_acc = evaluate(model, target_loader)

    # Also evaluate on source domains individually
    source_accs = {}
    for domain in SOURCE_DOMAINS:
        loader = DataLoader(
            get_domain_dataset(domain, test_tf),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
        )
        source_accs[domain] = evaluate(model, loader)

    return {
        'subspace_type': subspace_type,
        'seed': seed,
        'trainable_params': trainable_params,
        'final_target_acc': final_target_acc,
        'best_target_acc': best_target_acc,
        'source_accs': source_accs,
        'history': history,
    }


# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['minor', 'major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Results Analysis and Reporting
# -----------------------------------------------------------------------------

def analyze_results(results):
    """Analyze and print summary of results"""
    import pandas as pd

    # Convert to DataFrame
    rows = []
    for r in results:
        rows.append({
            'variant': r['subspace_type'],
            'seed': r['seed'],
            'trainable_params': r['trainable_params'],
            'target_acc': r['final_target_acc'],
            'best_target_acc': r['best_target_acc'],
            'source_photo': r['source_accs'].get('photo', 0),
            'source_art': r['source_accs'].get('art_painting', 0),
            'source_cartoon': r['source_accs'].get('cartoon', 0),
        })

    df = pd.DataFrame(rows)

    # Summary statistics
    print("\n" + "=" * 70)
    print("EXPERIMENT 3 RESULTS SUMMARY")
    print("=" * 70)

    print("\n--- Target Domain Accuracy (Sketch) ---")
    summary = df.groupby('variant')['target_acc'].agg(['mean', 'std']).round(4)
    summary = summary.sort_values('mean', ascending=False)
    print(summary)

    print("\n--- Best Target Accuracy (over training) ---")
    summary_best = df.groupby('variant')['best_target_acc'].agg(['mean', 'std']).round(4)
    summary_best = summary_best.sort_values('mean', ascending=False)
    print(summary_best)

    print("\n--- Source Domain Accuracies ---")
    for variant in df['variant'].unique():
        v_df = df[df['variant'] == variant]
        print(f"\n{variant.upper()}:")
        print(f"  Photo:   {v_df['source_photo'].mean():.4f} ± {v_df['source_photo'].std():.4f}")
        print(f"  Art:     {v_df['source_art'].mean():.4f} ± {v_df['source_art'].std():.4f}")
        print(f"  Cartoon: {v_df['source_cartoon'].mean():.4f} ± {v_df['source_cartoon'].std():.4f}")

    print("\n--- Trainable Parameters ---")
    params = df.groupby('variant')['trainable_params'].first()
    print(params)

    # Statistical comparison
    print("\n" + "=" * 70)
    print("KEY COMPARISONS")
    print("=" * 70)

    minor_acc = df[df['variant'] == 'minor']['target_acc'].values
    major_acc = df[df['variant'] == 'major']['target_acc'].values
    middle_acc = df[df['variant'] == 'middle']['target_acc'].values
    random_acc = df[df['variant'] == 'random']['target_acc'].values
    full_acc = df[df['variant'] == 'full']['target_acc'].values

    print(f"\nMinor vs Major: {minor_acc.mean():.4f} vs {major_acc.mean():.4f} (diff: {minor_acc.mean() - major_acc.mean():+.4f})")
    print(f"Minor vs Middle: {minor_acc.mean():.4f} vs {middle_acc.mean():.4f} (diff: {minor_acc.mean() - middle_acc.mean():+.4f})")
    print(f"Minor vs Random: {minor_acc.mean():.4f} vs {random_acc.mean():.4f} (diff: {minor_acc.mean() - random_acc.mean():+.4f})")
    print(f"Minor vs Full: {minor_acc.mean():.4f} vs {full_acc.mean():.4f} (diff: {minor_acc.mean() - full_acc.mean():+.4f})")

    # Interpretation
    print("\n" + "=" * 70)
    print("INTERPRETATION")
    print("=" * 70)

    constrained_accs = np.concatenate([minor_acc, major_acc, middle_acc, random_acc])
    constrained_mean = constrained_accs.mean()
    constrained_std = constrained_accs.std()

    print(f"\nAll constrained methods: {constrained_mean:.4f} ± {constrained_std:.4f}")
    print(f"Full finetuning:         {full_acc.mean():.4f} ± {full_acc.std():.4f}")

    # Check if minor is special
    minor_vs_others = minor_acc.mean() - np.mean([major_acc.mean(), middle_acc.mean(), random_acc.mean()])
    print(f"\nMinor vs average of other constrained: {minor_vs_others:+.4f}")

    if abs(minor_vs_others) < 0.02:
        print("\n→ Minor is NOT special. All subspace constraints perform similarly.")
        print("→ SoMA's benefit likely comes from regularization, not subspace selection.")
    elif minor_vs_others > 0.02:
        print("\n→ Minor outperforms other subspaces. SoMA's hypothesis may have merit.")
    else:
        print("\n→ Minor underperforms other subspaces. SoMA's hypothesis is contradicted.")

    return df


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()

    # Analyze results
    df = analyze_results(results)

    # Save results
    df.to_csv(f"{PROJECT_ROOT}/experiment3_results.csv", index=False)
    print(f"\nResults saved to {PROJECT_ROOT}/experiment3_results.csv")

Using device: cuda
ADAPTER_LAYERS: [('layer2', 1, 'conv2'), ('layer2', 3, 'conv2'), ('layer3', 2, 'conv2'), ('layer3', 5, 'conv2'), ('layer4', 1, 'conv2'), ('layer4', 2, 'conv2')]
EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [42]
  Adapter layers: 6 layers

Running variant: MINOR

  Seed 42...
    Trainable params: 196,487
    Final target acc: 0.5996
    Best target acc:  0.5996
    Source accs: phot=0.999, art_=0.992, cart=0.992

Running variant: MAJOR

  Seed 42...
    Trainable params: 196,487
    Final target acc: 0.6154
    Best target acc:  0.6442
    Source accs: phot=0.999, art_=0.991, cart=0.992

Running variant: MIDDLE

  Seed 42...
    Trainable params: 196,487
    Final target acc: 0.5846
    Best target acc:  0.5867
    Source accs: phot=0.999, art_=0.984, cart=0.988

Running variant: RANDOM

  Seed 42...
   

In [None]:
# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""
    SEEDS = [42]

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()


EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [42]
  Adapter layers: 6 layers

Running variant: MAJOR

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.5668
    Best target acc:  0.5750
    Source accs: phot=0.991, art_=0.966, cart=0.971

Running variant: MIDDLE

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.6139
    Best target acc:  0.6139
    Source accs: phot=0.992, art_=0.970, cart=0.962

Running variant: RANDOM

  Seed 42...
    Trainable params: 142,215
    Final target acc: 0.6124
    Best target acc:  0.6124
    Source accs: phot=0.990, art_=0.967, cart=0.965

Running variant: FULL

  Seed 42...
    Trainable params: 11,180,103
    Final target acc: 0.6760
    Best target acc:  0.7597
    Source accs: phot=0.984, art_=0.968, cart=0.997


In [None]:
# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""
    SEEDS = [456]

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()


EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [456]
  Adapter layers: 6 layers

Running variant: MAJOR

  Seed 456...
    Trainable params: 196,487
    Final target acc: 0.6434
    Best target acc:  0.6434
    Source accs: phot=0.999, art_=0.992, cart=0.994

Running variant: MIDDLE

  Seed 456...
    Trainable params: 196,487
    Final target acc: 0.5492
    Best target acc:  0.6035
    Source accs: phot=0.993, art_=0.983, cart=0.982

Running variant: RANDOM

  Seed 456...


In [None]:
# -----------------------------------------------------------------------------
# Main Experiment Runner
# -----------------------------------------------------------------------------

def run_experiment_3():
    """Run the complete Experiment 3"""
    SEEDS = [123]

    print("=" * 70)
    print("EXPERIMENT 3: Subspace Comparison for Domain Generalization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Rank: {RANK}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Learning rate: {LR}")
    print(f"  Source domains: {SOURCE_DOMAINS}")
    print(f"  Target domain: {TARGET_DOMAIN}")
    print(f"  Seeds: {SEEDS}")
    print(f"  Adapter layers: {len(ADAPTER_LAYERS)} layers")

    # Variants to test
    variants = ['major', 'middle', 'random', 'full']

    all_results = []

    for variant in variants:
        print(f"\n{'='*70}")
        print(f"Running variant: {variant.upper()}")
        print(f"{'='*70}")

        variant_results = []

        for seed in SEEDS:
            print(f"\n  Seed {seed}...")
            result = run_single_experiment(variant, seed)
            variant_results.append(result)

            print(f"    Trainable params: {result['trainable_params']:,}")
            print(f"    Final target acc: {result['final_target_acc']:.4f}")
            print(f"    Best target acc:  {result['best_target_acc']:.4f}")
            print(f"    Source accs: {', '.join([f'{d[:4]}={v:.3f}' for d,v in result['source_accs'].items()])}")

        all_results.extend(variant_results)

    return all_results


# -----------------------------------------------------------------------------
# Run Everything
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    # Run experiments
    results = run_experiment_3()


EXPERIMENT 3: Subspace Comparison for Domain Generalization

Configuration:
  Rank: 8
  Epochs: 10
  Learning rate: 0.0001
  Source domains: ['photo', 'art_painting', 'cartoon']
  Target domain: sketch
  Seeds: [123]
  Adapter layers: 6 layers

Running variant: MAJOR

  Seed 123...
    Trainable params: 196,487
    Final target acc: 0.6411
    Best target acc:  0.6633
    Source accs: phot=0.994, art_=0.994, cart=0.991

Running variant: MIDDLE

  Seed 123...
    Trainable params: 196,487
    Final target acc: 0.6350
    Best target acc:  0.6355
    Source accs: phot=0.999, art_=0.991, cart=0.994

Running variant: RANDOM

  Seed 123...
    Trainable params: 196,487
    Final target acc: 0.6449
    Best target acc:  0.6457
    Source accs: phot=0.999, art_=0.993, cart=0.993

Running variant: FULL

  Seed 123...
    Trainable params: 23,522,375
    Final target acc: 0.7468
    Best target acc:  0.7468
    Source accs: phot=0.997, art_=0.974, cart=0.994


In [None]:
# -----------------------------------------------------------------------------
# Results Analysis and Reporting
# -----------------------------------------------------------------------------

def analyze_results(results):
    """Analyze and print summary of results"""
    import pandas as pd

    # Convert to DataFrame
    rows = []
    for r in results:
        rows.append({
            'variant': r['subspace_type'],
            'seed': r['seed'],
            'trainable_params': r['trainable_params'],
            'target_acc': r['final_target_acc'],
            'best_target_acc': r['best_target_acc'],
            'source_photo': r['source_accs'].get('photo', 0),
            'source_art': r['source_accs'].get('art_painting', 0),
            'source_cartoon': r['source_accs'].get('cartoon', 0),
        })

    df = pd.DataFrame(rows)

    # Summary statistics
    print("\n" + "=" * 70)
    print("EXPERIMENT 3 RESULTS SUMMARY")
    print("=" * 70)

    print("\n--- Target Domain Accuracy (Sketch) ---")
    summary = df.groupby('variant')['target_acc'].agg(['mean', 'std']).round(4)
    summary = summary.sort_values('mean', ascending=False)
    print(summary)

    print("\n--- Best Target Accuracy (over training) ---")
    summary_best = df.groupby('variant')['best_target_acc'].agg(['mean', 'std']).round(4)
    summary_best = summary_best.sort_values('mean', ascending=False)
    print(summary_best)

    print("\n--- Source Domain Accuracies ---")
    for variant in df['variant'].unique():
        v_df = df[df['variant'] == variant]
        print(f"\n{variant.upper()}:")
        print(f"  Photo:   {v_df['source_photo'].mean():.4f} ± {v_df['source_photo'].std():.4f}")
        print(f"  Art:     {v_df['source_art'].mean():.4f} ± {v_df['source_art'].std():.4f}")
        print(f"  Cartoon: {v_df['source_cartoon'].mean():.4f} ± {v_df['source_cartoon'].std():.4f}")

    print("\n--- Trainable Parameters ---")
    params = df.groupby('variant')['trainable_params'].first()
    print(params)

    # Statistical comparison
    print("\n" + "=" * 70)
    print("KEY COMPARISONS")
    print("=" * 70)

    minor_acc = df[df['variant'] == 'minor']['target_acc'].values
    major_acc = df[df['variant'] == 'major']['target_acc'].values
    middle_acc = df[df['variant'] == 'middle']['target_acc'].values
    random_acc = df[df['variant'] == 'random']['target_acc'].values
    full_acc = df[df['variant'] == 'full']['target_acc'].values

    print(f"\nMinor vs Major: {minor_acc.mean():.4f} vs {major_acc.mean():.4f} (diff: {minor_acc.mean() - major_acc.mean():+.4f})")
    print(f"Minor vs Middle: {minor_acc.mean():.4f} vs {middle_acc.mean():.4f} (diff: {minor_acc.mean() - middle_acc.mean():+.4f})")
    print(f"Minor vs Random: {minor_acc.mean():.4f} vs {random_acc.mean():.4f} (diff: {minor_acc.mean() - random_acc.mean():+.4f})")
    print(f"Minor vs Full: {minor_acc.mean():.4f} vs {full_acc.mean():.4f} (diff: {minor_acc.mean() - full_acc.mean():+.4f})")

    # Interpretation
    print("\n" + "=" * 70)
    print("INTERPRETATION")
    print("=" * 70)

    constrained_accs = np.concatenate([minor_acc, major_acc, middle_acc, random_acc])
    constrained_mean = constrained_accs.mean()
    constrained_std = constrained_accs.std()

    print(f"\nAll constrained methods: {constrained_mean:.4f} ± {constrained_std:.4f}")
    print(f"Full finetuning:         {full_acc.mean():.4f} ± {full_acc.std():.4f}")

    # Check if minor is special
    minor_vs_others = minor_acc.mean() - np.mean([major_acc.mean(), middle_acc.mean(), random_acc.mean()])
    print(f"\nMinor vs average of other constrained: {minor_vs_others:+.4f}")

    if abs(minor_vs_others) < 0.02:
        print("\n→ Minor is NOT special. All subspace constraints perform similarly.")
        print("→ SoMA's benefit likely comes from regularization, not subspace selection.")
    elif minor_vs_others > 0.02:
        print("\n→ Minor outperforms other subspaces. SoMA's hypothesis may have merit.")
    else:
        print("\n→ Minor underperforms other subspaces. SoMA's hypothesis is contradicted.")

    return df


In [None]:
    # Analyze results
    df = analyze_results(results)

    # Save results
    df.to_csv(f"{PROJECT_ROOT}/experiment3_results.csv", index=False)
    print(f"\nResults saved to {PROJECT_ROOT}/experiment3_results.csv")

# UPDATED TIER 3 CODE

In [3]:
# =============================================================================
# Tier 3 (ResNet-50) — Comparable to your ResNet-18 Tier 3
# =============================================================================
#
# Goals (match ResNet-18 Tier 3 logic):
# - Start from ImageNet-pretrained ResNet-50
# - Train on source domains (photo, art_painting, cartoon), test on sketch
# - Domain-balanced steps: equal contribution per step from each source domain
# - Class-balanced sampling per domain (approx via WeightedRandomSampler)
# - Freeze BN running stats (BN in eval), keep BN affine trainable
# - Rank r = 8 low-rank subspace-constrained adapters (LoRA-style A trainable, B fixed)
#
# Variants (match your ResNet-18 set):
#   1) soma_minor:      B = bottom-r left singular vectors of W (U_minor)
#   2) soma_major:      B = top-r left singular vectors of W (U_major)
#   3) random:          B = random orthonormal basis (rank-matched)
#   4) deltaW_subspace: B = top-r left singular vectors of mean ΔW̄ from small per-domain adaptations
#   5) full_finetune:   standard finetuning baseline (no adapters)
#
# Scopes (match ResNet-18 idea, only layer3/layer4):
#   - scope = "layer4_only"
#   - scope = "layer3_4"
#
# Transition-block confound control (reviewer concern):
#   We run TWO placement schemes for each scope:
#   - scheme = "endpoints": pick [first block, last block] in each adapted stage
#                           (role-matched: transition + late refinement)
#   - scheme = "late_only": pick [last two blocks] in each adapted stage
#                           (avoids transition/downsample/projection confound)
#
# This gives a defensible story:
# - main: endpoints (role-matched to ResNet-18’s 2 blocks/stage)
# - robustness: late_only (removes transition confound)
#
# =============================================================================
import os, copy, math, random, pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms, models
from tqdm import tqdm

import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# EDIT if needed
# PACS_ROOT = globals().get("PACS_ROOT", "/content/drive/MyDrive/DG_PACS/PACS")
# PROJECT_ROOT = globals().get("PROJECT_ROOT", "/content/drive/MyDrive/DG_PACS/Tier3_R50")
# os.makedirs(PROJECT_ROOT, exist_ok=True)

NUM_CLASSES = 7
SOURCE_DOMAINS = ["photo", "art_painting", "cartoon"]
TARGET_DOMAIN = "sketch"

# Tier 3 settings (match ResNet-18 Tier 3)
RANK = 8
EPOCHS = 10
LR = 1e-4
WEIGHT_DECAY = 0.0
NUM_WORKERS = 2

# Domain-balanced: total batch split equally across domains
BATCH_SIZE_TOTAL = 30  # must be divisible by 3 => 10 per domain

SEEDS = [42]  # extend later if needed

# ΔW basis computation settings (small per-domain adaptation)
DELTAW_SAMPLES_PER_CLASS = 20
DELTAW_STEPS = 200
DELTAW_LR = 2e-4
DELTAW_WEIGHT_DECAY = 0.0

VARIANTS = ["soma_minor", "deltaW_subspace", "random", "soma_major", "full_finetune"]

IMG_SIZE = 224
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


Using device: cuda


In [4]:
# -------------------------
# Repro
# -------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# -------------------------
# Data helpers
# -------------------------
def get_domain_dataset(domain: str, tfm):
    path = os.path.join(PACS_ROOT, domain)
    if not os.path.isdir(path):
        raise FileNotFoundError(f"Missing domain folder: {path}")
    return datasets.ImageFolder(root=path, transform=tfm)

def make_class_balanced_sampler(imagefolder_ds: datasets.ImageFolder, seed: int):
    targets = np.array(imagefolder_ds.targets)
    class_counts = np.bincount(targets, minlength=NUM_CLASSES)
    class_counts = np.maximum(class_counts, 1)
    class_weights = 1.0 / class_counts
    sample_weights = class_weights[targets]

    g = torch.Generator()
    g.manual_seed(seed)

    return WeightedRandomSampler(
        weights=torch.as_tensor(sample_weights, dtype=torch.double),
        num_samples=len(sample_weights),
        replacement=True,
        generator=g
    )

def make_small_balanced_subset(imagefolder_ds: datasets.ImageFolder, per_class: int, seed: int):
    rng = np.random.RandomState(seed)
    targets = np.array(imagefolder_ds.targets)
    indices = []
    for c in range(NUM_CLASSES):
        idx_c = np.where(targets == c)[0]
        if len(idx_c) == 0:
            continue
        rng.shuffle(idx_c)
        take = min(per_class, len(idx_c))
        indices.extend(idx_c[:take].tolist())
    rng.shuffle(indices)
    return Subset(imagefolder_ds, indices)

def make_domain_loaders(batch_total: int, seed: int):
    assert batch_total % len(SOURCE_DOMAINS) == 0, "BATCH_SIZE_TOTAL must be divisible by #source domains"
    bpd = batch_total // len(SOURCE_DOMAINS)

    loaders = {}
    for d in SOURCE_DOMAINS:
        ds = get_domain_dataset(d, train_tf)
        sampler = make_class_balanced_sampler(ds, seed=seed + (hash(d) % 10000))
        loaders[d] = DataLoader(
            ds, batch_size=bpd, sampler=sampler,
            num_workers=NUM_WORKERS, pin_memory=True
        )

    steps_per_epoch = min(len(loaders[d]) for d in SOURCE_DOMAINS)
    return loaders, steps_per_epoch

@torch.no_grad()
def make_eval_loader(domain: str):
    ds = get_domain_dataset(domain, test_tf)
    return DataLoader(ds, batch_size=64, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# -------------------------
# BN policy
# -------------------------
def freeze_bn_running_stats(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()
            if m.weight is not None:
                m.weight.requires_grad = True
            if m.bias is not None:
                m.bias.requires_grad = True

# =============================================================================
# Adapter (LoRA-style): ΔW = reshape(B @ A), B fixed, A trainable
# =============================================================================
class SubspaceConvAdapter(nn.Module):
    def __init__(self, conv: nn.Conv2d, B: np.ndarray, scale: float = 1.0):
        super().__init__()
        self.base_conv = conv
        self.scale = scale

        # Freeze base conv params
        for p in self.base_conv.parameters():
            p.requires_grad = False

        C_out, C_in, kH, kW = conv.weight.shape
        r = B.shape[1]
        assert B.shape[0] == C_out

        self.C_out, self.C_in, self.kH, self.kW, self.r = C_out, C_in, kH, kW, r
        self.A = nn.Parameter(torch.randn(r, C_in * kH * kW) * 0.01)
        self.register_buffer("B", torch.from_numpy(B.astype(np.float32)))

    def forward(self, x):
        y0 = self.base_conv(x)
        dW_mat = self.B @ self.A
        dW = dW_mat.view(self.C_out, self.C_in, self.kH, self.kW)
        y1 = F.conv2d(
            x, dW, bias=None,
            stride=self.base_conv.stride,
            padding=self.base_conv.padding,
            dilation=self.base_conv.dilation,
            groups=self.base_conv.groups
        )
        return y0 + self.scale * y1

# =============================================================================
# ResNet-50 block selection (defensible + transition confound control)
# =============================================================================
def stage_block_count_resnet50(stage_name: str) -> int:
    return {"layer1": 3, "layer2": 4, "layer3": 6, "layer4": 3}[stage_name]

def pick_blocks(stage_name: str, scheme: str):
    n = stage_block_count_resnet50(stage_name)
    if scheme == "endpoints":
        # role-matched: transition block + late refinement block
        return [0, n - 1]
    elif scheme == "late_only":
        # confound control: avoid transition/downsample/projection block
        return [max(0, n - 2), n - 1] if n > 1 else [0]
    else:
        raise ValueError(f"Unknown scheme: {scheme}")

def make_adapter_layers(scope: str, scheme: str):
    if scope == "layer4_only":
        stages = ["layer4"]
    elif scope == "layer3_4":
        stages = ["layer3", "layer4"]
    else:
        raise ValueError(f"Unknown scope: {scope}")

    layers = []
    for st in stages:
        for bi in pick_blocks(st, scheme):
            layers.append((st, bi, "conv2"))
    return layers

def layer_key(layer_name: str, block_idx: int):
    return f"{layer_name}.{block_idx}.conv2"

def get_conv2(model: nn.Module, layer_name: str, block_idx: int):
    return getattr(model, layer_name)[block_idx].conv2

def set_conv2(model: nn.Module, layer_name: str, block_idx: int, new_module: nn.Module):
    getattr(model, layer_name)[block_idx].conv2 = new_module

# =============================================================================
# Bases: SoMA (from W), Random, ΔW̄
# =============================================================================
def svd_left_vectors(W_conv: torch.Tensor):
    W = W_conv.detach().cpu().numpy()
    C_out = W.shape[0]
    Wm = W.reshape(C_out, -1)
    U, S, Vh = np.linalg.svd(Wm, full_matrices=False)
    return U

def orthonormal_random_basis(C_out: int, r: int, seed: int):
    rng = np.random.RandomState(seed)
    M = rng.randn(C_out, r)
    Q, _ = np.linalg.qr(M)
    return Q[:, :r]

def compute_deltaW_bases_for_layers(adapter_layers, seed: int):
    """
    Compute B from top-r left singular vectors of ΔW̄ (mean over per-domain small adaptations),
    starting from ImageNet base, using only source domains.
    """
    set_seed(seed)

    base = models.resnet50(weights="IMAGENET1K_V1")
    base.fc = nn.Linear(2048, NUM_CLASSES)
    base = base.to(device)

    # Snapshot base weights
    W_base = {}
    for (ln, bi, _) in adapter_layers:
        W_base[layer_key(ln, bi)] = get_conv2(base, ln, bi).weight.detach().cpu().clone()

    dW_sum = {k: torch.zeros_like(v) for k, v in W_base.items()}

    for d in SOURCE_DOMAINS:
        ds_full = get_domain_dataset(d, train_tf)
        ds_small = make_small_balanced_subset(ds_full, per_class=DELTAW_SAMPLES_PER_CLASS,
                                              seed=seed + (hash(d) % 9999))
        loader = DataLoader(ds_small, batch_size=24, shuffle=True,
                            num_workers=NUM_WORKERS, pin_memory=True)

        m = copy.deepcopy(base).to(device)

        # Freeze all
        for p in m.parameters():
            p.requires_grad = False

        # Unfreeze fc
        for p in m.fc.parameters():
            p.requires_grad = True

        # Unfreeze only the selected conv2 weights
        for (ln, bi, _) in adapter_layers:
            conv = get_conv2(m, ln, bi)
            conv.weight.requires_grad = True
            if conv.bias is not None:
                conv.bias.requires_grad = True

        freeze_bn_running_stats(m)

        opt = torch.optim.Adam(
            [p for p in m.parameters() if p.requires_grad],
            lr=DELTAW_LR,
            weight_decay=DELTAW_WEIGHT_DECAY
        )
        ce = nn.CrossEntropyLoss()

        m.train()
        it = iter(loader)
        for step in range(DELTAW_STEPS):
            try:
                x, y = next(it)
            except StopIteration:
                it = iter(loader)
                x, y = next(it)
            x, y = x.to(device), y.to(device)

            freeze_bn_running_stats(m)
            opt.zero_grad()
            logits = m(x)
            loss = ce(logits, y)
            loss.backward()
            opt.step()

        # Accumulate ΔW
        for (ln, bi, _) in adapter_layers:
            k = layer_key(ln, bi)
            W_adapt = get_conv2(m, ln, bi).weight.detach().cpu()
            dW_sum[k] += (W_adapt - W_base[k])

        del m
        torch.cuda.empty_cache()

    deltaW_bases = {}
    for k, dW in dW_sum.items():
        dW_bar = dW / float(len(SOURCE_DOMAINS))
        U = svd_left_vectors(dW_bar)
        deltaW_bases[k] = U[:, :RANK]
    return deltaW_bases

# =============================================================================
# Models
# =============================================================================
class ResNet50WithSubspaceAdapters(nn.Module):
    def __init__(self, variant: str, adapter_layers, deltaW_bases=None, seed: int = 0):
        super().__init__()
        assert variant in ["soma_minor", "soma_major", "random", "deltaW_subspace"]
        self.variant = variant
        self.adapter_layers = adapter_layers
        self.seed = seed

        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(2048, NUM_CLASSES)

        for (ln, bi, _) in adapter_layers:
            k = layer_key(ln, bi)
            conv = get_conv2(self.backbone, ln, bi)

            if variant == "soma_minor":
                U = svd_left_vectors(conv.weight)
                B = U[:, -RANK:]
            elif variant == "soma_major":
                U = svd_left_vectors(conv.weight)
                B = U[:, :RANK]
            elif variant == "random":
                C_out = conv.weight.shape[0]
                layer_seed = seed + (hash(k) % 100000)
                B = orthonormal_random_basis(C_out, RANK, layer_seed)
            elif variant == "deltaW_subspace":
                assert deltaW_bases is not None and k in deltaW_bases
                B = deltaW_bases[k]
            else:
                raise ValueError("bad variant")

            adapter = SubspaceConvAdapter(conv, B, scale=1.0)
            set_conv2(self.backbone, ln, bi, adapter)

        self.freeze_for_adapter_training()

    def freeze_for_adapter_training(self):
        for p in self.backbone.parameters():
            p.requires_grad = False

        for m in self.backbone.modules():
            if isinstance(m, SubspaceConvAdapter):
                m.A.requires_grad = True

        for p in self.backbone.fc.parameters():
            p.requires_grad = True

        freeze_bn_running_stats(self.backbone)

    def forward(self, x):
        return self.backbone(x)

    def trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

class ResNet50FullFinetune(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(2048, NUM_CLASSES)

    def prepare_train(self):
        for p in self.backbone.parameters():
            p.requires_grad = True
        freeze_bn_running_stats(self.backbone)

    def forward(self, x):
        return self.backbone(x)

    def trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# =============================================================================
# Train/eval
# =============================================================================
def train_one_epoch(model, domain_loaders, steps_per_epoch, optimizer, criterion):
    model.train()
    freeze_bn_running_stats(model)

    iters = {d: iter(domain_loaders[d]) for d in SOURCE_DOMAINS}

    total_loss, total_correct, total_seen = 0.0, 0, 0
    for _ in tqdm(range(steps_per_epoch), leave=False):
        xs, ys = [], []
        for d in SOURCE_DOMAINS:
            try:
                x, y = next(iters[d])
            except StopIteration:
                iters[d] = iter(domain_loaders[d])
                x, y = next(iters[d])
            xs.append(x)
            ys.append(y)

        x = torch.cat(xs, dim=0).to(device)
        y = torch.cat(ys, dim=0).to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_seen += x.size(0)

    return total_loss / max(1, total_seen), total_correct / max(1, total_seen)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    freeze_bn_running_stats(model)

    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    return correct / max(1, total)

# =============================================================================
# Scenario runner (shared by all scenario cells)
# =============================================================================
def run_scenario(scope: str, scheme: str, scenario_name: str, variants=VARIANTS, seeds=SEEDS, save=True):
    """
    Runs one scenario = (scope, scheme) across all variants.
    Returns a dataframe for easy side-by-side comparison.
    """
    results = []
    adapter_layers = make_adapter_layers(scope, scheme)

    print(f"\n=== Scenario: {scenario_name} ===")
    print(f"scope={scope} | scheme={scheme}")
    print("adapter_layers:", adapter_layers)

    for seed in seeds:
        set_seed(seed)

        # Precompute ΔW bases ONCE per seed for this scenario (only if needed)
        deltaW_bases = None
        if "deltaW_subspace" in variants:
            print(f"\n[Seed {seed}] Computing ΔW bases...")
            deltaW_bases = compute_deltaW_bases_for_layers(adapter_layers, seed=seed)
            print("[Done] ΔW bases computed.\n")

        # Data loaders (domain-balanced)
        domain_loaders, steps_per_epoch = make_domain_loaders(BATCH_SIZE_TOTAL, seed=seed)
        target_loader = make_eval_loader(TARGET_DOMAIN)

        for variant in variants:
            print(f"[Seed {seed}] Variant: {variant}")
            if variant == "full_finetune":
                model = ResNet50FullFinetune().to(device)
                model.prepare_train()
                trainable = model.trainable_params()
            else:
                model = ResNet50WithSubspaceAdapters(
                    variant=variant,
                    adapter_layers=adapter_layers,
                    deltaW_bases=deltaW_bases,
                    seed=seed
                ).to(device)
                trainable = model.trainable_params()

            optimizer = torch.optim.Adam(
                [p for p in model.parameters() if p.requires_grad],
                lr=LR,
                weight_decay=WEIGHT_DECAY
            )
            criterion = nn.CrossEntropyLoss()

            best_target = 0.0
            hist = {"train_loss": [], "train_acc": [], "target_acc": []}

            for ep in range(EPOCHS):
                tr_loss, tr_acc = train_one_epoch(model, domain_loaders, steps_per_epoch, optimizer, criterion)
                tgt_acc = evaluate(model, target_loader)
                hist["train_loss"].append(tr_loss)
                hist["train_acc"].append(tr_acc)
                hist["target_acc"].append(tgt_acc)
                best_target = max(best_target, tgt_acc)
                print(f"  ep {ep+1:02d}/{EPOCHS} | loss={tr_loss:.4f} acc={tr_acc:.4f} | sketch={tgt_acc:.4f}")

            # Source accuracies
            source_accs = {}
            for d in SOURCE_DOMAINS:
                loader = make_eval_loader(d)
                source_accs[d] = evaluate(model, loader)

            final_target = evaluate(model, target_loader)

            results.append({
                "scenario": scenario_name,
                "scope": scope,
                "scheme": scheme,
                "variant": variant,
                "seed": seed,
                "rank": RANK,
                "trainable_params": trainable,
                "sketch_final": final_target,
                "sketch_best": best_target,
                "photo": source_accs["photo"],
                "art": source_accs["art_painting"],
                "cartoon": source_accs["cartoon"],
            })

    df = pd.DataFrame(results)

    if save:
        pkl_path = os.path.join(PROJECT_ROOT, f"{scenario_name}_results.pkl")
        csv_path = os.path.join(PROJECT_ROOT, f"{scenario_name}_results.csv")
        with open(pkl_path, "wb") as f:
            pickle.dump(results, f)
        df.to_csv(csv_path, index=False)
        print("\nSaved:", pkl_path)
        print("Saved:", csv_path)

    print("\nSummary (Sketch final):")
    print(df.groupby(["variant"])["sketch_final"].agg(["mean","std"]).sort_values("mean", ascending=False))

    return df


Scenario 1 — Scope 1 (layer4 only), endpoints (role-matched)

In [None]:
df_s1 = run_scenario(
    scope="layer4_only",
    scheme="endpoints",
    scenario_name="S1_layer4_endpoints"
)
df_s1



=== Scenario: S1_layer4_endpoints ===
scope=layer4_only | scheme=endpoints
adapter_layers: [('layer4', 0, 'conv2'), ('layer4', 2, 'conv2')]

[Seed 42] Computing ΔW bases...
[Done] ΔW bases computed.

[Seed 42] Variant: soma_minor




  ep 01/10 | loss=0.9046 acc=0.7469 | sketch=0.5434




  ep 02/10 | loss=0.2158 acc=0.9293 | sketch=0.5263




  ep 03/10 | loss=0.1497 acc=0.9493 | sketch=0.5920




  ep 04/10 | loss=0.1295 acc=0.9575 | sketch=0.5956




  ep 05/10 | loss=0.0924 acc=0.9703 | sketch=0.5676




  ep 06/10 | loss=0.0758 acc=0.9743 | sketch=0.5940




  ep 07/10 | loss=0.0646 acc=0.9788 | sketch=0.6251




  ep 08/10 | loss=0.0647 acc=0.9804 | sketch=0.6325




  ep 09/10 | loss=0.0460 acc=0.9848 | sketch=0.6294




  ep 10/10 | loss=0.0381 acc=0.9888 | sketch=0.6236
[Seed 42] Variant: deltaW_subspace




  ep 01/10 | loss=0.8563 acc=0.7407 | sketch=0.3797




  ep 02/10 | loss=0.2186 acc=0.9246 | sketch=0.4019




  ep 03/10 | loss=0.1518 acc=0.9475 | sketch=0.5337




  ep 04/10 | loss=0.1242 acc=0.9577 | sketch=0.4917




  ep 05/10 | loss=0.1117 acc=0.9619 | sketch=0.5495




  ep 06/10 | loss=0.0844 acc=0.9717 | sketch=0.5136




  ep 07/10 | loss=0.0608 acc=0.9810 | sketch=0.5877




  ep 08/10 | loss=0.0661 acc=0.9774 | sketch=0.5968




  ep 09/10 | loss=0.0406 acc=0.9868 | sketch=0.5566




  ep 10/10 | loss=0.0386 acc=0.9866 | sketch=0.5890
[Seed 42] Variant: random




  ep 01/10 | loss=0.8531 acc=0.7605 | sketch=0.4166




  ep 02/10 | loss=0.2149 acc=0.9303 | sketch=0.4808




  ep 03/10 | loss=0.1329 acc=0.9545 | sketch=0.5131




  ep 04/10 | loss=0.1088 acc=0.9645 | sketch=0.5378




  ep 05/10 | loss=0.1012 acc=0.9671 | sketch=0.5546




  ep 06/10 | loss=0.0805 acc=0.9725 | sketch=0.5584




  ep 07/10 | loss=0.0670 acc=0.9810 | sketch=0.5554




  ep 08/10 | loss=0.0593 acc=0.9820 | sketch=0.5594




  ep 09/10 | loss=0.0451 acc=0.9876 | sketch=0.6058




  ep 10/10 | loss=0.0360 acc=0.9902 | sketch=0.6022
[Seed 42] Variant: soma_major




  ep 01/10 | loss=0.8135 acc=0.7485 | sketch=0.4317




  ep 02/10 | loss=0.1910 acc=0.9369 | sketch=0.5151




  ep 03/10 | loss=0.1307 acc=0.9575 | sketch=0.5083




  ep 04/10 | loss=0.1078 acc=0.9631 | sketch=0.4093




  ep 05/10 | loss=0.0869 acc=0.9711 | sketch=0.4795




  ep 06/10 | loss=0.0727 acc=0.9768 | sketch=0.6080




  ep 07/10 | loss=0.0588 acc=0.9814 | sketch=0.5261




  ep 08/10 | loss=0.0550 acc=0.9832 | sketch=0.5396




  ep 09/10 | loss=0.0499 acc=0.9848 | sketch=0.5732




  ep 10/10 | loss=0.0385 acc=0.9892 | sketch=0.6032
[Seed 42] Variant: full_finetune




  ep 01/10 | loss=0.3929 acc=0.8689 | sketch=0.5905




  ep 02/10 | loss=0.2149 acc=0.9329 | sketch=0.6154




  ep 03/10 | loss=0.1646 acc=0.9485 | sketch=0.6610




  ep 04/10 | loss=0.1542 acc=0.9511 | sketch=0.6396




  ep 05/10 | loss=0.1576 acc=0.9505 | sketch=0.6035




  ep 06/10 | loss=0.1141 acc=0.9677 | sketch=0.4866




  ep 07/10 | loss=0.0817 acc=0.9741 | sketch=0.6518




  ep 08/10 | loss=0.0990 acc=0.9719 | sketch=0.6495




  ep 09/10 | loss=0.0995 acc=0.9683 | sketch=0.6498




  ep 10/10 | loss=0.0969 acc=0.9701 | sketch=0.6264

Saved: /content/drive/MyDrive/SoMA_PACS_R50/S1_layer4_endpoints_results.pkl
Saved: /content/drive/MyDrive/SoMA_PACS_R50/S1_layer4_endpoints_results.csv

Summary (Sketch final):
                     mean  std
variant                       
full_finetune    0.626368  NaN
soma_minor       0.623568  NaN
soma_major       0.603207  NaN
random           0.602189  NaN
deltaW_subspace  0.588954  NaN


Unnamed: 0,scenario,scope,scheme,variant,seed,rank,trainable_params,sketch_final,sketch_best,photo,art,cartoon
0,S1_layer4_endpoints,layer4_only,endpoints,soma_minor,42,8,141191,0.623568,0.632476,0.999401,0.987793,0.975256
1,S1_layer4_endpoints,layer4_only,endpoints,deltaW_subspace,42,8,141191,0.588954,0.596844,1.0,0.98877,0.984642
2,S1_layer4_endpoints,layer4_only,endpoints,random,42,8,141191,0.602189,0.605752,0.998802,0.989746,0.978669
3,S1_layer4_endpoints,layer4_only,endpoints,soma_major,42,8,141191,0.603207,0.608043,0.998802,0.986328,0.977389
4,S1_layer4_endpoints,layer4_only,endpoints,full_finetune,42,8,23522375,0.626368,0.660982,0.98982,0.955566,0.984215


In [None]:
!find "/content/drive" -maxdepth 5 -type d -name "PACS" 2>/dev/null


In [None]:
PACS_ROOT = "<that path>"


Scenario 2 — Scope 1 (layer4 only), late_only (confound control)

In [None]:
df_s2 = run_scenario(
    scope="layer4_only",
    scheme="late_only",
    scenario_name="S2_layer4_lateonly"
)
df_s2



=== Scenario: S2_layer4_lateonly ===
scope=layer4_only | scheme=late_only
adapter_layers: [('layer4', 1, 'conv2'), ('layer4', 2, 'conv2')]

[Seed 42] Computing ΔW bases...
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 206MB/s]


[Done] ΔW bases computed.

[Seed 42] Variant: soma_minor




  ep 01/10 | loss=0.9238 acc=0.7491 | sketch=0.4385




  ep 02/10 | loss=0.2062 acc=0.9297 | sketch=0.5897




  ep 03/10 | loss=0.1537 acc=0.9495 | sketch=0.5862




  ep 04/10 | loss=0.1142 acc=0.9617 | sketch=0.5951




  ep 05/10 | loss=0.0887 acc=0.9715 | sketch=0.6526




  ep 06/10 | loss=0.0875 acc=0.9705 | sketch=0.5890




  ep 07/10 | loss=0.0724 acc=0.9758 | sketch=0.5742




  ep 08/10 | loss=0.0477 acc=0.9862 | sketch=0.5856




  ep 09/10 | loss=0.0544 acc=0.9808 | sketch=0.5671




  ep 10/10 | loss=0.0408 acc=0.9870 | sketch=0.5905
[Seed 42] Variant: deltaW_subspace




  ep 01/10 | loss=0.7798 acc=0.7485 | sketch=0.3207




  ep 02/10 | loss=0.2130 acc=0.9275 | sketch=0.5319




  ep 03/10 | loss=0.1557 acc=0.9451 | sketch=0.4561




  ep 04/10 | loss=0.1146 acc=0.9623 | sketch=0.4637




  ep 05/10 | loss=0.1110 acc=0.9615 | sketch=0.5235




  ep 06/10 | loss=0.0857 acc=0.9697 | sketch=0.5630




  ep 07/10 | loss=0.0680 acc=0.9766 | sketch=0.5615




  ep 08/10 | loss=0.0513 acc=0.9826 | sketch=0.5688




  ep 09/10 | loss=0.0512 acc=0.9856 | sketch=0.5098




  ep 10/10 | loss=0.0401 acc=0.9868 | sketch=0.5638
[Seed 42] Variant: random




  ep 01/10 | loss=0.8401 acc=0.7667 | sketch=0.3663




  ep 02/10 | loss=0.2174 acc=0.9297 | sketch=0.4930




  ep 03/10 | loss=0.1339 acc=0.9565 | sketch=0.5902




  ep 04/10 | loss=0.1083 acc=0.9625 | sketch=0.5246




  ep 05/10 | loss=0.0908 acc=0.9697 | sketch=0.6091




  ep 06/10 | loss=0.0830 acc=0.9750 | sketch=0.6103




  ep 07/10 | loss=0.0541 acc=0.9834 | sketch=0.5638




  ep 08/10 | loss=0.0571 acc=0.9802 | sketch=0.6108




  ep 09/10 | loss=0.0580 acc=0.9818 | sketch=0.5645




  ep 10/10 | loss=0.0425 acc=0.9876 | sketch=0.6114
[Seed 42] Variant: soma_major




  ep 01/10 | loss=0.7217 acc=0.7515 | sketch=0.4518




  ep 02/10 | loss=0.2024 acc=0.9301 | sketch=0.5579




  ep 03/10 | loss=0.1288 acc=0.9561 | sketch=0.5393




  ep 04/10 | loss=0.1156 acc=0.9595 | sketch=0.4882




  ep 05/10 | loss=0.0890 acc=0.9695 | sketch=0.5541




  ep 06/10 | loss=0.0682 acc=0.9770 | sketch=0.5897




  ep 07/10 | loss=0.0749 acc=0.9737 | sketch=0.5174




  ep 08/10 | loss=0.0555 acc=0.9830 | sketch=0.5154




  ep 09/10 | loss=0.0554 acc=0.9826 | sketch=0.5515




  ep 10/10 | loss=0.0363 acc=0.9904 | sketch=0.6126
[Seed 42] Variant: full_finetune




  ep 01/10 | loss=0.5051 acc=0.8307 | sketch=0.6732




  ep 02/10 | loss=0.2506 acc=0.9224 | sketch=0.5635




  ep 03/10 | loss=0.1575 acc=0.9515 | sketch=0.5961




  ep 04/10 | loss=0.1656 acc=0.9463 | sketch=0.6007




  ep 05/10 | loss=0.1211 acc=0.9607 | sketch=0.5612




  ep 06/10 | loss=0.1045 acc=0.9667 | sketch=0.6836




Scenario 3 — Scope 2 (layer3 + layer4), endpoints (role-matched)

In [None]:
df_s3 = run_scenario(
    scope="layer3_4",
    scheme="endpoints",
    scenario_name="S3_layer34_endpoints"
)
df_s3


Scenario 4 — Scope 2 (layer3 + layer4), late_only (confound control)

In [None]:
df_s4 = run_scenario(
    scope="layer3_4",
    scheme="late_only",
    scenario_name="S4_layer34_lateonly"
)
df_s4


In [None]:
df_all = pd.concat([df_s1, df_s2, df_s3, df_s4], ignore_index=True)

# Side-by-side means (Sketch final)
pivot = df_all.pivot_table(
    index="variant",
    columns="scenario",
    values="sketch_final",
    aggfunc="mean"
).sort_values(by=list(df_all["scenario"].unique())[0], ascending=False)

pivot
