In [None]:
import torch
from matplotlib import pyplot as plt
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import seaborn as sns
import pandas as pd
hub_repo = 'pytorch/vision:v0.10.0'

input_channels = 3
kernel_size = 3

if kernel_size == 3:
    model_names = ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 
                   'vgg19_bn', 'deeplabv3_mobilenet_v3_large', 'inception_v3', 
                   'lraspp_mobilenet_v3_large', 'mobilenet_v2', 'mobilenet_v3_large',
                   'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'squeezenet1_1']
elif kernel_size == 7:
    model_names = ['deeplabv3_resnet101', 'deeplabv3_resnet50', 'densenet121', 'densenet161',
                   'densenet169', 'densenet201', 'fcn_resnet101', 'fcn_resnet50', 'googlenet',
                   'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d',
                   'resnext50_32x4d', 'squeezenet1_0', 'wide_resnet101_2', 'wide_resnet50_2']

Calculate the PCA of the dataset patches.

In [None]:
# download dataset
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
# rescale to [0,1], patchify
images = torch.tensor(cifar_testset.data / cifar_testset.data.max(), dtype=torch.double)
image_patches = images.unfold(1, 3, 1).unfold(2, kernel_size, 1).unfold(3, input_channels, 1)
image_patches = image_patches.flatten(0, -4).flatten(1)

_, _, pca_components = torch.pca_lowrank(image_patches, q=image_patches.shape[-1])

Download all pretrained models from the pytorch vision hub with the same (given) kernel size in the first layer.

In [None]:
model_first_layers = dict()
for modelname in model_names:
    model = torch.hub.load(hub_repo, modelname, pretrained=True)
    model_first_layers[modelname] = next(iter(model.state_dict().items()))[1].flatten(1)


After computing the PCA components from the test images and downloading the models for comparison, we are ready to project them onto the components and calculate the energy. 


In [None]:
df = pd.DataFrame()
figs, (ax1, ax2) = plt.subplots(1,2,figsize=(25,15))
for m in model_first_layers:
    first_layer = model_first_layers[m]
    energy_profile = torch.norm(first_layer @ pca_components.to(torch.float32), dim=0).numpy()
    energy_profile = energy_profile / energy_profile.max()
    x = range(1, energy_profile.shape[0] + 1)
    ax1.plot(x, energy_profile, label=m)
    row = {i: energy_profile[i-1] for i in x}
    row["Model Name"] = m
    df = df.append(row, ignore_index=True)


df.set_index("Model Name",inplace=True)
sns.heatmap(df.T.corr(), vmin=0, vmax=1, annot=True, ax=ax2, square=True)