# Import Libraries and Download Data

In [None]:
!pip install imutils
!pip install segmentation_models_pytorch
!pip install captum
!pip install albumentations
!pip install gdown 
import gdown 
url = 'https://drive.google.com/uc?id=1HSpQOnJoqP6frkqy1GqWuaU5-wgDDFYF' 
output = 'data.zip'
gdown.download(url, output)
url = 'https://drive.google.com/uc?id=12b9pu931vmw0mNn2RFt-95_qTu8s_G_4' 
output = 'best_model.pth'
gdown.download(url, output)
!unzip data.zip

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
from collections import defaultdict, OrderedDict
import shutil
import time
import copy
import math
import random
from imutils import paths
import warnings

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpy import unravel_index

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

from torchvision import transforms
from torchvision import datasets

from PIL import *
import albumentations as A
import skimage

import segmentation_models_pytorch as smp

from captum.attr import visualization as viz
from captum.attr import GuidedGradCam, Saliency, DeepLift, GuidedBackprop, LayerGradCam, LayerDeepLift, LayerAttribution
from captum.metrics import sensitivity_max, infidelity

warnings.filterwarnings('ignore')
print(torch.cuda.is_available())

# Load Data

In [None]:
def visualize(**images):
    n_images = len(images)
    f, axarr = plt.subplots(1, n_images, figsize=(4 * n_images,4))
    for idx, (name, image) in enumerate(images.items()):
        if image.shape[0] == 3 or image.shape[0] == 2:
            axarr[idx].imshow(np.squeeze(image.permute(1, 2, 0)))
        else: 
            axarr[idx].imshow(np.squeeze(image))
        axarr[idx].set_title(name.replace('_',' ').title(), fontsize=20)
    plt.show()
    
class EndoscopyDataset(Dataset):
    def __init__(self, images, masks, augmentations=None):   
        self.input_images = images
        self.target_masks = masks
        self.augmentations = augmentations

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx): 
        img = Image.open(os.path.join(self.input_images[idx])).convert('RGB')
        mask = Image.open(os.path.join(self.target_masks[idx])).convert('RGB')
        img = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()])(img)
        mask = transforms.Compose([transforms.Resize((400, 400), interpolation=transforms.InterpolationMode.NEAREST), transforms.Grayscale(), transforms.ToTensor()])(mask)
        img = img.permute((1, 2, 0))
        mask = mask.permute((1, 2, 0))
        img = img.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        
        if self.augmentations:
            augmented = self.augmentations(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute((2, 0, 1))
        mask = torch.tensor(mask, dtype=torch.float)
        mask = mask.permute((2, 0, 1))
        
        return [img, mask]
    
train_batch_size = 8
val_batch_size = 4
test_batch_size = 4
num_workers = 2

main_dir = './KVASIR-SEG/'

train_images = sorted(list(paths.list_files(main_dir + 'train/images/', contains="jpg")))
val_images = sorted(list(paths.list_files(main_dir + 'val/images/', contains="jpg")))
test_images = sorted(list(paths.list_files(main_dir + 'test/images/', contains="jpg")))

train_masks = sorted(list(paths.list_files(main_dir + 'train/masks/', contains="jpg")))
val_masks = sorted(list(paths.list_files(main_dir + 'val/masks/', contains="jpg")))
test_masks = sorted(list(paths.list_files(main_dir + 'test/masks/', contains="jpg")))

augmentations = A.Compose({
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=(-90, 90)),
        A.VerticalFlip(p=0.5),
        A.Transpose(p=0.5),
        A.GaussianBlur(p=0.5),
        A.augmentations.geometric.transforms.Affine(scale=(0.9, 1.1), translate_percent=0.1)
})

dataset = {
    'train': EndoscopyDataset(train_images, train_masks, augmentations), 
    'val': EndoscopyDataset(val_images, val_masks, None), 
    'test': EndoscopyDataset(test_images, test_masks, None)
}

dataloader = {
    'train': DataLoader(dataset['train'], batch_size=train_batch_size, shuffle=True, num_workers=num_workers),
    'val': DataLoader(dataset['val'], batch_size=val_batch_size, shuffle=True, num_workers=num_workers),
    'test': DataLoader(dataset['test'], batch_size=test_batch_size, shuffle=False, num_workers=num_workers)
}

image, mask = dataset['train'][random.randint(0, len(dataset['train'])-1)]
print(image.shape, image.min(), image.max())
print(mask.shape, mask.min(), mask.max())
visualize(
    original_image = image,
    grund_truth_mask = mask,
    polyp = skimage.segmentation.mark_boundaries(image.permute(1, 2, 0).detach().cpu().numpy(), mask.detach().cpu().numpy()[0].astype(np.int64), color=(0, 0, 1), mode='outer')
)

# Train Model

In [None]:
class Loss(smp.utils.base.Loss):
    def __init__(self, eps=1.0, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.activation = smp.base.modules.Activation(activation)
        self._name = 'loss'

    def forward(self, y_pr, y_gt):
        y_pr = self.activation(y_pr)
        return (1 - smp.utils.functional.iou(y_pr, y_gt, eps=self.eps, threshold=0.3)) + (1 - smp.utils.functional.f_score(y_pr, y_gt, beta=1, eps=self.eps, threshold=0.3)) + nn.functional.binary_cross_entropy(y_pr, y_gt)

In [None]:
training = True
epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True)

loss = Loss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.3),
    smp.utils.metrics.Fscore(threshold=0.3)
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0005),
])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=0.00001,
)

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

if training:

    best_model = {'loss': 0.0, 'iou_score': 0.0, 'fscore': 0.0}
    train_logs_list, valid_logs_list = [], []

    for i in range(0, epochs):
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(dataloader['train'])
        valid_logs = valid_epoch.run(dataloader['val'])
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        if best_model['iou_score'] < valid_logs['iou_score']:
            torch.save(model, main_dir + 'best_model.pth')
            best_model['loss'] = valid_logs['loss']
            best_model['iou_score'] = valid_logs['iou_score']
            best_model['fscore'] = valid_logs['fscore']
            print('Model saved!')

In [None]:
print(best_model)

In [None]:
train_logs_df = pd.DataFrame(train_logs_list)
valid_logs_df = pd.DataFrame(valid_logs_list)

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.loss.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.loss.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Loss', fontsize=20)
plt.title('Loss Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('loss_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('IoU Score', fontsize=20)
plt.title('IoU Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('iou_score_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.fscore.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.fscore.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('F1 Score', fontsize=20)
plt.title('F1 Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('fscore_plot.png')
plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.plot(train_logs_df.index.tolist(), train_logs_df.accuracy.tolist(), lw=1, label = 'Train')
plt.plot(valid_logs_df.index.tolist(), valid_logs_df.accuracy.tolist(), lw=1, label = 'Valid')
plt.xlabel('Epochs', fontsize=20)
plt.ylabel('Accuracy Score', fontsize=20)
plt.title('Accuracy Score Plot', fontsize=20)
plt.legend(loc='best', fontsize=16)
plt.grid()
plt.savefig('accuracy_plot.png')
plt.show()

# Load Best Model

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# best_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True)
best_model = torch.load('./best_model.pth')

# Test Model

In [None]:
%matplotlib inline

best_model.eval()

IOUs = []
F1s = []

with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloader['test']):
        inputs = inputs.to(device)
        labels = labels.to(device)

        pred_mask = best_model(inputs)

        for i in range(len(inputs)):
            test_image = inputs[i]
            test_mask = labels[i]
            predMask = pred_mask[i]
            
            iou = smp.utils.functional.iou(predMask, test_mask, threshold=0.3)
            IOUs.append(iou.cpu().detach())

            f1 = smp.utils.functional.f_score(predMask, test_mask, threshold=0.3)
            F1s.append(f1.cpu().detach())
            
            predMask = torch.where(predMask >= 0.3, 1, 0)

In [None]:
print('Test IOU: ' + str(np.mean(IOUs)))
print('Test F1: ' + str(np.mean(F1s)))

# Interpret

In [None]:
def get_border(mask):
    border = skimage.segmentation.find_boundaries(mask.detach().cpu().numpy(), mode='outer').astype(np.uint8)
    indices = np.where(border == 1)
    indices = np.concatenate((indices[0][...,np.newaxis],indices[1][...,np.newaxis],indices[2][...,np.newaxis]),axis=1)
    return list(map(tuple, indices))

In [None]:
def get_attr(method, image, targets):
    maps = list()
    for target in targets:
        if isinstance(method, DeepLift):
            attr = method.attribute(image.to(device), target=target, return_convergence_delta=False)
        else:
            attr = method.attribute(image.to(device), target=target)
        if attr.shape[2] < image.shape[2]:
            upsampled_attr = LayerAttribution.interpolate(attr, (image.shape[2], image.shape[3]))
        else:
            upsampled_attr = attr
        maps.append(np.mean(upsampled_attr.detach().cpu().numpy()[0], 0, keepdims=True)[0])
    return np.array(maps)

In [None]:
def agg_attr(out, size):
    i = np.zeros(out.shape)
    for k in range(out.shape[0]):
        a = np.max(np.abs(out[k]))
        a = a if a != 0 else 1
        i[k] = out[k] / a
    return np.sum(i, 0).reshape(size, size, 1)

In [None]:
def interpret_model(model, image, image_name):
    figure_size = (5, 5)
    img_cpu = image.cpu().permute(1, 2, 0).detach().numpy()
    img_batch = image.unsqueeze(0)
    pred_mask = model(img_batch.to(device))[0]
    binary_mask = torch.where(pred_mask >= 0.3, 1, 0)
    
    sm = Saliency(model)
    sm_out = get_attr(sm, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(sm_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, fig_size=figure_size)
    figure.savefig(image_name + '_Saliency.png', format='png', dpi=1200)

    gbp = GuidedBackprop(model)
    gbp_out = get_attr(gbp, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(gbp_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, fig_size=figure_size)
    figure.savefig(image_name + '_Guided Backpropagation.png', format='png', dpi=1200)
    
    ggc = GuidedGradCam(model, model.conv)
    ggc_out = get_attr(ggc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ggc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, fig_size=figure_size)
    figure.savefig(image_name + '_Guided Grad-CAM.png', format='png', dpi=1200)
    
    dl = DeepLift(model)
    dl_out = get_attr(dl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(dl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, fig_size=figure_size)
    figure.savefig(image_name + '_DeepLift.png', format='png', dpi=1200)

In [None]:
def interpret_layers_with_gradcam(model, image, image_name):
    figure_size = (5, 5)
    img_cpu = image.cpu().permute(1, 2, 0).detach().numpy()
    img_batch = image.unsqueeze(0)
    pred_mask = model(img_batch.to(device))[0]
    binary_mask = torch.where(pred_mask >= 0.3, 1, 0)

    lgc = LayerGradCam(model, model.encoder1)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"],
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Encoder1"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Encoder1.png', format='png', dpi=1200)
    
    lgc = LayerGradCam(model, model.encoder2)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Encoder2"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Encoder2.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.encoder3)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Encoder3"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Encoder3.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.encoder4)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Encoder4"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Encoder4.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.bottleneck)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Bottleneck"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Bottleneck.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.decoder4)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Decoder4"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Decoder4.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.decoder3)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Decoder3"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Decoder3.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.decoder2)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Decoder2"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Decoder2.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.decoder1)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Decoder1"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_Decoder1.png', format='png', dpi=1200)
                                      
    lgc = LayerGradCam(model, model.conv)
    lgc_out = get_attr(lgc, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(lgc_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["Grad-CAM - Last Layer"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer Grad-CAM_LastLayer.png', format='png', dpi=1200)

In [None]:
def interpret_layers_with_deeplift(model, image, image_name):
    figure_size = (5, 5)
    img_cpu = image.cpu().permute(1, 2, 0).detach().numpy()
    img_batch = image.unsqueeze(0)
    pred_mask = model(img_batch.to(device))[0]
    binary_mask = torch.where(pred_mask >= 0.3, 1, 0)
                                      
    ldl = LayerDeepLift(model, model.encoder1)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Encoder1"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Encoder1.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.encoder2)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Encoder2"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Encoder2.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.encoder3)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Encoder3"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Encoder3.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.encoder4)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Encoder4"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Encoder4.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.bottleneck)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Bottleneck"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Bottleneck.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.decoder4)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Decoder4"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Decoder4.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.decoder3)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Decoder3"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Decoder3.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.decoder2)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu,signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Decoder2"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Decoder2.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.decoder1)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Decoder1"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_Decoder1.png', format='png', dpi=1200)
                                      
    ldl = LayerDeepLift(model, model.conv)
    ldl_out = get_attr(ldl, img_batch, get_border(binary_mask))
    figure, _ = viz.visualize_image_attr_multiple(agg_attr(ldl_out, image.shape[1]), original_image=img_cpu, signs=["all"], 
                                                  methods=["blended_heat_map"], show_colorbar=True, titles=["DeepLift - Last Layer"], fig_size=figure_size)
    figure.savefig(image_name + '_Layer DeepLift_LastLayer.png', format='png', dpi=1200)

# Evaluate Interpretions

In [None]:
def perturb_fn(inputs):
    noise = torch.tensor(np.random.normal(0, 0.001, inputs.shape)).float().to(device)
    return noise, inputs - noise

In [None]:
def infidelity_score_model_interpretations(model, image):
    
    methods = [Saliency(model), GuidedBackprop(model), GuidedGradCam(model, model.conv), DeepLift(model)]
    infidelity_scores = [0, 0, 0, 0]
    
    img_cpu = image.cpu().permute(1, 2, 0).detach().numpy()
    img_batch = image.unsqueeze(0)
    pred_mask = model(img_batch.to(device))[0]
    binary_mask = torch.where(pred_mask >= 0.3, 1, 0)
    border = get_border(binary_mask)
    
    for i in range(len(methods)):
        infid = 0
        for j in range(len(border)):
            if isinstance(methods[i], DeepLift):
                attribution = methods[i].attribute(img_batch.to(device), target=border[j], return_convergence_delta=False)
            else:
                attribution = methods[i].attribute(img_batch.to(device), target=border[j])
            infid += infidelity(model, perturb_fn, img_batch.to(device), attribution, n_perturb_samples=1)
        infid /= len(border)
        infidelity_scores[i] += infid
    
    return infidelity_scores

In [None]:
def infidelity_score_layer_interpretations(model, test_dataset):
    
    methods = [[LayerGradCam(model, model.encoder1), LayerGradCam(model, model.encoder2), LayerGradCam(model, model.encoder3), 
                LayerGradCam(model, model.encoder4), LayerGradCam(model, model.bottleneck), LayerGradCam(model, model.decoder4), 
                LayerGradCam(model, model.decoder3), LayerGradCam(model, model.decoder2), LayerGradCam(model, model.decoder1), 
                LayerGradCam(model, model.conv)], 
               [LayerDeepLift(model, model.encoder1), LayerDeepLift(model, model.encoder2), LayerDeepLift(model, model.encoder3), 
                LayerDeepLift(model, model.encoder4), LayerDeepLift(model, model.bottleneck), LayerDeepLift(model, model.decoder4), 
                LayerDeepLift(model, model.decoder3), LayerDeepLift(model, model.decoder2), LayerDeepLift(model, model.decoder1), 
                LayerDeepLift(model, model.conv)]]
    infidelity_scores = [0, 0]
    
    img_cpu = image.cpu().permute(1, 2, 0).detach().numpy()
    img_batch = image.unsqueeze(0)
    pred_mask = model(img_batch.to(device))[0]
    binary_mask = torch.where(pred_mask >= 0.3, 1, 0)
    border = get_border(binary_mask)
        
    for i in range(len(methods)):
        infid = 0
        for k in range(len(methods[i])):
            layer_infid = 0
            for j in range(len(border)):
                attribution = methods[i][k].attribute(img_batch.to(device), target=border[j])
                if attribution.shape[2] < image.shape[1]:
                    attribution = LayerAttribution.interpolate(attribution, (image.shape[1], image.shape[2]))
                if isinstance(methods[i][k], LayerDeepLift):
                    attribution = torch.mean(attribution, 1, keepdims=True)
                layer_infid += infidelity(model, perturb_fn, img_batch.to(device), attribution.repeat(1, 3, 1, 1), n_perturb_samples=1)
            layer_infid /= len(border)
            infid += layer_infid
        infid /= len(methods[i])
        infidelity_scores[i] += infid
    
    return infidelity_scores

In [None]:
model_infid = []
for idx in range(len(dataset['test'])):
    img, _ = dataset['test'][idx]
    model_infid.append(infidelity_score_model_interpretations(best_model, img))
print(np.mean(model_infid))

In [None]:
layer_infid = []
for idx in range(len(dataset['test'])):
    img, _ = dataset['test'][idx]
    layer_infid.append(infidelity_score_layer_interpretations(best_model, img))
print(np.mean(layer_infid))