In [None]:
import os
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models, transforms

In [None]:
def get_abs_path(n_parent: int = 0):
    return Path('../' * n_parent).resolve()

def min_max_normalization(image):
    return (image - image.min()) / (image.max() - image.min())

In [None]:
path = get_abs_path(1)
model_path = path / 'models' / 'deep_geo_guessr.pt'
images_path = path / 'visualization_examples'
images_paths = images_path.glob('**/*.png')
images_paths = list(images_paths)
images_paths = [str(path) for path in images_paths]

device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {0} device'.format(device))

In [None]:
path = get_abs_path(1)
data_path = path / 'data'
class_names = [d.name for d in data_path.iterdir() if d.is_dir()]
class_labels = {value:key for (key,value) in enumerate(class_names)}
print('Labels:', class_labels)

In [None]:
data_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

images = []
for image_path in images_paths:
    image = Image.open(image_path)
    image = data_transform(image).to(device)
    image = image.unsqueeze(0)
    images.append(image)

In [None]:
class CountryClassificator(nn.Module):

    def __init__(self, num_classes):
        super(CountryClassificator, self).__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.model.fc.in_features, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return x


model = CountryClassificator(num_classes=5)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)
print('Model loaded')

# Different cnn visualization techniques
Author: Francesco Saverio Zuppichini\
[Github](https://github.com/FrancescoSaverioZuppichini/A-journey-into-Convolutional-Neural-Network-visualization-)

In [None]:
from cnn_visualizations_zuppichini.utils import *
from cnn_visualizations_zuppichini.visualisation.core import *
from cnn_visualizations_zuppichini.visualisation.core.utils import imshow
from cnn_visualizations_zuppichini.visualisation.core.utils import image_net_postprocessing

# Weights Visualization

In [None]:
plt.rcParams['figure.figsize'] = 14,14
model_traced = module2traced(model, images[0])
vis = Weights(model, device)

for i in range(1):
    layer = model_traced[2]
    run_vis_plot(vis, images[i], layer, ncols=4, nrows=4)

# Class Activation Mapping

In [None]:
plt.rcParams['figure.figsize'] = 24,12

def get_images(outs):
    images = [x[0] for x in outs]
    return images

In [None]:
vis = ClassActivationMapping(model, device)

classes = [0, 1, 2, 3, 4]
outs = [vis(images[2].to(device),
        None,
        postprocessing=image_net_postprocessing,
        target_class=c,
        guide=True) for c in classes]

processed_images = get_images(outs)

subplot(processed_images,
        rows_titles=['france', 'greece', 'portugal', 'spain', 'switzerland'],
        nrows=1,
        ncols=5,
        parse=tensor2img)

# Grad Cam

In [None]:
vis = GradCam(model, device)

classes = [0, 1, 2, 3, 4]
outs = [
        vis(images[2].to(device),
        # vis(min_max_normalization(images[2]).to(device), #worse results, no warning
        None,
        postprocessing=image_net_postprocessing,
        target_class=c,
        guide=True) for c in classes]

processed_images = get_images(outs)

subplot(processed_images,
        rows_titles=['france', 'greece', 'portugal', 'spain', 'switzerland'],
        nrows=1,
        ncols=5,
        parse=tensor2img)

# Interesting Regions

In [None]:
def gradcam2crop(cam, original_img, TR):
    b, c, w, h = original_img.shape
    cam = cam.numpy()
    cam -= np.min(cam)
    cam /= np.max(cam)

    cam = cv2.resize(cam, (w,h))
    mask = cam > TR

    original_img = tensor2img(image_net_postprocessing(original_img[0].squeeze()))

    crop = original_img.copy()
    crop[mask == 0] = 0
    return crop

In [None]:
TR =  0.4
vis = GradCam(model, device)

_ = vis(images[2],
        None,
        postprocessing=image_net_postprocessing)

crop = gradcam2crop(vis.cam.cpu(), images[2].cpu(), TR)
plt.imshow(crop)

# Torch Prism
Author: Tomasz Szandała\
[Github](https://github.com/szandala/TorchPRISM)

In [None]:
from torchprism import PRISM

input_batch = torch.stack([img.squeeze(0) for img in images])
with torch.no_grad():
    PRISM.prune_old_hooks(None)
    PRISM.register_hooks(model)
    output = model(input_batch)
    output = nn.Softmax(dim=-1)(output)
    percentages, output = torch.max(output, 1)
    prism_maps = PRISM.get_maps().permute(0, 2, 3, 1).detach().cpu().numpy()

    columns = input_batch.shape[0]
    fig, ax = plt.subplots(nrows=2, ncols=columns)
    input_batch = input_batch.permute(0, 2, 3, 1).detach().cpu().numpy()

    for column in range(columns):
        class_name = class_names[output[column]]
        percentage = percentages[column]
        ax[0][column].imshow(min_max_normalization(input_batch[column]))
        ax[0][column].set_title(f'{class_name}\n{percentage.item()*100:.2f}%', fontsize=22)
        ax[0][column].axis('off')

    for column in range(columns):
        ax[1][column].imshow(prism_maps[column])
        ax[1][column].axis('off')

    fig.suptitle(f'PRISM\n', fontsize=30)
    fig.tight_layout()
    plt.show()