In [None]:
import sys
sys.path.append('..')
from experiments import perform_gradcam, perform_lrp_captum
from internal_utils import preprocess_images, condense_to_heatmap, blur_image_batch, add_random_noise_batch, get_data_imagenette, get_teacher_model
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms.functional as TF

def visualise_panel_image(image, model, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, method, label):
    """Visualise the panel of images for the model."""
    # Assume the image tensor is already in batch format, if not, unsqueeze it
    if image.dim() == 3:
        image = image.unsqueeze(0)
    
    original_image = image
    # treated images
    blurred_small = blur_image_batch(image, kernel_size_min)
    blurred_large = blur_image_batch(image, kernel_size_max)
    noisy_small = add_random_noise_batch(image, noise_level_min)
    noisy_large = add_random_noise_batch(image, noise_level_max)
    
    # model outputs
    original_heatmap = condense_to_heatmap(method(preprocess_images(image), label, model)).detach()
    blurred_small_heatmap = condense_to_heatmap(method(preprocess_images(blurred_small), label, model)).detach()
    blurred_large_heatmap = condense_to_heatmap(method(preprocess_images(blurred_large),label,  model)).detach()
    noisy_small_heatmap = condense_to_heatmap( method(preprocess_images(noisy_small), label, model)).detach()
    noisy_large_heatmap = condense_to_heatmap(method(preprocess_images(noisy_large), label, model)).detach()
    
    # Display images
    fig, ax = plt.subplots(2, 5, figsize=(15, 5))
    ax[0][0].imshow(original_image.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][0].set_title('Original Image')
    ax[0][1].imshow(blurred_small.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][1].set_title('Small Blurred Image')
    ax[0][2].imshow(blurred_large.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][2].set_title('Large Blurred Image')
    ax[0][3].imshow(noisy_small.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[0][3].set_title('Small Noisy Image')
    ax[0][4].imshow(noisy_large.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[0][4].set_title('Large Noisy Image')
    
    ax[1][0].imshow(original_heatmap.squeeze(0), cmap='seismic')
    ax[1][0].set_title('Original Heatmap')
    ax[1][1].imshow(blurred_small_heatmap.squeeze(0), cmap='seismic')
    ax[1][1].set_title('Small Blurred Heatmap')
    ax[1][2].imshow(blurred_large_heatmap.squeeze(0), cmap='seismic')
    ax[1][2].set_title('Large Blurred Heatmap')
    ax[1][3].imshow(noisy_small_heatmap.squeeze(0), cmap ='seismic')  # Example visualization
    ax[1][3].set_title('Small Noisy Heatmap')
    ax[1][4].imshow(noisy_large_heatmap.squeeze(0), cmap ='seismic')  # Example visualization
    ax[1][4].set_title('Large Noisy Heatmap')
    
    for i in ax:
        for j in i:
            j.axis('off')
    plt.show()

In [None]:
data = get_data_imagenette()
input_images, labels = next(iter(data))
model = get_teacher_model()
# define params
kernel_size_min = 1
kernel_size_max = 7
noise_level_min = 0.1
noise_level_max = 0.2


In [None]:

image, label = input_images[8], labels[8]

visualise_panel_image(image, model, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, perform_gradcam, label)

In [None]:
visualise_panel_image(image, model, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, perform_lrp_captum, label)

In [None]:
import torch
from experiments import perform_lrp_plain, WrapperNet
visualise_panel_image(image, WrapperNet(model, hybrid_loss=True), kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, perform_lrp_plain, label.unsqueeze(0))

In [None]:
import torch

checkpoint = torch.load("/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/data/trained_CIFAR_models/checkpoint_299_2024-07-21_17-27-55.tar", map_location=torch.device('cpu'))

In [None]:
checkpoint['state_dict'].keys()

In [None]:
import baselines.trainVggBaselineForCIFAR10.vgg as vgg
model = WrapperNet(vgg.__dict__['vgg11'](), hybrid_loss=True)
from internal_utils import update_dictionary_patch
checkpoint = update_dictionary_patch(checkpoint)
model.load_state_dict(checkpoint['new_state_dict'])

In [None]:
model = model.eval()
model

In [None]:
from internal_utils import get_CIFAR10_dataloader
data_loader = get_CIFAR10_dataloader(train=False)
data, target = next(iter(data_loader))


In [None]:
image, label = data[10], target[10]
visualise_panel_image(image, model, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, perform_lrp_plain, label.unsqueeze(0))

In [None]:
visualise_panel_image(image, model.model, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, perform_lrp_captum, label)