In [1]:
import torch
from transformer_lens import HookedTransformer

In [2]:
model = HookedTransformer.from_pretrained("gpt2-small")

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [2]:
import os

cache_dir = os.path.expanduser("~/.cache/huggingface/hub")

In [3]:
model = HookedTransformer.from_pretrained(
    model_name="Qwen/Qwen3-4B",
    cache_dir=cache_dir,
    dtype=torch.bfloat16,
    trust_remote_code=True,
    device="cpu",
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model Qwen/Qwen3-4B into HookedTransformer


In [4]:
test_text = "The capital of France is"
tokens = model.to_tokens(test_text)
input_length = tokens.shape[1]

In [5]:
print(f"Input text: '{test_text}'")
print(f"Input tokens: {model.to_str_tokens(test_text)}")
print(f"Input length: {input_length} tokens")
print(f"Input token IDs: {tokens[0].tolist()}")
print("-" * 50)

Input text: 'The capital of France is'
Input tokens: ['The', ' capital', ' of', ' France', ' is']
Input length: 5 tokens
Input token IDs: [785, 6722, 315, 9625, 374]
--------------------------------------------------


In [6]:
hook_calls = []


def debug_hook(activations, hook):
    seq_len = activations.shape[1]
    hook_calls.append(
        {
            "call_number": len(hook_calls) + 1,
            "sequence_length": seq_len,
            "layer": hook.name,
            "shape": activations.shape,
            "is_input_processing": seq_len == input_length,
            "is_generation_step": seq_len > input_length,
            "new_tokens": seq_len - input_length if seq_len > input_length else 0,
        }
    )

    print(
        f"Hook call #{len(hook_calls)}: seq_len={seq_len}, "
        f"new_tokens={seq_len - input_length}, "
        f"layer={hook.name}"
    )

In [7]:
layer_to_test = 20  # middle layer
model.add_hook(f"blocks.{layer_to_test}.hook_resid_pre", debug_hook)

print("Starting generation...")
print("=" * 50)

Starting generation...


In [8]:
with torch.no_grad():
    output = model.generate(
        test_text,
        max_new_tokens=5,  # Just 5 tokens to keep it simple
        temperature=0.0,  # Deterministic
        do_sample=False,
    )

print("=" * 50)
print("Generation complete!")
print(f"Output: '{output}'")
# Remove hooks
model.reset_hooks()

  0%|          | 0/5 [00:00<?, ?it/s]

Hook call #1: seq_len=5, new_tokens=0, layer=blocks.20.hook_resid_pre
Hook call #2: seq_len=1, new_tokens=-4, layer=blocks.20.hook_resid_pre
Hook call #3: seq_len=1, new_tokens=-4, layer=blocks.20.hook_resid_pre
Hook call #4: seq_len=1, new_tokens=-4, layer=blocks.20.hook_resid_pre
Hook call #5: seq_len=1, new_tokens=-4, layer=blocks.20.hook_resid_pre
Generation complete!
Output: 'The capital of France is Paris. The capital of'


In [10]:
print("\n" + "=" * 60)
print("ANALYSIS:")
print("=" * 60)

if not hook_calls:
    print("❌ No hook calls detected!")
else:
    print(f"✅ Total hook calls: {len(hook_calls)}")
    print()

    first_call = hook_calls[0]
    if first_call["is_input_processing"]:
        print("✅ FIRST CALL captures INPUT PROCESSING")
        print(f"   - Sequence length: {first_call['sequence_length']}")
        print(f"   - Matches input length: {input_length}")
    else:
        print("❌ First call is NOT input processing")
        print(f"   - Expected seq_len: {input_length}")
        print(f"   - Actual seq_len: {first_call['sequence_length']}")

    print()

    # Check generation calls
    generation_calls = [call for call in hook_calls if call["is_generation_step"]]
    if generation_calls:
        print(f"✅ {len(generation_calls)} calls during GENERATION")
        for i, call in enumerate(generation_calls):
            print(
                f"   Generation step {i + 1}: seq_len={call['sequence_length']}, "
                f"new_tokens={call['new_tokens']}"
            )
    else:
        print(hook_calls)


ANALYSIS:
✅ Total hook calls: 5

✅ FIRST CALL captures INPUT PROCESSING
   - Sequence length: 5
   - Matches input length: 5

[{'call_number': 1, 'sequence_length': 5, 'layer': 'blocks.20.hook_resid_pre', 'shape': torch.Size([1, 5, 2560]), 'is_input_processing': True, 'is_generation_step': False, 'new_tokens': 0}, {'call_number': 2, 'sequence_length': 1, 'layer': 'blocks.20.hook_resid_pre', 'shape': torch.Size([1, 1, 2560]), 'is_input_processing': False, 'is_generation_step': False, 'new_tokens': 0}, {'call_number': 3, 'sequence_length': 1, 'layer': 'blocks.20.hook_resid_pre', 'shape': torch.Size([1, 1, 2560]), 'is_input_processing': False, 'is_generation_step': False, 'new_tokens': 0}, {'call_number': 4, 'sequence_length': 1, 'layer': 'blocks.20.hook_resid_pre', 'shape': torch.Size([1, 1, 2560]), 'is_input_processing': False, 'is_generation_step': False, 'new_tokens': 0}, {'call_number': 5, 'sequence_length': 1, 'layer': 'blocks.20.hook_resid_pre', 'shape': torch.Size([1, 1, 2560]), 