# Tier 0 — Baseline Training + Spectral Introspection (PACS, ResNet-18)

In [None]:
# ===== Colab / Drive setup =====
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/SoMA_PACS"
import os
os.makedirs(PROJECT_ROOT, exist_ok=True)


Mounted at /content/drive


In [None]:
!pip -q install datasets huggingface_hub pillow


In [None]:
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm


In [None]:
import shutil, os

PACS_ROOT = "/content/drive/MyDrive/NEW_PACS/PACS"
assert "PACS" in PACS_ROOT
if not os.path.exists(PACS_ROOT):
    os.makedirs(PACS_ROOT, exist_ok=True)
    print("Created PACS folder.")
else:
    print("PACS folder exists. Skipping deletion.")



Created PACS folder.


In [None]:
from datasets import load_dataset
from tqdm.auto import tqdm
import os

ds = load_dataset("flwrlabs/pacs")

DOMAIN_MAP = {
    "photo": "photo",
    "art_painting": "art_painting",
    "cartoon": "cartoon",
    "sketch": "sketch",
}

def safe(s):
    return str(s).strip().lower().replace(" ", "_")

# Per-domain counters to ensure unique filenames
counters = {d: { } for d in DOMAIN_MAP.values()}

for ex in tqdm(ds["train"], desc="Exporting PACS"):
    img = ex["image"]

    domain = ex["domain"]
    label  = ex["label"]

    if not isinstance(domain, str):
        domain = ds["train"].features["domain"].int2str(domain)
    if not isinstance(label, str):
        label = ds["train"].features["label"].int2str(label)

    domain = DOMAIN_MAP[safe(domain)]
    label  = safe(label)

    # initialize counter
    counters[domain].setdefault(label, 0)
    idx = counters[domain][label]
    counters[domain][label] += 1

    out_dir = os.path.join(PACS_ROOT, domain, label)
    os.makedirs(out_dir, exist_ok=True)

    out_path = os.path.join(out_dir, f"{label}_{idx:05d}.jpg")
    img.save(out_path, quality=95)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Exporting PACS:   0%|          | 0/9991 [00:00<?, ?it/s]

In [None]:
from torchvision import datasets
import os

for d in ["sketch","photo","art_painting","cartoon"]:
    ds = datasets.ImageFolder(os.path.join(PACS_ROOT, d))
    print(d, ds.class_to_idx)
    assert len(ds.class_to_idx) == 7, f"{d} does not have 7 classes!"



sketch {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
photo {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
art_painting {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
cartoon {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}


In [None]:
from torchvision import datasets
import os

def get_class_to_idx(domain):
    ds = datasets.ImageFolder(root=os.path.join(PACS_ROOT, domain))
    return ds.class_to_idx

mappings = {d: get_class_to_idx(d) for d in ["photo","art_painting","cartoon","sketch"]}
for d, m in mappings.items():
    print(d, m)

# Compare
base = mappings["photo"]
for d in mappings:
    print(d, "matches photo:", mappings[d] == base)


photo {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
art_painting {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
cartoon {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
sketch {'dog': 0, 'elephant': 1, 'giraffe': 2, 'guitar': 3, 'horse': 4, 'house': 5, 'person': 6}
photo matches photo: True
art_painting matches photo: True
cartoon matches photo: True
sketch matches photo: True


In [None]:
# ===== Reproducibility =====
import random, numpy as np, torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ===== Imports =====
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm


In [None]:
# ===== PACS paths (edit if needed) =====
PACS_ROOT = "/content/drive/MyDrive/datasets/PACS"

SOURCE_DOMAINS = ["photo", "art_painting", "cartoon"]
TARGET_DOMAIN = "sketch"

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2


In [None]:
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])


In [None]:
def make_loader(domain, tf, shuffle):
    ds = datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=tf
    )
    return DataLoader(ds, batch_size=BATCH_SIZE,
                      shuffle=shuffle, num_workers=NUM_WORKERS)

train_loaders = [make_loader(d, train_tf, True) for d in SOURCE_DOMAINS]
test_loader = make_loader(TARGET_DOMAIN, test_tf, False)


In [None]:
class ResNet18_FeatureHook(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = models.resnet18(weights="IMAGENET1K_V1")
        self.backbone.fc = nn.Linear(512, num_classes)

        self._features = {}

        def make_hook(name):
            def hook(module, inp, out):
                self._features[name] = out
            return hook

        # IMPORTANT: include layer1 now
        for lname in ["layer1", "layer2", "layer3", "layer4"]:
            layer = getattr(self.backbone, lname)
            for i, block in enumerate(layer):
                block.bn2.register_forward_hook(
                    make_hook(f"{lname}.{i}.bn2")
                )

    def forward(self, x):
        self._features = {}
        return self.backbone(x)



# Training loop (Tier-0 backbone)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet18_FeatureHook(num_classes=7).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 181MB/s]


In [None]:
def train_epoch(loaders):
    model.train()
    total_loss, total_correct, total = 0, 0, 0

    for loader in loaders:
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            total_correct += (logits.argmax(1) == y).sum().item()
            total += x.size(0)

    return total_loss / total, total_correct / total


In [None]:
@torch.no_grad()
def eval_epoch(loader):
    model.eval()
    total_correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_correct / total


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [None]:
EPOCHS = 15
for ep in range(EPOCHS):
    tr_loss, tr_acc = train_epoch(train_loaders)
    te_acc = eval_epoch(test_loader)

    print(f"[{ep:02d}] loss={tr_loss:.3f} | src_acc={tr_acc:.3f} | tgt_acc={te_acc:.3f}")


[00] loss=0.386 | src_acc=0.874 | tgt_acc=0.661
[01] loss=0.213 | src_acc=0.929 | tgt_acc=0.632
[02] loss=0.144 | src_acc=0.953 | tgt_acc=0.680
[03] loss=0.123 | src_acc=0.958 | tgt_acc=0.652
[04] loss=0.085 | src_acc=0.972 | tgt_acc=0.725
[05] loss=0.058 | src_acc=0.981 | tgt_acc=0.718
[06] loss=0.105 | src_acc=0.967 | tgt_acc=0.733
[07] loss=0.063 | src_acc=0.982 | tgt_acc=0.708
[08] loss=0.061 | src_acc=0.981 | tgt_acc=0.688
[09] loss=0.074 | src_acc=0.976 | tgt_acc=0.754
[10] loss=0.031 | src_acc=0.990 | tgt_acc=0.722
[11] loss=0.028 | src_acc=0.991 | tgt_acc=0.732
[12] loss=0.063 | src_acc=0.979 | tgt_acc=0.699
[13] loss=0.044 | src_acc=0.986 | tgt_acc=0.708
[14] loss=0.038 | src_acc=0.987 | tgt_acc=0.736


Save trained backbone

In [None]:
BACKBONE_PATH = f"{PROJECT_ROOT}/resnet18_pacs_base.pt"

torch.save(model.state_dict(), BACKBONE_PATH)
print("Saved PACS-trained backbone.")


Saved PACS-trained backbone.


# TIER 0 — SPECTRAL INTROSPECTION

In [None]:
import numpy.linalg as LA

def conv_svd(conv):
    # conv.weight: [C_out, C_in, k, k]
    W = conv.weight.detach().cpu().numpy()
    Cout = W.shape[0]a
    Wmat = W.reshape(Cout, -1)
    U, S, Vt = LA.svd(Wmat, full_matrices=False)
    return U, S


In [None]:
spectra = {}

for lname in ["layer2", "layer3", "layer4"]:
    for i, block in enumerate(getattr(model.backbone, lname)):
        conv = block.conv2
        U, S = conv_svd(conv)
        spectra[f"{lname}.{i}"] = S


In [None]:
import pickle

with open(f"{PROJECT_ROOT}/tier0_spectra.pkl", "wb") as f:
    pickle.dump(spectra, f)


# Tier-0 Feature Statistics (Energy by Spectral Rank)

In [None]:
@torch.no_grad()
def collect_feature_energy(loader):
    model.eval()
    energy = {}

    for x, _ in loader:
        x = x.to(device)
        _ = model(x)

        for k, feat in model._features.items():
            # GAP → channel vector
            h = feat.mean(dim=[2,3])  # [B, C]
            e = (h**2).mean(0).cpu().numpy()
            energy.setdefault(k, []).append(e)

    for k in energy:
        energy[k] = np.mean(energy[k], axis=0)

    return energy


In [None]:
src_energy = collect_feature_energy(train_loaders[0])
tgt_energy = collect_feature_energy(test_loader)

with open(f"{PROJECT_ROOT}/tier0_feature_energy.pkl", "wb") as f:
    pickle.dump({"src": src_energy, "tgt": tgt_energy}, f)


Saved to Drive

- tier0_spectra.pkl → singular values per conv layer

- tier0_feature_energy.pkl → channel-wise activation energy (src vs tgt)

- trained baseline checkpoint (implicitly in model state)



Scientific baseline

- no spectral intervention

- no leakage

- fixed feature definition

- BN explicitly visible

# Tier 1A — Do Minor Singular Subspaces Encode More Domain Information?

Is the minor singular subspace more domain-specific, while the major subspace is more class-semantic?

In [None]:
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy.linalg as LA

In [None]:
with open(f"{PROJECT_ROOT}/tier0_spectra.pkl", "rb") as f:
    spectra = pickle.load(f)


In [None]:
def get_conv_svd(conv):
    """
    Computes SVD of a conv layer in output-channel space.

    conv.weight: [C_out, C_in, k, k]
    Returns:
        U: [C_out, r]
        S: [r]
    """
    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    W_mat = W.reshape(C_out, -1)
    U, S, _ = LA.svd(W_mat, full_matrices=False)
    return U, S



Build Subspace Bases (Major / Minor / Random)

In [None]:
def get_subspace(U, kind, r, seed=0):
    """
    Returns an orthonormal basis of shape [C_out, r]
    """
    C = U.shape[0]
    rng = np.random.default_rng(seed)

    if kind == "major":
        return U[:, :r]

    if kind == "minor":
        return U[:, -r:]

    if kind == "random":
        Q, _ = np.linalg.qr(rng.standard_normal((C, r)))
        return Q

    raise ValueError(f"Unknown subspace kind: {kind}")



In [None]:
@torch.no_grad()
def extract_gap_features(domains, layer_name):
    DOMAIN_TO_ID = {
        "photo": 0,
        "art_painting": 1,
        "cartoon": 2,
        "sketch": 3,
    }

    X, y_class, y_domain = [], [], []
    model.eval()

    for domain in domains:
        loader = make_loader(domain, test_tf, shuffle=False)
        d_id = DOMAIN_TO_ID[domain]

        for x, y in loader:
            x = x.to(device)
            _ = model(x)

            feat = model._features[f"{layer_name}.bn2"]
            h = feat.mean(dim=[2, 3]).cpu().numpy()  # GAP

            X.append(h)
            y_class.append(y.numpy())
            y_domain.append(np.full(h.shape[0], d_id))

    return (
        np.concatenate(X),
        np.concatenate(y_class),
        np.concatenate(y_domain),
    )



In [None]:
class ProjectedDataset(Dataset):
    def __init__(self, X, y, Usub):
        Z = X @ Usub
        self.X = torch.tensor(Z, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]



In [None]:
class LinearProbe(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.fc(x)


def train_probe(dataset, num_classes, seed=0, epochs=50, train_frac=0.7):
    torch.manual_seed(seed)

    n = len(dataset)
    n_train = int(train_frac * n)
    n_val = n - n_train

    train_ds, val_ds = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)

    probe = LinearProbe(dataset.X.shape[1], num_classes).to(device)
    opt = torch.optim.Adam(probe.parameters(), lr=1e-2)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss_fn(probe(x), y).backward()
            opt.step()

    correct, total = 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            pred = probe(x).argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    return correct / total


# Tier-1.1: Domain probe (core SoMA assumption)

In [None]:
# -------- Tier 1.1: Domain Probing --------

layers = {
    "layer2.1": model.backbone.layer2[1].conv2,
    "layer3.1": model.backbone.layer3[1].conv2,
    "layer4.1": model.backbone.layer4[1].conv2,
}

domains = ["photo", "art_painting", "cartoon", "sketch"]
ranks = [4, 8, 16]
seeds = [0, 1, 2]

for layer_name, conv in layers.items():
    print(f"\n=== Domain Probe @ {layer_name} ===")

    U, S = get_conv_svd(conv)
    X, _, y_domain = extract_gap_features(domains, model, layer_name)

    for r in ranks:
        print(f"\nRank {r}")
        for kind in ["major", "minor", "random"]:
            accs = []
            for seed in seeds:
                Usub = get_subspace(U, kind, r, seed)
                ds = ProjectedDataset(X, y_domain, Usub)
                accs.append(train_probe(ds, num_classes=4, seed=seed))

            print(f"{kind:>6}: {np.mean(accs):.3f} ± {np.std(accs):.3f}")



=== Domain Probe @ layer2.1 ===

Rank 4
 major: 0.678 ± 0.007
 minor: 0.672 ± 0.003
random: 0.624 ± 0.064

Rank 8
 major: 0.752 ± 0.003
 minor: 0.750 ± 0.003
random: 0.792 ± 0.023

Rank 16
 major: 0.842 ± 0.003
 minor: 0.823 ± 0.003
random: 0.826 ± 0.012

=== Domain Probe @ layer3.1 ===

Rank 4
 major: 0.704 ± 0.004
 minor: 0.547 ± 0.013
random: 0.659 ± 0.040

Rank 8
 major: 0.775 ± 0.007
 minor: 0.669 ± 0.010
random: 0.714 ± 0.053

Rank 16
 major: 0.800 ± 0.003
 minor: 0.731 ± 0.007
random: 0.822 ± 0.017

=== Domain Probe @ layer4.1 ===

Rank 4
 major: 0.535 ± 0.016
 minor: 0.478 ± 0.004
random: 0.583 ± 0.022

Rank 8
 major: 0.654 ± 0.001
 minor: 0.528 ± 0.012
random: 0.622 ± 0.012

Rank 16
 major: 0.730 ± 0.005
 minor: 0.603 ± 0.010
random: 0.708 ± 0.012


# Tier-1.2: Class probe (complementary assumption)

In [None]:
# -------- Tier 1.2: Class Probing --------

source_domains = ["photo", "art_painting", "cartoon"]

for layer_name, conv in layers.items():
    print(f"\n=== Class Probe @ {layer_name} ===")

    U, S = get_conv_svd(conv)
    X, y_class, _ = extract_gap_features(source_domains, model, layer_name)

    for r in ranks:
        print(f"\nRank {r}")
        for kind in ["major", "minor", "random"]:
            accs = []
            for seed in seeds:
                Usub = get_subspace(U, kind, r, seed)
                ds = ProjectedDataset(X, y_class, Usub)
                accs.append(train_probe(ds, num_classes=7, seed=seed))

            print(f"{kind:>6}: {np.mean(accs):.3f} ± {np.std(accs):.3f}")



=== Class Probe @ layer2.1 ===

Rank 4
 major: 0.245 ± 0.006
 minor: 0.299 ± 0.002
random: 0.288 ± 0.016

Rank 8
 major: 0.336 ± 0.010
 minor: 0.331 ± 0.004
random: 0.361 ± 0.010

Rank 16
 major: 0.402 ± 0.014
 minor: 0.374 ± 0.005
random: 0.422 ± 0.012

=== Class Probe @ layer3.1 ===

Rank 4
 major: 0.342 ± 0.012
 minor: 0.285 ± 0.004
random: 0.338 ± 0.012

Rank 8
 major: 0.470 ± 0.002
 minor: 0.332 ± 0.010
random: 0.416 ± 0.018

Rank 16
 major: 0.624 ± 0.005
 minor: 0.431 ± 0.005
random: 0.528 ± 0.013

=== Class Probe @ layer4.1 ===

Rank 4
 major: 0.938 ± 0.004
 minor: 0.471 ± 0.005
random: 0.846 ± 0.018

Rank 8
 major: 0.970 ± 0.004
 minor: 0.704 ± 0.011
random: 0.936 ± 0.010

Rank 16
 major: 0.975 ± 0.001
 minor: 0.821 ± 0.004
random: 0.962 ± 0.007


Domain & Class Probe for all layers

In [None]:
BACKBONE_PATH = f"{PROJECT_ROOT}/resnet18_pacs_base.pt"

model = ResNet18_FeatureHook(num_classes=7).to(device)

state = torch.load(BACKBONE_PATH, map_location=device)
model.load_state_dict(state)

model.eval()


ResNet18_FeatureHook(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=T

Testing that the right model was loaded

In [None]:
te_acc = eval_epoch(test_loader)
print("Target accuracy:", te_acc)


Target accuracy: 0.7358106388393993


In [None]:
# Enumerate all conv2 layers in ResNet-18
conv_blocks = {}

for lname in ["layer1", "layer2", "layer3", "layer4"]:
    layer = getattr(model.backbone, lname)
    for i, block in enumerate(layer):
        key = f"{lname}.{i}"
        conv_blocks[key] = block.conv2


In [None]:
domains_all = ["photo", "art_painting", "cartoon", "sketch"]
domains_source = ["photo", "art_painting", "cartoon"]

ranks = [4, 8, 16]
seeds = [0, 1, 2]


In [None]:
print("===== DOMAIN PROBES (ALL CONV BLOCKS) =====")

for layer_name, conv in conv_blocks.items():
    print(f"\n=== Domain probe @ {layer_name} ===")

    # SVD of this conv layer
    U, S = get_conv_svd(conv)

    # Extract GAP features
    X, _, y_domain = extract_gap_features(domains_all, layer_name)

    for r in ranks:
        print(f"Rank {r}")
        for kind in ["major", "minor", "random"]:
            accs = []
            for seed in seeds:
                Usub = get_subspace(U, kind, r, seed)
                ds = ProjectedDataset(X, y_domain, Usub)
                accs.append(train_probe(ds, num_classes=4, seed=seed))

            print(f"{kind:>6}: {np.mean(accs):.3f} ± {np.std(accs):.3f}")


===== DOMAIN PROBES (ALL CONV BLOCKS) =====

=== Domain probe @ layer1.0 ===
Rank 4
 major: 0.749 ± 0.004
 minor: 0.585 ± 0.016
random: 0.656 ± 0.031
Rank 8
 major: 0.772 ± 0.004
 minor: 0.715 ± 0.017
random: 0.730 ± 0.026
Rank 16
 major: 0.843 ± 0.004
 minor: 0.780 ± 0.003
random: 0.828 ± 0.014

=== Domain probe @ layer1.1 ===
Rank 4
 major: 0.601 ± 0.010
 minor: 0.617 ± 0.013
random: 0.626 ± 0.076
Rank 8
 major: 0.700 ± 0.009
 minor: 0.706 ± 0.012
random: 0.705 ± 0.030
Rank 16
 major: 0.823 ± 0.002
 minor: 0.797 ± 0.004
random: 0.805 ± 0.029

=== Domain probe @ layer2.0 ===
Rank 4
 major: 0.607 ± 0.008
 minor: 0.733 ± 0.015
random: 0.661 ± 0.054
Rank 8
 major: 0.739 ± 0.008
 minor: 0.775 ± 0.010
random: 0.791 ± 0.018
Rank 16
 major: 0.835 ± 0.006
 minor: 0.836 ± 0.007
random: 0.836 ± 0.008

=== Domain probe @ layer2.1 ===
Rank 4
 major: 0.678 ± 0.007
 minor: 0.672 ± 0.003
random: 0.624 ± 0.064
Rank 8
 major: 0.752 ± 0.003
 minor: 0.750 ± 0.003
random: 0.792 ± 0.023
Rank 16
 major: 0.

In [None]:
print("\n===== CLASS PROBES (ALL CONV BLOCKS) =====")

for layer_name, conv in conv_blocks.items():
    print(f"\n=== Class probe @ {layer_name} ===")

    # SVD of this conv layer
    U, S = get_conv_svd(conv)

    # Extract GAP features (source domains only)
    X, y_class, _ = extract_gap_features(domains_source, layer_name)

    for r in ranks:
        print(f"Rank {r}")
        for kind in ["major", "minor", "random"]:
            accs = []
            for seed in seeds:
                Usub = get_subspace(U, kind, r, seed)
                ds = ProjectedDataset(X, y_class, Usub)
                accs.append(train_probe(ds, num_classes=7, seed=seed))

            print(f"{kind:>6}: {np.mean(accs):.3f} ± {np.std(accs):.3f}")



===== CLASS PROBES (ALL CONV BLOCKS) =====

=== Class probe @ layer1.0 ===
Rank 4
 major: 0.299 ± 0.008
 minor: 0.265 ± 0.007
random: 0.278 ± 0.011
Rank 8
 major: 0.332 ± 0.002
 minor: 0.296 ± 0.003
random: 0.326 ± 0.017
Rank 16
 major: 0.394 ± 0.004
 minor: 0.348 ± 0.013
random: 0.366 ± 0.004

=== Class probe @ layer1.1 ===
Rank 4
 major: 0.279 ± 0.008
 minor: 0.236 ± 0.005
random: 0.274 ± 0.018
Rank 8
 major: 0.331 ± 0.006
 minor: 0.292 ± 0.005
random: 0.301 ± 0.020
Rank 16
 major: 0.397 ± 0.005
 minor: 0.363 ± 0.007
random: 0.373 ± 0.005

=== Class probe @ layer2.0 ===
Rank 4
 major: 0.262 ± 0.012
 minor: 0.234 ± 0.008
random: 0.287 ± 0.017
Rank 8
 major: 0.378 ± 0.009
 minor: 0.321 ± 0.011
random: 0.358 ± 0.019
Rank 16
 major: 0.439 ± 0.007
 minor: 0.366 ± 0.002
random: 0.409 ± 0.005

=== Class probe @ layer2.1 ===
Rank 4
 major: 0.245 ± 0.006
 minor: 0.299 ± 0.002
random: 0.288 ± 0.016
Rank 8
 major: 0.336 ± 0.010
 minor: 0.331 ± 0.004
random: 0.361 ± 0.010
Rank 16
 major: 0.402 

# Tier-2: ΔW-Based Domain-Sensitive Subspace Discovery

In [None]:
import copy
import numpy as np
import torch
import torch.nn as nn
import numpy.linalg as LA


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Using device: cuda


Load frozen Tier-0 backbone

In [None]:
BACKBONE_PATH = f"{PROJECT_ROOT}/resnet18_pacs_base.pt"

def load_base_model():
    model = ResNet18_FeatureHook(num_classes=7).to(device)
    state = torch.load(BACKBONE_PATH, map_location=device)
    model.load_state_dict(state)
    return model


Freeze BatchNorm running statistics: This prevents BN from silently encoding domain identity.

In [None]:
def freeze_bn_stats(model):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()                  # freeze running mean/var
            m.weight.requires_grad = True
            m.bias.requires_grad = True


Class-balanced domain loader - This ensures ΔW reflects domain, not class imbalance.



In [None]:
from collections import defaultdict

def make_balanced_loader(domain, tf, batch_size):
    ds = datasets.ImageFolder(
        root=os.path.join(PACS_ROOT, domain),
        transform=tf
    )

    # group indices by class
    class_to_indices = defaultdict(list)
    for idx, (_, y) in enumerate(ds.samples):
        class_to_indices[y].append(idx)

    # sample equal counts per class
    min_count = min(len(v) for v in class_to_indices.values())
    balanced_indices = []
    for v in class_to_indices.values():
        balanced_indices.extend(v[:min_count])

    sampler = torch.utils.data.SubsetRandomSampler(balanced_indices)

    return DataLoader(
        ds,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=NUM_WORKERS
    )


Train single-domain model (light adaptation)

In [None]:
def train_single_domain(domain, epochs=2, lr=1e-4):
    model = load_base_model()
    freeze_bn_stats(model)

    loader = make_balanced_loader(domain, train_tf, BATCH_SIZE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    return model

# Note:

# Few epochs

# Small learning rate

# No over-fitting

# This captures early domain pressure.


Extract ΔW for a convolutional layer - ΔW lives in exactly the same space as Tier-1 SVD.

In [None]:
def extract_delta_W(base_model, adapted_model, conv_path):
    """
    conv_path example: ("layer3", 1, "conv2")
    """
    base_conv = getattr(
        getattr(base_model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    adapt_conv = getattr(
        getattr(adapted_model.backbone, conv_path[0])[conv_path[1]],
        conv_path[2]
    )

    W0 = base_conv.weight.detach().cpu().numpy()
    W1 = adapt_conv.weight.detach().cpu().numpy()

    dW = W1 - W0
    C_out = dW.shape[0]
    return dW.reshape(C_out, -1)


Collect ΔW across domains

In [None]:
# Layer1 is intentionally excluded because it mostly captures low-level visual primitives (edges, textures), where ΔW is dominated by generic statistics rather than domain-specific adaptation.

conv_layers = {
    "layer2.1": ("layer2", 1, "conv2"),
    "layer3.1": ("layer3", 1, "conv2"),
    "layer4.1": ("layer4", 1, "conv2"),
}

base_model = load_base_model()
deltaWs = {k: [] for k in conv_layers}

for domain in SOURCE_DOMAINS:
    print(f"Adapting to domain: {domain}")
    adapted = train_single_domain(domain)

    for lname, path in conv_layers.items():
        dW = extract_delta_W(base_model, adapted, path)
        deltaWs[lname].append(dW)


Adapting to domain: photo
Adapting to domain: art_painting
Adapting to domain: cartoon


In [None]:
import numpy as np
import os

SAVE_PATH = f"{PROJECT_ROOT}/tier2_deltaWs_source.npz"

np.savez(
    SAVE_PATH,
    **{
        f"{lname}_{i}": dW
        for lname, dW_list in deltaWs.items()
        for i, dW in enumerate(dW_list)
    }
)

print("Saved raw ΔW matrices to:", SAVE_PATH)


Saved raw ΔW matrices to: /content/drive/MyDrive/SoMA_PACS/tier2_deltaWs_source.npz


In [None]:
def load_deltaWs(path, conv_layers, num_domains=3):
    data = np.load(path)
    deltaWs = {lname: [] for lname in conv_layers}

    for lname in conv_layers:
        for i in range(num_domains):
            deltaWs[lname].append(data[f"{lname}_{i}"])

    return deltaWs

deltaWs = load_deltaWs(
    f"{PROJECT_ROOT}/tier2_deltaWs_source.npz",
    conv_layers,
    num_domains=len(SOURCE_DOMAINS)
)

print("Loaded ΔW matrices from disk.")


Loaded ΔW matrices from disk.


In [None]:
np.savez(
    f"{PROJECT_ROOT}/tier2_deltaW_subspaces.npz",
    **{
        f"{lname}_U": U
        for lname, (U, S) in delta_subspaces.items()
    }
)

print("Saved ΔW subspaces.")


Saved ΔW subspaces.


In [None]:
def load_deltaW_subspaces(path, conv_layers):
    data = np.load(path)
    return {
        lname: (data[f"{lname}_U"], None)
        for lname in conv_layers
    }

delta_subspaces = load_deltaW_subspaces(
    f"{PROJECT_ROOT}/tier2_deltaW_subspaces.npz",
    conv_layers
)

print("Loaded ΔW subspaces.")


Loaded ΔW subspaces.


# Tier-2.2 — Empirical Domain-Sensitive Subspace (Channel Space)

Build channel–channel covariance from ΔW

In [None]:
def build_deltaW_channel_subspace(dW_list):
    """
    dW_list: list of ΔW matrices for one layer
             each ΔW has shape [C_out, D]

    Returns:
        U_domain: [C_out, C_out] eigenvectors (descending order)
        S_domain: eigenvalues
    """
    C_out = dW_list[0].shape[0]
    C = np.zeros((C_out, C_out))

    for dW in dW_list:
        C += dW @ dW.T   # channel-channel covariance

    # Eigen-decomposition (symmetric PSD matrix)
    eigvals, eigvecs = np.linalg.eigh(C)

    # Sort descending
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]

    return eigvecs, eigvals


Construct ΔW domain subspaces per layer

In [None]:
delta_subspaces = {}

for lname, dW_list in deltaWs.items():
    U_domain, S_domain = build_deltaW_channel_subspace(dW_list)
    delta_subspaces[lname] = (U_domain, S_domain)

print("Constructed empirical ΔW domain-sensitive subspaces.")


Constructed empirical ΔW domain-sensitive subspaces.


# E2.3 — ALIGNMENT WITH SoMA MINOR SUBSPACE

SoMA subspace constructor

In [None]:
# def get_soma_subspace(conv, kind, r):
#     """
#     kind: 'minor', 'major', or 'random'
#     Returns: [C_out, r] orthonormal basis
#     """
#     W = conv.weight.detach().cpu().numpy()
#     C_out = W.shape[0]
#     Wmat = W.reshape(C_out, -1)

#     U, S, _ = np.linalg.svd(Wmat, full_matrices=False)

#     if kind == "minor":
#         return U[:, -r:]
#     if kind == "major":
#         return U[:, :r]
#     if kind == "random":
#         Q, _ = np.linalg.qr(np.random.randn(C_out, r))
#         return Q

#     raise ValueError

Alignment metric computation

In [None]:
def compute_alignment_metrics(U, V):
    """
    U, V: [C_out, r] orthonormal basis matrices

    Returns:
        mean_angle_deg
        overlap (mean cos^2)
        projection_energy
    """
    # Cross-subspace matrix
    M = U.T @ V   # [r, r]

    # Singular values = cos(theta_i)
    sv = np.linalg.svd(M, compute_uv=False)
    sv = np.clip(sv, -1.0, 1.0)

    # Principal angles (degrees)
    angles = np.degrees(np.arccos(sv))
    mean_angle = angles.mean()

    # Overlap (mean squared cosine)
    overlap = np.mean(sv ** 2)

    # Projection energy (equivalent to overlap)
    proj_energy = np.linalg.norm(M, ord="fro") ** 2 / V.shape[1]

    return mean_angle, overlap, proj_energy


Alignment sweep (SoMA vs ΔW)

In [None]:
# Build conv_blocks mapping (SoMA reference layers)

conv_layers = {
    "layer2.1": ("layer2", 1, "conv2"),
    "layer3.1": ("layer3", 1, "conv2"),
    "layer4.1": ("layer4", 1, "conv2"),
}
base_model = load_base_model()
conv_blocks = {
    "layer2.1": base_model.backbone.layer2[1].conv2,
    "layer3.1": base_model.backbone.layer3[1].conv2,
    "layer4.1": base_model.backbone.layer4[1].conv2,
}

print("Defined conv_blocks:", list(conv_blocks.keys()))


Defined conv_blocks: ['layer2.1', 'layer3.1', 'layer4.1']


In [None]:
ranks = [2, 4, 8, 16]
n_random = 10

alignment_results = {}

for lname, (U_domain, S_domain) in delta_subspaces.items():
    print(f"\n=== Alignment @ {lname} ===")
    alignment_results[lname] = {}

    conv = conv_blocks[lname]

    for r in ranks:
        print(f"\nRank {r}")
        V = U_domain[:, :r]   # empirical ΔW subspace

        alignment_results[lname][r] = {}

        for kind in ["minor", "major", "random"]:
            angles, overlaps, energies = [], [], []

            for seed in range(n_random):
                np.random.seed(seed)
                U = get_soma_subspace(conv, kind, r)

                mean_angle, overlap, proj_energy = compute_alignment_metrics(U, V)

                angles.append(mean_angle)
                overlaps.append(overlap)
                energies.append(proj_energy)

            alignment_results[lname][r][kind] = {
                "mean_angle_deg": (np.mean(angles), np.std(angles)),
                "overlap": (np.mean(overlaps), np.std(overlaps)),
                "proj_energy": (np.mean(energies), np.std(energies)),
            }

            print(
                f"{kind:>6} | "
                f"angle={np.mean(angles):.2f}° ± {np.std(angles):.2f} | "
                f"overlap={np.mean(overlaps):.3f} ± {np.std(overlaps):.3f} | "
                f"energy={np.mean(energies):.3f} ± {np.std(energies):.3f}"
            )



=== Alignment @ layer2.1 ===

Rank 2
 minor | angle=88.70° ± 0.00 | overlap=0.001 ± 0.000 | energy=0.001 ± 0.000
 major | angle=71.25° ± 0.00 | overlap=0.123 ± 0.000 | energy=0.123 ± 0.000
random | angle=84.26° ± 2.56 | overlap=0.014 ± 0.011 | energy=0.014 ± 0.011

Rank 4
 minor | angle=84.59° ± 0.00 | overlap=0.012 ± 0.000 | energy=0.012 ± 0.000
 major | angle=75.95° ± 0.00 | overlap=0.085 ± 0.000 | energy=0.085 ± 0.000
random | angle=81.61° ± 1.29 | overlap=0.031 ± 0.009 | energy=0.031 ± 0.009

Rank 8
 minor | angle=82.35° ± 0.00 | overlap=0.024 ± 0.000 | energy=0.024 ± 0.000
 major | angle=70.81° ± 0.00 | overlap=0.146 ± 0.000 | energy=0.146 ± 0.000
random | angle=77.92° ± 1.28 | overlap=0.061 ± 0.012 | energy=0.061 ± 0.012

Rank 16
 minor | angle=78.92° ± 0.00 | overlap=0.051 ± 0.000 | energy=0.051 ± 0.000
 major | angle=61.52° ± 0.00 | overlap=0.271 ± 0.000 | energy=0.271 ± 0.000
random | angle=72.06° ± 0.96 | overlap=0.125 ± 0.012 | energy=0.125 ± 0.012

=== Alignment @ layer3.1

(Optional) Save results

In [None]:
import pandas as pd

rows = []
for lname, layer_data in alignment_results.items():
    for r, r_data in layer_data.items():
        for kind, metrics in r_data.items():
            rows.append({
                "layer": lname,
                "rank": r,
                "subspace": kind,
                "mean_angle_deg": metrics["mean_angle_deg"][0],
                "angle_std": metrics["mean_angle_deg"][1],
                "overlap": metrics["overlap"][0],
                "overlap_std": metrics["overlap"][1],
                "proj_energy": metrics["proj_energy"][0],
                "energy_std": metrics["proj_energy"][1],
            })

df = pd.DataFrame(rows)
df.to_csv(f"{PROJECT_ROOT}/tier2_alignment_metrics.csv", index=False)

print("Saved Tier-2.3 alignment results.")

In [None]:
print("Adapting baseline model to Sketch domain...")
adapted_sketch = train_single_domain("sketch")


Adapting baseline model to Sketch domain...


In [None]:
deltaWs_sketch = {}

for lname, path in conv_layers.items():
    dW = extract_delta_W(base_model, adapted_sketch, path)
    deltaWs_sketch[lname] = dW
    print(f"{lname}: ΔW shape = {dW.shape}")


layer2.1: ΔW shape = (128, 1152)
layer3.1: ΔW shape = (256, 2304)
layer4.1: ΔW shape = (512, 4608)


Convert Sketch ΔW into a subspace

In [None]:
def get_deltaW_subspace(dW, r):
    """
    Convert ΔW matrix into rank-r subspace via SVD.
    """
    U, S, _ = np.linalg.svd(dW, full_matrices=False)
    return U[:, :r]


Align Sketch ΔW subspace with SoMA subspaces

In [None]:
ranks = [2, 4, 8, 16]
n_random = 10

print("\n=== Sketch ΔW vs SoMA Subspace Alignment ===")

sketch_alignment_results = {}

for lname in conv_layers:
    print(f"\n=== Layer {lname} ===")
    sketch_alignment_results[lname] = {}

    dW = deltaWs_sketch[lname]
    conv = conv_blocks[lname]

    for r in ranks:
        print(f"\nRank {r}")
        sketch_alignment_results[lname][r] = {}

        # Sketch ΔW subspace
        U_sketch = get_deltaW_subspace(dW, r)

        for kind in ["minor", "major", "random"]:
            angles, overlaps, energies = [], [], []

            for seed in range(n_random):
                np.random.seed(seed)

                # SoMA subspace
                U_soma = get_soma_subspace(conv, kind, r)

                mean_angle, overlap, proj_energy = compute_alignment_metrics(
                    U_soma, U_sketch
                )

                angles.append(mean_angle)
                overlaps.append(overlap)
                energies.append(proj_energy)

            sketch_alignment_results[lname][r][kind] = {
                "mean_angle_deg": (np.mean(angles), np.std(angles)),
                "overlap": (np.mean(overlaps), np.std(overlaps)),
                "proj_energy": (np.mean(energies), np.std(energies)),
            }

            print(
                f"{kind:>6} | "
                f"angle={np.mean(angles):.2f}° ± {np.std(angles):.2f} | "
                f"overlap={np.mean(overlaps):.3f} ± {np.std(overlaps):.3f} | "
                f"energy={np.mean(energies):.3f} ± {np.std(energies):.3f}"
            )



=== Sketch ΔW vs SoMA Subspace Alignment ===

=== Layer layer2.1 ===

Rank 2
 minor | angle=86.56° ± 0.00 | overlap=0.006 ± 0.000 | energy=0.006 ± 0.000
 major | angle=80.82° ± 0.00 | overlap=0.035 ± 0.000 | energy=0.035 ± 0.000
random | angle=85.68° ± 1.70 | overlap=0.008 ± 0.006 | energy=0.008 ± 0.006

Rank 4
 minor | angle=84.65° ± 0.00 | overlap=0.014 ± 0.000 | energy=0.014 ± 0.000
 major | angle=79.65° ± 0.00 | overlap=0.052 ± 0.000 | energy=0.052 ± 0.000
random | angle=81.75° ± 1.56 | overlap=0.030 ± 0.011 | energy=0.030 ± 0.011

Rank 8
 minor | angle=81.52° ± 0.00 | overlap=0.033 ± 0.000 | energy=0.033 ± 0.000
 major | angle=72.98° ± 0.00 | overlap=0.119 ± 0.000 | energy=0.119 ± 0.000
random | angle=77.64° ± 1.45 | overlap=0.065 ± 0.014 | energy=0.065 ± 0.014

Rank 16
 minor | angle=77.17° ± 0.00 | overlap=0.072 ± 0.000 | energy=0.072 ± 0.000
 major | angle=64.74° ± 0.00 | overlap=0.224 ± 0.000 | energy=0.224 ± 0.000
random | angle=72.19° ± 0.45 | overlap=0.123 ± 0.005 | energy

Check if Sketch's ΔW aligns with 3-domain ΔW subspace

In [None]:
def project_matrix_onto_subspace(U, dW):
    """
    U  : [C_out, r] subspace basis (3-domain ΔW subspace)
    dW : [C_out, D] Sketch ΔW matrix

    Returns:
        projection_energy (fraction of ΔW energy captured)
        effective_angle_deg
    """
    # Project ΔW onto subspace
    P = U.T @ dW

    proj_energy = (np.linalg.norm(P, ord="fro") ** 2) / \
                  (np.linalg.norm(dW, ord="fro") ** 2)

    # Effective angle (for interpretability)
    eff_angle = np.degrees(
        np.arccos(np.sqrt(np.clip(proj_energy, 0.0, 1.0)))
    )

    return proj_energy, eff_angle


In [None]:
ranks = [2, 4, 8, 16]

print("\n=== Sketch ΔW vs 3-Domain ΔW Subspace (Approach B) ===")

sketch_vs_deltaW_results = {}

for lname, (U_domain, S_domain) in delta_subspaces.items():
    print(f"\n=== Layer {lname} ===")
    sketch_vs_deltaW_results[lname] = {}

    dW_sketch = deltaWs_sketch[lname]
    C_out = dW_sketch.shape[0]

    for r in ranks:
        U_sub = U_domain[:, :r]

        proj_energy, eff_angle = project_matrix_onto_subspace(
            U_sub, dW_sketch
        )

        random_baseline = r / C_out

        sketch_vs_deltaW_results[lname][r] = {
            "projection_energy": proj_energy,
            "effective_angle_deg": eff_angle,
            "random_baseline": random_baseline
        }

        print(
            f"Rank {r:2d} | "
            f"proj={proj_energy:.3f} | "
            f"angle={eff_angle:.1f}° | "
            f"random≈{random_baseline:.3f}"
        )



=== Sketch ΔW vs 3-Domain ΔW Subspace (Approach B) ===

=== Layer layer2.1 ===
Rank  2 | proj=0.041 | angle=78.4° | random≈0.016
Rank  4 | proj=0.097 | angle=71.8° | random≈0.031
Rank  8 | proj=0.173 | angle=65.4° | random≈0.062
Rank 16 | proj=0.289 | angle=57.5° | random≈0.125

=== Layer layer3.1 ===
Rank  2 | proj=0.039 | angle=78.6° | random≈0.008
Rank  4 | proj=0.083 | angle=73.3° | random≈0.016
Rank  8 | proj=0.154 | angle=66.9° | random≈0.031
Rank 16 | proj=0.256 | angle=59.6° | random≈0.062

=== Layer layer4.1 ===
Rank  2 | proj=0.208 | angle=62.8° | random≈0.004
Rank  4 | proj=0.482 | angle=46.0° | random≈0.008
Rank  8 | proj=0.666 | angle=35.3° | random≈0.016
Rank 16 | proj=0.708 | angle=32.7° | random≈0.031


# Tier 3.1 — Subspace Swap Experiment

In [None]:
import torch, numpy as np, random

def set_seed(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(0)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Path to the baseline (source-trained) ResNet-18
base_model_path = f"{PROJECT_ROOT}/resnet18_pacs_base.pt"

def load_baseline_model(path):
    model = ResNet18_FeatureHook(num_classes=7).to(device)
    state = torch.load(path, map_location=device)
    model.load_state_dict(state)
    return model


In [None]:
def get_soma_subspace(conv, kind, r, seed=0):
    """
    Returns an orthonormal basis U [C_out, r]
    for the requested subspace of conv2 weights.
    """
    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)

    U, S, _ = np.linalg.svd(Wmat, full_matrices=False)

    if kind == "major":
        return U[:, :r]

    if kind == "minor":
        return U[:, -r:]

    if kind == "middle":
        start = (C_out // 2) - (r // 2)
        return U[:, start:start + r]

    if kind == "random":
        rng = np.random.default_rng(seed)
        Q, _ = np.linalg.qr(rng.standard_normal((C_out, r)))
        return Q

    raise ValueError(f"Unknown subspace kind: {kind}")


In [None]:
# ===== Tier-3 Data Loaders =====

train_loaders = [
    make_balanced_loader("photo", train_tf, BATCH_SIZE),
    make_balanced_loader("art_painting", train_tf, BATCH_SIZE),
    make_balanced_loader("cartoon", train_tf, BATCH_SIZE),
]

test_loader = make_loader("sketch", test_tf, shuffle=False)

print("Tier-3 loaders ready.")


Tier-3 loaders ready.


Gradient projection hook (core mechanism)

In [None]:
def register_gradient_projection(param, U_np):
    """
    param : torch.nn.Parameter
    U_np  : numpy array [C_out, r]
    """

    # Convert once, outside the hook
    U = torch.tensor(U_np, dtype=torch.float32, device=param.device)
    U.requires_grad_(False)

    def hook(grad):
        # grad: [C_out, C_in, k, k]
        g = grad.view(grad.shape[0], -1)     # [C_out, D]
        g_proj = U @ (U.T @ g)               # projection in torch
        return g_proj.view_as(grad)

    param.register_hook(hook)



Apply subspace constraint to selected layers

In [None]:
def apply_subspace_constraints(model, conv_layers, subspaces):
    """
    subspaces: {layer_name: U (C_out x r)}
    """
    for lname, U in subspaces.items():
        layer, idx, conv_name = conv_layers[lname]
        conv = getattr(getattr(model.backbone, layer)[idx], conv_name)
        register_gradient_projection(conv.weight, U)


Training loop (shared across all variants)

In [None]:
def train_with_constraints(
    model,
    train_loaders,
    epochs=3,
    lr=1e-4
):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for ep in range(epochs):
        for loader in train_loaders:
            for x, y in loader:
                x, y = x.to(device), y.to(device)

                optimizer.zero_grad()
                loss = criterion(model(x), y)
                loss.backward()
                optimizer.step()



Evaluation (target domain)

In [None]:
@torch.no_grad()
def eval_accuracy(model, loader):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        preds = model(x).argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    return correct / total


In [None]:
ranks = [4, 8, 16]
variants = ["minor", "major", "middle", "random"]

tier3_results = {}

for r in ranks:
    print(f"\n=== Rank {r} ===")
    tier3_results[r] = {}

    for variant in variants:
        print(f"\nVariant: {variant}")

        # 1. Fresh baseline
        model = load_baseline_model(base_model_path)

        # 2. Build subspaces
        subspaces = {}
        for lname, (layer, idx, conv_name) in conv_layers.items():
            conv = getattr(getattr(model.backbone, layer)[idx], conv_name)
            U = get_soma_subspace(conv, variant, r)
            subspaces[lname] = U

        # 3. Apply constraints
        apply_subspace_constraints(model, conv_layers, subspaces)

        # 4. Train
        train_with_constraints(model, train_loaders)

        # 5. Evaluate
        acc = eval_accuracy(model, test_loader)
        tier3_results[r][variant] = acc

        print(f"Target accuracy: {acc:.3f}")



=== Rank 4 ===

Variant: minor
Target accuracy: 0.711

Variant: major
Target accuracy: 0.744

Variant: middle
Target accuracy: 0.735

Variant: random
Target accuracy: 0.728

=== Rank 8 ===

Variant: minor
Target accuracy: 0.658

Variant: major
Target accuracy: 0.745

Variant: middle
Target accuracy: 0.717

Variant: random
Target accuracy: 0.728

=== Rank 16 ===

Variant: minor
Target accuracy: 0.729

Variant: major
Target accuracy: 0.737

Variant: middle
Target accuracy: 0.712

Variant: random
Target accuracy: 0.730


# Tier 3.4 — BatchNorm Confound Controls

In [None]:
import os, copy, random
import numpy as np
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

def set_seed(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
# You must have PROJECT_ROOT already
base_model_path = f"{PROJECT_ROOT}/resnet18_pacs_base.pt"
assert os.path.exists(base_model_path), f"Missing baseline weights at {base_model_path}"

def load_baseline_model():
    model = ResNet18_FeatureHook(num_classes=7).to(device)
    state = torch.load(base_model_path, map_location=device)
    model.load_state_dict(state)
    return model


In [None]:
# Uses your existing make_balanced_loader, make_loader, train_tf, test_tf, BATCH_SIZE
train_loaders = [
    make_balanced_loader("photo", train_tf, BATCH_SIZE),
    make_balanced_loader("art_painting", train_tf, BATCH_SIZE),
    make_balanced_loader("cartoon", train_tf, BATCH_SIZE),
]
test_loader = make_loader("sketch", test_tf, shuffle=False)

print("Tier-3 data loaders ready.")


Tier-3 data loaders ready.


In [None]:
def apply_subspace_constraints(model, conv_layers, kind, r, seed=0):
    """
    Apply gradient projection constraints to conv2 weights.

    model       : nn.Module
    conv_layers : dict {lname: (layer, idx, conv_name)}
    kind        : {"minor", "major", "middle", "random"}
    r           : rank of subspace
    seed        : random seed (used only for 'random')
    """
    for lname, (layer, idx, conv_name) in conv_layers.items():
        conv = getattr(getattr(model.backbone, layer)[idx], conv_name)

        # get subspace basis (numpy)
        U = get_soma_subspace(conv, kind, r, seed=seed)

        # register gradient projection hook
        register_gradient_projection(conv.weight, U)


In [None]:
def set_bn_mode(model, mode: str):
    """
    mode in {"bn_default", "bn_frozen", "bn_affine_only"}
    """
    assert mode in {"bn_default", "bn_frozen", "bn_affine_only"}

    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            if mode == "bn_default":
                m.train()  # stats update
                m.weight.requires_grad_(True)
                m.bias.requires_grad_(True)

            elif mode == "bn_frozen":
                m.eval()   # stats frozen
                m.weight.requires_grad_(False)
                m.bias.requires_grad_(False)

            elif mode == "bn_affine_only":
                m.eval()   # stats frozen
                m.weight.requires_grad_(True)
                m.bias.requires_grad_(True)


def replace_bn_with_gn(module: nn.Module, num_groups: int = 32):
    """
    Recursively replace all BatchNorm2d layers with GroupNorm.
    This removes running-stat confounds entirely.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.BatchNorm2d):
            num_channels = child.num_features
            # GN requires num_channels divisible by num_groups; adjust safely.
            g = min(num_groups, num_channels)
            while num_channels % g != 0 and g > 1:
                g -= 1
            gn = nn.GroupNorm(num_groups=g, num_channels=num_channels, affine=True)
            setattr(module, name, gn)
        else:
            replace_bn_with_gn(child, num_groups=num_groups)


def apply_norm_config(model, norm_cfg: str):
    """
    norm_cfg in {"bn_default", "bn_frozen", "bn_affine_only", "groupnorm"}
    """
    assert norm_cfg in {"bn_default", "bn_frozen", "bn_affine_only", "groupnorm"}

    if norm_cfg == "groupnorm":
        replace_bn_with_gn(model)
        # GroupNorm has no running stats; keep train mode.
        model.train()
    else:
        set_bn_mode(model, norm_cfg)

    return model


In [None]:
def get_soma_subspace(conv, kind, r, seed=0):
    """
    Returns U basis [C_out, r] as numpy array (orthonormal).
    """
    W = conv.weight.detach().cpu().numpy()
    C_out = W.shape[0]
    Wmat = W.reshape(C_out, -1)
    U, S, _ = np.linalg.svd(Wmat, full_matrices=False)

    if kind == "major":
        return U[:, :r]
    if kind == "minor":
        return U[:, -r:]
    if kind == "middle":
        start = (C_out // 2) - (r // 2)
        return U[:, start:start + r]
    if kind == "random":
        rng = np.random.default_rng(seed)
        Q, _ = np.linalg.qr(rng.standard_normal((C_out, r)))
        return Q

    raise ValueError(f"Unknown kind: {kind}")


In [None]:
def register_gradient_projection(param, U_np):
    """
    Project gradient onto span(U).
    U_np is numpy [C_out, r].
    """
    U = torch.tensor(U_np, dtype=torch.float32, device=param.device)
    U.requires_grad_(False)

    def hook(grad):
        g = grad.view(grad.shape[0], -1)         # [C_out, D]
        g_proj = U @ (U.T @ g)                   # [C_out, D]
        return g_proj.view_as(grad)

    param.register_hook(hook)


def apply_subspace_constraints(model, conv_layers, kind, r, seed=0):
    """
    Apply gradient projection to conv2 weights for selected layers.
    conv_layers: dict {lname: (layer, idx, "conv2")}
    """
    for lname, (layer, idx, conv_name) in conv_layers.items():
        conv = getattr(getattr(model.backbone, layer)[idx], conv_name)
        U = get_soma_subspace(conv, kind, r, seed=seed)
        register_gradient_projection(conv.weight, U)


In [None]:
@torch.no_grad()
def eval_accuracy(model, loader):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total


def train_with_logging(model, train_loaders, epochs=3, lr=1e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    logs = {
        "epoch_loss": [],
        "grad_norm": [],
        "update_norm": [],
    }

    # snapshot params for update norm
    with torch.no_grad():
        init_params = {n: p.detach().clone() for n, p in model.named_parameters() if p.requires_grad}

    model.train()
    for ep in range(epochs):
        ep_loss_sum, ep_n = 0.0, 0
        grad_norm_sum, grad_n = 0.0, 0

        for loader in train_loaders:
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                logits = model(x)
                loss = criterion(logits, y)
                loss.backward()

                # grad norm (after projection, since hooks already applied)
                with torch.no_grad():
                    total_g = 0.0
                    for p in model.parameters():
                        if p.grad is not None:
                            total_g += p.grad.detach().float().norm().item() ** 2
                    grad_norm_sum += total_g ** 0.5
                    grad_n += 1

                optimizer.step()

                ep_loss_sum += loss.item() * x.size(0)
                ep_n += x.size(0)

        logs["epoch_loss"].append(ep_loss_sum / max(ep_n, 1))
        logs["grad_norm"].append(grad_norm_sum / max(grad_n, 1))

    # update norm (final minus initial)
    with torch.no_grad():
        total_u = 0.0
        for n, p in model.named_parameters():
            if n in init_params:
                total_u += (p.detach() - init_params[n]).float().norm().item() ** 2
        logs["update_norm"].append(total_u ** 0.5)

    return logs


In [None]:
norm_cfgs = ["bn_default", "bn_affine_only", "bn_frozen", "groupnorm"]
subspaces = ["minor", "major", "middle", "random"]
ranks = [4, 8, 16]
seeds = [0, 1, 2]

results = []

for norm_cfg in norm_cfgs:
    for r in ranks:
        for kind in subspaces:
            for seed in seeds:
                set_seed(seed)

                # 1) fresh baseline
                model = load_baseline_model()

                # 2) apply normalization configuration
                model = apply_norm_config(model, norm_cfg)

                # 3) apply subspace constraint (seed matters for random)
                apply_subspace_constraints(model, conv_layers, kind, r, seed=seed)

                # 4) train with logging
                logs = train_with_logging(model, train_loaders, epochs=3, lr=1e-4)

                # 5) eval
                acc = eval_accuracy(model, test_loader)

                row = {
                    "norm_cfg": norm_cfg,
                    "rank": r,
                    "subspace": kind,
                    "seed": seed,
                    "target_acc": acc,
                    "loss_ep0": logs["epoch_loss"][0] if len(logs["epoch_loss"]) > 0 else None,
                    "loss_ep1": logs["epoch_loss"][1] if len(logs["epoch_loss"]) > 1 else None,
                    "loss_ep2": logs["epoch_loss"][2] if len(logs["epoch_loss"]) > 2 else None,
                    "gradnorm_mean": float(np.mean(logs["grad_norm"])) if len(logs["grad_norm"]) else None,
                    "update_norm": logs["update_norm"][0] if len(logs["update_norm"]) else None,
                }
                results.append(row)

                print(
                    f"[{norm_cfg}] r={r:2d} {kind:>6} seed={seed} "
                    f"acc={acc:.3f} loss={logs['epoch_loss'][-1]:.3f} "
                    f"gn={np.mean(logs['grad_norm']):.3f} upd={logs['update_norm'][0]:.3f}"
                )


[bn_default] r= 4  minor seed=0 acc=0.735 loss=0.005 gn=0.610 upd=4.324
[bn_default] r= 4  minor seed=1 acc=0.721 loss=0.007 gn=0.543 upd=4.498
[bn_default] r= 4  minor seed=2 acc=0.674 loss=0.006 gn=0.661 upd=4.460
[bn_default] r= 4  major seed=0 acc=0.731 loss=0.005 gn=0.654 upd=4.476
[bn_default] r= 4  major seed=1 acc=0.733 loss=0.007 gn=0.537 upd=4.586


In [None]:
import pandas as pd

df = pd.DataFrame(results)

# Aggregate mean ± std by condition
summary = (
    df.groupby(["norm_cfg", "rank", "subspace"])
      .agg(target_acc_mean=("target_acc", "mean"),
           target_acc_std=("target_acc", "std"),
           gradnorm_mean=("gradnorm_mean", "mean"),
           update_norm_mean=("update_norm", "mean"))
      .reset_index()
      .sort_values(["norm_cfg", "rank", "target_acc_mean"], ascending=[True, True, False])
)

display(summary)

# Save
csv_path = f"{PROJECT_ROOT}/tier3_bn_grid_results.csv"
summary_path = f"{PROJECT_ROOT}/tier3_bn_grid_summary.csv"
npz_path = f"{PROJECT_ROOT}/tier3_bn_grid_results.npz"

df.to_csv(csv_path, index=False)
summary.to_csv(summary_path, index=False)

# also save raw arrays in npz for safe reload
np.savez(npz_path, **{f"row_{i}": np.array(list(row.values()), dtype=object) for i, row in enumerate(results)})

print("Saved:")
print(" -", csv_path)
print(" -", summary_path)
print(" -", npz_path)
