# Getting necessary requirements for the experiment

# Requried imports and helper function

In [None]:
IMAGE_SIZE = 256

In [None]:
import torch
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms import ToTensor, Normalize, Resize
from torchvision.datasets import ImageFolder


import pl_bolts
from pl_bolts.models.autoencoders import VAE
from pl_bolts.models.self_supervised import SimCLR, SwAV
import torchvision

from captum.attr import IntegratedGradients, GradientShap

import tqdm.notebook as tqdm
import cv2

import os
import warnings
warnings.filterwarnings("ignore")


In [None]:
# Load the Explanation method COCOA and Feature Attribution method
from contrastive_corpus_similarity import ContrastiveCorpusSimilarity
from rise import RISE
from utils import get_black_baseline

# Get the models

In [None]:
from lightly.loss import NTXentLoss
from lightly.models.modules.heads import SimCLRProjectionHead
import pytorch_lightning as pl

class SimCLRModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        hidden_dim = resnet.fc.in_features
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)

        self.criterion = NTXentLoss()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log("train_loss_ssl", loss)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

In [None]:
#load some models

def load_supervised():
    supervised = torchvision.models.resnet50(pretrained=True)
    modules = list(supervised.children())[:-1]
    supervised = nn.Sequential(*modules, nn.Flatten()).to(DEVICE)
    supervised.eval()
    return supervised

def load_swav():
    swav_weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'
    swav = SwAV.load_from_checkpoint(swav_weight_path, strict=False)
    swav = swav.to(DEVICE)
    swav.eval()
    return swav


def load_simclr_custom():
    simclr = torch.load('./full_model3.pth')
    simclr.eval()
    return simclr

def load_vae():
    vae = VAE(input_height=IMAGE_SIZE).from_pretrained('stl10-resnet18')
    modules = list(vae.children())[:-3]
    vae = nn.Sequential(*modules, nn.Flatten()).to(DEVICE)
    vae.eval()
    return vae

In [None]:
def load_simclr():
    simclr_weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
    simclr = SimCLR.load_from_checkpoint(simclr_weight_path, strict=False)
    simclr = simclr.to(DEVICE)
    simclr.eval()
    return simclr


In [None]:
#dowload models

load_supervised()
load_swav()
load_simclr()
load_vae()
load_simclr_custom()

print('done')

In [None]:
def load_img_from_drive(img_name, shape=IMAGE_SIZE):
    # loads two4two images from personal drive 
    img = Image.open('../242/test/{}'.format(img_name)).convert('RGB')

    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((shape, shape)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    x = transform(img).unsqueeze(0)

    return x.to(DEVICE)

In [None]:
def load_mask_from_drive(mask_name, shape=IMAGE_SIZE):
  #loads two4two masks from personal drive (after they have been uploaded and unzipped)
    mask = cv2.imread('../242/test/{}'.format(mask_name))

    target_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), torchvision.transforms.Resize((shape, shape))])
    mask = target_transform(mask).sum(dim=0).unsqueeze(dim=0)

    mask = torch.where(mask > 0, torch.tensor(1), torch.tensor(0))

    return mask

In [None]:
def to_np(x):
    return x.cpu().detach().numpy()

# Get Baseline method for samples

In [None]:
from functools import partial

get_baseline = partial(
    get_black_baseline
)



# Get Reference and Foil set

In [None]:
# NOTE: the images are structured according to
# ./mixed/mixed/<image_name>.png
data_path = "./mixed/"

# Define the transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the images to a specific size
    transforms.ToTensor(),  # Convert the images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),

])

# Load the images from the directory using the ImageFolder dataset
dataset = ImageFolder(root=data_path, transform=transform)

# Create a data loader to iterate over the dataset in batches
batch_size = 32
data_loader_ref = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
# NOTE: the images are structured according to
# ./noise/noise/<image_name>.png
data_path = "./noise/"

# Define the transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the images to a specific size
    transforms.ToTensor(),  # Convert the images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225]),
])

# Load the images from the directory using the ImageFolder dataset
dataset = ImageFolder(root=data_path, transform=transform)

# Create a data loader to iterate over the dataset in batches
batch_size = 32
data_loader_foil = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Initialize Contrastive Corpus Similarity

In [None]:
model_super = load_supervised()
model_sim = load_simclr()
model_swav = load_swav()
model_vae = load_vae()
model_sim_custom = load_simclr_custom()

ccs_super = ContrastiveCorpusSimilarity(
        encoder=model_super.to(DEVICE),
        corpus_dataloader=data_loader_ref,
        foil_dataloader=data_loader_foil,
        normalize=True,
        batch_size=32
      )

ccs_sim = ContrastiveCorpusSimilarity(
        encoder=model_sim.to(DEVICE),
        corpus_dataloader=data_loader_ref,
        foil_dataloader=data_loader_foil,
        normalize=True,
        batch_size=32,
      )

ccs_swav = ContrastiveCorpusSimilarity(
        encoder=model_swav.to(DEVICE),
        corpus_dataloader=data_loader_ref,
        foil_dataloader=data_loader_foil,
        normalize=True,
        batch_size=32
      )

ccs_vae = ContrastiveCorpusSimilarity(
        encoder=model_vae.to(DEVICE),
        corpus_dataloader=data_loader_ref,
        foil_dataloader=data_loader_foil,
        normalize=True,
        batch_size=32
      )

ccs_sim_custom = ContrastiveCorpusSimilarity(
        encoder=model_sim_custom.to(DEVICE),
        corpus_dataloader=data_loader_ref,
        foil_dataloader=data_loader_foil,
        normalize=True,
        batch_size=32
      )

In [None]:
ccs_list = [ccs_super, ccs_sim, ccs_swav, ccs_vae, ccs_sim_custom]

# RISE Experiment

In [None]:
list_images = sorted(os.listdir('../242/test'))

model_name_list = ['Supervised', 'SimCLR', 'SwAV', 'VAE']#, 'Custom SimCLR']
model_list = [load_supervised, load_simclr, load_swav, load_vae]#, load_simclr_custom]
mask_bs = 100
num_batches = 30

tk_list = []
pg_list = []
rra_list = []

print(model_name_list)
#we have a total of 5k images and 5k masks for the images
#iterate over images + masks
for i in range(0,4650):
    attr_list = []
    tk_indv_list = []
    pg_indv_list = []
    rra_indiv_list = []
    image_name = list_images[i*2]
    mask_name = list_images[(i*2)+1]
    x = load_img_from_drive(image_name)
    baseline = get_baseline(x)
    mask = load_mask_from_drive(mask_name)

    #iterate over models
    for model_loader, model_name, ccs in zip(model_list, model_name_list, ccs_list):

#         model = model_loader()

        with torch.no_grad():

            rise = RISE(ccs)
            attr = rise.attribute(
                    x.cuda(),
                    (8,8),
                    baselines=baseline.cuda(),
                    mask_prob=0.5,
                    n_samples=3000,
                    normalize_by_mask_prob=True,
                )

        attr_list.append(attr)


  #iterate over models' results
    for j in range(len(model_list)):
        s=to_np(mask.squeeze())
        a=to_np(attr_list[j].mean(dim=1).cpu().detach().squeeze())

        k=1000
        #prepare shapes
        s = s.astype(bool)
        top_k_binary_mask = np.zeros(a.shape)

        #top-k intersection
        #sort and create masks
        sorted_indices = np.argsort(a, axis=None)
        np.put_along_axis(top_k_binary_mask, sorted_indices[-k:], 1, axis=None)
        top_k_binary_mask = top_k_binary_mask.astype(bool)
        tki = 1.0 / k * np.sum(np.logical_and(s, top_k_binary_mask))

        #pointing game
        #find indices with max value
        a_pg = a.copy().flatten()
        s_pg = to_np(mask.squeeze()).flatten().astype(bool)
        max_index = np.argwhere(a_pg == np.max(a_pg))
        #check if maximum of explanation is on target object class.
        hit = np.any(s_pg[max_index])

        #relevance rank accuracy
        s = to_np(mask.squeeze())
        s = np.where(s.flatten().astype(bool))[0]
        #size of the ground truth mask
        n = len(s)
        #sort in descending order
        a_sorted = np.argsort(a.flatten())[-int(n) :]
        #calculate hits
        hits = len(np.intersect1d(s, a_sorted))
        if hits != 0:
            rank_accuracy = hits / float(n)
        else:
            rank_accuracy = 0.0

        tk_indv_list.append(tki)
        pg_indv_list.append(np.float(hit))
        rra_indiv_list.append(rank_accuracy)
    tk_list.append(tk_indv_list)
    pg_list.append(pg_indv_list)
    rra_list.append(rra_indiv_list)
    if i % 1 == 0: #print rolling mean after every **1** images processed.
        print(i)
        print('Pointing Game:')
        print(np.array(pg_list).mean(axis=0))
        print('Top-K:')
        print(np.array(tk_list).mean(axis=0))
        print('Relevance Rank Acc.:')
        print(np.array(rra_list).mean(axis=0))

# Integrated Gradients Experiment

In [None]:
list_images = sorted(os.listdir('../242/test'))

model_name_list = ['Supervised', 'SimCLR', 'SwAV', 'VAE', 'Custom SimCLR']
model_list = [load_supervised, load_simclr, load_swav, load_vae, load_simclr_custom]
mask_bs = 100
num_batches = 30

tk_list = []
pg_list = []
rra_list = []

print(model_name_list)
#we have a total of 5k images and 5k masks for the images
#iterate over images + masks
for i in range(0,4650):
    attr_list = []
    tk_indv_list = []
    pg_indv_list = []
    rra_indiv_list = []
    image_name = list_images[i*2]
    mask_name = list_images[(i*2)+1]
    x = load_img_from_drive(image_name)
    baseline = get_baseline(x)
    mask = load_mask_from_drive(mask_name)

    #iterate over models
    for model_loader, model_name, ccs in zip(model_list, model_name_list, ccs_list):

#         model = model_loader()

        with torch.no_grad():

            ig = IntegratedGradients(ccs)
            attr = ig.attribute(
                    x.cuda(),
                    baselines=baseline.cuda(),
                )

        attr_list.append(attr)


  #iterate over models' results
    for j in range(len(model_list)):
        s=to_np(mask.squeeze())
        a=to_np(attr_list[j].mean(dim=1).cpu().detach().squeeze())

        k=1000
        #prepare shapes
        s = s.astype(bool)
        top_k_binary_mask = np.zeros(a.shape)

        #top-k intersection
        #sort and create masks
        sorted_indices = np.argsort(a, axis=None)
        np.put_along_axis(top_k_binary_mask, sorted_indices[-k:], 1, axis=None)
        top_k_binary_mask = top_k_binary_mask.astype(bool)
        tki = 1.0 / k * np.sum(np.logical_and(s, top_k_binary_mask))

        #pointing game
        #find indices with max value
        a_pg = a.copy().flatten()
        s_pg = to_np(mask.squeeze()).flatten().astype(bool)
        max_index = np.argwhere(a_pg == np.max(a_pg))
        #check if maximum of explanation is on target object class.
        hit = np.any(s_pg[max_index])

        #relevance rank accuracy
        s = to_np(mask.squeeze())
        s = np.where(s.flatten().astype(bool))[0]
        #size of the ground truth mask
        n = len(s)
        #sort in descending order
        a_sorted = np.argsort(a.flatten())[-int(n) :]
        #calculate hits
        hits = len(np.intersect1d(s, a_sorted))
        if hits != 0:
            rank_accuracy = hits / float(n)
        else:
            rank_accuracy = 0.0

        tk_indv_list.append(tki)
        pg_indv_list.append(np.float(hit))
        rra_indiv_list.append(rank_accuracy)
    tk_list.append(tk_indv_list)
    pg_list.append(pg_indv_list)
    rra_list.append(rra_indiv_list)
    if i % 1 == 0: #print rolling mean after every **1** images processed.
        print(i)
        print('Pointing Game:')
        print(np.array(pg_list).mean(axis=0))
        print('Top-K:')
        print(np.array(tk_list).mean(axis=0))
        print('Relevance Rank Acc.:')
        print(np.array(rra_list).mean(axis=0))

# GradientShap Experiment

In [None]:
list_images = sorted(os.listdir('../242/test'))

model_name_list = ['Supervised', 'SimCLR', 'SwAV', 'VAE', 'Custom SimCLR']
model_list = [load_supervised, load_simclr, load_swav, load_vae]#, load_simclr_custom]
mask_bs = 100
num_batches = 30

tk_list = []
pg_list = []
rra_list = []

print(model_name_list)
#we have a total of 5k images and 5k masks for the images
#iterate over images + masks
for i in range(0,4650):
    attr_list = []
    tk_indv_list = []
    pg_indv_list = []
    rra_indiv_list = []
    image_name = list_images[i*2]
    mask_name = list_images[(i*2)+1]
    x = load_img_from_drive(image_name)
    baseline = get_baseline(x)
    mask = load_mask_from_drive(mask_name)

    #iterate over models
    for model_loader, model_name, ccs in zip(model_list, model_name_list, ccs_list):

#         model = model_loader()

        with torch.no_grad():

            gs = GradientShap(ccs)
            attr = gs.attribute(
                    x.cuda(),
                    baselines=baseline.cuda(),
                    n_samples=50,
                    stdevs=0.2
                )

        attr_list.append(attr)


    #iterate over models' results
    for j in range(len(model_list)):
        s=to_np(mask.squeeze())
        a=to_np(attr_list[j].mean(dim=1).cpu().detach().squeeze())

        k=1000
        #prepare shapes
        s = s.astype(bool)
        top_k_binary_mask = np.zeros(a.shape)

        #top-k intersection
        #sort and create masks
        sorted_indices = np.argsort(a, axis=None)
        np.put_along_axis(top_k_binary_mask, sorted_indices[-k:], 1, axis=None)
        top_k_binary_mask = top_k_binary_mask.astype(bool)
        tki = 1.0 / k * np.sum(np.logical_and(s, top_k_binary_mask))

        #pointing game
        #find indices with max value
        a_pg = a.copy().flatten()
        s_pg = to_np(mask.squeeze()).flatten().astype(bool)
        max_index = np.argwhere(a_pg == np.max(a_pg))
        #check if maximum of explanation is on target object class.
        hit = np.any(s_pg[max_index])

        #relevance rank accuracy
        s = to_np(mask.squeeze())
        s = np.where(s.flatten().astype(bool))[0]
        #size of the ground truth mask
        n = len(s)
        #sort in descending order
        a_sorted = np.argsort(a.flatten())[-int(n) :]
        #calculate hits
        hits = len(np.intersect1d(s, a_sorted))
        if hits != 0:
            rank_accuracy = hits / float(n)
        else:
            rank_accuracy = 0.0

        tk_indv_list.append(tki)
        pg_indv_list.append(np.float(hit))
        rra_indiv_list.append(rank_accuracy)
    tk_list.append(tk_indv_list)
    pg_list.append(pg_indv_list)
    rra_list.append(rra_indiv_list)
    if i % 1 == 0: #print rolling mean after every **1** images processed.
        print(i)
        print('Pointing Game:')
        print(np.array(pg_list).mean(axis=0))
        print('Top-K:')
        print(np.array(tk_list).mean(axis=0))
        print('Relevance Rank Acc.:')
        print(np.array(rra_list).mean(axis=0))