# Understanding Pytorch Hooks

> ref: https://www.kaggle.com/sironghuang/understanding-pytorch-hooks/notebook

In [1]:
import numpy as np

import torch
import torch.nn as nn

### 1. Toy example to understand Pytorch hooks

 ![](fig.png)

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 2)  # y = xA^T + b 
        self.s1 = nn.Sigmoid()
        self.fc2 = nn.Linear(2, 2)
        self.s2 = nn.Sigmoid()
        self.fc1.weight = torch.nn.Parameter(torch.Tensor([[0.15, 0.25],
                                                           [0.20, 0.30]]))
        self.fc1.bias = torch.nn.Parameter(torch.Tensor([0.35]))
        self.fc2.weight = torch.nn.Parameter(torch.Tensor([[0.40, 0.50],
                                                           [0.45, 0.55]]))
        self.fc2.bias = torch.nn.Parameter(torch.Tensor([0.6]))
        
    def forward(self, x):
        x= self.fc1(x)
        x = self.s1(x)
        x= self.fc2(x)
        x = self.s2(x)
        return x

net = Net()
print(net)

Net(
  (fc1): Linear(in_features=2, out_features=2, bias=True)
  (s1): Sigmoid()
  (fc2): Linear(in_features=2, out_features=2, bias=True)
  (s2): Sigmoid()
)


In [3]:
# parameters: weight and bias
weight1 = list(net.parameters())[0]
weight2 = list(net.parameters())[2]

print(list(net.parameters()))

[Parameter containing:
tensor([[0.1500, 0.2500],
        [0.2000, 0.3000]], requires_grad=True), Parameter containing:
tensor([0.3500], requires_grad=True), Parameter containing:
tensor([[0.4000, 0.5000],
        [0.4500, 0.5500]], requires_grad=True), Parameter containing:
tensor([0.6000], requires_grad=True)]


In [4]:
# input data
data = torch.Tensor([0.05, 0.1])
data = torch.unsqueeze(data, dim=0)
data.required_grad = True

target = torch.Tensor([0.01, 0.99])  # a dummy target, for example
target = torch.unsqueeze(target, dim=0)

print('data  :', data)
print('target:', target)

data  : tensor([[0.0500, 0.1000]])
target: tensor([[0.0100, 0.9900]])


In [5]:
out = net(data)
print('out :', out)

out : tensor([[0.7569, 0.7677]], grad_fn=<SigmoidBackward>)


In [6]:
class Hook():
    def __init__(self, module, backward=False):
        self.module = module
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fw)
        else:
            self.hook = module.register_backward_hook(self.hook_bw)
            
    def hook_fw(self, module, input, output):
        self.input = input
        self.output = output
        
    def hook_bw(self, module, grad_in, grad_out):
        self.input = grad_in
        self.output = grad_out
        
    def close(self):
        self.hook.remove()

In [7]:
list(net._modules.items())

[('fc1', Linear(in_features=2, out_features=2, bias=True)),
 ('s1', Sigmoid()),
 ('fc2', Linear(in_features=2, out_features=2, bias=True)),
 ('s2', Sigmoid())]

In [8]:
# register hooks on each layer
# layer[0] is the name, and layer[1] is the instance
hookF = [Hook(layer[1]) for layer in list(net._modules.items())]
hookB = [Hook(layer[1],backward=True) for layer in list(net._modules.items())]

out=net(data)
# backprop once to get the backward hook results
out.backward(torch.tensor([[1, 1]],dtype=torch.float), retain_graph=True)
#! loss.backward(retain_graph=True)  # doesn't work with backward hooks, 
#! since it's not a network layer but an aggregated result from the outputs of last layer vs target 

print('***'*3+'  Forward Hooks Inputs & Outputs  '+'***'*3)
for i, hook in enumerate(hookF):
    print(hook.module)
    print('layer {}, input : {}'.format(i+1, hook.input))
    print('layer {}, output: {}'.format(i+1, hook.output))
    print('---'*17)
print('\n')
print('***'*3+'  Backward Hooks Inputs & Outputs  '+'***'*3)
for i, hook in enumerate(hookB):
    print(hook.module)
    print('layer {}, input : {}'.format(i+1, hook.input))
    print('layer {}, output: {}'.format(i+1, hook.output))         
    print('---'*17)
print('\n')
print('***'*3+'      Gradients of parameters      '+'***'*3)
for name, p in net.named_parameters():
    print(name)
    print(p.grad)

*********  Forward Hooks Inputs & Outputs  *********
Linear(in_features=2, out_features=2, bias=True)
layer 1, input : (tensor([[0.0500, 0.1000]]),)
layer 1, output: tensor([[0.3825, 0.3900]], grad_fn=<ThAddmmBackward>)
---------------------------------------------------
Sigmoid()
layer 2, input : (tensor([[0.3825, 0.3900]], grad_fn=<ThAddmmBackward>),)
layer 2, output: tensor([[0.5945, 0.5963]], grad_fn=<SigmoidBackward>)
---------------------------------------------------
Linear(in_features=2, out_features=2, bias=True)
layer 3, input : (tensor([[0.5945, 0.5963]], grad_fn=<SigmoidBackward>),)
layer 3, output: tensor([[1.1359, 1.1955]], grad_fn=<ThAddmmBackward>)
---------------------------------------------------
Sigmoid()
layer 4, input : (tensor([[1.1359, 1.1955]], grad_fn=<ThAddmmBackward>),)
layer 4, output: tensor([[0.7569, 0.7677]], grad_fn=<SigmoidBackward>)
---------------------------------------------------


*********  Backward Hooks Inputs & Outputs  *********
Linear(in_fe

In [9]:
net.zero_grad()

### What is the input and output of forward and backward pass?

##### Things to notice:

1. Because backward pass runs from back to the start, it's ***parameter order*** should be reversed compared to the forward pass. Therefore, to be it clearer, I'll use a different naming convention below.
2. For forward pass, ***previous layer*** of layer 2 is layer1; for backward pass, previous layer of layer 2 is layer 3.
3. ***Model output*** is the output of last layer in forward pass.

##### `layer.register_backward_hook(module, input, output)`

* `input`: previous layer's output
* `output`: current layer's output

##### `layer.register_backward_hook(module, grad_out, grad_in)`

* `grad_in`: gradient of model output wrt. layer output       # from forward pass 
    * = a tensor that represent the error of each neuron in this layer (= gradient of model output wrt. layer output = how much it should be improved)
    * For the last layer: eg. [1, 1] <=> gradient of model output wrt. itself, which means calculate all gradients as normal
    * It can also be considered as a weight map: eg. [1, 0] turn off the second gradient; [2, 1] put double weight on first gradient etc.
* `grad_out`: `grad_in` * (gradient of layer output wrt. layer input)
    * = next layer's error(due to chain rule)
    
Check the print from the cell above to confirm and enhance your understanding!

In [53]:
# the 4th layer - sigmoid
forward_output = hookF[-1].output

grad_out = torch.tensor([[1., 1.]])  # sigmoid layer
grad_in = grad_out * (forward_output * (1 - forward_output))

print('grad_in :', grad_in)
print('grad_out:', grad_out)

grad_in : tensor([[0.1840, 0.1783]], grad_fn=<ThMulBackward>)
grad_out: tensor([[1., 1.]])


In [54]:
# the 3th layer - linear
grad_out = grad_in
grad_in = grad_out @ weight2

print('grad_in  :', grad_in)
print('grad_out :', grad_out)

input = hookF[-2].input[0]

# x = [[x_1, x_2]]
# grad_out = [[grad_o1, grad_o2]]
# grad_(A^T)_ij = grad_o_j * d o_j / d (A^T)_ij = grad_oj * x_i => grad_(A^T) = x^T @ grad_out
# grad_w = grad_A = grad_out^T @ x
grad_w2 = torch.transpose(grad_out, 0, 1) @ input
grad_b2 = torch.sum(grad_out)

print('grad_w2  :', grad_w2)
print('grad_b2  :', grad_b2)

grad_in  : tensor([[0.1538, 0.1901]], grad_fn=<MmBackward>)
grad_out : tensor([[0.1840, 0.1783]], grad_fn=<ThMulBackward>)
grad_w2  : tensor([[0.1094, 0.1097],
        [0.1060, 0.1063]], grad_fn=<MmBackward>)
grad_b2  : tensor(0.3623, grad_fn=<SumBackward0>)


In [55]:
# the 2nd layer - sigmoid
forward_output = hookF[-3].output

grad_out = grad_in
grad_in = grad_out * (forward_output * (1 - forward_output))

print('grad_in :', grad_in)
print('grad_out:', grad_out)

grad_in : tensor([[0.0371, 0.0458]], grad_fn=<ThMulBackward>)
grad_out: tensor([[0.1538, 0.1901]], grad_fn=<MmBackward>)


In [56]:
# the 1st layer - linear
grad_out = grad_in
grad_in = grad_out @ weight1

print('grad_in  :', grad_in)
print('grad_out :', grad_out)

input = hookF[0].input[0]

grad_w2 = torch.transpose(grad_out, 0, 1) @ input
grad_b2 = torch.sum(grad_out)

print('grad_w2  :', grad_w2)
print('grad_b2  :', grad_b2)

grad_in  : tensor([[0.0147, 0.0230]], grad_fn=<MmBackward>)
grad_out : tensor([[0.0371, 0.0458]], grad_fn=<ThMulBackward>)
tensor([[0.0500, 0.1000]])
grad_w2  : tensor([[0.0019, 0.0037],
        [0.0023, 0.0046]], grad_fn=<MmBackward>)
grad_b2  : tensor(0.0828, grad_fn=<SumBackward0>)


### Modify gradients with hooks

Hook function doesn't change gradients by default

But if ***return*** is called, the returned value will be the gradient output

In [9]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1)
        self.relu2 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x

net = ConvNet()
print(net)

ConvNet(
  (conv1): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
  (relu1): ReLU(inplace)
  (conv2): Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1))
  (relu2): ReLU(inplace)
)


In [10]:
data = torch.Tensor([[[[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]]]])
data.required_grad = True
out = net(data)

print(data)
print(data.size())
print(out)
print(out.size())

tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]])
torch.Size([1, 1, 3, 3])
tensor([[[[0.6557]]]], grad_fn=<ThresholdBackward1>)
torch.Size([1, 1, 1, 1])


In [11]:
hookF = [Hook(layer[1]) for layer in list(net._modules.items())]
hookB = [Hook(layer[1],backward=True) for layer in list(net._modules.items())]

out = net(data)
out.backward(torch.tensor([[[[1]]]], dtype=torch.float))

print('***'*3+'  Forward Hooks Inputs & Outputs  '+'***'*3)
for i, hook in enumerate(hookF):
    print(hook.module)
    print('layer {}, input : {}'.format(i+1, hook.input))
    print('layer {}, output: {}'.format(i+1, hook.output))
    print('---'*17)
print('\n')
print('***'*3+'  Backward Hooks Inputs & Outputs  '+'***'*3)
for i, hook in enumerate(hookB):
    print(hook.module)
    print('layer {}, input : {}'.format(i+1, hook.input))
    print('layer {}, output: {}'.format(i+1, hook.output))         
    print('---'*17)
print('\n')
print('***'*3+'      Gradients of parameters      '+'***'*3)
for name, p in net.named_parameters():
    print(name)
    print(p.grad)

*********  Forward Hooks Inputs & Outputs  *********
Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
layer 1, input : (tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]]),)
layer 1, output: tensor([[[[0.]],

         [[0.]]]], grad_fn=<ThresholdBackward1>)
---------------------------------------------------
ReLU(inplace)
layer 2, input : (tensor([[[[0.]],

         [[0.]]]], grad_fn=<ThresholdBackward1>),)
layer 2, output: tensor([[[[0.]],

         [[0.]]]], grad_fn=<ThresholdBackward1>)
---------------------------------------------------
Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1))
layer 3, input : (tensor([[[[0.]],

         [[0.]]]], grad_fn=<ThresholdBackward1>),)
layer 3, output: tensor([[[[0.6557]]]], grad_fn=<ThresholdBackward1>)
---------------------------------------------------
ReLU(inplace)
layer 4, input : (tensor([[[[0.6557]]]], grad_fn=<ThresholdBackward1>),)
layer 4, output: tensor([[[[0.6557]]]], grad_fn=<ThresholdBackward1>)
-----------------

In [73]:
list(net._modules.items())[0][1]

Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))

In [75]:
hookB[0].input[0]

In [None]:
def get_first_layer(model):
    first_layer = list(model.children())[0]
    
    while True:
        if isinstance(first_layer, nn.Conv2d):
            return first_layer
        else:
            if not first_layer:
                raise ValueError('The first layer is not an `nn.Conv2d` object.')
            first_layer = list(first_layer.children())[0]

In [None]:
class GuidedBackpropagation:
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.hook_handlers = []
        self.relu_forward_outputs = []
        
        self.model.eval()
        self.register_hooks()
        
    def register_hooks(self):
        def first_layer_hook_fn(module, grad_in, grad_out):
            self.gradient = grad_in[0]
            
        def relu_forward_hook_fn(module, ten_in, ten_out):
            self.relu_forward_outputs.append(ten_out)
            
        def relu_backward_hook_fn(module, grad_in, grad_out):
            assert len(grad_in) == 1
            
            features_map = self.relu_forward_outputs[-1]
            features_map[features_map > 0] = 1  # in place  # it works
            
            grad_in = torch.clamp(input=grad_in[0], min=0.0)  # not in place  # it works
            grad_in = grad_in * features_map
            
            del self.relu_forward_outputs[-1]
            
            return (grad_in,)
        
        first_layer = get_first_layer(self.model)
        handler = first_layer.register_backward_hook(first_layer_hook_fn)
        self.hook_handlers.append(handler)
        
        # The following code can only work in a sequential model
        for module in self.model.children():
            if isinstance(module, nn.ReLU):
                handler = module.register_forward_hook(relu_forward_hook_fn)
                self.hook_handlers.append(handler)
                handler = module.register_backward_hook(relu_backward_hook_fn)
                self.hook_handlers.append(handler)
            else:
                
        
    def remove(self):
        for handler in self.hook_handlers:
            handler.remove()
        self.hook_handlers = []
    
    def generate_guided_gradient(self, img, class_idx):
        assert isinstance(img, torch.Tensor)
        assert img.requires_grad
        
        output = self.model(img)
        self.model.zero_grad()
        
        onehot_target = torch.zeros(output.size(), dtype=torch.float)
        onehot_target[0][class_idx] = 1
        output.backward(gradient=onehot_target)
        
        guided_gradient = self.gradient.data.numpy()[0]
        
        self.remove()
        
        return guided_gradient