# Sheet to prototype and visualize explanation metrics. 

There are several steps in this workflow. 
   - Function to compare sensitivity: in both, a small change should result in a small change, and a large change should result in a large change in the explanation.
   - Random Noise
   - Gaussian blur

   
   - Function to compare faithfullness (to be completed)




In [None]:
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch
import matplotlib.pyplot as plt


def transform_batch_of_images(images):
    """Apply standard transformation to the batch of images."""
    # normalise the image to be in the right range
    normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
    # convert image to tensor
    to_tensor_transform = transforms.ToTensor()
    return normalize_transform(to_tensor_transform(images))

def get_data(path_to_data:str = '/home/charleshiggins/RL-LRP/baselines/trainVggBaselineForCIFAR10/data'):
    """Get Dataloader objects from Cifar10 dataset, with path passed in.

    Args:
        path_to_data (str): path to the data directory
    """
    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root=path_to_data, train=False),
        batch_size=64, shuffle=False,
        num_workers=2, pin_memory=True, transforms=transforms.ToTensor()
    )
    return val_loader

def blur_image_batch(images, kernel_size):
    """Blur the batch of images using a Gaussian kernel.

    Args:
        image (torch.Tensor): batch of images to be blurred
        kernel_size (int): size of the Gaussian kernel
    Returns:
        torch.Tensor: blurred images
    """
    
    blurred_images = torch.stack([TF.gaussian_blur(img, kernel_size=[kernel_size, kernel_size]) for img in images])
    return blurred_images

def add_random_noise_batch(images, noise_level):
    """Add random noise to the batch of images.

    Args:
        images (torch.Tensor): images to have noise added
        noise_level (float): level of noise to be added
    Returns:
        torch.Tensor: images with noise added
    """
    noise = torch.randn_like(images) * noise_level
    noisy_images = images + noise
    return noisy_images

def compute_distance_between_images(images1, images2):
    """Compute the distance between two batches of images.

    Args:
        image1 (torch.Tensor): Tensor of treated images
        image2 (torch.Tensor): Tensor of ground-truth images
    Returns:
        torch.Tensor: Tensor of distances between the two images
    """
    # Flatten the images to compute cosine similarity
    images1_flat = images1.view(images1.size(0), -1)
    images2_flat = images2.view(images2.size(0), -1)
    
    # Compute cosine similarity and convert to cosine distance
    cosine_similarity = F.cosine_similarity(images1_flat, images2_flat)
    cosine_distance = 1 - cosine_similarity  # Convert similarity to distance
    return cosine_distance


def visulise_panel_image(image, model, kernel_size, noise_level):
    """Visualise the panel of images for the model.

    Args:
        image (torch.Tensor): image to be visualised
        model (torch.nn.Module): model to be visualised
        kernel_size (int): size of the Gaussian kernel
        noise_level (float): level of noise to be added
    """
    import matplotlib.pyplot as plt

def visualise_panel_image(image, model, kernel_size, noise_level):
    """Visualise the panel of images for the model."""
    # Assume the image tensor is already in batch format, if not, unsqueeze it
    if image.dim() == 3:
        image = image.unsqueeze(0)
    
    blurred = blur_image_batch(image, kernel_size)
    noisy = add_random_noise_batch(image, noise_level)
    
    outputs = model(image)
    blurred_outputs = model(blurred)
    noisy_outputs = model(noisy)
    
    # Display images
    fig, ax = plt.subplots(1, 4, figsize=(15, 5))
    ax[0].imshow(image.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0].set_title('Original Image')
    ax[1].imshow(blurred.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[1].set_title('Blurred Image')
    ax[2].imshow(noisy.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[2].set_title('Noisy Image')
    ax[3].imshow(outputs.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[3].set_title('Model Output on Original')
    for a in ax:
        a.axis('off')
    plt.show()

def condense_to_heatmap(images):
    """
    Condense a batch of images to a heatmap by taking the maximum activation across channels.
    
    Args:
        images (torch.Tensor): A batch of images with dimensions (batch_size, channels, height, width).
    
    Returns:
        torch.Tensor: A tensor of heatmaps with dimensions (batch_size, height, width).
    """
    # Use torch.max to find the maximum across the channels (dim=1)
    # max function returns values and indices, so we select values using [0]
    heatmaps, _ = torch.max(images, dim=1)
    return heatmaps

def compute_sparseness_of_heatmap(input_images):
    """Compute the sparseness of the heatmap.

    Args:
        heatmap (torch.Tensor): heatmap to be computed
    Returns:
        float: sparseness of the heatmap
    """
    heatmaps = condense_to_heatmap(input_images)
    threshold = 0.01  # Define near-zero threshold
    near_zero = (heatmaps.abs() < threshold).float()
    sparseness = near_zero.mean(dim=[1, 2])  # Compute mean across spatial dimensions

    # Compute Gini coefficient for each heatmap in the batch
    batch_size, height, width = heatmaps.size()
    gini_indices = torch.empty(batch_size)
    
    for i in range(batch_size):
        values = heatmaps[i].view(-1)
        sorted_values, _ = torch.sort(values)
        n = len(values)
        cumvals = torch.cumsum(sorted_values, dim=0)
        sum_values = cumvals[-1]
        gini_index = (2 * torch.arange(1, n+1).to(heatmaps.device) * sorted_values).sum() / (n * sum_values) - (n + 1) / n
        gini_indices[i] = 1 - gini_index

    return sparseness, gini_indices


In [None]:
# workflow for the visualisation and data analsysis
import torch
# Load data
# generate the blurred images, noisy images, and the ground truth heatmap images
# then for each, calculate the distance between the heatmaps over the blurred images and the ground truth heatmap images
def process_batch(
    input_batch:torch.Tensor, 
    input_labels:torch.Tensor,  
    methods: list, 
    kernel_size_min: float, 
    kernel_size_max:float, 
    noise_level_min: float, 
    noise_level_max: float):
    """Process the batch of images.

    Args:
        model (torch.nn.Module): model to be visualised
        methods (list): list of methods(functions) to be used on each datapoint of form (name, method, model)
        kernel_size (int): size of the Gaussian kernel
        noise_level (float): level of noise to be added
    Returns:
        dict: dictionary of distances between the heatmaps
    """
    results_dictionary = {}
    for name, method, model in methods:
        # get the ground truth heatmap using the method
        ground_truth_heatmap = method(input_batch, input_labels, model)
        noisy_images_small = add_random_noise_batch(input_batch, noise_level_min)
        noisy_images_large = add_random_noise_batch(input_batch, noise_level_max)
        blurred_images_small = blur_image_batch(input_batch, kernel_size_min)
        blurred_images_large = blur_image_batch(input_batch, kernel_size_max)
        # calculate the distance between the heatmaps
        distance_noise_small = compute_distance_between_images(ground_truth_heatmap, noisy_images_small)
        distance_noise_large = compute_distance_between_images(ground_truth_heatmap, noisy_images_large)
        distance_blur_small = compute_distance_between_images(ground_truth_heatmap, blurred_images_small)
        distance_blur_large = compute_distance_between_images(ground_truth_heatmap, blurred_images_large)
        # calculate sparseness of heatmap
        sparseness = compute_sparseness_of_heatmap(ground_truth_heatmap)
        # store the results in the dictionary
        results_dictionary[f"{name}_distance_noise_small"] = distance_noise_small
        results_dictionary[f"{name}_distance_noise_large"] = distance_noise_large
        results_dictionary[f"{name}_distance_blur_small"] = distance_blur_small
        results_dictionary[f"{name}_distance_blur_large"] = distance_blur_large
        results_dictionary[f"{name}_sparseness"] = sparseness
    # return data
    return results_dictionary


def main():
    # define params
    kernel_size_min = 3
    kernel_size_max = 5
    noise_level_min = 0.1
    noise_level_max = 0.2
    # get the data
    data_loader = get_data("data")
    # get the model
    learner_model = get_learner_model()
    teacher_model = get_teacher_model()
    # define the methods
    methods = [
        ("LRP", perform_lrp_plain, teacher_model),
        ("LossLRP", perform_loss_lrp, learner_model),
        ("GradCAM", perform_gradcam, teacher_model),
    ]
    # process the data
    table = {}
    for input_batch, input_labels in data_loader:
        results = process_batch(
            input_batch, 
            input_labels, 
            methods, 
            kernel_size_min, 
            kernel_size_max, 
            noise_level_min, 
            noise_level_max
        )
        # print the results
        for key, value in results.items():
            if key not in table.keys():
                table[key] = value
            else:
                table[key] = torch.cat([table[key], value], dim = 0)
    # convert to pandas dataframe
    df = pd.DataFrame(table)
    # save results
    df.to_csv("results.csv")
    
    
        
    

In [None]:
from experiments import WrapperNet
def preprocess_image(image):
    """Preprocess the image.

    Args:
        image (torch.Tensor): image to be preprocessed
    Returns:
        torch.Tensor: preprocessed image
    """
    pass

def perform_lrp_plain(image, label, model):
    """Perform LRP on the image.

    Args:
        image (torch.Tensor): Tensor of images to be explained
        labels (torch.Tensor): labels of the image (i.e. the class)
        model (torch.nn.Module): model to be visualised
    Returns:
        torch.Tensor: heatmaps of the image
    """
    assert isinstance(model, WrapperNet), "Model must be a WrapperNet for LRP"
    output, class_idx = model(image, label)
    return output, class_idx

def perform_loss_lrp(image, label, model):
    """Perform LRP on the image using the loss.

    Args:
        image (torch.Tensor): Tensor of images to be explained
        labels (torch.Tensor): labels of the image (i.e. the class)
        model (torch.nn.Module): model to be visualised
    Returns:
        torch.Tensor: heatmaps of the image
    """
    assert isinstance(model, WrapperNet), "Model must be a WrapperNet for LossLRP"
    output, class_idx = model(image, label)
    return output, class_idx

def perform_gradcam(image, label, model):
    """Perform GradCAM on the image.

    Args:
        image (torch.Tensor): Tensor of images to be explained
        labels (torch.Tensor): labels of the image (i.e. the class)
        model (torch.nn.Module): model to be visualised
    Returns:
        torch.Tensor: heatmaps of the image
    """
    pass
