### Compute and save Jacobians and Bures distances for analysis

Example code extracting representations, computing Jacobians and Bures distances, and saving/loading all these things once computed.

In [1]:
import torch
import torchinfo
import torch.nn as nn
from PIL import Image
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix

import dsutils 
import metrics
import jsutils
import extract_internal_reps

import os
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

from sklearn.manifold import MDS
from sklearn.decomposition import PCA

import pickle


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

Using cpu for inference


In [2]:
# Available models (from torch hub)

avail_models = models.list_models(module=torchvision.models)
avail_models

['alexnet',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'googlenet',
 'inception_v3',
 'maxvit_t',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'regnet_x_16gf',
 'regnet_x_1_6gf',
 'regnet_x_32gf',
 'regnet_x_3_2gf',
 'regnet_x_400mf',
 'regnet_x_800mf',
 'regnet_x_8gf',
 'regnet_y_128gf',
 'regnet_y_16gf',
 'regnet_y_1_6gf',
 'regnet_y_32gf',
 'regnet_y_3_2gf',
 'regnet_y_400mf',
 'regnet_y_800mf',
 'regnet_y_8gf',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext101_64x4d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'sh

In [None]:
# Extract representations resulting from probe inputs in data_dir

data_dir = '../imagenet-sample-images'

# internal_reps = []
# model_2nds = []
repDict = {}

for model in avail_models:
    repDict.clear()
    x1, model_2nd = extract_internal_reps.extract_rep_gen(model, data_dir, weights="first")
    repDict[model] = [x1,model_2nd]
    # model_2nds.append(model_2nd)
    # internal_reps.append(x1)
    with open(model + '_internal_rep_classifier.pkl', 'wb') as f:
        pickle.dump(repDict, f)
    print(model + " done")

Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


alexnet done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


convnext_base done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


convnext_large done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


convnext_small done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


convnext_tiny done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


densenet121 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


densenet161 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


densenet169 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


densenet201 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b0 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b1 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b2 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b3 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b4 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b5 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b6 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_b7 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_v2_l done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_v2_m done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


efficientnet_v2_s done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


googlenet done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


inception_v3 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


maxvit_t done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mnasnet0_5 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mnasnet0_75 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mnasnet1_0 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mnasnet1_3 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mobilenet_v2 done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mobilenet_v3_large done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


mobilenet_v3_small done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


regnet_x_16gf done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


regnet_x_1_6gf done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


regnet_x_32gf done
Using cpu for inference


Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main
Using cache found in /mnt/home/sharvey/.cache/torch/hub/pytorch_vision_main


In [None]:
# Extract representations resulting from probe inputs in data_dir

# OLD METHOD

# data_dir = '../imagenet-sample-images'

# internal_reps = []
# model_2nds = []
# repDict = {}

# for model in model_names:
#     x1, model_2nd = extract_internal_reps.extract_rep(model, data_dir)
#     repDict[model] = [x1,model_2nd]
#     model_2nds.append(model_2nd)
#     internal_reps.append(x1)
#     print(model + " done")

# Save extracted representations

# with open('internal_reps_with_model_2nd_half_full_classifier.pkl', 'wb') as f:
#     pickle.dump(repDict, f)

In [None]:
# Load extracted representations

# Load a specified set of internal reps in a dictionary 

repDict = {}

# Models you want to load 
model_names = ["alexnet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "vgg16"]

N_models = len(model_names)
for model_name in model_names:
    with open(model_name + '_internal_rep_classifier.pkl', 'rb') as f:
        new_reps = pickle.load(f)
        repDict.update(new_reps)
        print(model_name)

In [None]:
# Compute decoding Jacobians and save

# Takes a while

import pickle
J_dict_rbyn = {}

for model_name in model_names:
    J_dict_rbyn.clear()
    Js = jsutils.decoding_jacobian(repDict[model_name][0], repDict[model_name][1])
    J_dict_rbyn[model_name] = Js
    with open(model_name + '_decoding_Js_rbyn.pkl', 'wb') as f:
        pickle.dump(J_dict_rbyn, f)
    print(model_name + " done")

In [None]:
# Convert Jacobian lists from a list of M, r by N Jacobians to a list of r, M by N Jacobians (if desired)

J_dict_mbyn = {}
for model_name in model_names:
    J_dict_mbyn.clear()
    Js_mbyn = jsutils.convert_Jacobian(J_dict_rbyn, model_name)
    J_dict_mbyn[model_name] = Js_mbyn
    with open(model_name + '_decoding_Js_mbyn.pkl', 'wb') as f:
        pickle.dump(J_dict_mbyn, f)
    print(model_name + " done")

In [None]:
# Flatten Jacobians

J_dict = jsutils.flatten_Jacobian(J_dict_rbyn)

for model_name in model_names:
    Js = {model_name: J_dict[model_name]}
    with open(model_name + '_decoding_Js_stacked.pkl', 'wb') as f:
        pickle.dump(Js, f)
    print(model_name + " done")

In [None]:
# Load Jacobians
# loads a subset of models' Jacobians

J_dict = {}

model_names = ["alexnet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152","vgg16"]

N_models = len(model_names)
for model_name in model_names:
    with open(model_name + '_decoding_Js_stacked.pkl', 'rb') as f:
        new_J = pickle.load(f)
        J_dict.update(new_J)
        print(model_name)

In [None]:
# Compute Bures distances (takes a while)

bures_dists_all_rbyn = jsutils.compute_Jacobian_Bures_distances(J_dict_rbyn, model_names)
np.savez('bures_dists_penultimate_decoding_Jrbyn.npz', bures_dists = bures_dists_all_rbyn, model_names = model_names)

In [None]:
# Compute Bures distances (takes a while)
# Because of the size of the flattened arrays, this one uses Procrustes distance for speed of computation

bures_dists_all = jsutils.compute_Jacobian_Procrustes_distances(J_dict, model_names)
np.savez('bures_dists_penultimate_decoding_Jstacked.npz', bures_dists = bures_dists_all, model_names = model_names)