In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from PIL import Image
import matplotlib.pyplot as plt
import torch
import numpy as np

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz

import os
import random

: 

In [None]:
# Default device plus free memory
torch.cuda.empty_cache()
device = "cpu"
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

: 

In [None]:
# Modify ResNet-18 for MNIST
model = resnet18(pretrained=False)
# Change the input layer to accept grayscale images
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# Adjust the final layer to output 10 classes
model.fc = nn.Linear(model.fc.in_features, 10)

# Load the saved state_dict
model.load_state_dict(torch.load('./Models/resnet18_mnist.pth'))

# Move model to device and set to evaluation mode
model = model.to(device).eval()

: 

In [None]:
# Apply resize and normalization to the images
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download the MNIST dataset
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)


: 

In [None]:
# Fetch an image and its label
image, label = test_dataset[0]  # This fetches the first image from the test set

plt.imshow(image.squeeze(), cmap='gray')  # Since it's a grayscale image
plt.title(f"True Label: {label}")
plt.show()

: 

In [None]:
# Add batch dimension and run the model
with torch.no_grad():
    image = image.unsqueeze(0).to(device)  # Add batch dimension
    output = model(image)
    predicted_label = torch.argmax(output).item()
    probabilities = F.softmax(output, dim=1)  # Apply softmax to get the probabilities
    prediction_score = probabilities[0][predicted_label].item()  # Confidence score for the predicted label

print(f"Predicted Label: {predicted_label}")
print(f"Confidence Score for the Predicted Label: {prediction_score:.4f}")
print(f"Probability Distribution over Classes: {probabilities[0].tolist()}")

: 

In [None]:
pred_label_idx = predicted_label
transformed_img = image

integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(transformed_img, target=predicted_label, n_steps=200)

: 

In [None]:
print(attributions_ig.squeeze().cpu().detach().numpy().shape)
transformed_img.squeeze().cpu().detach().numpy().shape

: 

In [None]:
from matplotlib.colors import LinearSegmentedColormap
import numpy as np

default_cmap = LinearSegmentedColormap.from_list(
    'custom blue', 
    [(0, '#ffffff'),
     (0.25, '#000000'),
     (1, '#000000')],
    N=256
)

attr_reshaped = attributions_ig.squeeze().cpu().detach().numpy().reshape(224, 224, 1)
img_reshaped = transformed_img.squeeze().cpu().detach().numpy().reshape(224, 224, 1)

_ = viz.visualize_image_attr(attr_reshaped,
                             img_reshaped,
                             method='heat_map',
                             cmap=default_cmap,
                             show_colorbar=True,
                             sign='positive',
                             outlier_perc=1)


: 

In [None]:
noise_tunnel = NoiseTunnel(integrated_gradients)
input = image

attributions_ig_nt = noise_tunnel.attribute(input, nt_samples=10, nt_type='smoothgrad_sq', target=predicted_label)
_ = viz.visualize_image_attr_multiple(attr_reshaped,
                                      img_reshaped,
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

: 

In [None]:
torch.manual_seed(0)
np.random.seed(0)
input = image

gradient_shap = GradientShap(model)

# Defining baseline distribution of images
rand_img_dist = torch.cat([input * 0, input * 1])

attributions_gs = gradient_shap.attribute(input,
                                          n_samples=50,
                                          stdevs=0.0001,
                                          baselines=rand_img_dist,
                                          target=pred_label_idx)
_ = viz.visualize_image_attr_multiple(attr_reshaped,
                                      img_reshaped,
                                      ["original_image", "heat_map"],
                                      ["all", "absolute_value"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

: 

In [None]:
occlusion = Occlusion(model)

attributions_occ = occlusion.attribute(input,
                                       strides = (1, 8, 8),
                                       target=pred_label_idx,
                                       sliding_window_shapes=(1,16, 16),
                                       baselines=0)


: 

In [None]:
_ = viz.visualize_image_attr_multiple(attr_reshaped,
                                      img_reshaped,
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2,
                                     )

: 

In [None]:
occlusion = Occlusion(model)

attributions_occ = occlusion.attribute(input,
                                       strides = (1, 48, 48),
                                       target=pred_label_idx,
                                       sliding_window_shapes=(1,80, 80),
                                       baselines=0)

_ = viz.visualize_image_attr_multiple(attr_reshaped,
                                      img_reshaped,
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      show_colorbar=True,
                                      outlier_perc=2,
                                     )

: 