In [None]:
import sys
sys.path.append('..')
from experiments import perform_gradcam, perform_lrp_captum
from internal_utils import preprocess_images, condense_to_heatmap, blur_image_batch, add_random_noise_batch, get_data_imagenette, get_teacher_model, get_CIFAR10_dataloader
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms.functional as TF

def visualise_panel_image(image, model, kernel_size_min, kernel_size_max, noise_level_min, noise_level_max, method, label):
    """Visualise the panel of images for the model."""
    # Assume the image tensor is already in batch format, if not, unsqueeze it
    if image.dim() == 3:
        image = image.unsqueeze(0)
    
    original_image = image
    # treated images
    blurred_small = blur_image_batch(image, kernel_size_min)
    blurred_large = blur_image_batch(image, kernel_size_max)
    noisy_small = add_random_noise_batch(image, noise_level_min)
    noisy_large = add_random_noise_batch(image, noise_level_max)
    
    # model outputs
    original_heatmap = condense_to_heatmap(method(preprocess_images(image), label, model)).detach()
    blurred_small_heatmap = condense_to_heatmap(method(preprocess_images(blurred_small), label, model)).detach()
    blurred_large_heatmap = condense_to_heatmap(method(preprocess_images(blurred_large),label,  model)).detach()
    noisy_small_heatmap = condense_to_heatmap( method(preprocess_images(noisy_small), label, model)).detach()
    noisy_large_heatmap = condense_to_heatmap(method(preprocess_images(noisy_large), label, model)).detach()
    
    # Display images
    fig, ax = plt.subplots(2, 5, figsize=(15, 5))
    ax[0][0].imshow(original_image.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][0].set_title('Original Image')
    ax[0][1].imshow(blurred_small.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][1].set_title('Small Blurred Image')
    ax[0][2].imshow(blurred_large.squeeze().permute(1, 2, 0).cpu().numpy())
    ax[0][2].set_title('Large Blurred Image')
    ax[0][3].imshow(noisy_small.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[0][3].set_title('Small Noisy Image')
    ax[0][4].imshow(noisy_large.squeeze().detach().permute(1, 2, 0).cpu().numpy())  # Example visualization
    ax[0][4].set_title('Large Noisy Image')
    
    ax[1][0].imshow(original_heatmap.squeeze(0), cmap='seismic')
    ax[1][0].set_title('Original Heatmap')
    ax[1][1].imshow(blurred_small_heatmap.squeeze(0), cmap='seismic')
    ax[1][1].set_title('Small Blurred Heatmap')
    ax[1][2].imshow(blurred_large_heatmap.squeeze(0), cmap='seismic')
    ax[1][2].set_title('Large Blurred Heatmap')
    ax[1][3].imshow(noisy_small_heatmap.squeeze(0), cmap ='seismic')  # Example visualization
    ax[1][3].set_title('Small Noisy Heatmap')
    ax[1][4].imshow(noisy_large_heatmap.squeeze(0), cmap ='seismic')  # Example visualization
    ax[1][4].set_title('Large Noisy Heatmap')
    
    for i in ax:
        for j in i:
            j.axis('off')
    plt.show()

In [None]:
from experiments import WrapperNet, WrapperNetContrastive
import torch
from internal_utils import update_dictionary_patch
from baselines.trainVggBaselineForCIFAR10.vgg import vgg11

def get_teacher_model(teacher_checkpoint_path):
    checkpoint = torch.load(teacher_checkpoint_path)
    # assume teacher model is vgg11 for now
    teacher = vgg11()
    try: 
        checkpoint = update_dictionary_patch(checkpoint)
        teacher.load_state_dict(checkpoint['new_state_dict'])
    except:
        print('Incorrect patch specified')
    return teacher
data = get_CIFAR10_dataloader()
input_images, labels = next(iter(data))
teacher_model = WrapperNet(get_teacher_model("/home/charleshiggins/RL-LRP/baselines/trainVggBaselineForCIFAR10/save_vgg11/checkpoint_299.tar"), hybrid_loss=True)
# define params
learner_model = WrapperNet(vgg11(), hybrid_loss=True)

In [None]:
input_images, labels = next(iter(data))

In [None]:
sample_image, sample_label = input_images[0], labels[0]


In [None]:
from experiments import perform_lrp_plain
visualise_panel_image(sample_image.unsqueeze(0), teacher_model, 3, 15, 0.1, 0.5, perform_lrp_plain, sample_label.unsqueeze(0))

In [None]:
visualise_panel_image(sample_image.unsqueeze(0), learner_model, 3, 15, 0.1, 0.5, perform_lrp_plain, sample_label.unsqueeze(0))

In [None]:
pp_images = preprocess_images(input_images)


In [None]:
print(f"The target tensor should be: {labels}")

In [None]:
output, heatmap = learner_model(pp_images)
output_target, heatmap_target = teacher_model(pp_images)

In [None]:
import torch.nn as nn
class CosineDistanceLoss(torch.nn.Module):
    def __init__(self):
        super(CosineDistanceLoss, self).__init__()

    def forward(self, input1, input2):
        # Flatten the images: shape from [b, 1, 28, 28] to [b, 784]
        input1_flat = input1.view(input1.size(0), -1)
        input2_flat = input2.view(input2.size(0), -1)
        
        # Compute cosine similarity, then convert to cosine distance
        cosine_sim = F.cosine_similarity(input1_flat, input2_flat)
        cosine_dist = 1 - cosine_sim
        
        # Calculate the mean of the cosine distances
        loss = cosine_dist.mean()
        # loss = F.mse_loss(input1_flat, input2_flat)
        return loss
    

# Define SSIM loss (we'll minimize 1 - SSIM)
class SSIMLoss(torch.nn.Module):
    def __init__(self, data_range=1.0, size_average=True, channel=3):
        super(SSIMLoss, self).__init__()
        self.ssim_module = SSIM(data_range=data_range, size_average=size_average, channel=channel)
    
    def forward(self, img1, img2):
        ssim_value = self.ssim_module(img1, img2)
        return 1 - ssim_value
    
def remove_grad_for_all_but_last_layer(learner_model, optimizer, scheduler):
    for name, module in learner_model.model.named_modules():
        if not isinstance(module, nn.Sequential) \
        and not isinstance(module, WrapperNet) \
        and not len(list(module.children())) > 0 \
            and type(module) not in [nn.ReLU, nn.MaxPool2d, nn.AdaptiveAvgPool2d, nn.LogSoftmax, nn.Dropout]:
            # print(name)
            # print(module)
            # print("####################### \n")
            if "classifier.6" not in name:
                print(f"removing grad from: {name}")
                for param in module.parameters():
                    param.requires_grad = False
            else:
                print(f"Grad will continue for {name}")
    optimizer = torch.optim.SGD(learner_model.parameters(), lr=0.25, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

    return learner_model, optimizer, scheduler

In [None]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
optimizer = torch.optim.SGD(learner_model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss()
mse_loss = torch.nn.MSELoss()
cos_loss = CosineDistanceLoss()
ssim_loss = SSIMLoss()
EPOCHS = 200
CHANGE_POINT = 100
with torch.no_grad():
    _, teacher_heatmap = teacher_model(pp_images)
for i in range(0, EPOCHS):
    if i <= CHANGE_POINT:
        model_out, model_heatmap = learner_model(pp_images, labels)
    else:
        model_out, model_heatmap = learner_model(pp_images)
    optimizer.zero_grad()
    loss = ssim_loss(model_heatmap, teacher_heatmap)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(learner_model.parameters(), 1)
    optimizer.step()
    # if i > CHANGE_POINT:
    #     scheduler.step()
    correct = model_out.argmax(dim=1).eq(labels).sum().item()
    correct_pct = 100 * correct/labels.shape[0] 
    print(f"iteration: {i} \t accuracy: {correct_pct}\t loss: {loss.float()}")
    if i == CHANGE_POINT:
        learner_model, optimizer, scheduler = remove_grad_for_all_but_last_layer(learner_model, optimizer, scheduler)
        print("REMOVED GRADIENTS")

In [None]:
model_out

In [None]:
visualise_panel_image(two_images[0].unsqueeze(0), learner_model, 3, 15, 0.1, 0.5, perform_lrp_plain, two_labels[0].unsqueeze(0))


In [None]:
learner_model

In [None]:
import torch.nn as nn
