In [None]:
# Compute and save Jacobians and Bures distances for analysis

In [None]:
import torch
import torchinfo
import torch.nn as nn
from PIL import Image
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix

import dsutils 
import metrics
import jsutils
import extract_internal_reps

import os
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

from sklearn.manifold import MDS
from sklearn.decomposition import PCA

import pickle


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

In [None]:
# models available

model_names = ["alexnet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "inceptionv3", "densenet", "mobilenetv2","vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn"]


In [None]:
# Extract representations resulting from probe inputs in data_dir

data_dir = '../imagenet-sample-images'

internal_reps = []
model_2nds = []
repDict = {}

for model in model_names:
    x1, model_2nd = extract_internal_reps.extract_rep(model, data_dir)
    repDict[model] = [x1,model_2nd]
    model_2nds.append(model_2nd)
    internal_reps.append(x1)
    print(model + " done")

In [None]:
# Save extracted representations

with open('internal_reps_with_model_2nd_half.pkl', 'wb') as f:
    pickle.dump(repDict, f)

In [None]:
# Compute decoding Jacobians and save

# Takes a while

import pickle
J_dict_rbyn = {}

for model_name in model_names:
    J_dict.clear()
    Js = jsutils.decoding_jacobian(repDict[model_name][0], repDict[model_name][1])
    J_dict_rbyn[model_name] = Js
    with open(model_name + '_decoding_Js_rbyn.pkl', 'wb') as f:
        pickle.dump(J_dict_rbyn, f)
    print(model_name + " done")

In [None]:
# Convert Jacobian lists from a list of M, r by N Jacobians to a list of r, M by N Jacobians (if desired)

J_dict_mbyn = {}
for model_name in model_names:
    J_dict_mbyn.clear()
    Js_mbyn = jsutils.convert_Jacobian(J_dict_rbyn, model_name)
    J_dict_mbyn[model_name] = Js_mbyn
    with open(model_name + '_decoding_Js_mbyn.pkl', 'wb') as f:
        pickle.dump(J_dict_mbyn, f)
    print(model_name + " done")

In [None]:
# Flatten Jacobians

J_dict = jsutils.flatten_Jacobian(J_dict_rbyn)

for model_name in model_names:
    Js = {model_name: J_dict[model_name]}
    with open(model_name + '_decoding_Js_stacked.pkl', 'wb') as f:
        pickle.dump(Js, f)
    print(model_name + " done")

In [None]:
# Load Jacobians
# loads a subset of models' Jacobians

import pickle

J_dict = {}

model_names = ["alexnet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152","vgg16"]

N_models = len(model_names)
for model_name in model_names:
    with open(model_name + '_decoding_Js_stacked.pkl', 'rb') as f:
        new_J = pickle.load(f)
        J_dict.update(new_J)
        print(model_name)

In [None]:
# Compute Bures distances (takes a while)

bures_dists_all_rbyn = jsutils.compute_Jacobian_Bures_distances(J_dict_rbyn, model_names)
np.savez('bures_dists_penultimate_decoding_Jrbyn.npz', bures_dists = bures_dists_all_rbyn, model_names = model_names)

In [None]:
# Compute Bures distances (takes a while)
# Because of the size of the flattened arrays, this one uses Procrustes distance for speed of computation

bures_dists_all = jsutils.compute_Jacobian_Procrustes_distances(J_dict, model_names)
np.savez('bures_dists_penultimate_decoding_Jstacked.npz', bures_dists = bures_dists_all, model_names = model_names)