In [None]:
import torch
import torch.nn as nn
from transformer_lens.hook_points import HookedRootModule, HookPoint

class SimpleFFN(HookedRootModule):
    def __init__(self, d_in=4, d_hidden=8, d_out=2):
        super().__init__()
        self.hook_in = HookPoint()     # before first layer
        self.lin1 = nn.Linear(d_in, d_hidden)
        self.hook_hidden = HookPoint() # after first layer
        self.lin2 = nn.Linear(d_hidden, d_out)
        self.hook_out = HookPoint()    # final output

        self.setup()  # required

    def forward(self, x):
        x = self.hook_in(x)
        x = self.lin1(x)
        x = torch.relu(self.hook_hidden(x))
        x = self.lin2(x)
        return self.hook_out(x)

# Example usage
model = SimpleFFN()
x = torch.randn(1, 4)

def print_hidden(tensor, hook):
    print("Hidden activations:", tensor)

logits = model.run_with_hooks(x, fwd_hooks=[("hook_hidden", print_hidden)])


Hidden activations: tensor([[ 0.5672, -0.6464, -1.0984,  0.2729, -0.6984, -0.9447, -0.5653,  0.7969]],
       grad_fn=<AddmmBackward0>)
