In [None]:
import os
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models, transforms

In [None]:
# visualization from:
# https://github.com/FrancescoSaverioZuppichini/A-journey-into-Convolutional-Neural-Network-visualization-
from cnn_visualizations.utils import *
from cnn_visualizations.visualisation.core import *

In [None]:
def get_abs_path(n_parent: int = 0):
    return Path('../' * n_parent).resolve()

In [None]:
path = get_abs_path(1)
model_path = path / 'models' / 'deep_geo_guessr.pt'
images_path = path / 'visualization_examples'
images_paths = images_path.glob('**/*.png')
images_paths = list(images_paths)
images_paths = [str(path) for path in images_paths]

device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {0} device'.format(device))

In [None]:
data_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

images = []
for image_path in images_paths:
    image = Image.open(image_path)
    image = data_transform(image).to(device)
    image = image.unsqueeze(0)
    images.append(image)

In [None]:
class CountryClassificator(nn.Module):

    def __init__(self, num_classes):
        super(CountryClassificator, self).__init__()

        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Linear(self.model.fc.in_features, num_classes),
            nn.Softmax()
        )

    def forward(self, x):
        x = self.model(x)
        return x


model = CountryClassificator(5)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)
print('Model loaded')

In [None]:
model_traced = module2traced(model, images[0])
plt.rcParams['figure.figsize'] = 14,14
vis = Weights(model, device)
layer = model_traced[1]
run_vis_plot(vis, images[1], layer, ncols=4, nrows=4)