In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from src.network import VGGModel
from src.dataset import MNISTDataset
from src.utils import train, evaluate, evaluate_per_class

In [None]:
# Model backbone
device = "cpu"
vgg16 = VGGModel(num_classes=10, device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg16.model.classifier.parameters(), lr=1e-4)  # Only train classifier

# Dataset
dataset = MNISTDataset()

In [None]:
# Comment/Uncomment if you do not want/want to train the model
for e in range(10):
      print(f"Epoch {e}...")
      train(vgg16.model,
            dataset.train_loader,
            criterion, 
            optimizer, 
            device)
      evaluate(vgg16.model,
               dataset.test_loader,
               device)
      
# Save the trained model
torch.save(vgg16.model.state_dict(), "./models/vgg16_mnist.pth")


In [None]:
# Before running the explainers, create the result folders
import os

folders = ['./results/MNIST/LIME/images/', 
           './results/MNIST/SHAP/images/', 
           './results/MNIST/LRP/images/']

for folder in folders:
    if not os.path.exists(folder):
        os.makedirs(folder)
        print(f"Created: {folder}")
    else:
        print(f"Already exists: {folder}")
        
# Load a pre-trained model, if available
vgg16.model.load_state_dict(torch.load("./models/vgg16_mnist.pth", weights_only=True))
vgg16.model.eval()

# Check the accuracy on the test set
evaluate_per_class(vgg16.model,
                    dataset.test_loader,
                    device)

### LIME

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random
from lime import lime_image
from skimage.segmentation import mark_boundaries

explainer = lime_image.LimeImageExplainer()

# size of the image for the model
def to_model_size(x):
    if x.ndim == 3:
        return torch.tensor(x)[None, :, :, :].permute(0, 3, 1, 2) 
    return torch.tensor(x).permute(0, 3, 1, 2)  # numpy to torch

def to_explainer_size(x):
    return x[0].permute(1, 2, 0).cpu().numpy()  # torch to numpy

# Wrap the model
def predict_fn(x):
    """Works with numpy images for compatibility with LIME"""
    global vgg16, to_explainer_size
    x_model = to_model_size(x).to(device)  # numpy to torch
    probs = vgg16.model(x_model)
    probs_numpy = probs.detach().cpu().numpy()
    return probs_numpy

# Pick images from the test set and compute the LIME explanation
reductions = []
for idx in range(100):
    print(f"Recursive explanation for Image {idx}...")
    # 1. Pick a random image and store input, label, and prediction
    image, y = dataset.test_dataset[idx]
    image = image[None, :, :, :].to(device)  # expand dims and send to device
    
    prediction = int(torch.argmax(vgg16.model(image)))
    predictions = [prediction]
    correctness = (y == prediction)

    # 2. Loop to find the fixed point
    iteration = 0
    prev_image_rgb_first = None
    while True:
        # Convert grayscale to RGB for LIME
        image_rgb = np.stack([image[0, 0, :, :]] * 3, axis=0)
        image_rgb = np.moveaxis(image_rgb, 0, -1)
        
        prediction = vgg16.model(torch.tensor(image_rgb).permute(2, 1, 0)[None, :, :, :])
        predictions.append(int(torch.argmax(prediction)))

        explanation = explainer.explain_instance(
            image_rgb,
            classifier_fn=predict_fn,
            top_labels=5,
            hide_color=0,
            num_samples=1000,
            progress_bar=False
        )

        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=True,
            negative_only=False,
            num_features=50,
            hide_rest=True
        )

        # Ensure float values are in [0,1]
        img_boundry2 = mark_boundaries(temp.astype(np.float32), mask)

        # Plot
        fig, axes = plt.subplots(1, 3, figsize=(8, 4))
        axes[0].imshow(image_rgb, cmap='gray')
        axes[0].set_title("Input")
        axes[0].axis('off')
        
        axes[1].imshow(img_boundry2, cmap='gray')
        axes[1].set_title("Boundary")
        axes[1].axis('off')

        image_rgb_first = image_rgb[:,:,0].flatten()
        mask = mask.flatten()
        image_rgb_first[mask > 0] = 0.
        image_rgb_first = image_rgb_first.reshape((32, 32))
        
        axes[2].imshow(image_rgb_first, cmap='gray')
        axes[2].set_title("Fixed Point Iteration")
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(f"./results/MNIST/LIME/images/image-{idx}-recursive-iteration{iteration}.png")
        iteration += 1
        
        if np.equal(image_rgb_first, prev_image_rgb_first).all() and prev_image_rgb_first is not None:
            break
        prev_image_rgb_first = image_rgb_first
        image = image_rgb_first[None, None, :, :]
    
    image_flatten = dataset.test_dataset[idx][0].detach().numpy().flatten()
    non_zero_pixels_input = len(image_flatten[image_flatten > 0])
    image_rgb_flattened = image_rgb_first.flatten()
    non_zero_pixels_result = len(image_rgb_flattened[image_rgb_flattened > 0])
    reduction = non_zero_pixels_result / non_zero_pixels_input
    print(f"Reduction: {reduction}")
    reductions.append(reduction)
        
    with open("./results/MNIST/LIME/results.txt", "a") as f:
        f.write(f"Image {idx}:\n")
        f.write(f"\tPredictions: {predictions}\n")
        f.write(f"\tConsistency: {len(set(predictions))==1}\n")
        f.write(f"\tCorrectness: {correctness}\n")
        f.write(f"\tGround truth: {y}\n")
        f.write("\n")
    
print(f"{np.mean(reductions)} \pm {np.std(reductions)}")

### SHAP

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import shap

X_test = torch.stack(([x for x,_ in dataset.test_dataset]))
y_test = [y for _,y in dataset.test_dataset]

explainer = shap.GradientExplainer(vgg16.model, X_test)

reductions = []
for idx in range(100):
    x = X_test[idx].reshape(1, 3, 32, 32)
    y = y_test[idx]
    
    prediction = int(torch.argmax(vgg16.model(x)))
    predictions = [prediction]
    correctness = (y == prediction)

    X_single = x[:,0,:,:]
    shap_values = explainer.shap_values(x)
    shap_values = shap_values[:,0,:,:,y:y+1]

    # Negative indices
    idx_neg_shap = np.argwhere(shap_values[0,:,:,0].flatten() < 0)[:,0].tolist()

    # # Plot SHAP heatmaps for all classes for a single image
    shap.image_plot(shap_values, 
                    np.array(X_single), 
                    save=f"./results/MNIST/SHAP/images/image-{idx}-original-iteration-0.png")  # original SHAP

    x_control_value = np.array(x.clone()[0,0,:,:].flatten())
    x_control_value_prev = None
    idx_neg_shap = np.argwhere(shap_values[0,:,:,0].flatten() < 0)[:,0].tolist()
    iteration = 1
    while True:
        
        # Modify the input image
        for i in idx_neg_shap:
            x[0,0,i//32,i%32] = 0.0
            
        if x_control_value_prev is not None and np.equal(x_control_value, x_control_value_prev).all():
            break
        
        # Model prediction
        prediction = int(torch.argmax(vgg16.model(x)))
        predictions.append(prediction)
        
        # Compute SHAP values
        X_single = x[:,0,:,:]
        shap_values = explainer.shap_values(x)
        shap_values = shap_values[:,0,:,:,y:y+1]

        # Negative indices
        idx_neg_shap = np.argwhere(shap_values[0,:,:,0].flatten() < 0)[:,0].tolist()
        
        # Clone the control value
        x_control_value_prev = x_control_value.copy()
        x_control_value = np.array(x.clone()[0,0,:,:].flatten())

        # # Plot SHAP heatmaps for all classes for a single image
        shap.image_plot(shap_values, 
                np.array(X_single), 
                save=f"./results/MNIST/SHAP/images/image-{idx}-original-iteration-{iteration}.png")  # original SHAP
        
        iteration += 1
        
    image_flatten = dataset.test_dataset[idx][0].detach().numpy().flatten()
    non_zero_pixels_input = len(image_flatten[image_flatten > 0])
    image_rgb_flattened = x.flatten()
    non_zero_pixels_result = len(image_rgb_flattened[image_rgb_flattened > 0])
    reduction = non_zero_pixels_result / non_zero_pixels_input
    print(f"Reduction: {reduction}")
    reductions.append(reduction)
    
        
    with open("./results/MNIST/SHAP/results.txt", "a") as f:
        f.write(f"Image {idx}:\n")
        f.write(f"\tPredictions: {predictions}\n")
        f.write(f"\tConsistency: {len(set(predictions))==1}\n")
        f.write(f"\tCorrectness: {correctness}\n")
        f.write(f"\tGround truth: {y}\n")
        f.write("\n")
        
print(f"{np.mean(reductions)} \pm {np.std(reductions)}")

### LRP

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import sys
from torchvision import transforms
from torchvision.models import vgg16 as vgg16pretrained  # avoid collision with vgg16.model (our pre-trained model)
from torchvision.models import VGG16_Weights

sys.path.append('./PyTorchRelevancePropagation/')
from PyTorchRelevancePropagation.src.lrp import LRPModel

vgg_pretrained = vgg16pretrained(weights=VGG16_Weights.DEFAULT)
lrp_model = LRPModel(vgg_pretrained)

transform_lrp = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

transform_model = transforms.Compose([
    transforms.Resize((32, 32)),     # VGG16 expects 224x224
    transforms.ToTensor()
])

In [None]:
vgg16.model.to("cpu")
img_shape = (3, 32, 32)

reductions = []
for idx in range(100):
    
    # Pick a random image from MNIST and pass it to the LRP model
    print(f"Computing fpe for Image[{idx}]...{99 - idx} left")
    predictions = []
    random_index = idx
    x, y = dataset.test_dataset[random_index]
    x = x.view(1, *x.shape)
    explanation = lrp_model.forward(x)
    
    # Compute the label for the model
    x_model = dataset.test_dataset[random_index][0].view(1, *img_shape)
    _, predicted = torch.max(vgg16.model(x_model).data, 1)
    predicted = int(predicted)
    predictions.append(predicted)
    correctness = (y == predicted)

    # Show the initial image and the explanation
    # Create subplots: 1 row, 2 columns
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    # Left: original input image
    axes[0].imshow(x[0].permute(1, 2, 0).detach().numpy())
    axes[0].set_title("Input Image")
    axes[0].axis('off')

    # Right: explanation (e.g. saliency map, activation map)
    axes[1].imshow(explanation.detach().numpy())
    axes[1].set_title("Explanation")
    axes[1].axis('off')

    plt.tight_layout()
    plt.savefig(f"./results/MNIST/LRP/images/image-{idx}-original.png")

    iteration = 1
    # Fixed point explanation
    while True:
        
        # Any pixel in the explanation whose value is lower than the mean is reduced
        x = x.view(1, 3, 32, 32)
        for i in range(3):
            x[:, i] *= torch.nn.functional.normalize(explanation)  # normalise the explanation and use it as a mask
            x_model = transform_model(transforms.ToPILImage()(x[0]))
            x_model = transforms.Resize((32, 32))(x_model)     # VGG16 expects 224x224
        
        # Get the new explanation
        x = x.view(1, 3, 32, 32)
        new_explanation = lrp_model.forward(x)
        predicted = int(torch.max(vgg16.model(x_model.view(1, *img_shape)).data, 1)[1])
        predictions.append(predicted)
        
        # # Fixed-point Explanations
        fig, axes = plt.subplots(1, 2, figsize=(8, 4))

        # Left: original input image
        axes[0].imshow(x[0].permute(1, 2, 0).detach().numpy())
        axes[0].set_title("Fixed Point Input Image")
        axes[0].axis('off')

        # Right: explanation (e.g. saliency map, activation map)
        axes[1].imshow(new_explanation.detach().numpy())
        axes[1].set_title("Fixed Point Explanation")
        axes[1].axis('off')

        plt.tight_layout()
        plt.savefig(f"./results/MNIST/LRP/images/image-{idx}-recursive-iteration-{iteration}.png")

        if torch.equal(explanation, new_explanation):
            break

        # Update the previous explanation
        explanation = new_explanation
        iteration += 1
        
    image_flatten = dataset.test_dataset[random_index][0].detach().numpy().flatten()
    non_zero_pixels_input = len(image_flatten[image_flatten > 0])
    image_rgb_flattened = x_model.flatten()
    non_zero_pixels_result = len(image_rgb_flattened[image_rgb_flattened > 0])
    reduction = non_zero_pixels_result / non_zero_pixels_input
    print(f"Reduction: {reduction}")
    reductions.append(reduction)

    with open("./results/MNIST/LRP/results.txt", "a") as f:
        f.write(f"Image {idx}:\n")
        f.write(f"\tPredictions: {predictions}\n")
        f.write(f"\tConsistency: {len(set(predictions))==1}\n")
        f.write(f"\tCorrectness: {correctness}\n")
        f.write(f"\tGround truth: {y}\n")
        f.write("\n")
    
print(f"{np.mean(reductions)} \pm {np.std(reductions)}")