In [None]:
# Initialization

import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import pickle
import seaborn as sns
import pandas as pd

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.

# Load ImageNet

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

normalization = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

data_dir = '../ILSVRC2012'
val_dataset = datasets.ImageNet(root=data_dir, split='val', transform=data_transforms)

BATCH_SIZE = 1

dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
dataset_sizes = len(val_dataset)
class_names = val_dataset.classes

In [None]:
available_models = {
    'resnet18' : torchvision.models.resnet18,
    'alexnet' : torchvision.models.alexnet,
    'googlenet' : torchvision.models.googlenet,
    'mobilenet_v3_small' : torchvision.models.mobilenet_v3_small,
    'mobilenet_v3_large' : torchvision.models.mobilenet_v3_large,
    'mnasnet1_0' : torchvision.models.mnasnet1_0,
    'vgg16' : torchvision.models.vgg16,
    'efficientnet_b1' : torchvision.models.efficientnet_b1,
    'densenet161' : torchvision.models.densenet161
}

def get_model(model_type):

    # Check if the specified model name is valid
    if model_type not in available_models:
        raise ValueError(f"Invalid model name. Available models: {list(available_models.keys())}")
    # Load the model with pre-trained weights
    model = available_models[model_type](weights="DEFAULT").to(device)

    return model


In [None]:
def get_explanation(model, inputs_no_norm, targets, explanation_mode):
    model.eval()

    inputs = normalization(inputs_no_norm.clone().detach())

    inputs.requires_grad = True
    model.zero_grad()

    outputs = model(inputs)

    if explanation_mode == 'normal':
        loss = outputs[torch.arange(outputs.size(0)), targets ].sum()
    elif explanation_mode == 'max':
        probs, preds = torch.topk(outputs, 2)
        correct_classifications = preds[:, 0] == targets
        wrong_classifications = preds[:, 0] != targets
        max_index_not_target = correct_classifications * preds[:, 1] + wrong_classifications * preds[:, 0]
        loss = ( outputs[torch.arange(outputs.size(0)), targets ] - outputs[torch.arange(outputs.size(0)), max_index_not_target ] ).sum()
    elif explanation_mode == 'weighted':
        outputs = torch.softmax(outputs, dim=1)
        loss = outputs[torch.arange(outputs.size(0)), targets ].sum()
    elif explanation_mode == 'mean':
        weights = -torch.ones(outputs.shape).to(outputs.device)/999
        weights[range(len(weights)), targets] = 1
        loss = (weights * outputs).sum()

    loss.backward()
    gradients = inputs.grad.clone().detach()
    explanations = gradients
    inputs.requires_grad = False

    return explanations

In [None]:
model = get_model('vgg16')

logit_explanation_norm = []

for i, batch in enumerate(iter(dataloader)):
    if i >= 1000:
        break

    inputs, targets = batch
    inputs = inputs.to(device)

    batch_normalized = normalization(inputs.clone().detach())
    with torch.no_grad():
        logits = model(batch_normalized)

    targets_sorted = torch.topk(logits, 1000)[1][0].cpu().numpy()
    targets = np.concatenate((targets_sorted[:50], targets_sorted[50:950:20], targets_sorted[950:]))    

    for target in tqdm(targets, desc="Processing targets..."):
        explanation = get_explanation(model, inputs, target, 'normal')
        explanation_norm = torch.norm(explanation)
        logit_explanation_norm.append((i, logits[0][target].cpu(), explanation_norm.cpu()))


# Save results
# Creating separate lists for x and y coordinates
ids, x, y = zip(*logit_explanation_norm)

ids = torch.tensor(ids).numpy()
x = torch.tensor(x).numpy()
y = torch.tensor(y).numpy()

df = pd.DataFrame({'Logit Value': x, 'Gradient Norm': y, 'ID': ids})

with open("test-id_logit_explanation_norm.pkl", 'wb') as file:
    pickle.dump(df, file)


# Plotting with Seaborn
sns.regplot(x='Logit Value', y='Gradient Norm', data=df, order=2, scatter=False, color='red')

# Overlaying with scatterplot for colored data points
sns.scatterplot(x='Logit Value', y='Gradient Norm', hue='ID', data=df, palette='viridis', legend=False)

# Displaying the plot
plt.savefig("logit_gradientnorm_relationship.png", bbox_inches='tight', pad_inches=0)
plt.show()