# Intro to Forward Hooks
- learning how to write hooks in Pytorch from first principles

In [318]:
import torch
import torch.nn as nn

## Logging Hook Example
- Hook on `nn.Linear` to print out activations

In [319]:
layer = nn.Linear(in_features=3, out_features=4) # Simple Up projection

In [320]:
# We can see that there are no forward hooks
print("Current Forward Hooks:", layer._forward_hooks)

Current Forward Hooks: OrderedDict()


### Forward Hook Format
- Hooks have the signature `my_hook(module: nn.Module, inpu: torch.Tensor, output: torch.Tensor) -> None`
    - Note: if we return None this means we do not patch the forward activation
- register the hook using `nn.Module.register_forward_hook(<Hook_Name>)`

In [321]:
def print_hook(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
    print("Inside hook!")
    print("Module:", module)
    print("Input shape:", input[0].shape)   # input is a tuple
    print("Output shape:", output.shape)

In [322]:
print_hook_handler = layer.register_forward_hook(print_hook)


In [323]:
# We can see that one hook is registered
print("Current Forward Hooks:", layer._forward_hooks)

Current Forward Hooks: OrderedDict({43: <function print_hook at 0x10fba0720>})


In [324]:
x = torch.rand((4,3))
y = layer(x)

Inside hook!
Module: Linear(in_features=3, out_features=4, bias=True)
Input shape: torch.Size([4, 3])
Output shape: torch.Size([4, 4])


### Cleaning up the hook
- Prevent memory leaks by calling `<handler>. remove()`

In [325]:
print_hook_handler.remove()

In [326]:
# We should see the registered hooks be empty
print("Current Forward Hooks:", layer._forward_hooks)

Current Forward Hooks: OrderedDict()


## Patching Hook Example

In [327]:
def create_ablation_hook(colum_to_zero: int):
    def ablation_hook(module: nn.Module, input: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
        mask = torch.ones_like(output)
        mask[:, colum_to_zero] = 0
        return output * mask
    return ablation_hook

In [328]:
w = nn.Linear(3,4)

In [329]:
ablation_handler = w.register_forward_hook(create_ablation_hook(2))

In [330]:
print(w._forward_hooks)

OrderedDict({44: <function create_ablation_hook.<locals>.ablation_hook at 0x1064e09a0>})


In [331]:
y = w(torch.rand(4,3))
print(y)

tensor([[-0.7583,  0.8632, -0.0000, -0.4746],
        [-0.6353,  0.8443, -0.0000, -0.2988],
        [-0.6618,  0.5972, -0.0000, -0.4115],
        [-0.5793,  0.5992, -0.0000, -0.3042]], grad_fn=<MulBackward0>)


In [332]:
ablation_handler.remove()
print(w._forward_hooks)

OrderedDict()


## Saving Activations

In [333]:
activations = {}

def save_activation(name: str) -> None:
    def create_save_neuron_hook(module: nn.Module, input: torch.Tensor, output: torch.Tensor) -> None:
        print(f"saving activations of shape {output.shape} to {name}")
        activations[name] = output.detach().cpu()
    return create_save_neuron_hook

In [334]:
handle = w.register_forward_hook(save_activation("layer_w_activation"))

In [335]:
y = w(torch.rand(5,3))

saving activations of shape torch.Size([5, 4]) to layer_w_activation


In [336]:
print(activations)

{'layer_w_activation': tensor([[-0.7699,  0.7086, -0.2821, -0.5579],
        [-0.7109,  0.6113, -0.6538, -0.6928],
        [-0.7080,  0.7207, -0.5546, -0.4953],
        [-0.5865,  0.6991, -0.3976, -0.1139],
        [-0.6517,  0.6065, -0.5421, -0.4833]])}


In [337]:
handle.remove()