<a href="https://colab.research.google.com/github/KianShokraneh/Captum-Attribution-Metrics-Analysis/blob/main/Captum_Attributions_Initial_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!gdown --id 1zeu1TYA3KFZxiJKwPRWLG50mPAnEj3jl

Downloading...
From (original): https://drive.google.com/uc?id=1zeu1TYA3KFZxiJKwPRWLG50mPAnEj3jl
From (redirected): https://drive.google.com/uc?id=1zeu1TYA3KFZxiJKwPRWLG50mPAnEj3jl&confirm=t&uuid=a09d205e-bbcc-45ab-a5aa-16e243566831
To: /content/resnet18_cifar10.pth
100% 44.8M/44.8M [00:02<00:00, 16.1MB/s]


In [None]:
!pip install torch torchvision captum


Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from captum.attr import IntegratedGradients, GradientShap, Saliency, GuidedBackprop, Occlusion
from captum.metrics import infidelity, sensitivity_max
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 10)  # Adjust the final layer for CIFAR-10
model = model.to(device)
model.load_state_dict(torch.load('resnet18_cifar10.pth'))
model.eval()

Files already downloaded and verified


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float().to(inputs.device)
    return noise, inputs + noise

infidelity_ig_list, infidelity_gs_list, infidelity_saliency_list, infidelity_gbp_list, infidelity_occ_list = [], [], [], [], []
sensitivity_ig_list, sensitivity_gs_list, sensitivity_saliency_list, sensitivity_gbp_list, sensitivity_occ_list = [], [], [], [], []

samples_count = 100

for idx, (inputs, labels) in enumerate(test_loader):
    if idx >= samples_count:
      break

    inputs, labels = inputs.to(device), labels.to(device)

    output = model(inputs)
    _, predicted_class = torch.max(output, 1)

    # Apply Integrated Gradients
    ig = IntegratedGradients(model)
    baseline = torch.zeros_like(inputs)
    attributions_ig = ig.attribute(inputs, baselines=baseline, target=predicted_class.item(), return_convergence_delta=False)

    # Apply Gradient SHAP
    rand_img_dist = torch.cat([inputs * 0, inputs * 1])
    gs = GradientShap(model)
    attributions_gs = gs.attribute(inputs, baselines=rand_img_dist, target=predicted_class.item(), n_samples=100)

    # Apply Saliency
    saliency = Saliency(model)
    attributions_saliency = saliency.attribute(inputs, target=predicted_class.item())

    # Apply Guided Backpropagation
    guided_bp = GuidedBackprop(model)
    attributions_gbp = guided_bp.attribute(inputs, target=predicted_class.item())

    # Apply Occlusion
    occlusion = Occlusion(model)
    attributions_occ = occlusion.attribute(inputs, strides=(3, 8, 8), target=predicted_class.item(), sliding_window_shapes=(3, 15, 15), baselines=0)

    # Calculate infidelity
    infidelity_ig_list.append(infidelity(model, perturb_fn, inputs, attributions_ig, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    infidelity_gs_list.append(infidelity(model, perturb_fn, inputs, attributions_gs, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    infidelity_saliency_list.append(infidelity(model, perturb_fn, inputs, attributions_saliency, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    infidelity_gbp_list.append(infidelity(model, perturb_fn, inputs, attributions_gbp, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    infidelity_occ_list.append(infidelity(model, perturb_fn, inputs, attributions_occ, target=predicted_class.item(), n_perturb_samples=10).mean().item())

    # Calculate sensitivity
    sensitivity_ig_list.append(sensitivity_max(ig.attribute, inputs, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    sensitivity_gs_list.append(sensitivity_max(gs.attribute, inputs, baselines=rand_img_dist, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    sensitivity_saliency_list.append(sensitivity_max(saliency.attribute, inputs, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    sensitivity_gbp_list.append(sensitivity_max(guided_bp.attribute, inputs, target=predicted_class.item(), n_perturb_samples=10).mean().item())
    sensitivity_occ_list.append(sensitivity_max(occlusion.attribute, inputs, strides=(3, 8, 8), target=predicted_class.item(), sliding_window_shapes=(3, 15, 15), baselines=0, n_perturb_samples=10).mean().item())

    if (idx + 1) % 10 == 0 or (idx + 1) == samples_count:
        print(f'Processed {idx + 1}/{samples_count} samples')

# Compute overall infidelity and sensitivity metrics
overall_infidelity_ig = np.mean(infidelity_ig_list)
overall_infidelity_gs = np.mean(infidelity_gs_list)
overall_infidelity_saliency = np.mean(infidelity_saliency_list)
overall_infidelity_gbp = np.mean(infidelity_gbp_list)
overall_infidelity_occ = np.mean(infidelity_occ_list)

overall_sensitivity_ig = np.mean(sensitivity_ig_list)
overall_sensitivity_gs = np.mean(sensitivity_gs_list)
overall_sensitivity_saliency = np.mean(sensitivity_saliency_list)
overall_sensitivity_gbp = np.mean(sensitivity_gbp_list)
overall_sensitivity_occ = np.mean(sensitivity_occ_list)

# Compute standard deviation for stability
stability_infidelity_ig = np.std(infidelity_ig_list)
stability_infidelity_gs = np.std(infidelity_gs_list)
stability_infidelity_saliency = np.std(infidelity_saliency_list)
stability_infidelity_gbp = np.std(infidelity_gbp_list)
stability_infidelity_occ = np.std(infidelity_occ_list)

stability_sensitivity_ig = np.std(sensitivity_ig_list)
stability_sensitivity_gs = np.std(sensitivity_gs_list)
stability_sensitivity_saliency = np.std(sensitivity_saliency_list)
stability_sensitivity_gbp = np.std(sensitivity_gbp_list)
stability_sensitivity_occ = np.std(sensitivity_occ_list)

# Print the overall metrics
print(f'Overall Infidelity of Integrated Gradients: {overall_infidelity_ig:.4f}')
print(f'Overall Infidelity of Gradient SHAP: {overall_infidelity_gs:.4f}')
print(f'Overall Infidelity of Saliency: {overall_infidelity_saliency:.4f}')
print(f'Overall Infidelity of Guided Backpropagation: {overall_infidelity_gbp:.4f}')
print(f'Overall Infidelity of Occlusion: {overall_infidelity_occ:.4f}')

print(f'Overall Sensitivity of Integrated Gradients: {overall_sensitivity_ig:.4f}')
print(f'Overall Sensitivity of Gradient SHAP: {overall_sensitivity_gs:.4f}')
print(f'Overall Sensitivity of Saliency: {overall_sensitivity_saliency:.4f}')
print(f'Overall Sensitivity of Guided Backpropagation: {overall_sensitivity_gbp:.4f}')
print(f'Overall Sensitivity of Occlusion: {overall_sensitivity_occ:.4f}')


Processed 10/100 samples
Processed 20/100 samples
Processed 30/100 samples
Processed 40/100 samples
Processed 50/100 samples
Processed 60/100 samples
Processed 70/100 samples
Processed 80/100 samples
Processed 90/100 samples
Processed 100/100 samples
Overall Infidelity of Integrated Gradients: 0.0006
Overall Infidelity of Gradient SHAP: 0.0004
Overall Infidelity of Saliency: 0.0007
Overall Infidelity of Guided Backpropagation: 0.0004
Overall Infidelity of Occlusion: 2.1232
Overall Sensitivity of Integrated Gradients: 0.5774
Overall Sensitivity of Gradient SHAP: 3.3291
Overall Sensitivity of Saliency: 0.5496
Overall Sensitivity of Guided Backpropagation: 0.5442
Overall Sensitivity of Occlusion: 0.0384


In [None]:
print('Overall Infidelity Metrics:')
print(f'Integrated Gradients: {overall_infidelity_ig:.4f} (Stability: {stability_infidelity_ig:.4f})')
print(f'Gradient SHAP: {overall_infidelity_gs:.4f} (Stability: {stability_infidelity_gs:.4f})')
print(f'Saliency: {overall_infidelity_saliency:.4f} (Stability: {stability_infidelity_saliency:.4f})')
print(f'Guided Backpropagation: {overall_infidelity_gbp:.4f} (Stability: {stability_infidelity_gbp:.4f})')
print(f'Occlusion: {overall_infidelity_occ:.4f} (Stability: {stability_infidelity_occ:.4f})')
print()
print('Overall Sensitivity Metrics:')
print(f'Integrated Gradients: {overall_sensitivity_ig:.4f} (Stability: {stability_sensitivity_ig:.4f})')
print(f'Gradient SHAP: {overall_sensitivity_gs:.4f} (Stability: {stability_sensitivity_gs:.4f})')
print(f'Saliency: {overall_sensitivity_saliency:.4f} (Stability: {stability_sensitivity_saliency:.4f})')
print(f'Guided Backpropagation: {overall_sensitivity_gbp:.4f} (Stability: {stability_sensitivity_gbp:.4f})')
print(f'Occlusion: {overall_sensitivity_occ:.4f} (Stability: {stability_sensitivity_occ:.4f})')


Overall Infidelity Metrics:
Integrated Gradients: 0.0006 (Stability: 0.0006)
Gradient SHAP: 0.0004 (Stability: 0.0004)
Saliency: 0.0007 (Stability: 0.0006)
Guided Backpropagation: 0.0004 (Stability: 0.0004)
Occlusion: 2.1232 (Stability: 2.0690)

Overall Sensitivity Metrics:
Integrated Gradients: 0.5774 (Stability: 0.0535)
Gradient SHAP: 3.3291 (Stability: 5.8345)
Saliency: 0.5496 (Stability: 0.0373)
Guided Backpropagation: 0.5442 (Stability: 0.0924)
Occlusion: 0.0384 (Stability: 0.0161)
