# Interpretability and Explainability

In [None]:
import albumentations as A
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
from albumentations.pytorch import ToTensorV2
from collections import defaultdict
from tqdm import tqdm

from modules import *
from networks import *
from training import *
from interpretability import *

COORDS = 'polar'  # cartesian, polar
ARCH = 'dual'  # dual, cascade
MODEL = 'ref'  # rau, ref, swin

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SIZE = 256 if MODEL != 'swin' else 224
transform = A.Compose([
    A.Resize(height=SIZE, width=SIZE, interpolation=cv.INTER_AREA),
    A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform, mask=polar_transform),
    A.Normalize(),
    ToTensorV2(),
])

# Load dataset
all_images = load_files_from_dir(['../data/DRISHTI/ROI/TestImages'])
all_masks = load_files_from_dir(['../data/DRISHTI/ROI/TestMasks'])

# Shuffle
indices = np.random.permutation(len(all_images))
all_images = [all_images[i] for i in indices]
all_masks = [all_masks[i] for i in indices]

loader = load_dataset(
    all_images,
    all_masks,
    transform,
    batch_size=1,
    shuffle=False,
)

# Sample data
images, masks = next(iter(loader))
images = images.float().to(device)
masks = masks.long().to(device)

# Load models
path = rf"..\models\{COORDS}\{MODEL}\binary.pth"
checkpoint = load_checkpoint(path, map_location=device)
base_model = checkpoint['model']
base_model = base_model.eval()

path = rf"..\models\{COORDS}\{MODEL}\{ARCH}.pth"
checkpoint = load_checkpoint(path, map_location=device)
model = checkpoint['model']
model = model.eval()
model

## Activation Visualization

In [15]:
# cascade_layers = {
#     # 'Encoder block 1': model.encoder.en1,
#     'Encoder block 2': model.encoder.en2,
#     # 'Encoder block 3': model.encoder.en3,
#     # 'Encoder block 4': model.encoder.en4,
#     'Encoder block 5': model.encoder.en5,

#     'Decoder block 1': model.decoder.de4,
#     # 'Decoder block 2': model.decoder.de3,
#     # 'Decoder block 3': model.decoder.de2,
#     'Decoder block 4': model.decoder.de1,
#     # 'Output': model.decoder.last,
# 
#     # 'Side input 1': model.encoder.side1,
#     # 'Side input 2': model.encoder.side2,
#     # 'Side input 3': model.encoder.side3,
# }

layers = {
    # 'Encoder block 1': model.encoder.en1,
    # 'Encoder block 2': model.encoder.en2,
    # 'Encoder block 3': model.encoder.en3,
    # 'Encoder block 4': model.encoder.en4,
    # 'Encoder block 5': model.encoder.en5,

    # 'Decoder block 1': model.decoder.de4,
    # 'Decoder block 2': model.decoder.de3,
    # 'Decoder block 3': model.decoder.de2,
    # 'Decoder block 4': model.decoder.de1,
    # 'Output': model.decoder.last,

    'OD Decoder block 1': model.decoder1.de4,
    # 'OD Decoder block 2': model.decoder1.de3,
    # 'OD Decoder block 3': model.decoder1.de2,
    'OD Decoder block 4': model.decoder1.de1,
    # 'OD Output': model.decoder1.last,

    'OC Decoder block 1': model.decoder2.de4,
    # 'OC Decoder block 2': model.decoder2.de3,
    # 'OC Decoder block 3': model.decoder2.de2,
    'OC Decoder block 4': model.decoder2.de1,
    # 'OC Output': model.decoder2.last,

    # 'Side input 1': model.encoder.side1,
    # 'Side input 2': model.encoder.side2,
    # 'Side input 3': model.encoder.side3,
}

activation_maps = ActivationMaps(model, layers, images)
activations = activation_maps.get_activations()
activation_maps.unregister_hooks()

ConvCBAM module activations saved in 'OD Decoder block 1' with shape=(1, 80, 32, 32)
ConvCBAM module activations saved in 'OD Decoder block 4' with shape=(1, 32, 256, 256)
ConvCBAM module activations saved in 'OC Decoder block 1' with shape=(1, 80, 32, 32)
ConvCBAM module activations saved in 'OC Decoder block 4' with shape=(1, 32, 256, 256)


In [None]:
for name in layers:
    activation_maps.show(name)

In [None]:
# TODO: Plot selected activations for the thesis visualization
fig, ax = plt.subplots(4, 8, figsize=(16, 8))
plt.suptitle('conv1 activations')
ax = ax.ravel()

for i in range(32):
    ax[i].imshow(activations['conv1'][0][0][i])
    ax[i].axis('off')

plt.tight_layout()
plt.show()

## Grad-CAM (Gradient-weighted Class Activation Mapping)

In [None]:
model_layers = dict(model.named_modules())
# print(model_layers.keys())
target_layer = model_layers['decoder.de1.conv']

gradcam = GradCAM(model, target_layer)
out = gradcam(images)
print(out.shape)

gradcam.show()
gradcam.unregister_hooks()