In [None]:
import os
import glob
from PIL import Image

jpeg_files_with_masks = []
jpeg_files_with_student_masks = []
extract_dir = "/content/drive/MyDrive/research/Splits"
sam_mask_dir = "/content/drive/MyDrive/research/sam_masks"

for slip_id in range(1, 11):
    slip_path = os.path.join(extract_dir, f"split_{slip_id}", "imagenet_data")

    nested_folders = glob.glob(os.path.join(slip_path, "*"))
    for folder in nested_folders:
        files = glob.glob(os.path.join(folder, "*.JPEG"))
        for jpeg_path in files:
            # Extract filename without extension
            filename = os.path.basename(jpeg_path).replace(".JPEG", "")
            # Construct corresponding mask path
            mask_path = os.path.join(sam_mask_dir, f"{filename}.png")

            # Check if the corresponding mask exists
            if os.path.exists(mask_path):
                jpeg_files_with_masks.append((jpeg_path, mask_path))

print(f"Paired {len(jpeg_files_with_masks)} images with SAM masks.")

Paired 176 images with SAM masks.


In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
import os
import random
import gc
import copy
import pickle
import sys

from PIL import Image
from time import time
from dataclasses import dataclass, asdict
from google.colab import drive
from IPython.display import clear_output
from collections import Counter
from itertools import combinations

import matplotlib.pyplot as plt
from matplotlib import rcParams

from torchvision import models
import torchvision.transforms as transforms
from torchvision.datasets import VOCSegmentation
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import torch

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, jaccard_score, f1_score, accuracy_score

from skimage.io import imread
from skimage.transform import resize

In [None]:
!git clone https://github.com/hila-chefer/Transformer-Explainability.git

os.chdir(f'./Transformer-Explainability')

!pip install einops

Cloning into 'Transformer-Explainability'...
remote: Enumerating objects: 386, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 386 (delta 3), reused 2 (delta 2), pack-reused 381 (from 2)[K
Receiving objects: 100% (386/386), 3.85 MiB | 33.69 MiB/s, done.
Resolving deltas: 100% (194/194), done.


In [None]:
def show_cam_on_image(img, mask):
    # create heatmap from mask on image
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

In [None]:
from baselines.ViT.ViT_explanation_generator import LRP
from baselines.ViT.ViT_explanation_generator import Baselines
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_LRP_new
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from torchvision import transforms

model_A = vit_LRP_new(pretrained=True).cuda()
model_B = vit_LRP(pretrained=True).cuda()

b = Baselines(model_A)
attribution_generator = LRP(model_B)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth


In [None]:
def generate_LRP(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    return transformer_attribution

def generate_saliency(original_image, class_index=None):
    original_image.requires_grad_()
    output = model_B(original_image.unsqueeze(0).cuda())
    loss = output[0, class_index] if class_index is not None else output.max()
    model_B.zero_grad()
    loss.backward()
    saliency = original_image.grad.data.abs().max(dim=0, keepdim=True)[0]
    saliency = torch.nn.functional.interpolate(saliency.unsqueeze(0), size=(14, 14), mode='bilinear')
    return saliency

def generate_rollout(input_image,class_index=None, start_layer=3):
    transformer_attribution = b.generate_rollout(input_image.unsqueeze(0).cuda(), start_layer=start_layer)
    return transformer_attribution

def generate_CAM(input_image, class_index=None):
    transformer_attribution = b.generate_cam_attn(input_image.unsqueeze(0).cuda(), index=class_index)
    return transformer_attribution

# Utility function to combine attributions and visualize
def combine_and_visualize_attributions_1way(input_image, method, use_thresholding=True):
    device = input_image.device
    attr = method(input_image).reshape(1, 1, 14, 14).to(device)

    combined_attr = torch.nn.functional.interpolate(attr, scale_factor=16, mode='bilinear')
    combined_attr = combined_attr.reshape(224, 224).cpu().detach().numpy()
    combined_attr = (combined_attr - combined_attr.min()) / (combined_attr.max() - combined_attr.min())

    if use_thresholding:
        combined_attr = combined_attr * 255
        combined_attr = combined_attr.astype(np.uint8)
        _, combined_attr = cv2.threshold(combined_attr, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        combined_attr[combined_attr == 255] = 1

    image_transformer_attribution = input_image.permute(1, 2, 0).cpu().detach().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, combined_attr)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, combined_attr

def compute_metrics(mask, gt):
    inter = np.logical_and(gt, mask).sum()
    union = np.logical_or(gt, mask).sum()
    jaccard = inter / union if union else 0
    tp = inter; fp = mask.sum() - tp; fn = gt.sum() - tp
    f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    pix_acc = inter / gt.sum() if gt.sum() > 0 else 0
    return jaccard, f1, pix_acc

# Function to visualize each method with different combine methods
def visualize_methods_1way(input_image, use_thresholding=True):
    methods = {
        'LRP': generate_LRP,
        'saliency': generate_saliency,
        'rollout': generate_rollout,
        'CAM': generate_CAM,
    }

    # Determine the predicted class index
    output = model_A(input_image.unsqueeze(0).cuda())
    class_index = output.argmax().item()
    # print(f"Predicted class index: {class_index}")

    results = []
    for method_name, method_func in methods.items():
          # print(f"Visualizing {method_name}")
          vis, mask = combine_and_visualize_attributions_1way(input_image, lambda img: method_func(img), use_thresholding)
          results.append((f"{method_name}", vis, mask))

    return results

In [None]:
transform_ = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

all_results_one_way = []
for img_path, mask_paths in tqdm(jpeg_files_with_masks, desc="Processing images with masks"):
  img_bgr = cv2.imread(img_path)
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  tens_ = transform_(img_rgb)

  true_mask = Image.open(mask_paths).convert("L")
  true_mask_resized = transforms.Resize((224, 224))(true_mask)
  true_mask_np = (np.array(true_mask_resized) > 0).astype(np.uint8)

  results = visualize_methods_1way(tens_, use_thresholding=True)
  for name, result, mask in results:
    iou, f1, px = compute_metrics(mask, true_mask_np)
    all_results_one_way.append({
            "Image Path": img_path,
            "Method": name,
            "Jaccard Index (IoU)": iou,
            "F1 Score": f1,
            "Pixel Accuracy": px
        })

  combined_attr = combined_attr.astype(np.uint8)
Processing images with masks: 100%|██████████| 176/176 [03:31<00:00,  1.20s/it]


In [None]:
results_df = pd.DataFrame(all_results_one_way)
results_df.to_csv("imagenet_sam_1way_total.csv", index=False)
stats = results_df.groupby("Method")[["Jaccard Index (IoU)", "F1 Score", "Pixel Accuracy"]].mean()
stats = stats.reset_index()

print(stats)
stats.to_csv("imagenetsam_1way.csv", index=False)


     Method  Jaccard Index (IoU)  F1 Score  Pixel Accuracy
0       CAM             0.143344  0.212258        0.193979
1       LRP             0.425939  0.567156        0.533221
2   rollout             0.365452  0.507887        0.663221
3  saliency             0.079444  0.127842        0.133708


In [None]:
import numpy as np
import torch
import cv2
from sklearn.metrics import auc

def deletion_metric(model, image, attribution_map, class_index=None, steps=100):
    """
    Computes the Deletion Metric for a given attribution map.

    Parameters:
    - model: Trained model used for classification.
    - image: Input image tensor (C, H, W).
    - attribution_map: The heatmap from CAM or LRP, normalized [0, 1].
    - class_index: Class index to track model confidence for (optional).
    - steps: Number of steps for iterative deletion.

    Returns:
    - auc_score: Area under the confidence curve (lower = better attribution).
    - confidence_drop: List of model confidences after each deletion step.
    """
    model.eval()

    # Flatten the attribution map and sort pixel indices by importance (descending order)
    importance_order = np.argsort(-attribution_map.flatten())

    # Create a copy of the image for deletion process
    image_np = image.permute(1, 2, 0).detach().cpu().numpy()  # Convert to (H, W, C)
    modified_image = image_np.copy()

    # Initial model confidence before deletion
    with torch.no_grad():
        output = model(image.unsqueeze(0).cuda())
        if class_index is None:
            class_index = output.argmax().item()
        initial_confidence = torch.softmax(output, dim=1)[0, class_index].item()

    confidence_drop = [initial_confidence]
    # print(f"Initial Confidence {initial_confidence}")

    # Deletion process: remove pixels in steps
    total_pixels = image_np.shape[0] * image_np.shape[1]
    pixels_per_step = total_pixels // steps

    for step in range(1, steps + 1):
        # Mask out the most important pixels
        pixels_to_mask = importance_order[(step - 1) * pixels_per_step: step * pixels_per_step]

        # Set those pixels to zero (blackout)
        for idx in pixels_to_mask:
            h, w = divmod(idx, image_np.shape[1])  # Convert 1D index to 2D coordinates
            modified_image[h, w, :] = 0  # Black out across all channels

        # Convert modified image back to tensor
        modified_image_tensor = torch.from_numpy(modified_image).permute(2, 0, 1).float().cuda()

        # Recalculate model confidence
        with torch.no_grad():
            output = model(modified_image_tensor.unsqueeze(0))
            confidence = torch.softmax(output, dim=1)[0, class_index].item()

        confidence_drop.append(confidence)

    # Calculate Area Under the Confidence Curve (AUC)
    x_axis = np.linspace(0, 1, len(confidence_drop))  # Percentage of pixels deleted
    auc_score = auc(x_axis, confidence_drop)
    # print(f"AUC {auc_score}")

    return auc_score, confidence_drop

In [None]:
all_expl_results_one_way = []

for img_path, mask_paths in tqdm(jpeg_files_with_masks, desc="Processing images with masks"):
  img_bgr = cv2.imread(img_path)
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  tens_ = transform_(img_rgb)

  true_mask = Image.open(mask_paths).convert("L")
  true_mask_resized = transforms.Resize((224, 224))(true_mask)
  true_mask_np = (np.array(true_mask_resized) > 0).astype(np.uint8)

  results = visualize_methods_1way(tens_, use_thresholding=False)
  for name, result, mask in results:
    auccc, _ = deletion_metric(model_A, tens_, mask)
    all_expl_results_one_way.append({
    "Image Index": img_path,
    "Method": name,
    "Deletion Accuracy": auccc
    })

  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
Processing images with masks: 100%|██████████| 176/176 [21:53<00:00,  7.46s/it]


In [None]:
def save_and_display_results1(file_name_for_saving, results):
  results_df = pd.DataFrame(results)
  csv_path = file_name_for_saving
  results_df.to_csv(csv_path, index=False)
  print(f"Results saved to {csv_path}")

  print("Statistics by Method and Combine Method:")
  stats = results_df.groupby("Method")[["Deletion Accuracy"]].mean()
  print(stats)

save_and_display_results1("NETmetrics_expl_results_1WAY_NOThresholding.csv", all_expl_results_one_way)

Results saved to NETmetrics_expl_results_1WAY_NOThresholding.csv
Statistics by Method and Combine Method:
          Deletion Accuracy
Method                     
CAM                0.401487
LRP                0.190977
rollout            0.238535
saliency           0.443203


# 2way

In [None]:
def combine_and_visualize_attributions_2way(input_image, method1, method2, combine_method='sqrt', use_thresholding=True):
    device = input_image.device
    attr1 = method1(input_image).reshape(1, 1, 14, 14).to(device)
    attr2 = method2(input_image).reshape(1, 1, 14, 14).to(device)

    if combine_method == 'sqrt':
        combined_attr = torch.sqrt(attr1 * attr2)
    elif combine_method == 'multiply':
        combined_attr = attr1 * attr2

    combined_attr = torch.nn.functional.interpolate(combined_attr, scale_factor=16, mode='bilinear')
    combined_attr = combined_attr.reshape(224, 224).cpu().detach().numpy()
    combined_attr = (combined_attr - combined_attr.min()) / (combined_attr.max() - combined_attr.min())

    if use_thresholding:
        combined_attr = combined_attr * 255
        combined_attr = combined_attr.astype(np.uint8)
        _, combined_attr = cv2.threshold(combined_attr, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        combined_attr[combined_attr == 255] = 1

    image_transformer_attribution = input_image.permute(1, 2, 0).cpu().detach().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, combined_attr)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, combined_attr

def visualize_combined_methods_2way(input_image, method1_name, method2_name, class_index, combine_method='sqrt', use_thresholding=True):
    methods = {
        'LRP': generate_LRP,
        'saliency': generate_saliency,
        'rollout': generate_rollout,
        'CAM': generate_CAM,
    }
    method1 = methods[method1_name]
    method2 = methods[method2_name]

    return combine_and_visualize_attributions_2way(input_image, lambda img: method1(img, class_index), lambda img: method2(img, class_index), combine_method, use_thresholding)

def visualize_all_combinations_2way(input_image, combine_methods=['sqrt', 'multiply'], use_thresholding=True):
    methods = ['LRP', 'saliency', 'rollout', 'CAM']
    combinations_list = list(combinations(methods, 2))

    # Determine the predicted class index
    output = model_A(input_image.unsqueeze(0).cuda())
    class_index = output.argmax().item()

    results = []
    for combo in combinations_list:
        for combine_method in combine_methods:
            vis, mask = visualize_combined_methods_2way(input_image, combo[0], combo[1], class_index, combine_method, use_thresholding)
            results.append((f"{' + '.join(combo)} ({combine_method})", vis, mask))

    return results

all_results_two_way = []
for img_path, mask_paths in tqdm(jpeg_files_with_masks, desc="Processing images with masks", mininterval=8.0):
  img_bgr = cv2.imread(img_path)
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  tens_ = transform_(img_rgb)

  true_mask = Image.open(mask_paths).convert("L")
  true_mask_resized = transforms.Resize((224, 224))(true_mask)
  true_mask_np = (np.array(true_mask_resized) > 0).astype(np.uint8)

  results = visualize_all_combinations_2way(tens_, combine_methods=['sqrt', 'multiply'], use_thresholding=True)
  for name, result, mask in results:
    iou, f1, px = compute_metrics(mask, true_mask_np)
    all_results_two_way.append({
            "Image Path": img_path,
            "Method": name,
            "Jaccard Index (IoU)": iou,
            "F1 Score": f1,
            "Pixel Accuracy": px
        })

  combined_attr = combined_attr.astype(np.uint8)
Processing images with masks: 100%|██████████| 176/176 [06:12<00:00,  2.12s/it]


In [None]:
results_df = pd.DataFrame(all_results_two_way)
results_df.to_csv("imagenet_sam_2way_total.csv", index=False)
stats = results_df.groupby("Method")[["Jaccard Index (IoU)", "F1 Score", "Pixel Accuracy"]].mean()
stats = stats.reset_index()
print(stats)
stats.to_csv("imagenetsam_2way.csv", index=False)

                           Method  Jaccard Index (IoU)  F1 Score  \
0            LRP + CAM (multiply)             0.130902  0.200635   
1                LRP + CAM (sqrt)             0.216166  0.310348   
2        LRP + rollout (multiply)             0.412267  0.547230   
3            LRP + rollout (sqrt)             0.523251  0.657107   
4       LRP + saliency (multiply)             0.217715  0.327822   
5           LRP + saliency (sqrt)             0.369727  0.509925   
6        rollout + CAM (multiply)             0.134932  0.201772   
7            rollout + CAM (sqrt)             0.206309  0.290058   
8       saliency + CAM (multiply)             0.091899  0.143442   
9           saliency + CAM (sqrt)             0.163829  0.237325   
10  saliency + rollout (multiply)             0.118137  0.187263   
11      saliency + rollout (sqrt)             0.191790  0.293633   

    Pixel Accuracy  
0         0.159289  
1         0.277168  
2         0.489180  
3         0.687073  
4         

In [None]:
all_expl_results_two_way = []

for img_path, mask_paths in tqdm(jpeg_files_with_masks, desc="Processing images with masks"):
  img_bgr = cv2.imread(img_path)
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  tens_ = transform_(img_rgb)

  true_mask = Image.open(mask_paths).convert("L")
  true_mask_resized = transforms.Resize((224, 224))(true_mask)
  true_mask_np = (np.array(true_mask_resized) > 0).astype(np.uint8)

  results = visualize_all_combinations_2way(tens_, combine_methods=['sqrt', 'multiply'], use_thresholding=False)
  for name, result, mask in results:
    auccc, _ = deletion_metric(model_A, tens_, mask)
    all_expl_results_two_way.append({
    "Image Index": img_path,
    "Method": name,
    "Deletion Accuracy": auccc
    })

  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
Processing images with masks: 100%|██████████| 176/176 [59:56<00:00, 20.43s/it]


In [None]:
def save_and_display_results1(file_name_for_saving, results):
  results_df = pd.DataFrame(results)
  csv_path = file_name_for_saving
  results_df.to_csv(csv_path, index=False)
  print(f"Results saved to {csv_path}")
  print("Statistics by Method and Combine Method:")
  stats = results_df.groupby("Method")[["Deletion Accuracy"]].mean()
  print(stats)

save_and_display_results1("NETmetrics_expl_results_2WAY_NOThresholding.csv", all_expl_results_two_way)

Results saved to NETmetrics_expl_results_2WAY_NOThresholding.csv
Statistics by Method and Combine Method:
                               Deletion Accuracy
Method                                          
LRP + CAM (multiply)                    0.378963
LRP + CAM (sqrt)                        0.373167
LRP + rollout (multiply)                0.186527
LRP + rollout (sqrt)                    0.180897
LRP + saliency (multiply)               0.242719
LRP + saliency (sqrt)                   0.230390
rollout + CAM (multiply)                0.390902
rollout + CAM (sqrt)                    0.385437
saliency + CAM (multiply)               0.409378
saliency + CAM (sqrt)                   0.401523
saliency + rollout (multiply)           0.349698
saliency + rollout (sqrt)               0.340655


# 3way

In [None]:
def combine_and_visualize_attributions_3way(input_image, methods, combine_method='sqrt', use_thresholding=True):
    device = input_image.device
    attributions = []
    for method in methods:
        if method.__name__ in ['generate_saliency', 'generate_CAM', 'generate_LRP']:
            attr = method(input_image, class_index=1).reshape(1, 1, 14, 14).to(device)  # class_index is set to 1 for demonstration
        else:
            attr = method(input_image).reshape(1, 1, 14, 14).to(device)
        attributions.append(attr)

    if combine_method == 'sqrt':
        combined_attr = torch.sqrt(attributions[0] * attributions[1] * attributions[2])
    elif combine_method == 'multiply':
        combined_attr = attributions[0] * attributions[1] * attributions[2]

    combined_attr = torch.nn.functional.interpolate(combined_attr, scale_factor=16, mode='bilinear')
    combined_attr = combined_attr.reshape(224, 224).cpu().detach().numpy()
    combined_attr = (combined_attr - combined_attr.min()) / (combined_attr.max() - combined_attr.min())

    if use_thresholding:
        combined_attr = combined_attr * 255
        combined_attr = combined_attr.astype(np.uint8)
        _, combined_attr = cv2.threshold(combined_attr, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        combined_attr[combined_attr == 255] = 1

    image_transformer_attribution = input_image.permute(1, 2, 0).cpu().detach().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, combined_attr)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, combined_attr

# Example usage
def visualize_combined_methods_3way(input_image, method_names, class_index, combine_method='sqrt', use_thresholding=True):
    methods = {
        'LRP': generate_LRP,
        'saliency': generate_saliency,
        'rollout': generate_rollout,
        'CAM': generate_CAM,
    }
    selected_methods = [methods[name] for name in method_names]

    return combine_and_visualize_attributions_3way(input_image, selected_methods, combine_method, use_thresholding)

# Function to visualize all 3-way combinations
def visualize_all_combinations_3way(input_image, combine_methods=['sqrt', 'multiply'], use_thresholding=True):
    methods = ['LRP', 'saliency', 'rollout', 'CAM']
    combinations_list = list(combinations(methods, 3))

    # Determine the predicted class index
    output = model_A(input_image.unsqueeze(0).cuda())
    class_index = output.argmax().item()

    results = []
    for combo in combinations_list:
        for combine_method in combine_methods:
            vis, mask = visualize_combined_methods_3way(input_image, combo, class_index, combine_method, use_thresholding)
            results.append((f"{' + '.join(combo)} ({combine_method})", vis, mask))

    return results


all_results_three_way = []
for img_path, mask_paths in tqdm(jpeg_files_with_masks, desc="Processing images with masks", mininterval=8.0):
  img_bgr = cv2.imread(img_path)
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  tens_ = transform_(img_rgb)

  true_mask = Image.open(mask_paths).convert("L")
  true_mask_resized = transforms.Resize((224, 224))(true_mask)
  true_mask_np = (np.array(true_mask_resized) > 0).astype(np.uint8)

  results = visualize_all_combinations_3way(tens_, combine_methods=['sqrt', 'multiply'], use_thresholding=True)
  for name, result, mask in results:
    iou, f1, px = compute_metrics(mask, true_mask_np)
    all_results_three_way.append({
            "Image Path": img_path,
            "Method": name,
            "Jaccard Index (IoU)": iou,
            "F1 Score": f1,
            "Pixel Accuracy": px
        })

In [None]:
results_df = pd.DataFrame(all_results_three_way)
results_df.to_csv("imagenet_sam_3way_total.csv", index=False)
stats = results_df.groupby("Method")[["Jaccard Index (IoU)", "F1 Score", "Pixel Accuracy"]].mean()
stats = stats.reset_index()
print(stats)
stats.to_csv("imagenetsam_3way.csv", index=False)


                                Method  Jaccard Index (IoU)  F1 Score  \
0       LRP + rollout + CAM (multiply)             0.218392  0.323411   
1           LRP + rollout + CAM (sqrt)             0.375814  0.505797   
2      LRP + saliency + CAM (multiply)             0.139217  0.216847   
3          LRP + saliency + CAM (sqrt)             0.278574  0.398569   
4  LRP + saliency + rollout (multiply)             0.201951  0.305029   
5      LRP + saliency + rollout (sqrt)             0.374451  0.515473   
6  saliency + rollout + CAM (multiply)             0.102894  0.159366   
7      saliency + rollout + CAM (sqrt)             0.222965  0.318297   

   Pixel Accuracy  
0        0.257308  
1        0.483598  
2        0.170882  
3        0.379368  
4        0.252577  
5        0.525543  
6        0.138337  
7        0.339636  


In [None]:
all_expl_results_three_way = []

for img_path, mask_paths in tqdm(jpeg_files_with_masks, desc="Processing images with masks"):
  img_bgr = cv2.imread(img_path)
  img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  tens_ = transform_(img_rgb)

  true_mask = Image.open(mask_paths).convert("L")
  true_mask_resized = transforms.Resize((224, 224))(true_mask)
  true_mask_np = (np.array(true_mask_resized) > 0).astype(np.uint8)

  results = visualize_all_combinations_3way(tens_, combine_methods=['sqrt', 'multiply'], use_thresholding=False)
  for name, result, mask in results:
    auccc, _ = deletion_metric(model_A, tens_, mask)
    all_expl_results_three_way.append({
    "Image Index": img_path,
    "Method": name,
    "Deletion Accuracy": auccc
    })

Processing images with masks: 100%|██████████| 176/176 [41:59<00:00, 14.31s/it]


In [None]:
save_and_display_results1("NETmetrics_expl_results_3WAY_NOThresholding.csv", all_expl_results_three_way)

Results saved to NETmetrics_expl_results_3WAY_NOThresholding.csv
Statistics by Method and Combine Method:
                                     Deletion Accuracy
Method                                                
LRP + rollout + CAM (multiply)                0.275882
LRP + rollout + CAM (sqrt)                    0.263707
LRP + saliency + CAM (multiply)               0.305666
LRP + saliency + CAM (sqrt)                   0.291612
LRP + saliency + rollout (multiply)           0.247940
LRP + saliency + rollout (sqrt)               0.236018
saliency + rollout + CAM (multiply)           0.337381
saliency + rollout + CAM (sqrt)               0.320341
