# Pytorch forward function hooks
Basic usage examples of `PyTorch` forward hooks.

# References
- [Forward and backward function hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) (pytorch documentation)
- [PyTorch's `torch.nn.Module.forward()`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.forward) (pytorch documentation)
- [Why is the input of hook function a tuple?](https://discuss.pytorch.org/t/why-is-the-input-of-hook-function-a-tuple/54229) (MahdiNazemi)
- [In neural network literature, which one is activation?](https://stats.stackexchange.com/questions/272035/in-neural-network-literature-which-one-is-activation) (stats.stackexchange)
- [What does the activation of a neuron mean?](https://datascience.stackexchange.com/questions/11059/what-does-the-activation-of-a-neuron-mean) (datascience.stackexchange)

---
tags: pytorch, tutorial, forward hooks

# Imports

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
%matplotlib inline

# Print input and ouput information of a linear module

In [3]:
class SingleLinear(nn.Module):
    def __init__(self):
        super(SingleLinear, self).__init__()
        self.linear = nn.Linear(1, 2)
           
    def forward(self, x):
        x = self.linear(x)
        return x
    
    
def print_sizes(self, input, output):
    print(f"Input type: {type(input)}")
    print(f"Input len: {len(input)}")   
    print(f"Input[0] type: {type(input[0])}")
    print(f"Input[0] shape: {input[0].shape}")
    print(f"Output type: {type(output)}")
    print(f"Output shape: {output.shape}")
    
model = SingleLinear()
model.linear.register_forward_hook(print_sizes)

x = torch.linspace(0, 1, 10).view(-1, 1)
# x = torch.randn(1)
y = model(x)


Input type: <class 'tuple'>
Input len: 1
Input[0] type: <class 'torch.Tensor'>
Input[0] shape: torch.Size([10, 1])
Output type: <class 'torch.Tensor'>
Output shape: torch.Size([10, 2])


# Print input and output dimensions of linear submodules

In [4]:
class ListOfLinears(nn.Module):
    def __init__(self, depth=1):
        super(ListOfLinears, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(i, i+1) for i in range(1, depth+1)])
           
    def forward(self, x):
        for linear in self.linears:
            x = linear(x)
        return x
    
def print_sizes(self, input, output):
    print(f"Input  dimension  : {input[0].shape[1]}")
    print(f"Output dimension  : {output.shape[1]}")
    
depth = 3
model = ListOfLinears(depth)

for linear in model.linears:
    linear.register_forward_hook(print_sizes)

x = torch.linspace(0, 1, 10).view(-1, 1)
# x = torch.randn(1,1)
y = model(x)

Input  dimension  : 1
Output dimension  : 2
Input  dimension  : 2
Output dimension  : 3
Input  dimension  : 3
Output dimension  : 4


# Record weights and biases of linear submodules

In [10]:
class ListOfLinears(nn.Module):
    def __init__(self, depth=1):
        super(ListOfLinears, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(1, 1) for _ in range(1, depth+1)])
           
    def forward(self, x):
        for linear in self.linears:
            x = linear(x)
        return x


weights = {}
biases = {}

def get_weight(idx):
    def hook(linear, input, output):
        weights[f"(Layer {idx}) Weight:"] = linear.weight.data
    return hook

def get_bias(idx):
    def hook(linear, input, output):
        biases[f"(Layer {idx}) Bias:"] = linear.bias.data
    return hook



model = ListOfLinears(4)

for idx, linear in enumerate(model.linears):
    linear.register_forward_hook(get_weight(idx))
    linear.register_forward_hook(get_bias(idx))

x = torch.randn(1, 1)
y = model(x)

weights, biases

({'(Layer 0) Weight:': tensor([[0.9237]]),
  '(Layer 1) Weight:': tensor([[0.7996]]),
  '(Layer 2) Weight:': tensor([[-0.7495]]),
  '(Layer 3) Weight:': tensor([[-0.8520]])},
 {'(Layer 0) Bias:': tensor([0.2656]),
  '(Layer 1) Bias:': tensor([-0.1786]),
  '(Layer 2) Bias:': tensor([0.5813]),
  '(Layer 3) Bias:': tensor([-0.8994])})