In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>")) # makes the notebook fill the whole window

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import saliency.core as saliency

# Local import
from thumbnail_classification import Classifier

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
img_size = 32, 32

transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize([0.5], [0.5]),
        transforms.Resize(img_size),
    ])

In [4]:
dataset = datasets.ImageFolder("./imgs/", transform = transform)

model = Classifier(device = torch.device("cpu"), image_size = (3, *img_size), n_classes = len(dataset.classes))
model.load_state_dict(torch.load("./model.pt"))

class_dict = {c:idx for idx, c in enumerate(dataset.classes)}
imgs = torch.stack([x for x,_ in dataset])

In [5]:
def call_model_function(images, class_idx = None, expected_keys = None):
    output = model(images)
    output = F.softmax(output, dim = 1)
    if saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys:
        outputs = output[:, class_idx]
        grads = torch.autograd.grad(outputs, images, grad_outputs = torch.ones_like(output))
        grads = torch.movedim(grads[0], 1, 3)
        gradients = grads.detach().numpy()
        return gradients
    else:
        one_hot = torch.zeros_like(output)
        one_hot[:, class_idx] = 1
        model.zero_grad()
        output.backward(gradient = one_hot, retain_graph = True)
        return model.grads