In [2]:
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from jaxtyping import Float, Int
import requests
import functools

In [3]:
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
model.generate("the cat is sm")

  0%|          | 0/10 [00:00<?, ?it/s]

'the cat is smothering its paw, which means the paw lies'

In [5]:
sample = model.to_tokens("the cat is sm")
top_logits, top_tokens = model.run_with_hooks(sample)[0, -1, :].topk(k=5, dim=-1)
print(top_logits)
model.to_str_tokens(top_tokens)

tensor([19.7922, 19.6165, 19.4624, 18.2938, 17.9073], device='cuda:0',
       grad_fn=<TopkBackward0>)


['eared', 'itten', 'elly', 'ot', 'okin']

In [6]:
r = requests.get("https://deepdecipher.org/api/gpt2-small/neuron2graph-search?query=activating:sm")

In [7]:
layers = [[]] * 12
for index in r.json()["data"]:
    layer = index["layer"]
    neuron = index["neuron"]
    layers[layer].append(neuron)

In [8]:
hooks = []


def hook_fn(
    indices: Int[Tensor, " _"], activation: Float[Tensor, "batch context neurons_per_layer"], hook: HookPoint
) -> None:
    activation[:, -1, indices] = 0.0


for layer_index, neurons in enumerate(layers):
    indices = torch.tensor(neurons)
    hook = functools.partial(hook_fn, indices)
    hooks.append((f"blocks.{layer_index}.mlp.hook_post", hook))


top_logits, top_tokens = model.run_with_hooks(sample, fwd_hooks=hooks)[0, -1, :].topk(k=5, dim=-1)
print(top_logits)
model.to_str_tokens(top_tokens)

tensor([20.9554, 20.9068, 20.5438, 19.5182, 18.8950], device='cuda:0',
       grad_fn=<TopkBackward0>)


['eared', 'elly', 'itten', 'ot', 'okin']

In [12]:
model_large = HookedTransformer.from_pretrained("gpt2-large")
model_large.generate("Apple->Red, Lime->Green, Banana->")


Loaded pretrained model gpt2-large into HookedTransformer


  0%|          | 0/10 [00:00<?, ?it/s]

'Apple->Red, Lime->Green, Banana->Pink, Antonchatka Chatter,�'

In [16]:
model_gelu = HookedTransformer.from_pretrained("gelu-1l")
model_gelu.generate("the cat is sm")


Loaded pretrained model gelu-1l into HookedTransformer


  0%|          | 0/10 [00:00<?, ?it/s]

"the cat is smokin's bead-like visorian-safe material"