## 05-03 Hook Function

In [14]:
from pathlib import Path

import numpy as np 
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.models as models
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

### 1.Hook Method

Tensor Register Hook

In [3]:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    a_grad.append(grad)

handle = a.register_hook(grad_hook)

y.backward()

print('gradient:', w.grad, x.grad, a.grad, b.grad, y.grad)
print('a_grad[0]:', a_grad[0])
handle.remove()

gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])


In [7]:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

def grad_hook(grad):
    grad *= 2
    return grad * 3

handle = w.register_hook(grad_hook)

y.backward()

print('w.grad:', w.grad)
handle.remove()

w.grad: tensor([30.])


Module Register Hook

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)

def forward_pre_hook(module, data_input):
    print('forward_pre_hook input:{}'.format(data_input))

def backward_hook(module, grad_input, grad_output):
    print('backward hook input: {}'.format(grad_input))
    print('backward hook output: {}'.format(grad_output))

net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

fmap_block, input_block = [], []
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)

fake_img = torch.ones(1, 1, 4, 4)
output = net(fake_img)

loss_func = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_func(target, output)
loss.backward()

print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
backward hook input: (None, tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output: (tensor([[[[0.5000, 0.0000],
          [0.0000, 0.0000]],

         [[0.5000, 0.0000],
          [0.0000, 0.0000]]]]),)
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],

         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)

feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9.,  9.],
          [ 9.,  9.]],

         [[18., 18.],
          [18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)

input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
        

### 2.FMAP Visualization

In [15]:
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.49139968, 0.48215827, 0.44653124],
        [0.24703233, 0.24348505, 0.26158768]
    )
])

image_path = Path('./image/lena.png')
img_pil = Image.open(image_path).convert('RGB')
img_tensor = image_transforms(img_pil)
img_tensor.unsqueeze_(0)

alexnet = models.alexnet(pretrained=True)

fmap_dict = {}
for name, sub_module in alexnet.named_modules():
    if isinstance(sub_module, nn.Conv2d):
        key_name = str(sub_module.weight.shape)
        fmap_dict.setdefault(key_name, list())

        n1, n2 = name.split('.')

        def hook_func(m, i, o):
            key_name = str(m.weight.shape)
            fmap_dict[key_name].append(o)

        alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

output = alexnet(img_tensor)

writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
for layer_name, fmap_list in fmap_dict.items():
    fmap = fmap_list[0]
    fmap.transpose_(0, 1)

    nrow = int(np.sqrt(fmap.shape[0]))
    fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
    writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)