Let's load the hooked transformer model. We are using the Llama 3 8B architecture, and passing the distilled reasoning model as the HF model because HookedTransformer doesn't support the actual Distill-Llama model yet.

In [1]:
# Import necessary libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
import torch
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the model using HookedTransformer
print("Loading DeepSeek-R1-Distill-Llama-8B with HookedTransformer...")
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# First load the model with HuggingFace to pass to HookedTransformer
print("Loading model with HuggingFace first...")
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # Use bfloat16 precision for memory efficiency
    trust_remote_code=True,
)

# Now load with HookedTransformer using the HuggingFace model
model = HookedTransformer.from_pretrained_no_processing(
    "meta-llama/Meta-Llama-3-8B",  # Use Llama-3-8B architecture as base
    hf_model=hf_model,             # Pass the actual model weights
    device=device,
    dtype=torch.bfloat16,
    tokenizer=tokenizer,
    # center_writing_weights=False,  # Don't try to center weights
    # center_unembed=False,          # Don't try to center unembed
    # fold_ln=False,                 # Use LayerNorm
)

print(f"Model loaded successfully: {model_name}")
print(f"Model has {model.cfg.n_layers} layers and {model.cfg.n_heads} attention heads")


Using device: cuda
Loading DeepSeek-R1-Distill-Llama-8B with HookedTransformer...
Loading model with HuggingFace first...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer
Model loaded successfully: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
Model has 32 layers and 32 attention heads


In [14]:
# Define a list of prompts to get activations
prompts = [
    "What is the solution of (x-3)(x+4)(x^3+1)?\n\nSo the term is factored, meaning we can plug in values for x that will set the whole expression to zero. Solving for (x-3) = 0, we obtain x = 3. Similary, we get x = -4 and x = -1. So we have solved the problem, right?",
    "Compute the derivative of f(x) = x^3 + 2x^2 - 5x + 3.\n\nTo find the derivative, I'll use the power rule and linearity of differentiation. So, first note that the derivative of x^3 is 3x^2, then the derivative of 2x^2 is 4x, next the derivative of -5x is -5, and finally the derivative of 3 is 0. Therefore, combining these results, the derivative is f'(x) = 3x^2 + 4x - 5.",
    "Compute the derivative of f(x) = x^3 + 2x^2 - 5x + 3.\n\nTo find the derivative, I'll use the power rule and linearity of differentiation. So, first note that the derivative of x^3 is 3x^3, then the derivative of 2x^2 is 4x, next the derivative of -5x is -5, and finally the derivative of 3 is 0. Therefore, combining these results, the derivative is f'(x) = 3x^2 + 4x - 5.",
    "Compute the derivative of f(x) = 10^6 x^5 - 672,583x^3 + 4028x^2 - 9999x + 150,000.\n\nTo find the derivative, I'll use the power rule and linearity of differentiation step by step. So, first note that the derivative of 10^6 x^5 is 5 * 10^6 x^4. Next, the derivative of -672,583x^3 is -3 * 672,583 x^2, which is 2017749 x^2. Next the derivative of 4028x^2 is 2 * 4028 x. Then the derivative of -9999x is -9999, and finally the derivative of 150,000 is 0. Alright, therefore, combining these results, the derivative is f'(x) = 5 * 10^6 x^4 - 2017749 x^2 + 2 * 4028 x - 9999. Simplifying this, we get f'(x) = 5000000 x^4 - 2017749 x^2 + 8056 x - 9999.",
    "Compute the derivative of f(x) = 10^6 x^5 - 672,583x^3 + 4028x^2 - 9999x + 150,000.\n\nTo find the derivative, I'll use the power rule and linearity of differentiation step by step. So, first note that the derivative of 10^6 x^5 is 5 * 10^6 x^4. Next, the derivative of -672,583x^3 is -3 * 672,583 x^2, which is 2017748 x^2. Next the derivative of 4028x^2 is 2 * 4028 x. Then the derivative of -9999x is -9999, and finally the derivative of 150,000 is 0. Alright, therefore, combining these results, the derivative is f'(x) = 5 * 10^6 x^4 - 2017748 x^2 + 2 * 4028 x - 9999. Simplifying this, we get f'(x) = 5000000 x^4 - 2017748 x^2 + 8056 x - 9999."
]

# Initialize a list to store caches for each prompt
all_logits = []
all_caches = []

# Process each prompt
for i, prompt in enumerate(prompts):
    print(f"\nRunning model with cache for prompt {i+1}/{len(prompts)}:")
    print(f"'{prompt[:50]}...'")
    
    # Tokenize the input
    tokens = model.to_tokens(prompt)
    print(f"Tokenized input shape: {tokens.shape}")
    
    # Run the model with cache to get all activations
    with torch.no_grad():
        logits, cache = model.run_with_cache(tokens)
    
    # Move logits and cache to CPU to free up GPU memory
    logits_cpu = logits.cpu()
    
    # Create a deep copy of the cache on CPU
    cache_cpu = {}
    for k, v in cache.items():
        if isinstance(v, torch.Tensor):
            cache_cpu[k] = v.cpu().clone()
        else:
            cache_cpu[k] = v
    
    # Store results
    all_logits.append(logits_cpu)
    all_caches.append(cache_cpu)
    
    # Explicitly delete GPU tensors to free memory
    del logits, cache
    
    # Clear GPU cache
    torch.cuda.empty_cache()
    
    print(f"Completed processing prompt {i+1}")

# Print information about the caches
print("\nCache contains the following activation types (from first prompt):")
for key in all_caches[0].keys():
    print(f"- {key}")

# Example: Access a specific activation from the first prompt
if "pattern" in all_caches[0]:
    attn_pattern = all_caches[0]["pattern", 0]  # Get attention pattern from layer 0
    print(f"\nAttention pattern from layer 0 shape (first prompt): {attn_pattern.shape}")

# Example: Access residual stream activations from the first prompt
if "resid_pre" in all_caches[0]:
    resid_pre = all_caches[0]["resid_pre", 0]  # Get residual stream before layer 0
    print(f"Residual stream before layer 0 shape (first prompt): {resid_pre.shape}")

# Example: Access MLP activations from the first prompt
if "mlp_out" in all_caches[0]:
    mlp_out = all_caches[0]["mlp_out", 0]  # Get MLP output from layer 0
    print(f"MLP output from layer 0 shape (first prompt): {mlp_out.shape}")

print(f"\nSuccessfully retrieved model activations for {len(prompts)} prompts!")



Running model with cache for prompt 1/5:
'What is the solution of (x-3)(x+4)(x^3+1)?

So the...'
Tokenized input shape: torch.Size([1, 89])
Completed processing prompt 1

Running model with cache for prompt 2/5:
'Compute the derivative of f(x) = x^3 + 2x^2 - 5x +...'
Tokenized input shape: torch.Size([1, 124])
Completed processing prompt 2

Running model with cache for prompt 3/5:
'Compute the derivative of f(x) = x^3 + 2x^2 - 5x +...'
Tokenized input shape: torch.Size([1, 124])
Completed processing prompt 3

Running model with cache for prompt 4/5:
'Compute the derivative of f(x) = 10^6 x^5 - 672,58...'
Tokenized input shape: torch.Size([1, 251])
Completed processing prompt 4

Running model with cache for prompt 5/5:
'Compute the derivative of f(x) = 10^6 x^5 - 672,58...'
Tokenized input shape: torch.Size([1, 251])
Completed processing prompt 5

Cache contains the following activation types (from first prompt):
- hook_embed
- blocks.0.hook_resid_pre
- blocks.0.ln1.hook_scale
- blocks

In [15]:
# Print top predicted tokens for each prompt
for prompt_idx, logits in enumerate(all_logits):
    # Get the logits from the last position
    last_token_logits = logits[0, -1, :]

    # Convert logits to probabilities using softmax
    probs = torch.nn.functional.softmax(last_token_logits, dim=-1)

    # Get the top 10 probability tokens
    top_k = 10
    top_probs, top_indices = torch.topk(probs, top_k)

    print(f"\nTop {top_k} tokens predicted after prompt {prompt_idx+1}:")
    print(f"'{prompts[prompt_idx][:50]}...'")
    for i, (index, prob) in enumerate(zip(top_indices, top_probs)):
        token = model.tokenizer.decode([index])
        print(f"{i+1}. Token: '{token}', Token ID: {index.item()}, Probability: {prob.item():.6f}")



Top 10 tokens predicted after prompt 1:
'What is the solution of (x-3)(x+4)(x^3+1)?

So the...'
1. Token: ' Wait', Token ID: 14144, Probability: 0.283203
2. Token: ' But', Token ID: 2030, Probability: 0.250000
3. Token: ' Hmm', Token ID: 89290, Probability: 0.081055
4. Token: ' So', Token ID: 2100, Probability: 0.063477
5. Token: ' The', Token ID: 578, Probability: 0.052490
6. Token: ' Or', Token ID: 2582, Probability: 0.046387
7. Token: ' Well', Token ID: 8489, Probability: 0.028076
8. Token: ' Let', Token ID: 6914, Probability: 0.019287
9. Token: ' 

', Token ID: 4815, Probability: 0.019287
10. Token: ' No', Token ID: 2360, Probability: 0.017090

Top 10 tokens predicted after prompt 2:
'Compute the derivative of f(x) = x^3 + 2x^2 - 5x +...'
1. Token: ' 

', Token ID: 4815, Probability: 0.585938
2. Token: ' I', Token ID: 358, Probability: 0.079590
3. Token: ' Wait', Token ID: 14144, Probability: 0.045166
4. Token: ' This', Token ID: 1115, Probability: 0.040039
5. Token: ' However', T