In [1]:
from utils import *
# from objectnet_test import *
from torch.nn.functional import relu
seed = 1
import matplotlib.pyplot as plt
set_seed(seed)
torch.set_default_dtype(torch.float32)
from collections import defaultdict
from sal_resnet import resnet50
import pytorch_lightning as pl
from datamodules import WaterbirdsDataModule, ImageNet9DataModule
from lightning_modules import *
from torchvision.datasets import CIFAR100, ImageNet
from torchvision import transforms
from inspect_utils import *

In [2]:
num_workers = 4*torch.cuda.device_count()
gpu_size = 256*torch.cuda.device_count()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
data_module = ImageNet9DataModule(get_masks=True)
data_module.setup()
batch_size = max(gpu_size, 100)
fg_loader = data_module.onlyfg_dataset.make_loaders(num_workers, batch_size, mode='val', shuffle_val=False)
bg_loader = data_module.nofg_dataset.make_loaders(num_workers, batch_size, mode='val', shuffle_val=False)
orig_loader = data_module.orig_dataset.make_loaders(num_workers,batch_size, mode='val', shuffle_val=False)

==> Preparing dataset ImageNet9..
==> Preparing dataset ImageNet9..
==> Preparing dataset ImageNet9..


In [4]:
activation = {}

def read_output(inp, out, name):
    if isinstance(out, tuple):
        activation[name] = out[0].detach().cpu()
    else:
        activation[name] = out.detach().cpu()
    return

def get_hook(name, read_fn, modify_fn=None):
    # the hook signature
    def hook(model, input, output):
        read_fn(input, output, name)
        if modify_fn is not None:
            output = modify_fn(input, output, name)
            read_fn(input, output, name+"_after")
            return output
    return hook

In [5]:
lightning_model = ImageNet9Predictor.load_from_checkpoint('./models/imagenet9/mtype-resnet50_num_classes-9_lr-1e-05_weight_decay-1e-05/epoch=01-step=1400-orig_val_acc=0.98.ckpt')
model = lightning_model.model.to(device)
normalizer = lightning_model.normalizer
hook_names = ['avgpool', 'layer4']
hooks = [getattr(model, hook_name).register_forward_hook(get_hook(hook_name, read_output)) for hook_name in hook_names]
_ = model.eval()
# for i, layer in enumerate(model.encoder.layers):
#     layer.register_forward_hook(get_hook(f"layer{i}", read_output))
# _=model.encoder.register_forward_hook(get_hook(f"encoder", read_output))

In [7]:
label_names = ['0_dog', '1_bird', '2_wheeled vehicle', '3_reptile', '4_carnivore', '5_insect', '6_musical instrument', '7_primate', '8_fish']

In [23]:
all_acts, orig_hits, fg_hits, bg_hits, total = get_intermediate_acts(model, normalizer, orig_loader, activation, label_names)

In [9]:
import pickle
# with open('/cmlscratch/sriramb/cnn_activations.pkl', 'wb') as fp:
#     pickle.dump((dict(all_acts), orig_hits, fg_hits, bg_hits, total), fp)
with open('/cmlscratch/sriramb/cnn_activations.pkl', 'rb') as fp:
    all_acts, orig_hits, fg_hits, bg_hits, total = pickle.load(fp)

In [16]:
orig_hits, fg_hits, bg_hits, total = np.array(orig_hits), np.array(fg_hits), np.array(bg_hits), np.array(total)

In [18]:
fg_hits/total - bg_hits/total

array([0.87333333, 0.57555556, 0.18222222, 0.50222222, 0.71111111,
       0.01555556, 0.01333333, 0.80888889, 0.13333333])

In [153]:
def proj_on_subspace(x, Y):
    Y = Y/Y.norm(dim=-1, keepdim=True)
    x = x/x.norm(dim=-1, keepdim=True)
    return Y@x

def gram_schmidt(X):
    '''
    X: (D, N)
    Q: (D, min(D, N))
    '''
    Q, R = torch.linalg.qr(X)
    return Q

def subspace_intersection(X_core, X_spur):
    X_core = gram_schmidt(X_core)
    X_spur_on_core = X_spur@(X_core.T@X_core)
    X_spur_not_core = X_spur - X_spur_on_core
    X_spur_not_core = gram_schmidt(X_spur_not_core)
    return X_core, X_spur_on_core, X_spur_not_core


In [156]:
topk = 3
for i in range(9):
    fg_acts = all_acts[i][:,2].squeeze()
    bg_acts = all_acts[i][:,1].squeeze()
    fg_means = fg_acts.mean(dim=0)
    bg_means = bg_acts.mean(dim=0)
    fg_eigvecs, fg_eigvals = torch.lobpcg(torch.cov(fg_acts.transpose(0,1)), k=20)
    bg_eigvecs, bg_eigvals = torch.lobpcg(torch.cov(bg_acts.transpose(0,1)), k=20)
    fg_bases = torch.cat((fg_means, fg_eigvecs[:topk]), dim=0)
    bg_bases = torch.cat((bg_means, bg_eigvecs[:topk]), dim=0)
    X_core, X_spur_on_core, X_spur_not_core = subspace_intersection(fg_bases, bg_bases)
    W[i] = W[i] - W[i][None,:]@X_spur_not_core@X_spur_not_core.transpose(0,1)

tensor([[1., 0., 0.],
        [-0., 1., 0.],
        [-0., -0., 1.]])

In [41]:
def norm(x):
    if len(x) > 1:
        x = x - x.mean(dim=0, keepdim=True)
    return x/(x.norm(dim=-1, keepdim=True)+1e-8)

def formt(m, s):
    if s.isnan():
        return f'{m}'
    return f'{(m-s).item():.4f} to {(m+s).item():.4f}'

In [47]:
W = model.fc.weight.data
acts = all_acts['avgpool']

for i in range(9):
    w = W[i]/W[i].norm()
    o = acts[i][:,0].squeeze()#.mean(0, keepdim=True)#/(all_acts[i][:,0].squeeze().norm(dim=-1, keepdim=True) + 1e-8)
    a = acts[i][:,2].squeeze()#.mean(0, keepdim=True)#/(all_acts[i][:,2].squeeze().norm(dim=-1, keepdim=True) + 1e-8)
    b = acts[i][:,1].squeeze()#.mean(0, keepdim=True)#/(all_acts[i][:,1].squeeze().norm(dim=-1, keepdim=True) + 1e-8)
#     rand_inds = torch.randperm(a.shape[0])
#     a = a[rand_inds]
    m = (~b.isnan()[:,0]) & (~a.isnan()[:,0])
    a = a[m]
    b = b[m]
    o = o[m]
    print(f'''
        {a.mean(1).std()} {b.mean(1).std()}
    '''
    )
#     print(f'''
#         Class {label_names[i]}
#           w.orig: {formt((w*o).sum(dim=-1).mean(), (w*o).sum(dim=-1).std())}, acc: {orig_hits[i]/450}
#           w.fg: {formt((w*a).sum(dim=-1).mean(), (w*a).sum(dim=-1).std())}, acc: {fg_hits[i]/450}
#           w.bg: {formt((w*b).sum(dim=-1).mean(), (w*b).sum(dim=-1).std())}, acc: {bg_hits[i]/450}
#           fg.bg: {formt((norm(a)*norm(b)).sum(dim=-1).mean(), (norm(a)*norm(b)).sum(dim=-1).std())} 
#           '''
#          )



        0.09228075295686722 0.07021445780992508
    

        0.09317027032375336 0.07988784462213516
    

        0.10053576529026031 0.08723432570695877
    

        0.11255035549402237 0.09857594221830368
    

        0.10635022819042206 0.0765133649110794
    

        0.09958212822675705 0.0812973827123642
    

        0.12768259644508362 0.11645633727312088
    

        0.10658541321754456 0.07435311377048492
    

        0.10775458812713623 0.09488870203495026
    


torch.Size([450, 2048])

In [145]:
fg_eigvals, fg_eigvecs = torch.lobpcg(torch.cov(all_acts[0].squeeze()[:,2].transpose(0,1)), k=20)
bg_eigvals, bg_eigvecs = torch.lobpcg(torch.cov(all_acts[0].squeeze()[:,1].transpose(0,1)), k=20)

In [149]:
fg_mean = all_acts[0].squeeze()[:,2].mean(0)
bg_mean = all_acts[0].squeeze()[:,1].mean(0)

In [150]:
(fg_mean*bg_mean).sum(dim=-1)/(fg_mean.norm(dim=-1)*bg_mean.norm(dim=-1,))

tensor(0.7157)