In [1]:
import os
DEV_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2
        
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

# change renderer to colab if needed
import plotly.io as pio
if IN_COLAB or not DEV_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
    
print(f"Using renderer: {pio.renderers.default}")

# import circuit vis
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

import warnings
warnings.filterwarnings("ignore")

# Main imports
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

# transformer lens stuff
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

# set grad to false cuz we dont need to train
torch.set_grad_enabled(False)

from plot_utils import *

device = "cpu"


Using renderer: notebook_connected


# How to add hookpoint to own model

In [3]:
from transformer_lens.hook_points import HookedRootModule

class SquareThenAdd(nn.Module):
    def __init__(self, offset):
        super().__init__()
        self.offset = nn.Parameter(torch.tensor(offset))
        self.hook_square = HookPoint()
        
    def forward(self, x):
        square = self.hook_square(x*x)
        return self.offset + square
    
class TwoLayerModel(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.l1 = SquareThenAdd(3.0)
        self.l2 = SquareThenAdd(-4.0)
        self.hook_in = HookPoint()
        self.hook_mid = HookPoint()
        self.hook_out = HookPoint()
        
        # builds internal dict of modules and hooks and gives each a name
        super().setup()
        
    def forward(self, x):
        x_in = self.hook_in(x)
        x_mid = self.hook_mid(self.l1(x_in))
        x_out = self.hook_out(self.l2(x_mid))
        return x_out

model=TwoLayerModel()


In [4]:

out, cache = model.run_with_cache(torch.tensor(5.0))
print("Model output:", out.item())
for key in cache:
    print(f"Value cached at hook {key}", cache[key].item())


Model output: 780.0
Value cached at hook hook_in 5.0
Value cached at hook l1.hook_square 25.0
Value cached at hook hook_mid 28.0
Value cached at hook l2.hook_square 784.0
Value cached at hook hook_out 780.0


In [6]:
def set_to_zero_hook(tensor, hook):
    print(hook.name)
    return torch.tensor(0.0)


print(
    "Output after intervening on layer2.hook_scaled",
    model.run_with_hooks(
        torch.tensor(5.0), fwd_hooks=[("l2.hook_square", set_to_zero_hook)]
    ).item(),
)


l2.hook_square
Output after intervening on layer2.hook_scaled -4.0
