# Accessing intermediate states
Construct a multi-layer perceptron (MLP) and record the states of its hidden layers via `PyTorch` (forward) hooks.

For illustration purposes, all layers have width `1`.  The `depth` can be specfied.

The weights and biases are set to specific values for easy inspection.

# References
- [Forward and backward function hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) (pytorch documentation)
- [In neural network literature, which one is activation?](https://stats.stackexchange.com/questions/272035/in-neural-network-literature-which-one-is-activation) (stats.stackexchange)
- [Code snippet example](https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/6) (pytorch discuss)

---
tags: pytorch, tutorial, forward hooks, hidden layers

# Imports

In [1]:
import torch
import torch.nn as nn

# Record intermediate states

In [2]:
class MLP(nn.Module):
    def __init__(self, depth=1):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(1, 1) for _ in range(1, depth+1)])        
        self.activs = nn.ModuleList([nn.ReLU() for _ in range(1, depth+1)])
        self.out = nn.Linear(1, 1)
        
        for idx, linear in enumerate(self.linears):
            linear.weight.data = torch.ones_like(linear.weight.data)
            linear.bias.data = torch.ones_like(linear.bias.data)
        self.out.weight.data = torch.ones_like(self.out.weight.data)
        self.out.bias.data = .12345*torch.ones_like(self.out.bias.data)
           
    def forward(self, x):
        for linear, activ  in zip(self.linears, self.activs):
            x = activ(linear(x))
        x = self.out(x)
        return x


activations = {}

def get_activation(idx):
    def hook(linear, input, output):
        activations[f"Activation {idx}"] = output.detach().numpy()
    return hook
    
depth = 11
model = MLP(depth)

for idx, linear in enumerate(model.linears):
    linear.register_forward_hook(get_activation(idx))
model.out.register_forward_hook(get_activation(depth))

x = torch.tensor(-100.).view(-1, 1)
y = model(x)

sorted(activations.items(), key = lambda kv: int(kv[0].lstrip("Activation ")))

[('Activation 0', array([[-99.]], dtype=float32)),
 ('Activation 1', array([[1.]], dtype=float32)),
 ('Activation 2', array([[2.]], dtype=float32)),
 ('Activation 3', array([[3.]], dtype=float32)),
 ('Activation 4', array([[4.]], dtype=float32)),
 ('Activation 5', array([[5.]], dtype=float32)),
 ('Activation 6', array([[6.]], dtype=float32)),
 ('Activation 7', array([[7.]], dtype=float32)),
 ('Activation 8', array([[8.]], dtype=float32)),
 ('Activation 9', array([[9.]], dtype=float32)),
 ('Activation 10', array([[10.]], dtype=float32)),
 ('Activation 11', array([[10.12345]], dtype=float32))]