# Sheet to prototype and visualize explanation metrics. 

There are several steps in this workflow. 
   - Function to compare sensitivity: in both, a small change should result in a small change, and a large change should result in a large change in the explanation.
   - Random Noise
   - Gaussian blur

   
   - Function to compare faithfullness (to be completed)




In [5]:
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


def transform_batch_of_images(images):
    """Apply standard transformation to the batch of images."""
    # normalise the image to be in the right range
    normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
    # convert image to tensor
    to_tensor_transform = transforms.ToTensor()
    return normalize_transform(to_tensor_transform(images))

def get_data_imagenette(path = "/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/data/Imagenette"):
    # Define transformations
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])

    # Load Imagenette dataset
    imagenette_train = datasets.Imagenette(
        root=path,  # Specify the directory to store the dataset
        split='val',  # Use the validation split
        transform=transform,
        download=False  # Download the dataset if not already present
    )

    # Create a DataLoader
    val_loader = DataLoader(
        imagenette_train,
        batch_size=32,
        shuffle=True,
        num_workers=2
    )
    
    return val_loader

def get_data(path_to_data:str = '/home/charleshiggins/RL-LRP/baselines/trainVggBaselineForCIFAR10/data'):
    """Get Dataloader objects from Cifar10 dataset, with path passed in.

    Args:
        path_to_data (str): path to the data directory
    N.B. Changed to Imagenette dataset from CIFAR210
    Alternative here
    datasets.CIFAR10(root=path_to_data, train=False, 
        batch_size=64, shuffle=False, num_workers=2, 
        pin_memory=True, transforms=transforms.ToTensor()
    )
    """
    val_loader = datasets.CIFAR10(root=path_to_data, train=False, 
        batch_size=64, shuffle=False, num_workers=2, 
        pin_memory=True, transforms=transforms.ToTensor()
    )
    return val_loader

def blur_image_batch(images, kernel_size):
    """Blur the batch of images using a Gaussian kernel.

    Args:
        image (torch.Tensor): batch of images to be blurred
        kernel_size (int): size of the Gaussian kernel
    Returns:
        torch.Tensor: blurred images
    """
    
    blurred_images = torch.stack([TF.gaussian_blur(img, kernel_size=[kernel_size, kernel_size]) for img in images])
    return blurred_images

def add_random_noise_batch(images, noise_level):
    """Add random noise to the batch of images.

    Args:
        images (torch.Tensor): images to have noise added
        noise_level (float): level of noise to be added
    Returns:
        torch.Tensor: images with noise added
    """
    noise = torch.randn_like(images) * noise_level
    noisy_images = images + noise
    return noisy_images

def compute_distance_between_images(images1, images2):
    """Compute the distance between two batches of images.

    Args:
        image1 (torch.Tensor): Tensor of treated images
        image2 (torch.Tensor): Tensor of ground-truth images
    Returns:
        torch.Tensor: Tensor of distances between the two images
    """
    # Flatten the images to compute cosine similarity
    images1_flat = images1.view(images1.size(0), -1)
    images2_flat = images2.view(images2.size(0), -1)
    
    # Compute cosine similarity and convert to cosine distance
    cosine_similarity = F.cosine_similarity(images1_flat, images2_flat)
    cosine_distance = 1 - cosine_similarity  # Convert similarity to distance
    return cosine_distance


def visulise_panel_image(image, model, kernel_size, noise_level):
    """Visualise the panel of images for the model.

    Args:
        image (torch.Tensor): image to be visualised
        model (torch.nn.Module): model to be visualised
        kernel_size (int): size of the Gaussian kernel
        noise_level (float): level of noise to be added
    """
    import matplotlib.pyplot as plt

def visualise_panel_image(image, model, kernel_size, noise_level):
    """Visualise the panel of images for the model."""
    # Assume the image tensor is already in batch format, if not, unsqueeze it
    if image.dim() == 3:
        image = image.unsqueeze(0)
    
    original_image = image
    # treated images
    blurred_small = blur_image_batch(image, kernel_size_min)
    blurred_large = blur_image_batch(image, kernel_size_max)
    noisy_small = add_random_noise_batch(image, noise_level_min)
    noisy_large = add_random_noise_batch(image, noise_level_max)
    
    # model outputs
    original_heatmap = condense_to_heatmap(method(preprocess_images(image), label, model)).detach()
    blurred_small_heatmap = condense_to_heatmap(method(preprocess_images(blurred_small), label, model)).detach()
    blurred_large_heatmap = condense_to_heatmap(method(preprocess_images(blurred_large),label,  model)).detach()
    noisy_small_heatmap = condense_to_heatmap( method(preprocess_images(noisy_small), label, model)).detach()
    noisy_large_heatmap = condense_to_heatmap(method(preprocess_images(noisy_large), label, model)).detach()
    
    # Display images
    fig, ax = plt.subplots(2, 5, figsize=(15, 5))
    ax[0][0].imshow(original_image.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][0].set_title('Original Image')
    ax[0][1].imshow(blurred_small.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][1].set_title('Small Blurred Image')
    ax[0][2].imshow(blurred_large.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][2].set_title('Large Blurred Image')
    ax[0][3].imshow(noisy_small.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[0][3].set_title('Small Noisy Image')
    ax[0][4].imshow(noisy_large.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[0][4].set_title('Large Noisy Image')
    
    ax[1][0].imshow(original_heatmap.squeeze(0), cmap='hot')
    ax[1][0].set_title('Original Heatmap')
    ax[1][1].imshow(blurred_small_heatmap.squeeze(0), cmap='hot')
    ax[1][1].set_title('Small Blurred Heatmap')
    ax[1][2].imshow(blurred_large_heatmap.squeeze(0), cmap='hot')
    ax[1][2].set_title('Large Blurred Heatmap')
    ax[1][3].imshow(noisy_small_heatmap.squeeze(0), cmap = 'hot')  # Example visualization
    ax[1][3].set_title('Small Noisy Heatmap')
    ax[1][4].imshow(noisy_large_heatmap.squeeze(0), cmap = 'hot')  # Example visualization
    ax[1][4].set_title('Large Noisy Heatmap')
    
    for i in ax:
        for j in i:
            j.axis('off')
    plt.show()

In [6]:
# workflow for the visualisation and data analsysis
import torch
import pandas as pd
# Load data
# generate the blurred images, noisy images, and the ground truth heatmap images
# then for each, calculate the distance between the heatmaps over the blurred images and the ground truth heatmap images
def process_batch(
    input_batch:torch.Tensor, 
    input_labels:torch.Tensor,  
    methods: list, 
    kernel_size_min: float, 
    kernel_size_max:float, 
    noise_level_min: float, 
    noise_level_max: float):
    """Process the batch of images.

    Args:
        model (torch.nn.Module): model to be visualised
        methods (list): list of methods(functions) to be used on each datapoint of form (name, method, model)
        kernel_size (int): size of the Gaussian kernel
        noise_level (float): level of noise to be added
    Returns:
        dict: dictionary of distances between the heatmaps
    """
    results_dictionary = {}
    for name, method, model in methods:
        # get the ground truth heatmap using the method
        ground_truth_heatmap = method(input_batch, input_labels, model)
        # treat various images to get the noisy and blurred images
        # run preprecoessing on the images --- normalise them to be within the right range
        noisy_images_small = preprocess_images(add_random_noise_batch(input_batch, noise_level_min))
        noisy_images_large = preprocess_images(add_random_noise_batch(input_batch, noise_level_max))
        blurred_images_small = preprocess_images(blur_image_batch(input_batch, kernel_size_min))
        blurred_images_large = preprocess_images(blur_image_batch(input_batch, kernel_size_max))
        # generate the new heatmaps for each
        noisy_heatmaps_small = method(noisy_images_small, input_labels, model)
        noisy_heatmaps_large = method(noisy_images_large, input_labels, model)
        blurred_heatmaps_small = method(blurred_images_small, input_labels, model)
        blurred_heatmaps_large = method(blurred_images_large, input_labels, model)
        # calculate the distance between the heatmaps
        distance_noise_small = compute_distance_between_images(ground_truth_heatmap, noisy_heatmaps_small)
        distance_noise_large = compute_distance_between_images(ground_truth_heatmap, noisy_heatmaps_large)
        distance_blur_small = compute_distance_between_images(ground_truth_heatmap, blurred_heatmaps_small)
        distance_blur_large = compute_distance_between_images(ground_truth_heatmap, blurred_heatmaps_large)
        # calculate sparseness of heatmap
        sparseness_original, sparseness_gini = compute_sparseness_of_heatmap(ground_truth_heatmap)
        # store the results in the dictionary
        results_dictionary[f"{name}_distance_noise_small"] = distance_noise_small
        results_dictionary[f"{name}_distance_noise_large"] = distance_noise_large
        results_dictionary[f"{name}_distance_blur_small"] = distance_blur_small
        results_dictionary[f"{name}_distance_blur_large"] = distance_blur_large
        results_dictionary[f"{name}_sparseness_original"] = sparseness_original
        results_dictionary[f"{name}_sparseness_gini"] = sparseness_gini
    # return data
    return results_dictionary

    
        
    

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [7]:
import torch
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP')
from experiments import WrapperNet
from captum.attr import GuidedGradCam

def preprocess_images(image_batch):
    """Preprocess the image.

    Args:
        image (torch.Tensor): image to be preprocessed
    Returns:
        torch.Tensor: preprocessed image
    """
    if isinstance(image_batch, torch.Tensor) and image_batch.dim() == 4:
        # normalise the images
        normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
        images = normalize_transform(image_batch)
        return images
    else:
        print("something went wrong in preprocessing images")
        print(f"image batch is of shape {image_batch.shape}")
        raise ValueError(f"Input must be a tensor of images -- unknown format {type(image_batch)} and dimension {image_batch.dim()}")
        

def perform_lrp_plain(image, label, model):
    """Perform LRP on the image.

    Args:
        image (torch.Tensor): Tensor of images to be explained
        labels (torch.Tensor): labels of the image (i.e. the class)
        model (torch.nn.Module): model to be visualised
    Returns:
        torch.Tensor: heatmaps of the image
    """
    assert isinstance(model, WrapperNet), "Model must be a WrapperNet for LRP"
    class_idx, output = model(image, label)
    return output

def perform_loss_lrp(image, label, model):
    """Perform LRP on the image using the loss.

    Args:
        image (torch.Tensor): Tensor of images to be explained
        labels (torch.Tensor): labels of the image (i.e. the class)
        model (torch.nn.Module): model to be visualised
    Returns:
        torch.Tensor: heatmaps of the image
    """
    assert isinstance(model, WrapperNet), "Model must be a WrapperNet for LossLRP"
    class_idx, output = model(image, label)
    return output


def get_input_output_layers(model):
    """
    Gets the first and last convolutional layers of the model for GradCam
    
    Args:
    - model: The PyTorch model
    
    Returns:
    - input_layer: The first convolutional layer
    - output_layer: The last convolutional layer
    """
    layers = list(model.modules())
    conv_layers = [layer for layer in layers if isinstance(layer, torch.nn.Conv2d)]
    
    if not conv_layers:
        raise ValueError("The model does not contain any Conv2d layers.")
    
    input_layer = conv_layers[0]
    output_layer = conv_layers[-1]
    
    return input_layer, output_layer

def perform_gradcam(images, labels, model):
    """Perform GradCAM on the image.

    Args:
        image (torch.Tensor): Tensor of images to be explained
        labels (torch.Tensor): labels of the image (i.e. the class)
        model (torch.nn.Module): model to be visualised
    Returns:
        torch.Tensor: heatmaps of the image
    """
    # Ensure the model is in evaluation mode
    model.eval()
    
    # Get the input and output layers
    input_layer, output_layer = get_input_output_layers(model)
    
    # Create a LayerGradCam object
    layer_gc = GuidedGradCam(model, output_layer)
    
    # Compute GradCAM attributions
    attributions = layer_gc.attribute(images, target=labels)
    
    return attributions



In [8]:
import torchvision
from tqdm.notebook import tqdm
def get_learner_model():
    """Get the learner model."""
    pass

def get_teacher_model():
    """ Load and return a pretrained VGG16 model from TorchVision"""
    # Load the pretrained VGG16 model
    model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
    
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    # Set the model to evaluation mode
    model.eval()
    
    # Move the model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    return model

def main():
    # define params
    kernel_size_min = 3
    kernel_size_max = 5
    noise_level_min = 0.1
    noise_level_max = 0.2
    # get the data
    data_loader = get_data_imagenette()
    # get the model
    # learner_model = get_learner_model()
    teacher_model = get_teacher_model()
    # define the methods
    methods = [
        ("LRP", perform_lrp_plain, WrapperNet(teacher_model, hybrid_loss=True)),
        # ("LossLRP", perform_loss_lrp, learner_model),
        ("GradCAM", perform_gradcam, teacher_model),
    ]
    # process the data
    table = {}
    for i, (input_batch, input_labels) in tqdm(enumerate(data_loader)):
        results = process_batch(
            input_batch, 
            input_labels, 
            methods, 
            kernel_size_min, 
            kernel_size_max, 
            noise_level_min, 
            noise_level_max
        )
        # print the results
        # print(f"Batch {i} results: {results}")
        for key, value in results.items():
            if key not in table.keys():
                table[key] = value
            else:
                table[key] = torch.cat([table[key], value], dim = 0)
    # convert to pandas dataframe
    df = pd.DataFrame(table)
    # save results
    df.to_csv("test_results.csv")
    

In [9]:
main()

  return torch.nn.functional.log_softmax(y), relevance


<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])




<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
<class 'torch.Tensor'>
<class 'torch.Tensor'>
torch.Size([32, 3, 224, 224])
torch.Size([32, 3, 224, 224])
LRP_distance_noise_small
<class 'torch.Tensor'>
LRP_distance_noise_large
<class 'torch.Tensor'>
LRP_distance_blur_small
<class 'torch.Tensor'>
LRP_distance_blur_large
<class 'torch.Tensor'>
LRP_sparseness_original
<class 'torch.Tensor'>
LRP_sparseness_gini
<class 'torch.Tensor'>
GradCAM_distance_noise_small
<class 'torch.Tensor'>
GradCAM_distance_noise_large
<class 'torch.Tensor'>
GradCAM_distance_blur_small
<class 'torch.Tensor'>
GradCAM_distance_blur_large
<class 'torch.Tensor'>
GradCAM_sparseness_original
<class 'torch.Tensor'>
GradCAM_sparseness_gini
<class 'torch.T

KeyboardInterrupt: 