# Imports

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
import os
import random
import gc
import copy
import pickle
import sys

from PIL import Image
from time import time
from dataclasses import dataclass, asdict
from google.colab import drive
from IPython.display import clear_output
from collections import Counter
from itertools import combinations

import matplotlib.pyplot as plt
from matplotlib import rcParams

from torchvision import models
import torchvision.transforms as transforms
from torchvision.datasets import VOCSegmentation
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import torch

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, jaccard_score, f1_score, accuracy_score

from skimage.io import imread
from skimage.transform import resize

In [None]:
!git clone https://github.com/hila-chefer/Transformer-Explainability.git

os.chdir(f'./Transformer-Explainability')

!pip install einops

# VOC Dataset

In [None]:
dataset = VOCSegmentation(root='data', year='2012', image_set='val', download=True, transform=transforms.ToTensor())

# Utilities

In [None]:
def show_cam_on_image(img, mask):
    # create heatmap from mask on image
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

def print_top_classes(predictions):
    prob = torch.softmax(predictions, dim=1)
    top_prob, top_class = prob.topk(1, dim=1)
    return top_prob.item(), top_class.item()

def manipulate_img_and_mask(image, true_mask):
    # Convert tensor image to numpy array and back to tensor after transformation
    image_np = image.permute(1, 2, 0).numpy()
    image_np = (image_np * 255).astype(np.uint8)

    # Convert mask to numpy array and resize to match the image dimensions
    true_mask_np = np.array(true_mask)
    true_mask_np_resized = cv2.resize(true_mask_np, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Binarize the true mask (foreground is 1, background is 0)
    true_mask_np_resized = (true_mask_np_resized > 0).astype(np.uint8)

    # Convert images back to tensors
    image_tensor = transforms.ToTensor()(image_np)

    # Resize the original image to 224x224
    transform_resize = transforms.Compose([
        transforms.Resize((224, 224)),
    ])
    image_resized = transform_resize(image_tensor)

    return image_resized, true_mask_np_resized

# Model

In [None]:
from baselines.ViT.ViT_explanation_generator import LRP
from baselines.ViT.ViT_explanation_generator import Baselines
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit_LRP_new
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from torchvision import transforms

model_A = vit_LRP_new(pretrained=True).cuda()
model_B = vit_LRP(pretrained=True).cuda()

b = Baselines(model_A)
attribution_generator = LRP(model_B)

# 1Way

In [None]:
def generate_LRP(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    return transformer_attribution

def generate_saliency(original_image, class_index=None):
    original_image.requires_grad_()
    output = model_B(original_image.unsqueeze(0).cuda())
    loss = output[0, class_index] if class_index is not None else output.max()
    model_B.zero_grad()
    loss.backward()
    saliency = original_image.grad.data.abs().max(dim=0, keepdim=True)[0]
    saliency = torch.nn.functional.interpolate(saliency.unsqueeze(0), size=(14, 14), mode='bilinear')
    return saliency

def generate_rollout(input_image,class_index=None, start_layer=3):
    transformer_attribution = b.generate_rollout(input_image.unsqueeze(0).cuda(), start_layer=start_layer)
    return transformer_attribution

def generate_CAM(input_image, class_index=None):
    transformer_attribution = b.generate_cam_attn(input_image.unsqueeze(0).cuda(), index=class_index)
    return transformer_attribution

# Utility function to combine attributions and visualize
def combine_and_visualize_attributions_1way(input_image, method, use_thresholding=True):
    device = input_image.device
    attr = method(input_image).reshape(1, 1, 14, 14).to(device)

    combined_attr = torch.nn.functional.interpolate(attr, scale_factor=16, mode='bilinear')
    combined_attr = combined_attr.reshape(224, 224).cpu().detach().numpy()
    combined_attr = (combined_attr - combined_attr.min()) / (combined_attr.max() - combined_attr.min())

    if use_thresholding:
        combined_attr = combined_attr * 255
        combined_attr = combined_attr.astype(np.uint8)
        _, combined_attr = cv2.threshold(combined_attr, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        combined_attr[combined_attr == 255] = 1

    image_transformer_attribution = input_image.permute(1, 2, 0).cpu().detach().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, combined_attr)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, combined_attr

# Function to visualize each method with different combine methods
def visualize_methods_1way(input_image, use_thresholding=True):
    methods = {
        'LRP': generate_LRP,
        'saliency': generate_saliency,
        'rollout': generate_rollout,
        'CAM': generate_CAM,
    }

    output = model_A(input_image.unsqueeze(0).cuda())
    class_index = output.argmax().item()

    results = []
    for method_name, method_func in methods.items():
          vis, mask = combine_and_visualize_attributions_1way(input_image, lambda img: method_func(img), use_thresholding)
          results.append((f"{method_name}", vis, mask))

    return results

In [None]:
image, true_mask = dataset[20]
image_resized, true_mask_np_resized = manipulate_img_and_mask(image, true_mask)

# Convert tensor image to numpy array
image_np_resized = image_resized.permute(1, 2, 0).numpy()
image_np_resized = (image_np_resized * 255).astype(np.uint8)
results = visualize_methods_1way(image_resized, use_thresholding=True)

fig, axes = plt.subplots(len(results), 3, figsize=(15, len(results)*3))
for ax_row, (name, result, mask) in zip(axes, results):
    ax_row[0].imshow(result)
    ax_row[0].set_title(name)
    ax_row[0].axis('off')

    ax_row[1].imshow(mask, cmap='gray')
    ax_row[1].set_title(f"{name} - Mask")
    ax_row[1].axis('off')

    ax_row[2].imshow((np.array(transforms.Resize((224, 224))(true_mask)) > 0).astype(np.uint8), cmap='gray')
    ax_row[2].set_title("GT Lesion Mask")
    ax_row[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms

output_folder = "1way_folder"
os.makedirs(output_folder, exist_ok=True)

all_results_one_way = []

for idx, (image_data, lesion_data) in enumerate(dataset):
    image_resized, true_mask_np_resized = manipulate_img_and_mask(image_data, lesion_data)
    output = model_A(image_resized.unsqueeze(0).cuda())
    top_prob, top_class = print_top_classes(output)

    if top_prob > 0.85:
        results = visualize_methods_1way(image_resized, use_thresholding=True)
        if results != []:
            for name, result, mask in results:
                predicted_mask_np = (mask > 0.5).astype(np.uint8)
                true_mask_resized = cv2.resize(true_mask_np_resized, (224, 224), interpolation=cv2.INTER_NEAREST)
    
                # Flatten the masks for metric calculation
                true_mask_flat = true_mask_resized.flatten()
                predicted_mask_flat = predicted_mask_np.flatten()
    
                jaccard = jaccard_score(true_mask_flat, predicted_mask_flat)
                f1 = f1_score(true_mask_flat, predicted_mask_flat)
                pixel_accuracy = accuracy_score(true_mask_flat, predicted_mask_flat)
    
                # Store metrics
                all_results_one_way.append({
                    "Image Index": idx,
                    "Method": name,
                    "Jaccard Index (IoU)": jaccard,
                    "F1 Score": f1,
                    "Pixel Accuracy": pixel_accuracy
                })

In [None]:
def save_and_display_results(file_name_for_saving, results):
  results_df = pd.DataFrame(results)
  csv_path = file_name_for_saving
  results_df.to_csv(csv_path, index=False)
  print(f"Results saved to {csv_path}")

  print("Statistics by Method and Combine Method:")
  stats = results_df.groupby("Method")[["Jaccard Index (IoU)", "F1 Score", "Pixel Accuracy"]].mean()
  print(stats)

In [None]:
save_and_display_results("VOCmetrics_results_1WAY.csv", all_results_one_way)

In [None]:
import numpy as np
import torch
import cv2
from sklearn.metrics import auc 

def deletion_metric(model, image, attribution_map, class_index=None, steps=100):
    """
    Computes the Deletion Metric for a given attribution map.

    Parameters:
    - model: Trained model used for classification.
    - image: Input image tensor (C, H, W).
    - attribution_map: The heatmap from CAM or LRP, normalized [0, 1].
    - class_index: Class index to track model confidence for (optional).
    - steps: Number of steps for iterative deletion.

    Returns:
    - auc_score: Area under the confidence curve (lower = better attribution).
    - confidence_drop: List of model confidences after each deletion step.
    """
    model.eval()

    # Flatten the attribution map and sort pixel indices by importance (descending order)
    importance_order = np.argsort(-attribution_map.flatten())

    # Create a copy of the image for deletion process
    image_np = image.permute(1, 2, 0).detach().cpu().numpy()  # Convert to (H, W, C)
    modified_image = image_np.copy()

    # Initial model confidence before deletion
    with torch.no_grad():
        output = model(image.unsqueeze(0).cuda())
        if class_index is None:
            class_index = output.argmax().item()
        initial_confidence = torch.softmax(output, dim=1)[0, class_index].item()

    confidence_drop = [initial_confidence]

    # Deletion process: remove pixels in steps
    total_pixels = image_np.shape[0] * image_np.shape[1]
    pixels_per_step = total_pixels // steps

    for step in range(1, steps + 1):
        # Mask out the most important pixels
        pixels_to_mask = importance_order[(step - 1) * pixels_per_step: step * pixels_per_step]

        # Set those pixels to zero (blackout)
        for idx in pixels_to_mask:
            h, w = divmod(idx, image_np.shape[1])  # Convert 1D index to 2D coordinates
            modified_image[h, w, :] = 0  # Black out across all channels

        # Convert modified image back to tensor
        modified_image_tensor = torch.from_numpy(modified_image).permute(2, 0, 1).float().cuda()

        # Recalculate model confidence
        with torch.no_grad():
            output = model(modified_image_tensor.unsqueeze(0))
            confidence = torch.softmax(output, dim=1)[0, class_index].item()

        confidence_drop.append(confidence)

    # Calculate Area Under the Confidence Curve (AUC)
    x_axis = np.linspace(0, 1, len(confidence_drop)) 
    auc_score = auc(x_axis, confidence_drop)

    return auc_score, confidence_drop

In [None]:
all_expl_results_one_way = []
for idx in tqdm(range(len(dataset)), desc="Processing", unit="image"):
    image, true_mask = dataset[idx]

    image_resized, true_mask_np_resized = manipulate_img_and_mask(image, true_mask)

    # Convert tensor image to numpy array
    image_np_resized = image_resized.permute(1, 2, 0).numpy()
    image_np_resized = (image_np_resized * 255).astype(np.uint8)

    # Get model prediction and probability
    output = model_A(image_resized.unsqueeze(0).cuda())
    top_prob, top_class = print_top_classes(output)
    if top_prob > 0.85:
        results = visualize_methods_1way(image_resized, use_thresholding=False)
        if results != []:
            for name, result, mask in results:

                auccc, _ = deletion_metric(model_A, image_resized, mask)
                all_expl_results_one_way.append({
                "Image Index": idx,
                "Method": name,
                "Deletion Accuracy": auccc
                })

In [None]:
def save_and_display_results1(file_name_for_saving, results):
  results_df = pd.DataFrame(results)
  csv_path = file_name_for_saving
  results_df.to_csv(csv_path, index=False)
  print(f"Results saved to {csv_path}")

  print("Statistics by Method and Combine Method:")
  stats = results_df.groupby("Method")[["Deletion Accuracy"]].mean()
  print(stats)

save_and_display_results1("VOCmetrics_expl_results_1WAY_NOThresholding.csv", all_expl_results_one_way)

# 2way

In [None]:
def combine_and_visualize_attributions_2way(input_image, method1, method2, combine_method='sqrt', use_thresholding=True):
    device = input_image.device
    attr1 = method1(input_image).reshape(1, 1, 14, 14).to(device)
    attr2 = method2(input_image).reshape(1, 1, 14, 14).to(device)

    if combine_method == 'sqrt':
        combined_attr = torch.sqrt(attr1 * attr2)
    elif combine_method == 'multiply':
        combined_attr = attr1 * attr2

    combined_attr = torch.nn.functional.interpolate(combined_attr, scale_factor=16, mode='bilinear')
    combined_attr = combined_attr.reshape(224, 224).cpu().detach().numpy()
    combined_attr = (combined_attr - combined_attr.min()) / (combined_attr.max() - combined_attr.min())

    if use_thresholding:
        combined_attr = combined_attr * 255
        combined_attr = combined_attr.astype(np.uint8)
        _, combined_attr = cv2.threshold(combined_attr, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        combined_attr[combined_attr == 255] = 1

    image_transformer_attribution = input_image.permute(1, 2, 0).cpu().detach().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, combined_attr)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, combined_attr

# Example usage
def visualize_combined_methods_2way(input_image, method1_name, method2_name, class_index, combine_method='sqrt', use_thresholding=True):
    methods = {
        'LRP': generate_LRP,
        'saliency': generate_saliency,
        'rollout': generate_rollout,
        'CAM': generate_CAM,
    }
    method1 = methods[method1_name]
    method2 = methods[method2_name]

    return combine_and_visualize_attributions_2way(input_image, lambda img: method1(img, class_index), lambda img: method2(img, class_index), combine_method, use_thresholding)

# Function to visualize all 2-way combinations
def visualize_all_combinations_2way(input_image, combine_methods=['sqrt', 'multiply'], use_thresholding=True):
    methods = ['LRP', 'saliency', 'rollout', 'CAM']
    combinations_list = list(combinations(methods, 2))

    # Determine the predicted class index
    output = model_A(input_image.unsqueeze(0).cuda())
    class_index = output.argmax().item()

    results = []
    for combo in combinations_list:
        for combine_method in combine_methods:
            vis, mask = visualize_combined_methods_2way(input_image, combo[0], combo[1], class_index, combine_method, use_thresholding)
            results.append((f"{' + '.join(combo)} ({combine_method})", vis, mask))

    return results

image, true_mask = dataset[20]
image_resized, true_mask_np_resized = manipulate_img_and_mask(image, true_mask)

# Convert tensor image to numpy array
image_np_resized = image_resized.permute(1, 2, 0).numpy()
image_np_resized = (image_np_resized * 255).astype(np.uint8)

# Visualize all combinations
results = visualize_all_combinations_2way(image_resized, combine_methods=['sqrt', 'multiply'], use_thresholding=True)

# Display the results
fig, axes = plt.subplots(len(results), 3, figsize=(15, len(results)*3))
for ax_row, (name, result, mask) in zip(axes, results):
    ax_row[0].imshow(result)
    ax_row[0].set_title(name)
    ax_row[0].axis('off')

    ax_row[1].imshow(mask, cmap='gray')
    ax_row[1].set_title(f"{name} - Mask")
    ax_row[1].axis('off')

    ax_row[2].imshow((np.array(transforms.Resize((224, 224))(true_mask)) > 0).astype(np.uint8), cmap='gray')
    ax_row[2].set_title("GT Lesion Mask")
    ax_row[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms

all_results_two_way = []

for idx, (image_data, lesion_data) in enumerate(dataset):
    image_resized, true_mask_np_resized = manipulate_img_and_mask(image_data, lesion_data)
    output = model_A(image_resized.unsqueeze(0).cuda())
    top_prob, top_class = print_top_classes(output)
    
    if top_prob > 0.85:
        results = visualize_all_combinations_2way(image_resized, combine_methods=['sqrt', 'multiply'], use_thresholding=True)
        if results != []:
            for name, result, mask in results:
                predicted_mask_np = (mask > 0.5).astype(np.uint8)
                true_mask_resized = cv2.resize(true_mask_np_resized, (224, 224), interpolation=cv2.INTER_NEAREST)
    
                # Flatten the masks for metric calculation
                true_mask_flat = true_mask_resized.flatten()
                predicted_mask_flat = predicted_mask_np.flatten()
    
                jaccard = jaccard_score(true_mask_flat, predicted_mask_flat)
                f1 = f1_score(true_mask_flat, predicted_mask_flat)
                pixel_accuracy = accuracy_score(true_mask_flat, predicted_mask_flat)
    
                # Store metrics
                all_results_two_way.append({
                    "Image Index": idx,
                    "Method": name,
                    "Jaccard Index (IoU)": jaccard,
                    "F1 Score": f1,
                    "Pixel Accuracy": pixel_accuracy
                })

In [None]:
save_and_display_results("VOCmetrics_results_2WAY.csv", all_results_two_way)

In [None]:
all_expl_results_two_way = []
for idx in tqdm(range(len(dataset)), desc="Processing", unit="image"):
    image, true_mask = dataset[idx]
    image_resized, true_mask_np_resized = manipulate_img_and_mask(image, true_mask)

    # Convert tensor image to numpy array
    image_np_resized = image_resized.permute(1, 2, 0).numpy()
    image_np_resized = (image_np_resized * 255).astype(np.uint8)

    # Get model prediction and probability
    output = model_A(image_resized.unsqueeze(0).cuda())
    top_prob, top_class = print_top_classes(output)
    if top_prob > 0.85:
        results = visualize_all_combinations_2way(image_resized, combine_methods=['sqrt', 'multiply'], use_thresholding=False)
        if results != []:
            for name, result, mask in results:
                auccc, _ = deletion_metric(model_A, image_resized, mask)
                all_expl_results_two_way.append({
                "Image Index": idx,
                "Method": name,
                "Deletion Accuracy": auccc
                })

In [None]:
save_and_display_results1("VOCmetrics_expl_results_2WAY_NOThresholding.csv", all_expl_results_two_way)

# 3way

In [None]:
def combine_and_visualize_attributions_3way(input_image, methods, combine_method='sqrt', use_thresholding=True):
    device = input_image.device
    attributions = []
    for method in methods:
        if method.__name__ in ['generate_saliency', 'generate_CAM', 'generate_LRP']:
            attr = method(input_image, class_index=1).reshape(1, 1, 14, 14).to(device)  # class_index is set to 1 for demonstration
        else:
            attr = method(input_image).reshape(1, 1, 14, 14).to(device)
        attributions.append(attr)

    if combine_method == 'sqrt':
        combined_attr = torch.sqrt(attributions[0] * attributions[1] * attributions[2])
    elif combine_method == 'multiply':
        combined_attr = attributions[0] * attributions[1] * attributions[2]

    combined_attr = torch.nn.functional.interpolate(combined_attr, scale_factor=16, mode='bilinear')
    combined_attr = combined_attr.reshape(224, 224).cpu().detach().numpy()
    combined_attr = (combined_attr - combined_attr.min()) / (combined_attr.max() - combined_attr.min())

    if use_thresholding:
        combined_attr = combined_attr * 255
        combined_attr = combined_attr.astype(np.uint8)
        _, combined_attr = cv2.threshold(combined_attr, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        combined_attr[combined_attr == 255] = 1

    image_transformer_attribution = input_image.permute(1, 2, 0).cpu().detach().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, combined_attr)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis, combined_attr

# Example usage
def visualize_combined_methods_3way(input_image, method_names, class_index, combine_method='sqrt', use_thresholding=True):
    methods = {
        'LRP': generate_LRP,
        'saliency': generate_saliency,
        'rollout': generate_rollout,
        'CAM': generate_CAM,
    }
    selected_methods = [methods[name] for name in method_names]

    return combine_and_visualize_attributions_3way(input_image, selected_methods, combine_method, use_thresholding)

# Function to visualize all 3-way combinations
def visualize_all_combinations_3way(input_image, combine_methods=['sqrt', 'multiply'], use_thresholding=True):
    methods = ['LRP', 'saliency', 'rollout', 'CAM']
    combinations_list = list(combinations(methods, 3))

    # Determine the predicted class index
    output = model_A(input_image.unsqueeze(0).cuda())
    class_index = output.argmax().item()
    # print(f"Predicted class index: {class_index}")

    results = []
    for combo in combinations_list:
        for combine_method in combine_methods:
            # print(f"Visualizing {' + '.join(combo)} with {combine_method}")
            vis, mask = visualize_combined_methods_3way(input_image, combo, class_index, combine_method, use_thresholding)
            results.append((f"{' + '.join(combo)} ({combine_method})", vis, mask))

    return results

image, true_mask = dataset[20]
image_resized, true_mask_np_resized = manipulate_img_and_mask(image, true_mask)
image_np_resized = image_resized.permute(1, 2, 0).numpy()
image_np_resized = (image_np_resized * 255).astype(np.uint8)

results = visualize_all_combinations_3way(image_resized, combine_methods=['sqrt', 'multiply'], use_thresholding=True)

# Display the results
fig, axes = plt.subplots(len(results), 3, figsize=(15, len(results)*3))
for ax_row, (name, result, mask) in zip(axes, results):
    ax_row[0].imshow(result)
    ax_row[0].set_title(name)
    ax_row[0].axis('off')

    ax_row[1].imshow(mask, cmap='gray')
    ax_row[1].set_title(f"{name} - Mask")
    ax_row[1].axis('off')

    ax_row[2].imshow((np.array(transforms.Resize((224, 224))(true_mask)) > 0).astype(np.uint8), cmap='gray')
    ax_row[2].set_title("GT Lesion Mask")
    ax_row[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
all_expl_results_three_way = []
for idx in tqdm(range(len(dataset)), desc="Processing", unit="image"):
    image, true_mask = dataset[idx]
    image_resized, true_mask_np_resized = manipulate_img_and_mask(image, true_mask)

    # Convert tensor image to numpy array
    image_np_resized = image_resized.permute(1, 2, 0).numpy()
    image_np_resized = (image_np_resized * 255).astype(np.uint8)

    # Get model prediction and probability
    output = model_A(image_resized.unsqueeze(0).cuda())
    top_prob, top_class = print_top_classes(output)
    if top_prob > 0.85:
        results = visualize_all_combinations_3way(image_resized, combine_methods=['sqrt', 'multiply'], use_thresholding=False)
        if results != []:
            for name, result, mask in results:
                auccc, _ = deletion_metric(model_A, image_resized, mask)
                all_expl_results_three_way.append({
                "Image Index": idx,
                "Method": name,
                "Deletion Accuracy": auccc
                })

In [None]:
save_and_display_results1("VOCmetrics_expl_results_3WAY_NOThresholding.csv", all_expl_results_three_way)

In [None]:
import os
from PIL import Image
import numpy as np
import torch
from torchvision import transforms

all_results_three_way = []

for idx, (image_data, lesion_data) in tqdm(enumerate(dataset)):
    image_resized, true_mask_np_resized = manipulate_img_and_mask(image_data, lesion_data)
    output = model_A(image_resized.unsqueeze(0).cuda())
    top_prob, top_class = print_top_classes(output)
    
    if top_prob > 0.85:
        results = visualize_all_combinations_3way(image_resized, combine_methods=['sqrt', 'multiply'], use_thresholding=True)
        if results != []:
            for name, result, mask in results:
                predicted_mask_np = (mask > 0.5).astype(np.uint8)
                true_mask_resized = cv2.resize(true_mask_np_resized, (224, 224), interpolation=cv2.INTER_NEAREST)
    
                # Flatten the masks for metric calculation
                true_mask_flat = true_mask_resized.flatten()
                predicted_mask_flat = predicted_mask_np.flatten()
    
                jaccard = jaccard_score(true_mask_flat, predicted_mask_flat)
                f1 = f1_score(true_mask_flat, predicted_mask_flat)
                pixel_accuracy = accuracy_score(true_mask_flat, predicted_mask_flat)
    
                # Store metrics
                all_results_three_way.append({
                    "Image Index": idx,
                    "Method": name,
                    "Jaccard Index (IoU)": jaccard,
                    "F1 Score": f1,
                    "Pixel Accuracy": pixel_accuracy
                })

In [None]:
save_and_display_results("VOCmetrics_results_3WAY.csv", all_results_three_way)