## Requirements

### Imports

In [3]:
import os
import torch
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from captum.attr import IntegratedGradients, Saliency, GradientShap, Occlusion
from captum.attr import visualization as viz

In [4]:
import training

In [None]:
torch.manual_seed(42)

### Paths

In [6]:
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent

DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')

MODEL_DIR = os.path.join(ROOT_DIR, 'models')
BASELINE_DIR = os.path.join(MODEL_DIR, 'baselines_finetuned')
RESULT_DIR = os.path.join(BASELINE_DIR, 'results')

### Load Data

In [7]:
BATCH_SIZE = 32

In [8]:
CLASS_NAMES = {
    0: "amanita",
    1: "boletus",
    2: "chantelle",
    3: "deterrimus",
    4: "rufus",
    5: "torminosus",
    6: "aurantiacum",
    7: "procera",
    8: "involutus",
    9: "russula"
}

In [9]:
train_loader, val_loader, test_loader = training.get_data_loaders(PREPROCESSED_DIR, CSV_PATH, BATCH_SIZE)

### Training

##### Training config

In [10]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NUM_CLASSES = 10 
EPOCHS = 20
PATIENCE = 3
LEARNING_RATE = 0.0001

SCHEDULER = 'StepLR' # StepLR # OneCycleLR # None

##### Load model to explain

In [11]:
model_type = 'alexnet'  

# alexnet # resnet # vgg16 # densenet # efficientnet
# custom_alexnet custom_resnet 

In [None]:
model = training.load_model_for_explaining(model_type, NUM_CLASSES, DEVICE, finetuned=True)

##### Explain model

In [15]:
def explain_with_integrated_gradients(model, input_tensor, target_class, device):
    model.eval()
    input_tensor = input_tensor.to(device)
    
    ig = IntegratedGradients(model)
    attributions, _ = ig.attribute(input_tensor, target=target_class, return_convergence_delta=True)
    return attributions.cpu().detach().numpy()


def explain_with_saliency(model, input_tensor, target_class, device):
    model.eval()
    input_tensor = input_tensor.to(device)
    
    saliency = Saliency(model)
    attributions = saliency.attribute(input_tensor, target=target_class)
    return attributions.cpu().detach().numpy()


def explain_with_gradient_shap(model, input_tensor, target_class, device):
    model.eval()
    input_tensor = input_tensor.to(device)
    
    baseline = torch.zeros_like(input_tensor).to(device)
    gradient_shap = GradientShap(model)
    attributions = gradient_shap.attribute(input_tensor, baselines=baseline, target=target_class)
    return attributions.cpu().detach().numpy()


def explain_with_occlusion(model, input_tensor, target_class, device):
    model.eval()
    input_tensor = input_tensor.to(device)
    
    occlusion = Occlusion(model)
    sliding_window_shape = (input_tensor.shape[1], 15, 15)  
    attributions = occlusion.attribute(input_tensor, target=target_class, sliding_window_shapes=sliding_window_shape)
    return attributions.cpu().detach().numpy()


def visualize_attributions(attributions, original_image, class_name):
    attributions = np.transpose(attributions[0], (1, 2, 0))
    original_image = np.transpose(original_image[0], (1, 2, 0))
    _ = viz.visualize_image_attr(attributions, original_image, method="blended_heat_map", sign="positive",
                                 show_colorbar=True, title=f"Explanation for class: {class_name}")


def generate_explanations(model, input_tensor, target_class, class_names, device):
    original_image = input_tensor.cpu().detach().numpy()
    target_class_name = class_names[target_class]

    ig_attr = explain_with_integrated_gradients(model, input_tensor, target_class, device)
    saliency_attr = explain_with_saliency(model, input_tensor, target_class, device)
    gradient_shap_attr = explain_with_gradient_shap(model, input_tensor, target_class, device)
    occlusion_attr = explain_with_occlusion(model, input_tensor, target_class, device)

    plt.figure(figsize=(20, 10))

    plt.subplot(2, 2, 1)
    visualize_attributions(ig_attr, original_image, f"IG: {target_class_name}")

    plt.subplot(2, 2, 2)
    visualize_attributions(saliency_attr, original_image, f"Saliency: {target_class_name}")

    plt.subplot(2, 2, 3)
    visualize_attributions(gradient_shap_attr, original_image, f"Gradient SHAP: {target_class_name}")

    plt.subplot(2, 2, 4)
    visualize_attributions(occlusion_attr, original_image, f"Occlusion: {target_class_name}")

    plt.show()


In [None]:
for batch in test_loader:
    inputs, labels = batch 
    input_tensor = inputs[0].unsqueeze(0)  
    target_class = labels[0].item() 
    break  

generate_explanations(model, input_tensor, target_class, CLASS_NAMES, DEVICE)