In [1]:
!pip install captum



In [2]:
from google.colab import drive
drive.mount('/content/drive')

import os
if not os.path.exists('/content/xai-stability-benchmark'):
    os.makedirs('/content/xai-stability-benchmark', exist_ok=True)

%cd /content/xai-stability-benchmark
!mkdir -p notebooks src data results figures

Mounted at /content/drive
/content/xai-stability-benchmark


In [3]:
import torch
import torchvision
from torchvision import models, transforms, datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
import json
import pickle
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from scipy.stats import spearmanr
from captum.attr import IntegratedGradients, Saliency, LayerGradCam, LayerAttribution

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = models.resnet18(pretrained=True).to(device)
model.eval()
print("Model loaded")

Device: cuda
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 418MB/s]


Model loaded


In [4]:
DATA_DIR = '/content/drive/MyDrive/xai-stability-data'

class XAIMethod:
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.method_name = "base"

    def preprocess(self, pil_image):
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transform(pil_image).unsqueeze(0).to(self.device)

    def postprocess_attribution(self, attribution_tensor):
        if isinstance(attribution_tensor, torch.Tensor):
            attr = attribution_tensor.squeeze().cpu().detach().numpy()
        else:
            attr = np.array(attribution_tensor).squeeze()

        if len(attr.shape) == 3:
            attr = np.mean(np.abs(attr), axis=0)
        else:
            attr = np.abs(attr)

        return attr

In [5]:
class IntegratedGradientsMethod(XAIMethod):
    def __init__(self, model, device='cuda'):
        super().__init__(model, device)
        self.method_name = "integrated_gradients"
        self.ig = IntegratedGradients(model)

    def generate_attribution(self, input_tensor, target_class):
        input_tensor = input_tensor.clone().detach().requires_grad_(True)
        attributions = self.ig.attribute(input_tensor, target=target_class, n_steps=50)
        return self.postprocess_attribution(attributions)

In [6]:
class GradCAMMethod(XAIMethod):
    def __init__(self, model, device='cuda'):
        super().__init__(model, device)
        self.method_name = "gradcam"
        self.gradcam = LayerGradCam(model, model.layer4[-1])

    def generate_attribution(self, input_tensor, target_class):
        attributions = self.gradcam.attribute(input_tensor, target=target_class)
        attributions = LayerAttribution.interpolate(attributions, input_tensor.shape[-2:])
        return self.postprocess_attribution(attributions)

In [7]:
class VanillaGradientsMethod(XAIMethod):
    def __init__(self, model, device='cuda'):
        super().__init__(model, device)
        self.method_name = "vanilla_gradients"
        self.saliency = Saliency(model)

    def generate_attribution(self, input_tensor, target_class):
        input_tensor = input_tensor.clone().detach().requires_grad_(True)
        attributions = self.saliency.attribute(input_tensor, target=target_class)
        return self.postprocess_attribution(attributions)

In [8]:
class PerturbationGenerator:
    def __init__(self, image):
        if isinstance(image, torch.Tensor):
            image = transforms.ToPILImage()(image)
        self.original_image = image

    def rotate(self, angle):
        return self.original_image.rotate(angle, fillcolor=(128, 128, 128))

    def translate(self, x_shift, y_shift):
        return self.original_image.transform(
            self.original_image.size, Image.AFFINE,
            (1, 0, x_shift, 0, 1, y_shift), fillcolor=(128, 128, 128))

    def add_gaussian_noise(self, sigma=0.01):
        img_array = np.array(self.original_image).astype(np.float32) / 255.0
        noise = np.random.normal(0, sigma, img_array.shape)
        noisy = np.clip(img_array + noise, 0, 1)
        return Image.fromarray((noisy * 255).astype(np.uint8))

    def adjust_brightness(self, factor):
        enhancer = ImageEnhance.Brightness(self.original_image)
        return enhancer.enhance(factor)

    def jpeg_compression(self, quality):
        from io import BytesIO
        buffer = BytesIO()
        self.original_image.save(buffer, format='JPEG', quality=quality)
        buffer.seek(0)
        return Image.open(buffer)

    def generate_all_perturbations(self, config=None):
        if config is None:
            config = {
                'rotation': [-5, 5],
                'translation': [(-10, -10), (10, 10)],
                'noise_sigma': [0.01],
                'brightness': [0.9, 1.1],
                'jpeg_quality': [85]
            }

        perturbations = {'original': self.original_image}

        for angle in config['rotation']:
            perturbations[f'rotate_{angle}'] = self.rotate(angle)
        for i, (x, y) in enumerate(config['translation']):
            perturbations[f'translate_{i}'] = self.translate(x, y)
        for sigma in config['noise_sigma']:
            perturbations[f'noise_{sigma}'] = self.add_gaussian_noise(sigma)
        for factor in config['brightness']:
            perturbations[f'brightness_{factor}'] = self.adjust_brightness(factor)
        for quality in config['jpeg_quality']:
            perturbations[f'jpeg_{quality}'] = self.jpeg_compression(quality)

        return perturbations

In [9]:
class SimilarityMetrics:
    @staticmethod
    def normalize_attribution(attr):
        attr_min = attr.min()
        attr_max = attr.max()
        if attr_max - attr_min > 0:
            return (attr - attr_min) / (attr_max - attr_min)
        return attr

    @staticmethod
    def ssim_similarity(attr1, attr2):
        attr1_norm = SimilarityMetrics.normalize_attribution(attr1)
        attr2_norm = SimilarityMetrics.normalize_attribution(attr2)

        if attr1_norm.shape != attr2_norm.shape:
            raise ValueError(f"Shape mismatch: {attr1_norm.shape} vs {attr2_norm.shape}")

        score, _ = ssim(attr1_norm, attr2_norm, full=True, data_range=1.0)
        return score

    @staticmethod
    def spearman_similarity(attr1, attr2):
        flat1 = attr1.flatten()
        flat2 = attr2.flatten()
        corr, _ = spearmanr(flat1, flat2)

        if np.isnan(corr):
            corr = 0.0

        return corr

    @staticmethod
    def top_k_overlap(attr1, attr2, k=100):
        flat1 = attr1.flatten()
        flat2 = attr2.flatten()

        k = min(k, len(flat1))
        top_k1 = set(np.argsort(flat1)[-k:])
        top_k2 = set(np.argsort(flat2)[-k:])

        intersection = len(top_k1 & top_k2)
        union = len(top_k1 | top_k2)

        return intersection / union if union > 0 else 0.0

    @staticmethod
    def combined_similarity(attr1, attr2, weights=None):
        if weights is None:
            weights = {'ssim': 1/3, 'spearman': 1/3, 'topk': 1/3}

        ssim_score = SimilarityMetrics.ssim_similarity(attr1, attr2)
        spearman_score = SimilarityMetrics.spearman_similarity(attr1, attr2)
        topk_score = SimilarityMetrics.top_k_overlap(attr1, attr2, k=100)

        ssim_normalized = (ssim_score + 1) / 2
        spearman_normalized = (spearman_score + 1) / 2

        combined = (
            weights['ssim'] * ssim_normalized +
            weights['spearman'] * spearman_normalized +
            weights['topk'] * topk_score
        )

        return combined

In [10]:
class FASSCalculator:
    def __init__(self, xai_method, similarity_weights=None):
        self.xai_method = xai_method
        self.similarity_weights = similarity_weights

    def compute_fass_for_image(self, original_image, perturbations_dict, target_class):
        attributions = {}

        for name, img in perturbations_dict.items():
            input_tensor = self.xai_method.preprocess(img)
            attr = self.xai_method.generate_attribution(input_tensor, target_class)
            attributions[name] = attr

        names = list(attributions.keys())
        similarities = []

        for i in range(len(names)):
            for j in range(i + 1, len(names)):
                sim = SimilarityMetrics.combined_similarity(
                    attributions[names[i]],
                    attributions[names[j]],
                    weights=self.similarity_weights
                )
                similarities.append({
                    'pair': (names[i], names[j]),
                    'similarity': sim
                })

        fass_score = np.mean([s['similarity'] for s in similarities])

        return {
            'fass_score': fass_score,
            'pairwise_similarities': similarities,
            'attributions': attributions,
            'num_perturbations': len(perturbations_dict)
        }

In [11]:
print("Loading CIFAR-10 dataset...")
transform_simple = transforms.Compose([transforms.ToTensor()])
cifar_test = datasets.CIFAR10(root=f'{DATA_DIR}/cifar10', train=False, download=False, transform=transform_simple)
print(f"Loaded {len(cifar_test)} test images")

NUM_IMAGES = 500
BATCH_SIZE = 50

methods = {
    'integrated_gradients': IntegratedGradientsMethod(model, device),
    'gradcam': GradCAMMethod(model, device),
    'vanilla_gradients': VanillaGradientsMethod(model, device)
}

results_all_methods = {}

for method_name, xai_method in methods.items():
    print(f"\nProcessing method: {method_name}")
    print("="*60)

    fass_calc = FASSCalculator(xai_method)
    all_fass_scores = []
    all_detailed_results = []

    for batch_start in range(0, NUM_IMAGES, BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, NUM_IMAGES)
        print(f"\nBatch {batch_start//BATCH_SIZE + 1}: Images {batch_start}-{batch_end}")

        for img_idx in tqdm(range(batch_start, batch_end)):
            try:
                image_tensor, label = cifar_test[img_idx]
                image_pil = transforms.ToPILImage()(image_tensor)

                input_tensor = xai_method.preprocess(image_pil)
                with torch.no_grad():
                    output = model(input_tensor)
                    pred_class = output.argmax(dim=1).item()

                perturb_gen = PerturbationGenerator(image_pil)
                perturbations = perturb_gen.generate_all_perturbations()

                result = fass_calc.compute_fass_for_image(image_pil, perturbations, pred_class)
                all_fass_scores.append(result['fass_score'])
                all_detailed_results.append({
                    'image_idx': img_idx,
                    'fass_score': result['fass_score'],
                    'true_label': label,
                    'pred_class': pred_class
                })

            except Exception as e:
                print(f"\nError on image {img_idx}: {str(e)}")
                continue

        if (batch_start + BATCH_SIZE) % 100 == 0:
            checkpoint = {
                'method': method_name,
                'scores': all_fass_scores,
                'results': all_detailed_results
            }
            with open(f'results/checkpoint_{method_name}_{batch_start}.pkl', 'wb') as f:
                pickle.dump(checkpoint, f)

    results_all_methods[method_name] = {
        'mean_fass': np.mean(all_fass_scores),
        'std_fass': np.std(all_fass_scores),
        'median_fass': np.median(all_fass_scores),
        'min_fass': np.min(all_fass_scores),
        'max_fass': np.max(all_fass_scores),
        'all_scores': all_fass_scores,
        'detailed_results': all_detailed_results
    }





Loading CIFAR-10 dataset...
Loaded 10000 test images

Processing method: integrated_gradients

Batch 1: Images 0-50


100%|██████████| 50/50 [00:31<00:00,  1.58it/s]



Batch 2: Images 50-100


100%|██████████| 50/50 [00:30<00:00,  1.65it/s]



Batch 3: Images 100-150


100%|██████████| 50/50 [00:30<00:00,  1.65it/s]



Batch 4: Images 150-200


100%|██████████| 50/50 [00:31<00:00,  1.61it/s]



Batch 5: Images 200-250


100%|██████████| 50/50 [00:30<00:00,  1.64it/s]



Batch 6: Images 250-300


100%|██████████| 50/50 [00:30<00:00,  1.63it/s]



Batch 7: Images 300-350


100%|██████████| 50/50 [00:30<00:00,  1.65it/s]



Batch 8: Images 350-400


100%|██████████| 50/50 [00:30<00:00,  1.64it/s]



Batch 9: Images 400-450


100%|██████████| 50/50 [00:30<00:00,  1.65it/s]



Batch 10: Images 450-500


100%|██████████| 50/50 [00:30<00:00,  1.64it/s]



Processing method: gradcam

Batch 1: Images 0-50


100%|██████████| 50/50 [00:11<00:00,  4.54it/s]



Batch 2: Images 50-100


100%|██████████| 50/50 [00:11<00:00,  4.51it/s]



Batch 3: Images 100-150


100%|██████████| 50/50 [00:10<00:00,  4.57it/s]



Batch 4: Images 150-200


100%|██████████| 50/50 [00:10<00:00,  4.55it/s]



Batch 5: Images 200-250


100%|██████████| 50/50 [00:10<00:00,  4.56it/s]



Batch 6: Images 250-300


100%|██████████| 50/50 [00:10<00:00,  4.59it/s]



Batch 7: Images 300-350


100%|██████████| 50/50 [00:11<00:00,  4.51it/s]



Batch 8: Images 350-400


100%|██████████| 50/50 [00:10<00:00,  4.59it/s]



Batch 9: Images 400-450


100%|██████████| 50/50 [00:11<00:00,  4.50it/s]



Batch 10: Images 450-500


100%|██████████| 50/50 [00:11<00:00,  4.48it/s]



Processing method: vanilla_gradients

Batch 1: Images 0-50


100%|██████████| 50/50 [00:26<00:00,  1.86it/s]



Batch 2: Images 50-100


100%|██████████| 50/50 [00:26<00:00,  1.88it/s]



Batch 3: Images 100-150


100%|██████████| 50/50 [00:26<00:00,  1.88it/s]



Batch 4: Images 150-200


100%|██████████| 50/50 [00:27<00:00,  1.84it/s]



Batch 5: Images 200-250


100%|██████████| 50/50 [00:27<00:00,  1.83it/s]



Batch 6: Images 250-300


100%|██████████| 50/50 [00:27<00:00,  1.85it/s]



Batch 7: Images 300-350


100%|██████████| 50/50 [00:27<00:00,  1.84it/s]



Batch 8: Images 350-400


100%|██████████| 50/50 [00:27<00:00,  1.85it/s]



Batch 9: Images 400-450


100%|██████████| 50/50 [00:27<00:00,  1.84it/s]



Batch 10: Images 450-500


100%|██████████| 50/50 [00:26<00:00,  1.86it/s]


In [13]:
print(f"\n{method_name} Summary:")
print(f"Mean FASS: {results_all_methods[method_name]['mean_fass']:.4f}")
print(f"Std FASS: {results_all_methods[method_name]['std_fass']:.4f}")


vanilla_gradients Summary:
Mean FASS: 0.4208
Std FASS: 0.0187


In [14]:
print("\nFinal Results:")
print("="*60)
for method_name, stats in results_all_methods.items():
    print(f"{method_name:25s} FASS: {stats['mean_fass']:.4f} ± {stats['std_fass']:.4f}")

with open('results/final_results.pkl', 'wb') as f:
    pickle.dump(results_all_methods, f)

with open('results/final_results.json', 'w') as f:
    results_serializable = {
        method: {
            'mean_fass': float(stats['mean_fass']),
            'std_fass': float(stats['std_fass']),
            'median_fass': float(stats['median_fass']),
            'min_fass': float(stats['min_fass']),
            'max_fass': float(stats['max_fass'])
        }
        for method, stats in results_all_methods.items()
    }
    json.dump(results_serializable, f, indent=2)

print("\nResults saved to results/final_results.pkl and results/final_results.json")
print("Experiment complete")


Final Results:
integrated_gradients      FASS: 0.4613 ± 0.0241
gradcam                   FASS: 0.5576 ± 0.0183
vanilla_gradients         FASS: 0.4208 ± 0.0187

Results saved to results/final_results.pkl and results/final_results.json
Experiment complete
