In [1]:
import math
import os
from pathlib import Path
from typing import Callable, Optional
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from wilds import get_dataset

from models import WaterbirdResNet18, SPDTwoLayerFC
from spd.hooks import HookedRootModule
from spd.log import logger
from spd.models.base import SPDModel
from spd.module_utils import (
    get_nested_module_attr,
    collect_nested_module_attrs,
)
from spd.types import Probability
from spd.utils import set_seed
from train_resnet import WaterbirdsSubset
from run_spd import WaterbirdSPDConfig

In [2]:
import einops


def calc_topk_mask(attribution_scores: torch.Tensor, topk: float, batch_topk: bool) -> torch.Tensor:
    """
    Compute a boolean mask for top-k subcomponents. Same logic as TMS:
      - if batch_topk=True, multiply topk by batch_size & sort globally
      - else take topk as an integer K
    """
    batch_size = attribution_scores.shape[0]
    # topk is float that might be fraction of batch_size
    # we handle rounding, etc.
    k = int(topk * batch_size) if batch_topk else int(topk)

    if batch_topk:
        # flatten "b c" -> " (b c)"
        # find topk globally
        shape_ = attribution_scores.shape
        reshaped = einops.rearrange(attribution_scores, "b ... c -> ... (b c)")
        topk_indices = torch.topk(reshaped, k, dim=-1).indices
        mask = torch.zeros_like(reshaped, dtype=torch.bool)
        mask.scatter_(dim=-1, index=topk_indices, value=True)
        # reshape back
        mask = einops.rearrange(mask, "... (b c) -> b ... c", b=batch_size)
        return mask
    else:
        # topk in the last dimension for each sample
        topk_indices = attribution_scores.topk(k, dim=-1).indices
        mask = torch.zeros_like(attribution_scores, dtype=torch.bool)
        mask.scatter_(dim=-1, index=topk_indices, value=True)
        return mask


def calc_lp_sparsity_loss(
    out: torch.Tensor,
    attributions: torch.Tensor,
    step_pnorm: float,
) -> torch.Tensor:
    """
    Lp-sparsity on subnetwork attributions. Similar to TMS code:
       - we first normalize by out.shape[-1]
       - then raise absolute value to step_pnorm*0.5
    """
    # e.g. divide attributions by out_dim
    d_model_out = out.shape[-1]
    scaled_attrib = attributions / d_model_out
    # do (abs(...) + 1e-16)**(0.5*step_pnorm)
    return (scaled_attrib.abs() + 1e-16) ** (0.5 * step_pnorm)


@torch.inference_mode()
def calc_activation_attributions(
    component_acts: dict[str, torch.Tensor],
) -> torch.Tensor:
    """
    Example "activation" approach: sum of L2 over the subcomponent_acts
    shape: (batch, C) or (batch, n_instances, C).
    """
    # Just sum up squares of each subcomponent
    # e.g. each entry in component_acts is (batch, C, d_out)
    # sum over d_out dimension
    first_key = next(iter(component_acts.keys()))
    out_shape = component_acts[first_key].shape[:-1]  # strip d_out
    attributions = torch.zeros(out_shape, device=component_acts[first_key].device)
    for val in component_acts.values():
        attributions += val.pow(2).sum(dim=-1)
    return attributions


def calc_grad_attributions(
    model_out: torch.Tensor,  # teacher or spd output
    post_weight_acts: dict[str, torch.Tensor],
    pre_weight_acts: dict[str, torch.Tensor],
    component_weights: dict[str, torch.Tensor],
    C: int,
) -> torch.Tensor:
    """
    Like TMS's gradient approach: for each output dimension,
    grad wrt post acts * subcomponent partial forward -> subcomponent attribution
    Summed across output dims and squared.
    """
    import torch.autograd as autograd

    # unify keys
    post_names = [k.removesuffix(".hook_post") for k in post_weight_acts.keys()]
    pre_names = [k.removesuffix(".hook_pre") for k in pre_weight_acts.keys()]
    comp_names = list(component_weights.keys())
    assert set(post_names) == set(pre_names) == set(comp_names), "layer name mismatch"

    batch_prefix = model_out.shape[:-1]  # e.g. (batch,) or (batch, n_inst)
    out_dim = model_out.shape[-1]
    attribution_scores = torch.zeros((*batch_prefix, C), device=model_out.device)

    # get subcomponent partial forward
    component_acts = {}
    for nm in pre_names:
        # shape pre: (batch..., d_in), comp_W: (C, d_in, d_out)
        # => (batch..., C, d_out)
        pre_ = pre_weight_acts[nm + ".hook_pre"].detach()
        w_ = component_weights[nm]
        partial = einops.einsum(
            pre_, w_, "... d_in, C d_in d_out -> ... C d_out"
        )
        component_acts[nm] = partial

    for feature_idx in range(out_dim):
        # sum up that scalar
        # grads = autograd.grad(
        #     model_out[..., feature_idx].sum(),
        #     list(post_weight_acts.values()),
        #     retain_graph=True,
        # )
        # grads = autograd.grad(
        #     model_out[..., feature_idx].sum(),
        #     list(post_weight_acts.values()),
        #     retain_graph=True,
        #     allow_unused=True  # Add this parameter
        # )
        # grads is tuple of same length as post_weight_acts
        # feature_attrib = torch.zeros((*batch_prefix, C), device=model_out.device)
        # for grad_val, nm_post in zip(grads, post_weight_acts.keys()):
        #     nm_clean = nm_post.removesuffix(".hook_post")
        #     feature_attrib += einops.einsum(
        #         grad_val, component_acts[nm_clean],
        #         "... d_out, ... C d_out -> ... C"
        #     )

        # In calc_grad_attributions, modify the autograd.grad call:
        grads = autograd.grad(
            model_out[..., feature_idx].sum(),
            list(post_weight_acts.values()),
            retain_graph=True,
        )
        feature_attrib = torch.zeros((*batch_prefix, C), device=model_out.device)
        # Then handle potential None values in grads:
        for grad_val, nm_post in zip(grads, post_weight_acts.keys()):
            if grad_val is not None:  # Skip None gradients
                nm_clean = nm_post.removesuffix(".hook_post")
                feature_attrib += einops.einsum(
                    grad_val, component_acts[nm_clean],
                    "... d_out, ... C d_out -> ... C"
                )
        # square then accumulate
        attribution_scores += feature_attrib**2

    return attribution_scores


def calculate_attributions(
    model: SPDModel,
    input_x: torch.Tensor,
    out: torch.Tensor,
    teacher_out: torch.Tensor,  # or target_out
    pre_acts: dict[str, torch.Tensor],
    post_acts: dict[str, torch.Tensor],
    component_acts: dict[str, torch.Tensor],  # might not be used if gradient type
    attribution_type: str = "gradient",
) -> torch.Tensor:
    """
    We unify "ablation", "gradient", or "activation" approach; for simplicity we do "gradient" or "activation".
    """
    if attribution_type == "activation":
        return calc_activation_attributions(component_acts)
    elif attribution_type == "gradient":
        # we call teacher_out or out as the "model_out"? It's up to you.
        # Usually we do grad wrt target_out
        # We'll do wrt teacher_out or spd_out. For TMS code, it's "target_out" or "distil_from_target".
        # let's do teacher_out for distillation
        # or do out if you want to see SPD's own gradient. We'll do teacher_out.
        # If teacher_out is [B, 2], we do .sum() => need to check
        return calc_grad_attributions(
            model_out=teacher_out,
            post_weight_acts={k: v for k, v in post_acts.items()},
            pre_weight_acts={k: v for k, v in pre_acts.items()},
            component_weights=collect_nested_module_attrs(model, "component_weights", include_attr_name=False),
            C=model.C,
        )
    else:
        raise ValueError(f"Unsupported attribution type {attribution_type}")


def calc_recon_mse(pred: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
    """MSE across batch, shape e.g. (batch, #classes) => scalar"""
    return ((pred - ref) ** 2).mean(dim=-1).mean()


###############################
# The actual training script
###############################

from pydantic import BaseModel, PositiveInt, PositiveFloat, Field


class WaterbirdSPDConfig(BaseModel):
    # Basic
    seed: int = 0
    batch_size: PositiveInt = 32
    steps: PositiveInt = 500
    lr: float = 1e-3
    print_freq: int = 50
    save_freq: Optional[int] = None
    out_dir: Optional[str] = None

    # Distillation
    distill_coeff: float = 1.0  # how strongly we do MSE with teacher
    # If you want param match
    param_match_coeff: float = 0.0

    # For subcomponent #0 background detection
    alpha_condition: float = 1.0

    # SPD subcomponent config
    C: PositiveInt = 40
    m_fc1: PositiveInt = 16
    m_fc2: PositiveInt = 16

    # LR schedule
    lr_schedule: str = "constant"  # or "linear", "cosine", "exponential"
    lr_exponential_halflife: float | None = None
    lr_warmup_pct: float = 0.0

    unit_norm_matrices: bool = False
    schatten_coeff: float | None = None
    schatten_pnorm: float | None = None
    # teacher ckpt
    teacher_ckpt: str = "waterbird_resnet18_best.pth"

    # topk config
    topk: float | None = None
    batch_topk: bool = True
    topk_recon_coeff: float | None = None
    distil_from_target: bool = True

    # lp sparsity
    lp_sparsity_coeff: float | None = None
    pnorm: float | None = None

    # attribution type
    attribution_type: str = "gradient"  # or "activation"

def set_As_and_Bs_to_unit_norm(spd_fc: torch.nn.Module):
    """
    In-place normalization so that each (A_c, B_c) has ||A_c||_F = 1 and/or ||B_c||_F = 1,
    preventing unbounded scale growth in factorized weights.
    """
    with torch.no_grad():
        # FC1
        A1 = spd_fc.fc1.A  # shape [C, d_in, m_fc1]
        B1 = spd_fc.fc1.B  # shape [C, m_fc1, hidden_dim]
        # Normalize each subcomponent's A and B
        for c in range(A1.shape[0]):
            normA = A1[c].norm(p=2)
            if normA > 1e-9:
                A1[c] /= normA
            normB = B1[c].norm(p=2)
            if normB > 1e-9:
                B1[c] /= normB

        # FC2 (same logic; if you only want to do it for one layer, remove these lines)
        A2 = spd_fc.fc2.A  # shape [C, hidden_dim, m_fc2]
        B2 = spd_fc.fc2.B  # shape [C, m_fc2, num_classes]
        for c in range(A2.shape[0]):
            normA = A2[c].norm(p=2)
            if normA > 1e-9:
                A2[c] /= normA
            normB = B2[c].norm(p=2)
            if normB > 1e-9:
                B2[c] /= normB


def fix_normalized_adam_gradients(spd_fc: torch.nn.Module):
    """
    Removes the gradient component corresponding to pure scale changes in each factor pair (A_c, B_c).
    In rank factorization, scaling one factor by alpha and dividing the other by alpha is a no-op.
    This prevents Adam from chasing that redundant direction.
    """
    with torch.no_grad():
        # For each layer in SPD, do an orthogonal projection of grads
        # that removes the direction that changes the norm scale of (A_c, B_c).

        def _proj_out_scale_direction(A, B, gradA, gradB):
            """
            If d d/d(alpha) [A*alpha, B/alpha] = 0 means no scale changes:
            We'll find the derivative direction that changes scale
            and zero it out from (gradA, gradB).
            """
            # Flatten
            A_vec = A.view(-1)
            B_vec = B.view(-1)
            gA_vec = gradA.view(-1)
            gB_vec = gradB.view(-1)

            # The "scale" direction is something like: d/d(alpha) [A*alpha, B/alpha],
            # which at alpha=1 is (A, -B).
            # So the direction vector is v = [A, -B].
            # We compute the component of (gA, gB) in that direction and remove it.
            v = torch.cat([A_vec, -B_vec], dim=0)
            v_norm_sq = v.dot(v) + 1e-12

            g = torch.cat([gA_vec, gB_vec], dim=0)
            scale_coeff = g.dot(v) / v_norm_sq  # how much of g is in direction of v

            # new_g = g - scale_coeff * v
            new_g = g - scale_coeff * v

            # put back
            new_gA = new_g[:A_vec.shape[0]].view_as(gradA)
            new_gB = new_g[A_vec.shape[0]:].view_as(gradB)

            gradA.copy_(new_gA)
            gradB.copy_(new_gB)

        # FC1
        A1 = spd_fc.fc1.A
        B1 = spd_fc.fc1.B
        if A1.requires_grad and B1.requires_grad:
            for c in range(A1.shape[0]):
                if A1.grad is not None and B1.grad is not None:
                    gradA1 = A1.grad[c]
                    gradB1 = B1.grad[c]
                    _proj_out_scale_direction(A1[c], B1[c], gradA1, gradB1)

        # FC2
        A2 = spd_fc.fc2.A
        B2 = spd_fc.fc2.B
        if A2.requires_grad and B2.requires_grad:
            for c in range(A2.shape[0]):
                if A2.grad is not None and B2.grad is not None:
                    gradA2 = A2.grad[c]
                    gradB2 = B2.grad[c]
                    _proj_out_scale_direction(A2[c], B2[c], gradA2, gradB2)


def calc_schatten_loss(
    As: dict[str, torch.Tensor],
    Bs: dict[str, torch.Tensor],
    mask: torch.Tensor,
    p: float,
    n_params: int,
    device: torch.device,
) -> torch.Tensor:
    """
    Approximate rank penalty: sum_{c} ||P_c||_p^p, where P_c = A_c * B_c^T,
    implemented in a factorized form. 
    We multiply by a topk or attributions-based mask if desired (like in TMS).
    """
    # We assume all layers in As and Bs have the same shape for dimension 0 (the "C" dimension).
    # If you use topk_mask, shape should be (batch, C). We'll average over batch below.

    assert As.keys() == Bs.keys(), "As and Bs must have identical keys"
    batch_size = mask.shape[0]

    schatten_penalty = torch.zeros((), device=device)  # scalar
    for layer_name in As.keys():
        # A: shape [C, d_in, m], B: shape [C, m, d_out]
        # or possibly [batch, C, ...] if you handle multi-instance
        A = As[layer_name]
        B = Bs[layer_name]

        # S_A = sum_{i,j} A^2 over d_in
        # S_B = sum_{k,l} B^2 over d_out
        # We'll do something like TMS: 
        #    S_A = einops.einsum(A, A, "... C d_in m, ... C d_in m -> ... C m")
        #    S_B = einops.einsum(B, B, "... C m d_out, ... C m d_out -> ... C m")
        # Then multiply S_AB = S_A * S_B, apply the topk mask, etc.

        S_A = einops.einsum(A, A, "... C d_in m, ... C d_in m -> ... C m")
        S_B = einops.einsum(B, B, "... C m d_out, ... C m d_out -> ... C m")
        S_AB = S_A * S_B  # shape [batch..., C, m] or [C, m] if no batch dimension

        # Now apply mask along the "C" dimension
        # shape of mask is [batch, C]. We broadcast over the "m" dimension.
        # We'll do an einsum:
        # S_AB_topk = einops.einsum(S_AB, mask, "... C m, batch C -> batch ... C m")
        # Then sum and do the p/2 exponent.
        if S_AB.ndim == 2:
            # no extra batch dimension on the parameter side => broadcast
            # reshape so S_AB => (1, C, m) for broadcasting
            S_AB = S_AB.unsqueeze(0)  # shape (1, C, m)

        S_AB_topk = einops.einsum(S_AB, mask, "b C m, b C -> b C m")
        # Now apply ( +1e-16 )^(0.5 * p)
        schatten_penalty += ((S_AB_topk + 1e-16) ** (0.5 * p)).sum()

    # normalizations
    # 1) divide by number of parameters n_params
    # 2) divide by batch_size
    schatten_penalty = schatten_penalty / (n_params * batch_size)
    return schatten_penalty


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
config = WaterbirdSPDConfig(
    batch_size=32,
    steps=500,
    lr=1e-3,
    print_freq=50,
    save_freq=200,
    out_dir="waterbird_spd_out",
    seed=0,
    distill_coeff=1.0,
    param_match_coeff=0.0,
    alpha_condition=1.0,
    C=40,
    m_fc1=16,
    m_fc2=16,
    lp_sparsity_coeff=0.01,
    pnorm=2.0,
    topk=2.0,
    batch_topk=True,
    topk_recon_coeff=0.1,
    teacher_ckpt="checkpoints/waterbird_resnet18_best.pth",
    attribution_type="gradient",
    lr_schedule="constant",
)

In [4]:
ckpt = torch.load(config.teacher_ckpt, map_location="cpu")
if "model_state_dict" in ckpt:
    state_dict = ckpt["model_state_dict"]
else:
    state_dict = ckpt 


# 1) load teacher
teacher_model = WaterbirdResNet18(num_classes=2, hidden_dim=512)

missing, unexpected = teacher_model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

  ckpt = torch.load(config.teacher_ckpt, map_location="cpu")


Missing keys: []
Unexpected keys: []


In [5]:
spd_fc = SPDTwoLayerFC(
    in_features=512,
    hidden_dim=512,
    num_classes=2,
    C=config.C,
    m_fc1=config.m_fc1,
    m_fc2=config.m_fc2,
).to(device)

In [6]:
def count_factor_params():
    count = 0
    for pname, pval in spd_fc.named_parameters():
        if "fc1" in pname or "fc2" in pname:  # e.g. only final SPD layers
            count += pval.numel()
    return count
n_params_spd = count_factor_params()

# param names if we do param match
param_names = ["fc1", "fc2"]

waterbird_dataset = get_dataset(dataset="waterbirds", download=False)

all_indices = np.arange(len(waterbird_dataset))
np.random.shuffle(all_indices)

train_size = 2000 
val_size = 1000 

train_indices = all_indices[:train_size].tolist()

val_indices = all_indices[train_size:train_size+val_size].tolist()

train_transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),  # Adding data augmentation to reduce overfitting
    T.ToTensor()
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

# Create the training subset
train_subset = WaterbirdsSubset(
    waterbird_dataset, 
    indices=train_indices,
    transform=train_transform
)
loader = DataLoader(train_subset, batch_size=config.batch_size, shuffle=True)

In [7]:
teacher_model.to(device)
teacher_model.eval()

trunk = teacher_model.features
teacher_fc1 = teacher_model.fc1
teacher_fc2 = teacher_model.fc2


# 4) optimizer
opt = optim.AdamW(spd_fc.parameters(), lr=config.lr, weight_decay=0.0)
from spd.run_spd import get_lr_schedule_fn, get_lr_with_warmup
lr_sched_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife)

mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()

# 5) training loop
steps_per_epoch = max(1, len(train_indices)//config.batch_size)
data_iter = iter(loader)
epoch = 0

from math import ceil

In [8]:
step = 1 
step_lr = get_lr_with_warmup(
    step=step,
    steps=config.steps,
    lr=config.lr,
    lr_schedule_fn=lr_sched_fn,
    lr_warmup_pct=config.lr_warmup_pct,
)
for g in opt.param_groups:
    g["lr"] = step_lr

# fetch batch
try:
    batch_data = next(data_iter)
except StopIteration:
    epoch += 1
    data_iter = iter(loader)
    batch_data = next(data_iter)

imgs, bird_label, meta = batch_data
imgs = imgs.to(device)
bird_label = bird_label.to(device)
background_label = meta.float().to(device)  # 0 or 1

opt.zero_grad(set_to_none=True)

if config.unit_norm_matrices:
    set_As_and_Bs_to_unit_norm(spd_fc)

# Step 1: Extract features from trunk (no gradients needed for this part)
with torch.no_grad():
    feats = trunk(imgs)          # [B,512,1,1]
    feats = feats.flatten(1)     # [B,512]

# Step 2: Create tensor with gradients for teacher forward pass
feats_with_grad = feats.detach().clone().requires_grad_(True)

# Step 3: Teacher forward pass with manual activation caching
teacher_cache = {}
# Cache pre-activation for fc1
teacher_cache["fc1.hook_pre"] = feats_with_grad

# Forward through fc1
teacher_h_pre = teacher_fc1(feats_with_grad)
teacher_cache["fc1.hook_post"] = teacher_h_pre

# Apply ReLU
teacher_h = torch.relu(teacher_h_pre)
teacher_cache["fc2.hook_pre"] = teacher_h

# Forward through fc2
teacher_out = teacher_fc2(teacher_h)
teacher_cache["fc2.hook_post"] = teacher_out

# Step 4: SPD forward pass with caching hooks
spd_fc.reset_hooks()
cache_dict, fwd, bwd = spd_fc.get_caching_hooks()
with spd_fc.hooks(fwd_hooks=fwd, bwd_hooks=[], reset_hooks_end=True):
    spd_h_pre = spd_fc.fc1(feats)  # Use the original feats (no grad needed)
    spd_h = torch.relu(spd_h_pre)
    spd_out = spd_fc.fc2(spd_h)

# Step 5: Gather SPD activations from the cache
pre_weight_acts = {}
post_weight_acts = {}
comp_acts = {}
for k, v in cache_dict.items():
    if k.endswith("hook_pre"):
        pre_weight_acts[k] = v
    elif k.endswith("hook_post"):
        post_weight_acts[k] = v
    elif k.endswith("hook_component_acts"):
        comp_acts[k] = v

# Step 6: Split teacher activations into pre and post
teacher_pre_acts = {k: v for k, v in teacher_cache.items() if k.endswith("hook_pre")}
teacher_post_acts = {k: v for k, v in teacher_cache.items() if k.endswith("hook_post")}

# Step 7: Calculate attributions
attributions = calculate_attributions(
    model=spd_fc,
    input_x=feats,
    out=spd_out,
    teacher_out=teacher_out if config.distil_from_target else spd_out,
    pre_acts=teacher_pre_acts,
    post_acts=teacher_post_acts,
    component_acts=comp_acts,
    attribution_type=config.attribution_type,
)

In [14]:
target_params = {}
spd_params = {}
for param_name in param_names:
    target_params[param_name] = get_nested_module_attr(teacher_model, param_name + ".weight")
    spd_params[param_name] = get_nested_module_attr(spd_fc, param_name + ".weight")

In [15]:
spd_params

{'fc1': tensor([[-0.0014,  0.0033, -0.0019,  ..., -0.0006,  0.0006,  0.0003],
         [ 0.0022, -0.0036, -0.0025,  ..., -0.0016,  0.0033,  0.0069],
         [-0.0017, -0.0018, -0.0032,  ...,  0.0010, -0.0016,  0.0052],
         ...,
         [-0.0009, -0.0010, -0.0004,  ...,  0.0034,  0.0026, -0.0001],
         [-0.0026, -0.0017,  0.0044,  ..., -0.0067, -0.0007,  0.0039],
         [ 0.0027,  0.0042,  0.0021,  ...,  0.0015, -0.0005, -0.0066]],
        device='cuda:0', grad_fn=<ViewBackward0>),
 'fc2': tensor([[-0.0118, -0.0112],
         [ 0.0582, -0.0097],
         [ 0.1283,  0.0153],
         ...,
         [-0.0542, -0.0239],
         [ 0.0123,  0.1034],
         [ 0.0827,  0.0289]], device='cuda:0', grad_fn=<ViewBackward0>)}

In [21]:
spd_params['fc2'].shape

torch.Size([512, 2])

In [18]:
target_params['fc2'].shape

torch.Size([2, 512])