In [1]:
import json
import os
import argparse
import torch

from tqdm import tqdm
from collections import defaultdict
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import CrossEntropyLoss
from tuned_lens import TunedLens

from _config import HUFFINGFACE_KEY


# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
prompt = 'The London Bridge is in the city of'
device = utils.get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [27]:
model.generate(prompt, max_new_tokens=1,temperature=0.7,prepend_bos=True)

100%|██████████| 1/1 [00:04<00:00,  4.23s/it]


'The London Bridge is in the city of London'

In [24]:
tokens = model.to_tokens(prompt)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
print(model.cfg.n_layers)

12


In [25]:
# Get token embeddings (input to the model)
# Shape: [batch, seq_len, d_model]
embed = model.embed(tokens) + model.pos_embed(tokens)
# Apply layer norm if model uses final_rms (usually False at input, but check)
if model.cfg.final_rms:
    embed = model.ln_final(embed)

# Get logits from input embeddings
layer_logits = [model.unembed(embed)]

for layer in range(model.cfg.n_layers):
    # Get residual stream at this point
    resid = cache["resid_post", layer]  # Shape: [batch, seq_len, d_model]

    # Apply final LayerNorm if needed
    if model.cfg.final_rms:
        resid = model.ln_final(resid)

    # Compute logits: [batch, seq_len, d_vocab]
    logits_at_layer = model.unembed(resid)
    layer_logits.append(logits_at_layer)

In [26]:
len(layer_logits)

13