# What is hook in Pytorch ? - Module


![alt text](./asset/hook.jpeg)


### Module

- nn.Module.register_forward_hook 
- nn.Module.register_backward_hook         
- nn.Module.register_forward_pre_hook 

http://pytorch.org/docs/0.3.1/nn.html?highlight=register#torch.nn.Module.register_forward_hook

### 1) Import Required Libraries

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

batch_size =3

## 2) Define Class

- Define Model
- Set Weights to 1 & Bias to 0
- Register Hooks
- Define Hooks
- Remove Hooks
- Define Forward

In [2]:
class Linear(nn.Module):
    def __init__(self,feature_list):
        super(Linear,self).__init__()
        self.feature_list = feature_list
        self.layers = []
        
        # Define Layers
        for i in range(len(feature_list)-1):
            self.layers.append(nn.Linear(self.feature_list[i],self.feature_list[i+1]))
        self.total = nn.ModuleList(self.layers)
             
        # Initialize Weights to 1 in order to check gradients easily.
        for idx,m in enumerate(self.total):
            if isinstance(m, nn.Linear):
                m.weight.data.fill_(1)
                m.bias.data.fill_(0)
                
            if idx==1:
                self.h0 = m.register_forward_hook(self.forward_hook)
                self.h1 = m.register_forward_pre_hook(self.forward_pre_hook)
                self.h2 = m.register_backward_hook(self.backward_hook)
                
                
    # hook(module, input, output) -> None            
    def forward_hook(self,*args):
        module,input,output = args[0],args[1],args[2]
        print("\n This is Forward Hook \n")
        # This part is weird 
        for i in args:
            print(type(i))
            
        
    # hook(module, grad_input, grad_output)
    def backward_hook(self,*args):
        module, grad_input, grad_output = args[0],args[1],args[2]
        print("\n This is Backward Hook \n")
        for i in args:
            print(type(i))
        
    
    # hook(module, input)
    def forward_pre_hook(self,*args):
        module, input = args[0],args[1]
        print("\n This is Forward Pre Hook \n")
        for i in args:
            print(type(i))
        
    
    def remove_hook(self):
        self.h0.remove()
        self.h1.remove()
        self.h2.remove()
    
    def forward(self,x):
        out = x
        for idx,layer in enumerate(self.total):
            out = layer(out)
        return out

## 3) Create Instance & Check Model

In [3]:
feature_list = [1,2,4,8]
model = Linear(feature_list)
print(model)

Linear(
  (total): ModuleList(
    (0): Linear(in_features=1, out_features=2, bias=True)
    (1): Linear(in_features=2, out_features=4, bias=True)
    (2): Linear(in_features=4, out_features=8, bias=True)
  )
)


## 4) Forward

- Forward_Pre_Hook is called Before Forward
- Farward_Hook is called After Forward

In [4]:
x = Variable(torch.ones(batch_size,1),requires_grad=True)
out = model(x)
out = torch.sum(out)


 This is Forward Pre Hook 

<class 'torch.nn.modules.linear.Linear'>
<class 'tuple'>

 This is Forward Hook 

<class 'torch.nn.modules.linear.Linear'>
<class 'tuple'>
<class 'torch.autograd.variable.Variable'>


## 5) Backward

- Backward Hook is called after calculating the gradients

In [5]:
out.backward()


 This is Backward Hook 

<class 'torch.nn.modules.linear.Linear'>
<class 'tuple'>
<class 'tuple'>


## 6) Remove Hook & Check

In [6]:
model.remove_hook()

x = Variable(torch.ones(batch_size,1),requires_grad=True)
out = model(x)
out = torch.sum(out)
out.backward()