In [None]:
import fastai
from fastai import *
from fastai.vision.core import *
from fastai.vision.data import *
from fastai.vision.learner import *
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
fastbook.setup_book()

In [None]:
path = untar_data(URLs.PETS)/'images'

In [None]:
def is_cat(x): return x[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, get_image_files(path), valid_pct=0.2, seed=21,
                                     label_func=is_cat, item_tfms=Resize(224))

In [None]:
learn = vision_learner(dls, resnet34, metrics=error_rate)

In [None]:
learn.fine_tune(1)

In [None]:
img = PILImage.create(image_cat())
x, = first(dls.test_dl([img]))

In [None]:
class Hook:
   def __init__(self, module):
       self.hook = module.register_forward_hook(self.hook_fn)
   
   def hook_fn(self, module, input, output):
       self.output = output.detach().clone()
   
   def __enter__(self):
       return self
   
   def __exit__(self, *args):
       self.hook.remove()

In [None]:
hook = Hook(learn.model[0])
with torch.no_grad():
   output = learn.model.eval()(x)

In [None]:
act = hook.output[0]
hook.__exit__()

In [None]:
print(F.softmax(output, dim=-1))
print(dls.vocab)

In [None]:
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)
img_dec = dls.train.decode((x,))[0][0]
_, ax = plt.subplots()
img_dec.show(ctx=ax)
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),
         interpolation='bilinear', cmap='magma')
plt.show()

In [None]:
class GradHook:
   def __init__(self, module):
       self.hook = module.register_backward_hook(self.hook_fn)
   
   def hook_fn(self, module, grad_input, grad_output):
       self.grad_output = grad_output[0].detach().clone()
   
   def __enter__(self):
       return self
   
   def __exit__(self, *args):
       self.hook.remove()

In [None]:
cls = 1
with GradHook(learn.model[0]) as grad_hook:
   with Hook(learn.model[0]) as fwd_hook:
       output = learn.model.eval()(x.cuda())
       act = fwd_hook.output
   output[0, cls].backward()
   grad = grad_hook.grad_output

In [None]:
w = grad.mean(dim=[1, 2], keepdim=True)
cam_map = (w * act).sum(0)

In [None]:
_, ax = plt.subplots()
img_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
         interpolation='bilinear', cmap='magma')
plt.show()

In [None]:
with GradHook(learn.model[0][-2]) as grad_hook:
   with Hook(learn.model[0][-2]) as fwd_hook:
       output = learn.model.eval()(x.cuda())
       act = fwd_hook.output
   output[0, cls].backward()
   grad = grad_hook.grad_output

In [None]:
w = grad.mean(dim=[1, 2], keepdim=True)
cam_map = (w * act).sum(0)

In [None]:
_, ax = plt.subplots()
img_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
         interpolation='bilinear', cmap='magma')
plt.show()
