# **Pytorch Hook**

Hook provide you to do things during backpropagation.

You can register a hook on a Tensor or a nn.Module. 

You can register a hook on a Tensor or a nn.Module. A hook is basically a function that is executed when the either forward or backward is called.


When I say forward, I don't mean the forward of a nn.Module . forward function here means the forward  function of the **torch.Autograd.Function** object that is the grad_fn of a Tensor.

1.   Forward Hook
2.   Backward Hook

### Reference
- [computation graph in pytorch](https://blog.paperspace.com/pytorch-101-understanding-graphs-and-automatic-differentiation/)

- [Pytorch-Hook](https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/)




## Hooks for Tensor

There is no forward hook for a tensor.

```python
# backward hook for tensor
hook(grad) -> Tensor or None
```

grad is basically the value contained in the grad attribute of the tensor after backward is called


In [0]:
import torch 
a = torch.ones(5)
a.requires_grad = True

In [2]:
# Tensor user make has no grad_fn.
print(a)

tensor([1., 1., 1., 1., 1.], requires_grad=True)


In [4]:
b = 2*a
print(b)

tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)


In [5]:
b.retain_grad()  # Since b is non-leaf and it's grad will be destroyed otherwise
print(b)

tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)


In [0]:
# https://discuss.pytorch.org/t/what-does-the-backward-function-do/9944
# backward computes dvalue/dx

c = b.mean()
c.backward() 

In [8]:
print(a.grad, b.grad)

tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]) tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


In [10]:
# Redo with hook.
a = torch.ones(5)

a.requires_grad = True

b = 2*a

b.retain_grad()

b.register_hook(lambda x: print(x))  

b.mean().backward() 
print(a.grad, b.grad)

tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]) None


In [11]:
# You can modify the gradient value.
a = torch.ones(5)

a.requires_grad = True
b = 2*a

b.retain_grad()


b.mean().backward() 


print(a.grad, b.grad)

b.grad *= 2

print(a.grad, b.grad)       # a's gradient needs to updated manually

tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]) tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]) tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])


## Hooks for nn.Module objects

- Forward Hook
```python
hook(module, grad_input, grad_output) -> Tensor or None
```

- Backward Hook
```python
hook(module, input, output) -> None
```

- reference: https://tutorials.pytorch.kr/beginner/former_torchies/nn_tutorial.html


In [36]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    return self.fc1(self.flatten(x))
  

net = myNet()

def hook_fn(m, i, o):
  print(m)
  print("------------Input Grad------------")

  for grad in i:
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")

  print("------------Output Grad------------")
  for grad in o:  
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")
net.conv.register_backward_hook(hook_fn)
net.fc1.register_backward_hook(hook_fn)
inp = torch.randn(1,3,8,8)
out = net(inp)

(1 - out.mean()).backward()

Linear(in_features=160, out_features=5, bias=True)
------------Input Grad------------
torch.Size([5])
torch.Size([5])
------------Output Grad------------
torch.Size([5])


Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
------------Input Grad------------
None found for Gradient
torch.Size([10, 3, 2, 2])
torch.Size([10])
------------Output Grad------------
torch.Size([1, 10, 4, 4])




In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MNISTConvNet(nn.Module):

    def __init__(self):
        # 여기에서모든 모듈을 초기화해놓고,
        # 나중에 여기에 선언한 이름으로 접근할 수 있습니다.
        super(MNISTConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    # 순전파 함수에서 신경망의 구조를 정의합니다.
    # 여기에서는 단 하나의 입력만 받지만, 필요하면 더 받도록 변경하면 됩니다.
    def forward(self, input):
        x = self.pool1(F.relu(self.conv1(input)))
        x = self.pool2(F.relu(self.conv2(x)))

        # 모델 구조를 정의할 때는 어떤 Python 코드를 사용해도 괜찮습니다.
        # 모든 코드는 autograd에 의해 올바르고 완벽하게 처리될 것입니다.
        # if x.gt(0) > x.numel() / 2:
        #      ...
        #
        # 심지어 동일한 모듈을 재사용하거나 반복(loop)해도 됩니다.
        # 모듈은 더 이상 일시적인 상태를 갖고 있지 않으므로,
        # 순전파 과정에서 여러번 사용해도 됩니다.
        # while x.norm(2) < 10:
        #    x = self.conv1(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x
net = MNISTConvNet()

In [44]:
def printnorm(self, input, output):
    # input is a tuple of packed inputs
    # output is a Tensor. output.data is the Tensor we are interested
    print('Inside ' + self.__class__.__name__ + ' forward')
    print('')
    print('input: ', type(input))
    print('input[0]: ', type(input[0]))
    print('output: ', type(output))
    print('')
    print('input size:', input[0].size())
    print('output size:', output.data.size())
    print('output norm:', output.data.norm())


net.conv2.register_forward_hook(printnorm)
inp = torch.randn(1, 1, 28, 28)
out = net(inp)

Inside Conv2d forward

input:  <class 'tuple'>
input[0]:  <class 'torch.Tensor'>
output:  <class 'torch.Tensor'>

input size: torch.Size([1, 10, 12, 12])
output size: torch.Size([1, 20, 8, 8])
output norm: tensor(16.1047)
Inside Conv2d forward

input:  <class 'tuple'>
input[0]:  <class 'torch.Tensor'>
output:  <class 'torch.Tensor'>

input size: torch.Size([1, 10, 12, 12])
output size: torch.Size([1, 20, 8, 8])
output norm: tensor(16.1047)


In [43]:
def printgradnorm(self, grad_input, grad_output):
    print('Inside ' + self.__class__.__name__ + ' backward')
    print('Inside class:' + self.__class__.__name__)
    print('')
    print('grad_input: ', type(grad_input))
    print('grad_input[0]: ', type(grad_input[0]))
    print('grad_output: ', type(grad_output))
    print('grad_output[0]: ', type(grad_output[0]))
    print('')
    print('grad_input size:', grad_input[0].size())
    print('grad_output size:', grad_output[0].size())
    print('grad_input norm:', grad_input[0].norm())


net.conv2.register_backward_hook(printgradnorm)
inp = torch.randn(1, 1, 28, 28)
out = net(inp)
(1 - out.mean()).backward()

Inside Conv2d forward

input:  <class 'tuple'>
input[0]:  <class 'torch.Tensor'>
output:  <class 'torch.Tensor'>

input size: torch.Size([1, 10, 12, 12])
output size: torch.Size([1, 20, 8, 8])
output norm: tensor(17.0415)
Inside Conv2d backward
Inside class:Conv2d

grad_input:  <class 'tuple'>
grad_input[0]:  <class 'torch.Tensor'>
grad_output:  <class 'tuple'>
grad_output[0]:  <class 'torch.Tensor'>

grad_input size: torch.Size([1, 10, 12, 12])
grad_output size: torch.Size([1, 20, 8, 8])
grad_input norm: tensor(0.0147)


## Proper Way of Using Hooks : An Opinion

In [15]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    x.register_hook(lambda grad : torch.clamp(grad, min = 0))     #No gradient shall be backpropagated 
                                                                  #conv outside less than 0
      
    # print whether there is any negative grad
    x.register_hook(lambda grad: print("Gradients less than zero:", bool((grad < 0).any())))  
    return self.fc1(self.flatten(x))
  

net = myNet()

for name, param in net.named_parameters():
  # if the param is from a linear and is a bias
  if "fc" in name and "bias" in name:
    param.register_hook(lambda grad: torch.zeros(grad.shape))


out = net(torch.randn(1,3,8,8)) 

(1 - out).mean().backward()

print("The biases are", net.fc1.bias.grad)     #bias grads are zero

Gradients less than zero: False
The biases are tensor([0., 0., 0., 0., 0.])


## The Forward Hook for Visualising Activations

In [17]:
import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
    self.seq = nn.Sequential(nn.Linear(5,3), nn.Linear(3,2))
    
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    x = self.fc1(self.flatten(x))
    x = self.seq(x)
  

net = myNet()
visualisation = {}

def hook_fn(m, i, o):
  visualisation[m] = o 

def get_all_layers(net):
  for name, layer in net._modules.items():
    #If it is a sequential, don't register a hook on it
    # but recursively register hook on all it's module children
    if isinstance(layer, nn.Sequential):
      get_all_layers(layer)
    else:
      # it's a non sequential. Register a hook
      layer.register_forward_hook(hook_fn)

get_all_layers(net)

  
out = net(torch.randn(1,3,8,8))

# Just to check whether we got all layers
visualisation.keys()      #output includes sequential layers

dict_keys([Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2)), ReLU(), Linear(in_features=160, out_features=5, bias=True), Linear(in_features=5, out_features=3, bias=True), Linear(in_features=3, out_features=2, bias=True)])

In [25]:
list(visualisation.keys())[0]

Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))

In [26]:
visualisation[list(visualisation.keys())[0]]

tensor([[[[-0.1303, -0.1141,  0.8475, -0.3684],
          [-1.0200,  0.2118,  0.9783,  0.5840],
          [ 0.3214,  0.6003, -0.2423,  0.3670],
          [ 0.7810, -1.3224,  0.5154, -0.5555]],

         [[-0.3507, -0.1943, -0.1344, -0.2603],
          [-0.3106, -0.0644,  1.1350, -0.4445],
          [ 0.5473,  0.9649,  0.1628,  0.2881],
          [ 0.1003, -1.2433, -0.9339,  0.3527]],

         [[ 0.6440,  0.7188, -0.6971,  0.4463],
          [ 0.5904,  0.8193, -0.2583, -0.7565],
          [-0.0084, -0.3424,  0.4223,  0.0085],
          [-0.2998,  0.0545, -0.9587,  0.7342]],

         [[-0.1454,  0.5621,  1.0001,  0.7003],
          [-0.8641,  0.2516, -0.1495,  0.5826],
          [-0.0562, -0.9388,  0.5149, -0.1571],
          [ 0.8980, -0.0932,  0.9554, -0.8234]],

         [[-0.5771, -0.2858, -0.1472, -0.6650],
          [-0.6099, -0.2352,  0.3240,  1.1149],
          [ 0.4402,  1.0028, -0.3964,  0.7477],
          [ 0.4653, -0.2095, -0.0076,  0.2770]],

         [[-0.2883,  1.0039,  

In [27]:
visualisation[list(visualisation.keys())[1]]

tensor([[[[0.0000, 0.0000, 0.8475, 0.0000],
          [0.0000, 0.2118, 0.9783, 0.5840],
          [0.3214, 0.6003, 0.0000, 0.3670],
          [0.7810, 0.0000, 0.5154, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 1.1350, 0.0000],
          [0.5473, 0.9649, 0.1628, 0.2881],
          [0.1003, 0.0000, 0.0000, 0.3527]],

         [[0.6440, 0.7188, 0.0000, 0.4463],
          [0.5904, 0.8193, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.4223, 0.0085],
          [0.0000, 0.0545, 0.0000, 0.7342]],

         [[0.0000, 0.5621, 1.0001, 0.7003],
          [0.0000, 0.2516, 0.0000, 0.5826],
          [0.0000, 0.0000, 0.5149, 0.0000],
          [0.8980, 0.0000, 0.9554, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.3240, 1.1149],
          [0.4402, 1.0028, 0.0000, 0.7477],
          [0.4653, 0.0000, 0.0000, 0.2770]],

         [[0.0000, 1.0039, 0.1484, 0.0000],
          [0.1415, 1.0619, 0.4815, 0.0000],
          [0.1933, 0.4