In [1]:
import math
import os
from pathlib import Path
from typing import Callable, Optional
import numpy as np
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from wilds import get_dataset

from models import ColorMNISTConvNetGAP, SPDTwoLayerFC
from spd.run_spd import get_lr_schedule_fn, get_lr_with_warmup
from spd.hooks import HookedRootModule
from spd.log import logger
from spd.models.base import SPDModel
from spd.module_utils import (
    get_nested_module_attr,
    collect_nested_module_attrs,
)
from spd.types import Probability
from spd.utils import set_seed
from train_mnist import SpuriousMNIST
from delta_attr_run_spd import make_cf
from delta_attr_run_spd import SPDConfig
from delta_attr_run_spd import calculate_attributions
from delta_attr_run_spd import make_cf_digit_swap

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
config = SPDConfig(
        seed            = 0,
        out_dir         = "spd_models_new",
        batch_size      = 64,
        lr              = 3e-4,         
        lr_schedule      = "linear",
        steps            = 10000,     
        lr_warmup_pct   = 0.05,

        C               = 2,              
        m_fc1           = 320,            
        m_fc2           = 320,

        teacher_ckpt    = "checkpoints/best_model.pth",

        param_match_coeff           = 0.0,    
        relative_param_match_coeff  = 2.0,    
        distill_coeff               = 1.0,    

        alpha_condition     = 1.0,            
        cond_coeff          = 0.0,            

        topk                = None,              
        batch_topk          = False,
        topk_recon_coeff    = 0.0,            

        schatten_coeff      = 0.0,
        schatten_pnorm      = 0.9,
        unit_norm_matrices  = False,
        relative_schatten_coeff = 0.0,
        lp_sparsity_coeff   = 0.0,            
        pnorm               = None,
        lambda_r            = 1.00,
        mu_r                = 1.00,
        warmup_steps        = 1000,

        print_freq      = 50,
        save_freq       = 1000,
    )

In [3]:
mnist_model = ColorMNISTConvNetGAP()
mnist_model.load_state_dict(torch.load(config.teacher_ckpt))
mnist_model.eval()
mnist_model.to(device)

spd_fc = SPDTwoLayerFC(
    in_features=28*28,
    hidden_dim=128,
    num_classes=10,
    C=config.C,
    m_fc1=config.m_fc1,
    m_fc2=config.m_fc2
)

spd_fc.load_state_dict(torch.load('spd_models_routing_loss/waterbird_spd_best.pth'))
spd_fc.eval()
spd_fc.to(device)



  mnist_model.load_state_dict(torch.load(config.teacher_ckpt))
  spd_fc.load_state_dict(torch.load('spd_models_routing_loss/waterbird_spd_best.pth'))


SPDTwoLayerFC(
  (fc1): LinearComponent(
    (hook_pre): HookPoint()
    (hook_component_acts): HookPoint()
    (hook_post): HookPoint()
  )
  (fc2): LinearComponent(
    (hook_pre): HookPoint()
    (hook_component_acts): HookPoint()
    (hook_post): HookPoint()
  )
)

In [4]:
train_transform = T.Compose([
    T.ToTensor()
])

train_subset = SpuriousMNIST(root_dir="colorized-MNIST/training", transform=train_transform)

loader = DataLoader(train_subset, batch_size=config.batch_size, shuffle=True)

In [5]:
trunk = mnist_model.conv  # up to conv part
teacher_fc1 = mnist_model.fc1
teacher_fc2 = mnist_model.fc2

In [6]:
batch = next(iter(loader))

data_iter = iter(loader)
epoch = 0

# If you want param matching:
param_names = ["fc1","fc2"]
n_params = 0
for param_name in param_names:
    n_params += get_nested_module_attr(mnist_model, param_name + ".weight").numel()

imgs, digit_label, background_label = batch
imgs = imgs.to(device)
digit_label = digit_label.to(device)
background_label = background_label.to(device)
imgs_cf = make_cf(imgs)


with torch.no_grad():
    feats_cf = trunk(imgs_cf).mean(1).flatten(1) 

with torch.no_grad():
    feats = trunk(imgs)
    feats = feats.mean(dim=1)
    feats = feats.flatten(1)  # [B, 512]

#=========================
# (2) feats_with_grad
#=========================
feats_with_grad = feats.detach().clone().requires_grad_(True)

#=========================
# (3) teacher forward pass manually, storing in a teacher_cache
#=========================
feats_cf_grad = feats_cf.detach().clone().requires_grad_(True)
t_cache_cf = {}

t_cache_cf["fc1.hook_pre"]  = feats_cf_grad
h1_cf = teacher_fc1(feats_cf_grad)
t_cache_cf["fc1.hook_post"] = h1_cf
h_relu_cf = torch.relu(h1_cf)
t_cache_cf["fc2.hook_pre"]  = h_relu_cf
teacher_out_cf              = teacher_fc2(h_relu_cf)
t_cache_cf["fc2.hook_post"] = teacher_out_cf

teacher_cache = {}

teacher_cache["fc1.hook_pre"] = feats_with_grad
teacher_h_pre = teacher_fc1(feats_with_grad)
teacher_cache["fc1.hook_post"] = teacher_h_pre

teacher_h = torch.relu(teacher_h_pre)
teacher_cache["fc2.hook_pre"] = teacher_h

teacher_out = teacher_fc2(teacher_h)
teacher_cache["fc2.hook_post"] = teacher_out

#=========================
# (4) SPD forward pass with hooking
#=========================

spd_fc.reset_hooks()
cache_cf, fwd_hooks_cf, _ = spd_fc.get_caching_hooks()
with spd_fc.hooks(fwd_hooks_cf, [], reset_hooks_end=True):
    h1_spd_cf = spd_fc.fc1(feats_cf)
    h_spd_cf  = torch.relu(h1_spd_cf)
    spd_out_cf = spd_fc.fc2(h_spd_cf)
spd_fc.reset_hooks()
cache_dict, fwd_hooks, _ = spd_fc.get_caching_hooks()
with spd_fc.hooks(fwd_hooks, [], reset_hooks_end=True):
    spd_h_pre = spd_fc.fc1(feats)
    spd_h = torch.relu(spd_h_pre)
    spd_out = spd_fc.fc2(spd_h)

#=========================
# (5) gather SPD activations
#=========================
pre_cf  = {k:v for k,v in cache_cf.items() if k.endswith("hook_pre")}
post_cf = {k:v for k,v in cache_cf.items() if k.endswith("hook_post")}
comp_cf = {k.removesuffix(".hook_component_acts"):v
        for k,v in cache_cf.items() if k.endswith("hook_component_acts")}
pre_weight_acts = {}
post_weight_acts = {}
comp_acts = {}
for k,v in cache_dict.items():
    if k.endswith("hook_pre"):
        pre_weight_acts[k] = v
    elif k.endswith("hook_post"):
        post_weight_acts[k] = v
    elif k.endswith("hook_component_acts"):
        comp_acts[k.removesuffix(".hook_component_acts")] = v  # e.g. "fc1", "fc2"

#=========================
# (6) teacher pre/post from teacher_cache
#=========================
teacher_pre_acts = {k:v for k,v in teacher_cache.items() if k.endswith("hook_pre")}
teacher_post_acts= {k:v for k,v in teacher_cache.items() if k.endswith("hook_post")}

#=========================
# (7) calculate attributions
#=========================
attrib_cf = calculate_attributions(
    model          = spd_fc,
    input_x        = feats_cf,
    out            = spd_out_cf,
    teacher_out    = teacher_out_cf if config.distil_from_target else spd_out_cf,
    pre_acts       = {k:v for k,v in t_cache_cf.items() if k.endswith("hook_pre")},
    post_acts      = {k:v for k,v in t_cache_cf.items() if k.endswith("hook_post")},
    component_acts = comp_cf,
    attribution_type = config.attribution_type,
)

attributions = calculate_attributions(
    model=spd_fc,
    input_x=feats,
    out=spd_out,
    teacher_out=teacher_out if getattr(config,"distil_from_target",True) else spd_out,
    pre_acts=teacher_pre_acts,
    post_acts=teacher_post_acts,
    component_acts=comp_acts,
    attribution_type=config.attribution_type
)

delta_attrib = attributions - attrib_cf 


In [7]:
'''
When ablating component 0, this should show that the f(x | theta - a_0) = f( x'| theta - a_0)
then a_0(x) = a_0(x')
so |a_0(x) - a_0(x')| should be close to 0? 
'''
print(f'delta attrib for component 0: {delta_attrib[:, 0].abs().mean()}')
print(f'delta attrib for component 1: {delta_attrib[:, 1].abs().mean()}')


delta attrib for component 0: 8.545426368713379
delta attrib for component 1: 4.292652130126953


In [8]:
batch = next(iter(loader))
imgs, digit_label, background_label = batch
imgs = imgs.to(device)
digit_label = digit_label.to(device)
background_label = background_label.to(device)

imgs_cf = make_cf(imgs)
imgs_d_cf = make_cf_digit_swap(imgs, digit_label, background_label)

feats = trunk(imgs).mean(dim=1).flatten(1)
feats_cf = trunk(imgs_cf).mean(dim=1).flatten(1)
feats_d_cf = trunk(imgs_d_cf).mean(dim=1).flatten(1)

full_spd_weight_1 = spd_fc.fc1.A[0] @ spd_fc.fc1.B[0] +  spd_fc.fc1.A[1] @ spd_fc.fc1.B[1]
full_spd_weight_2 = spd_fc.fc2.A[0] @ spd_fc.fc2.B[0] +  spd_fc.fc2.A[1] @ spd_fc.fc2.B[1]

full_recon = feats @ full_spd_weight_1
full_recon = torch.relu(full_recon)
full_recon = full_recon @ full_spd_weight_2

full_recon_cf = feats_cf @ full_spd_weight_1
full_recon_cf = torch.relu(full_recon_cf)
full_recon_cf = full_recon_cf @ full_spd_weight_2


full_recon_d_cf = feats_d_cf @ full_spd_weight_1
full_recon_d_cf = torch.relu(full_recon_d_cf)
full_recon_d_cf = full_recon_d_cf @ full_spd_weight_2


out_recon = spd_fc(feats) # using this because the original paper code does this for some reason, and not use teacher model
out_recon_cf = spd_fc(feats_cf)
out_recon_d_cf = spd_fc(feats_d_cf)

spd_weight_1 = spd_fc.fc1.A[1] @ spd_fc.fc1.B[1]
spd_weight_2 = spd_fc.fc2.A[1] @ spd_fc.fc2.B[1]

out_spd_ablate = feats @ spd_weight_1 
out_spd_ablate = torch.relu(out_spd_ablate)
out_spd_ablate = out_spd_ablate @ spd_weight_2 

out_spd_ablate_cf = feats_cf @ spd_weight_1 
out_spd_ablate_cf = torch.relu(out_spd_ablate_cf)
out_spd_ablate_cf = out_spd_ablate_cf @ spd_weight_2 

out_spd_ablate_d_cf = feats_d_cf @ spd_weight_1 
out_spd_ablate_d_cf = torch.relu(out_spd_ablate_d_cf)
out_spd_ablate_d_cf = out_spd_ablate_d_cf @ spd_weight_2 

true_attrib_scores = ((out_recon - out_spd_ablate) ** 2).mean(dim=-1) # f(x | theta) - f(x | theta - P)
true_attrib_scores_cf = ((out_recon_cf - out_spd_ablate_cf)**2).mean(dim=-1) # f(x_cf | theta) - f(x_cf | theta - P)
true_attrib_scores_d_cf = ((out_recon_d_cf - out_spd_ablate_d_cf)**2).mean(dim=-1) # f(x_d_cf | theta) - f(x_d_cf | theta - P)

ablate_score_cf = out_spd_ablate - out_spd_ablate_cf # f(x | theta - P) - f(x_cf | theta - P), want this to be low for component 0 
ablate_score_d_cf = out_spd_ablate - out_spd_ablate_d_cf # f(x | theta - P) - f(x_d_cf | theta - P), want this to be high for component 0 

# true_attrib_scores - true_attrib_scores_cf = f(x | theta) - f(x | theta - P) - f(x_cf | theta) + f(x_cf | theta - P)
# = f(x | theta) - f(x_cf | theta) + f(x_cf | theta - P) - f(x | theta - P) 
# = f(x | theta) - f(x_cf | theta) - ablate_score_cf 

print(f'|f(x | theta) - f(x | theta - P) - f(x_cf | theta) + f(x_cf | theta - P)|: {(true_attrib_scores - true_attrib_scores_cf).abs().mean()}')
print(f'|f(x_cf | theta - P_0) - f(x | theta - P_0)|: {ablate_score_cf.abs().mean()}')

print(f'|f(x | theta) - f(x | theta - P) - f(x_d_cf | theta) + f(x_d_cf | theta - P)|: {(true_attrib_scores - true_attrib_scores_d_cf).abs().mean()}')
print(f'|f(x_cf | theta - P_0) - f(x | theta - P_0): {ablate_score_d_cf.abs().mean()}')

print(f'|f(x | theta) - f(x | theta - P_0): {(full_recon - full_recon_cf).abs().mean()}')
print(f'|f(x | theta) - f(x | theta - P_0): {(full_recon - full_recon_d_cf).abs().mean()}')



# print(f'Component 0 ablation score on background values (should be low I think): {((fin_ablate - fin_ablate_cf).abs()).mean()}')
# print(f'Component 0 ablation score on background values (should be high I think): {(fin_ablate - fin_ablate_d_cf).abs().mean()}')


|f(x | theta) - f(x | theta - P) - f(x_cf | theta) + f(x_cf | theta - P)|: 15.57197093963623
|f(x_cf | theta - P_0) - f(x | theta - P_0)|: 2.7106430530548096
|f(x | theta) - f(x | theta - P) - f(x_d_cf | theta) + f(x_d_cf | theta - P)|: 13.398056030273438
|f(x_cf | theta - P_0) - f(x | theta - P_0): 2.1955978870391846
|f(x | theta) - f(x | theta - P_0): 8.833394050598145
|f(x | theta) - f(x | theta - P_0): 5.780861854553223


In [9]:
batch = next(iter(loader))
imgs, digit_label, background_label = batch
imgs = imgs.to(device)
digit_label = digit_label.to(device)
background_label = background_label.to(device)

imgs_cf = make_cf(imgs)
imgs_d_cf = make_cf_digit_swap(imgs, digit_label, background_label)

feats = trunk(imgs).mean(dim=1).flatten(1)
feats_cf = trunk(imgs_cf).mean(dim=1).flatten(1)
feats_d_cf = trunk(imgs_d_cf).mean(dim=1).flatten(1)

out_recon = spd_fc(feats) # using this because the original paper code does this for some reason, and not use teacher model
out_recon_cf = spd_fc(feats_cf)
out_recon_d_cf = spd_fc(feats_d_cf)

spd_weight_1 = spd_fc.fc1.A[1] @ spd_fc.fc1.B[1]
spd_weight_2 = spd_fc.fc2.A[1] @ spd_fc.fc2.B[1]

out_spd_ablate = feats @ spd_weight_1 
out_spd_ablate = torch.relu(out_spd_ablate)
out_spd_ablate = out_spd_ablate @ spd_weight_2 

out_spd_ablate_cf = feats_cf @ spd_weight_1 
out_spd_ablate_cf = torch.relu(out_spd_ablate_cf)
out_spd_ablate_cf = out_spd_ablate_cf @ spd_weight_2 

out_spd_ablate_d_cf = feats_d_cf @ spd_weight_1 
out_spd_ablate_d_cf = torch.relu(out_spd_ablate_d_cf)
out_spd_ablate_d_cf = out_spd_ablate_d_cf @ spd_weight_2 

fin_ablate = ((out_recon - out_spd_ablate) ** 2).mean(dim=-1)
fin_ablate_cf = out_spd_ablate - out_spd_ablate_cf
fin_ablate_d_cf = (out_spd_ablate - out_spd_ablate_d_cf)

print(f'Component 0 ablation score on background values (should be low I think): {((fin_ablate_cf)**2).mean()}')
print(f'Component 0 ablation score on background values (should be high I think): {((fin_ablate_d_cf)**2).mean()}')


Component 0 ablation score on background values (should be low I think): 10.33073616027832
Component 0 ablation score on background values (should be high I think): 8.602184295654297


In [10]:
batch = next(iter(loader))
imgs, digit_label, background_label = batch
imgs = imgs.to(device)
digit_label = digit_label.to(device)
background_label = background_label.to(device)

imgs_cf = make_cf(imgs)
imgs_d_cf = make_cf_digit_swap(imgs, digit_label, background_label)

feats = trunk(imgs).mean(dim=1).flatten(1)
feats_cf = trunk(imgs_cf).mean(dim=1).flatten(1)
feats_d_cf = trunk(imgs_d_cf).mean(dim=1).flatten(1)

out_recon = spd_fc(feats) # using this because the original paper code does this for some reason, and not use teacher model
out_recon_cf = spd_fc(feats_cf)
out_recon_d_cf = spd_fc(feats_d_cf)

spd_weight_1 = spd_fc.fc1.A[0] @ spd_fc.fc1.B[0]
spd_weight_2 = spd_fc.fc2.A[0] @ spd_fc.fc2.B[0]

out_spd_ablate = feats @ spd_weight_1 
out_spd_ablate = torch.relu(out_spd_ablate)
out_spd_ablate = out_spd_ablate @ spd_weight_2 

out_spd_ablate_cf = feats_cf @ spd_weight_1 
out_spd_ablate_cf = torch.relu(out_spd_ablate_cf)
out_spd_ablate_cf = out_spd_ablate_cf @ spd_weight_2 

out_spd_ablate_d_cf = feats_d_cf @ spd_weight_1 
out_spd_ablate_d_cf = torch.relu(out_spd_ablate_d_cf)
out_spd_ablate_d_cf = out_spd_ablate_d_cf @ spd_weight_2 

fin_ablate = ((out_recon - out_spd_ablate) ** 2).mean(dim=-1)
fin_ablate_cf = (out_spd_ablate - out_spd_ablate_cf)
fin_ablate_d_cf = (out_spd_ablate - out_spd_ablate_d_cf)

# out_recon = f(x | theta)
# out_spd_ablate = f(x | theta - P)
# f(x | theta - P ) - f(x' | theta  - P)
# out_spd_ablate - out_spd_ablate_cf 



print(f'Component 1 ablation score on background values (should be high I think): {((fin_ablate_cf)**2).mean()}')
print(f'Component 1 ablation score on background values (should be low I think): {((fin_ablate_d_cf)**2).mean()}')


Component 1 ablation score on background values (should be high I think): 4.569861888885498
Component 1 ablation score on background values (should be low I think): 1.8238804340362549
