In [None]:
!mkdir ImageNet
%cd ImageNet
!wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
!wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
%cd ..

In [None]:
from torchvision.models import (resnet18, ResNet18_Weights, resnet34,
                                ResNet34_Weights, resnet50,
                                ResNet50_Weights, resnet101,
                                ResNet101_Weights, resnet152, ResNet152_Weights)
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import numpy as np


def get_layers_names(net):
    children = []
    for child in net.named_children():
        if isinstance(child[1], nn.Dropout):
            continue
        elif isinstance(child[1], nn.ModuleList):
            for ch in child[1]._modules:
                children.append((child[1], child[0], ch))
        else:
            children.append(child[0])
    return children


activation = {}


def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()

    return hook


def get_correlation(batch_size, layers1, layers2):
    layers1 = layers1[:, 0].reshape(batch_size, -1)
    layers2 = layers2[:, 0].reshape(batch_size, -1)
    correlations = np.zeros_like(layers1.cpu().detach().numpy()[0])
    for i in range(layers1.shape[1]):
        neuron_i = layers1[:, i]
        max_corr = -1
        for j in range(layers2.shape[1]):
            neuron_j = layers2[:, j]
            stacked = torch.stack((neuron_i, neuron_j))
            corr = torch.corrcoef(stacked)[0, 1]
            corr = corr.cpu().detach().numpy()
            if corr > max_corr:
                max_corr = corr
        correlations[i] = max_corr

    return correlations.mean()


def mean_correlation(batch_size, layers1, layers2):
    return (get_correlation(batch_size, layers1, layers2) + get_correlation(batch_size, layers2, layers1)) / 2


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)

    weights = [ResNet18_Weights.DEFAULT, ResNet34_Weights.DEFAULT, ResNet50_Weights.DEFAULT, ResNet101_Weights.DEFAULT, ResNet152_Weights.DEFAULT]
    models = [resnet18(weights=weights[0]).to(device), resnet34(weights=weights[1]).to(device), resnet50(weights=weights[2]).to(device), resnet101(weights=weights[3]).to(device), resnet152(weights=weights[4]).to(device)]


    for model in models:
        model.eval()

    # Load ImageNet dataset
    transform = weights[0].transforms()

    imagenet_data = datasets.ImageNet('/data/ImageNet', split='val', transform=transform)
    data_loader = DataLoader(imagenet_data, batch_size=10, shuffle=True)

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            activations = []
            for model in models:
                getattr(model, "layer4").register_forward_hook(get_activation("layer4"))
                outputs = model(images)
                activations.append(activation["layer4"].detach())

            for act1_i in range(len(activations)):
                for act2_i in range(act1_i + 1, len(activations)):
                    act1 = activations[act1_i]
                    act2 = activations[act2_i]
                    print(act1_i, act2_i, mean_correlation(10, act1, act2))
            break