# Interpretability and Explainability

In [None]:
import albumentations as A
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
from albumentations.pytorch import ToTensorV2
from collections import defaultdict
from tqdm import tqdm

from modules import *
from networks import *
from training import *
from interpretability import *

COORDS = 'polar'  # cartesian, polar
ARCH = 'cascade'  # dual, cascade
MODEL = 'ref'  # rau, ref, swin

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SIZE = 256 if MODEL != 'swin' else 224
transform = A.Compose([
    A.Resize(height=SIZE, width=SIZE, interpolation=cv.INTER_AREA),
    A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform, mask=polar_transform),
    A.Normalize(),
    ToTensorV2(),
])

# Load dataset
all_images = load_files_from_dir(['../data/DRISHTI/ROI/TestImages'])
all_masks = load_files_from_dir(['../data/DRISHTI/ROI/TestMasks'])

# Shuffle
indices = np.random.permutation(len(all_images))
all_images = [all_images[i] for i in indices]
all_masks = [all_masks[i] for i in indices]

loader = load_dataset(
    all_images,
    all_masks,
    transform,
    batch_size=1,
    shuffle=False,
)

# Sample data
images, masks = next(iter(loader))
images = images.float().to(device)
masks = masks.long().to(device)

# Load models
path = rf"..\models\{COORDS}\{MODEL}\binary.pth"
checkpoint = load_checkpoint(path, map_location=device)
base_model = checkpoint['model']
base_model = base_model.eval()

path = rf"..\models\{COORDS}\{MODEL}\{ARCH}.pth"
checkpoint = load_checkpoint(path, map_location=device)
model = checkpoint['model']
model = model.eval()
model

## Activation Visualization

In [None]:
activation_maps = ActivationMaps(model, {
    'conv1': model.encoder.en1,
    'conv2': model.encoder.en2,
    'conv3': model.encoder.en3,
    'conv4': model.encoder.en4,
    'conv5': model.encoder.en5,
    'conv6': model.decoder.de1,
    'conv7': model.decoder.de2,
    'conv8': model.decoder.de3,
    'conv9': model.decoder.de4,
    'last': model.decoder.last,
}, images)
activations = activation_maps.get_activations()
activation_maps.unregister_hooks()

In [None]:
# Plot conv1 activations
fig, ax = plt.subplots(4, 8, figsize=(16, 8))
plt.suptitle('conv1 activations')
ax = ax.ravel()
for i in range(32):
    ax[i].imshow(activations['conv1'][0][0][i])
    ax[i].axis('off')
plt.tight_layout()
plt.show()

# activation_maps.show('conv1')

## Grad-CAM (Gradient-weighted Class Activation Mapping)

In [None]:
model_layers = dict(model.named_modules())
# print(model_layers.keys())
target_layer = model_layers['decoder.de1.conv']

gradcam = GradCAM(model, target_layer)
out = gradcam(images)
print(out.shape)

gradcam.show()
gradcam.unregister_hooks()