# What is hook in Pytorch ? - Variable


## Variable


- torch.autograd.Variable.register_hook

http://pytorch.org/docs/0.3.1/autograd.html?highlight=register_hook#torch.autograd.Variable.register_hook

### 1) Import Required Libraries

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

## 2) Define Class

- Define Model
- Initialize Weights
- Define Hook Funciton
- Define Hook Remove Function
- Register Hook at Desired Location when Forwarding

In [2]:
class Linear(nn.Module):
    def __init__(self,feature_list):
        super(Linear,self).__init__()
        self.feature_list = feature_list
        self.layers = []
        
        # Define Layers
        for i in range(len(feature_list)-1):
            self.layers.append(nn.Linear(self.feature_list[i],self.feature_list[i+1]))
        self.total = nn.ModuleList(self.layers)
        
        # Initialize Weights to 1 in order to check gradients easily.
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.fill_(1)
                m.bias.data.fill_(0)
    
    # define hook 
    # if you uncomment 3rd line, gradients will be doubled.
    def hook(self,grad):
        print(self,grad)
        #return 2*grad
    
    # Once a hook is registered, it stays until explicitly removed.
    def remove_hook(self):
        self.hook.remove()
    
    # register_hook is applied to Variable
    # so it should be registered on a Variable, not nn.Module
    def forward(self,x,hook_layer):
        out = x
        for idx,layer in enumerate(self.total):
            out = layer(out)
            if idx == hook_layer:
                self.hook = out.register_hook(self.hook)
        return out

## 3) Create Instance & Check

In [3]:
feature_list = [1,2,4,8]
model = Linear(feature_list)
print(model)

Linear(
  (total): ModuleList(
    (0): Linear(in_features=1, out_features=2, bias=True)
    (1): Linear(in_features=2, out_features=4, bias=True)
    (2): Linear(in_features=4, out_features=8, bias=True)
  )
)


## 4) Forward a Variable & Register hook

In [4]:
x = Variable(torch.ones(1),requires_grad=True)
out = model(x,hook_layer=2)
out = torch.sum(out)

## 5) Hook called with Backward

In [5]:
out.backward()

# hook stays until explicitly removed.
model.remove_hook()

Linear(
  (total): ModuleList(
    (0): Linear(in_features=1, out_features=2, bias=True)
    (1): Linear(in_features=2, out_features=4, bias=True)
    (2): Linear(in_features=4, out_features=8, bias=True)
  )
) Variable containing:
 1
 1
 1
 1
 1
 1
 1
 1
[torch.FloatTensor of size 8]



## 6) Check Gradient of Leaf Variable

In [6]:
# if hook function returns modified gradient, this will also affected.
# 64 without returning gradient, 128 with returning 2*grad
x.grad

Variable containing:
 64
[torch.FloatTensor of size 1]