# Interp on my GPT-2 Flax Implementation

In [1]:
# Imports
import jax
import jax.numpy as jnp
from jax import random
from model import GPT, GPTConfig
import os

import penzai
from __future__ import annotations

import os
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils

import treescope
import penzai
from penzai import pz

import sentencepiece as spm

from penzai.models import transformer
from penzai.models import simple_mlp, transformer

from penzai.toolshed import token_visualization
from penzai.toolshed import basic_training
from penzai.toolshed import jit_wrapper
from penzai.nn.layer import Layer
import dataclasses
treescope.basic_interactive_setup(autovisualize_arrays=True)


## Load model from checkpoint

In [2]:
from flax.training import checkpoints

config = GPTConfig(
    vocab_size=50257,  # Updated for GPT-2 BPE
    block_size=256,  # or whichever block size you used
    n_layer=4,
    n_head=4,
    n_embd=256,
    embd_pdrop=0.0,  # 0 for inference
    resid_pdrop=0.0,
    attn_pdrop=0.0,
)

# Load your model from a checkpoint directory
ckpt_dir = os.path.abspath("checkpoints/")

model = GPT(config)
params = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, step=201000,target=None)
if params is None:
    raise ValueError(f"No checkpoint found in {ckpt_dir}")




### Count number of parameters

In [3]:

# Create a dummy input. The shape should match the model's expected input shape.
# Here we use a batch size of 1 and sequence length equal to block_size.
dummy_input = jnp.zeros((1, config.block_size), dtype=jnp.int32)

# Initialize the model parameters with a PRNG key
key = random.PRNGKey(0)

# Function to count total parameters
def count_params(params):
    # Use jax.tree_util.tree_leaves to get all arrays in the nested dict
    return sum(x.size for x in jax.tree_util.tree_leaves(params))


total_params = count_params(params)
print("Total number of parameters:", total_params)

Total number of parameters: 28891136


## Generating completions via Inference

We can yield the completion token-by-token...

In [4]:
from inference import generate_text
from bpe import BPETokenizerJax

# Example prompt
prompt = "Shakespeare wrote:"
prompt = "You are gentlemen of brave metal; you would lift the moon out of her sphere, if she would continue in it five weeks without changing."
prompt = prompt +  prompt
tokenizer = BPETokenizerJax()

# Tokenize the prompt via BPE
# The tokenizer(...) call returns a list of jnp arrays (one per input string).
encoded_list = tokenizer(prompt, return_tensors="jax")
# For a single prompt string, encoded_list is a list of length 1.
idx = encoded_list[0]  # shape (sequence_length,)

# Add a batch dimension (shape: (1, seq_len))
idx_jax = idx[jnp.newaxis, :]

# Initialize RNG key
rng = jax.random.PRNGKey(0)

print("Prompt:", prompt)

print("Generated text:")

# Generate text
for new_token in model.apply(
        {"params": params},
        idx_jax,
        max_new_tokens=20,
        rng=rng,
        temperature=0.7,
        do_sample=True,
        top_k=None,
        method=model.generate_yield,
    ):
    # Decode the generated token
    generated_text = tokenizer.decode(new_token)
    print(generated_text, end="")



Prompt: You are gentlemen of brave metal; you would lift the moon out of her sphere, if she would continue in it five weeks without changing.You are gentlemen of brave metal; you would lift the moon out of her sphere, if she would continue in it five weeks without changing.
Generated text:


SEBASTIAN:
We would so, and then go a bat-fow

Or write an external function which generates that completion token-by-token...

In [5]:
import jax
import jax.numpy as jnp
import flax.linen as nn

def external_generate(
    model: nn.Module,
    params: dict,
    idx: jnp.ndarray,
    max_new_tokens: int,
    rng: jax.random.PRNGKey,
    temperature: float = 1.0,
    do_sample: bool = False,
    top_k: int = None,
):
    """
    Generate tokens from the model using only the __call__ method.

    Args:
        model: The GPT model (an instance of nn.Module) with a __call__ method.
        params: The parameters of the model.
        idx: jnp.ndarray of shape (B, T) containing the initial token indices.
        max_new_tokens: Number of tokens to generate.
        rng: A JAX random key.
        temperature: Temperature factor for sampling.
        do_sample: If True, sample from the distribution; otherwise, take argmax.
        top_k: If provided, restrict sampling to the top_k logits.

    Yields:
        Generated tokens one at a time (each of shape (1,)).
    """
    for _ in range(max_new_tokens):
        # Crop the context to the model's block size if necessary
        if idx.shape[1] > model.config.block_size:
            idx_cond = idx[:, -model.config.block_size:]
        else:
            idx_cond = idx

        # FIX: Wrap the parameters in {"params": params} so that the "params" collection is set correctly.
        logits, _ = model.apply({"params": params}, idx_cond, deterministic=True)

        # Focus on the logits for the last token in the sequence and scale by temperature
        logits = logits[:, -1, :] / temperature

        # Optionally, apply top-k filtering
        if top_k is not None:
            # Get the top_k logits and determine the minimum value among them
            top_logits, _ = jax.lax.top_k(logits, top_k)
            k_threshold = jnp.min(top_logits, axis=-1, keepdims=True)
            # Replace logits not in the top_k with -infinity
            logits = jnp.where(logits < k_threshold, -jnp.inf, logits)

        # Convert logits to probabilities
        probs = nn.softmax(logits, axis=-1)

        # Sample the next token (or take the argmax)
        rng, subkey = jax.random.split(rng)
        if do_sample:
            next_token = jax.random.categorical(subkey, jnp.log(probs), axis=-1)
        else:
            next_token = jnp.argmax(probs, axis=-1)

        # Ensure the next token has shape (B, 1) and append it to the current sequence
        next_token = next_token[:, None]
        idx = jnp.concatenate([idx, next_token], axis=1)
        yield next_token[0]

# Example usage:
print("Prompt:", prompt)
print("Generated text:")

for new_token in external_generate(
        model=model,
        params=params,  # This should be the inner parameters; external_generate now wraps them correctly.
        idx=idx_jax,
        max_new_tokens=1,
        rng=rng,
        temperature=0.8,
    ):
    # Decode the generated token
    generated_text = tokenizer.decode(new_token)
    print(generated_text, end="")


Prompt: You are gentlemen of brave metal; you would lift the moon out of her sphere, if she would continue in it five weeks without changing.You are gentlemen of brave metal; you would lift the moon out of her sphere, if she would continue in it five weeks without changing.
Generated text:



## Unflaxify the Model for Penzai

In [6]:
from penzai.toolshed.unflaxify import unflaxify_apply

# Create dummy input
dummy_input = jnp.ones((1, config.block_size), dtype=jnp.int32)
rng = jax.random.PRNGKey(42)
init_rng, dropout_rng = jax.random.split(rng)

# Initialize
variables = model.init(
    {
        "params": init_rng,
        "dropout": dropout_rng
    },
    idx=dummy_input,
    deterministic=False
)
    
# Intercept the forward pass
pz_model = unflaxify_apply(
    module=model,
    variables=variables,
    idx=dummy_input,     # our forward arguments
    deterministic=True   # example kwarg
)

pz_model

# Visualize Attention Patterns in the Causal Self-Attention

Looking for induction heads, we start by creating a repeated sequence...

In [7]:

repeated_sequence_prompt = "You are gentlemen of brave metal; you would lift the moon out of her sphere, if she would continue in it five weeks without changing."

# repeat the prompt twice
repeated_sequence = repeated_sequence_prompt + repeated_sequence_prompt

# Tokenize the prompt via BPE
tokens = tokenizer(repeated_sequence)

tokens[0]



In [8]:
from penzai.toolshed.unflaxify import ArgsAndKwargs

# Wrap tokens and tag with "batch" and "seq"
token_seq = pz.nx.wrap(tokens).tag("batch", "seq")

# Remove the tags and unwrap to get a plain JAX array
token_seq_plain = token_seq.untag("batch", "seq").unwrap()

# Now capture the plain array inside an ArgsAndKwargs instance.
token_seq_wrapped = ArgsAndKwargs.capture(token_seq_plain)

# Pass the wrapped, plain array to the model.
logits = pz_model(token_seq_wrapped)[0][0] # [0] removes the batch dim 

logits

# Find the index of the highest logit
highest_logit_index = jnp.argmax(logits, axis=-1)
print(highest_logit_index)

[38585 12662 17263 17263 17263  5634 14066  3183  2259  7471 45922  2259
 32654 32654 45922  7454 32654 32654 32654 30595 32654 32654  7471 23020
 32654 38316 22469 36701 19059 45922 42307 24260 28449 24260 32654 34290
 28449  7471 32654 28449 32654 32654 45922 28449 32654 20875 32654 28449
 32654 32654 24260 29871 32654 38316 32654 42292]


In [9]:
logits =pz.nx.wrap(logits).tag("seq", "vocabulary")


# Map softmax over the vocabulary
log_probs = pz.nx.nmap(jax.nn.log_softmax)(
    logits.untag("vocabulary")
).tag("vocabulary")
log_probs


In [10]:
# Indexing with a dictionary indexes the named axes; pz.slice helps slice them.
sliced_preds = log_probs[{"seq": pz.slice[:-1]}]
correct_next_token = token_seq[{"seq": pz.slice[1:]}]


log_prob_of_correct_next = sliced_preds[{"vocabulary": correct_next_token}]
log_prob_of_correct_next

In [11]:
# Unbind the names, slice the positional array, rebind the names if needed.
sliced_preds = log_probs.untag("seq")[:-1].tag("seq")
correct_next_token = token_seq.untag("seq")[1:].tag("seq")
sliced_preds.untag("vocabulary")[correct_next_token]

In [12]:
import penzai
from penzai import pz
from penzai.core.variables import StateVariable
from penzai.nn.layer import Layer
from typing import Any
from penzai.toolshed.unflaxify import InterceptedFlaxModuleMethod

@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.nn.Layer):  # <- pz.nn.Layer is the base class of Penzai layers.
  def __call__(self, intermediate_value, **unused_side_inputs):
    # print("Intermediate value:", intermediate_value)
    print(intermediate_value.shape)
    wrapped_intermediate_value = intermediate_value[0]
    # Show the value:
    pz.show(wrapped_intermediate_value)
    # And return it unchanged.
    return intermediate_value

def combined_call(original_submodule):
    # Create a wrapper that calls `original_submodule(...)`,
    # then does DisplayIntermediateValue or any post-processing.

    def new_callable(*args, **kwargs):
        output = original_submodule(*args, **kwargs)
        # do whatever side effect or instrumentation we want
        DisplayIntermediateValue()(output)  # e.g. just call it
        return output

    return new_callable

patched_selection = (
    pz.select(pz_model)
      .at_subtrees_where(
          lambda node:
              isinstance(node, InterceptedFlaxModuleMethod)
              and getattr(node.module, "__class__", None).__name__ == "Dropout"
              and node.module.name == 'attn_dropout'
      )
      
)


patched_model = patched_selection.apply(
          lambda method: combined_call(method),
      )

patched_selection


In [13]:
repeated_sequence_prompt = "You are gentlemen of brave metal"

# repeat the prompt twice
repeated_sequence_prompt = repeated_sequence_prompt + "|" + repeated_sequence_prompt

# Tokenize the prompt via BPE
tokens = tokenizer(repeated_sequence_prompt)

# Wrap tokens and tag with "batch" and "seq"
token_seq = pz.nx.wrap(tokens).tag("batch", "seq")

# Remove the tags and unwrap to get a plain JAX array
token_seq_plain = token_seq.untag("batch", "seq").unwrap()

# Now capture the plain array inside an ArgsAndKwargs instance.
token_seq_wrapped = ArgsAndKwargs.capture(token_seq_plain)


logits = patched_model(token_seq_wrapped)



(1, 4, 13, 13)


(1, 4, 13, 13)


(1, 4, 13, 13)


(1, 4, 13, 13)


# Summary

On the whole, I felt like Penzai was much "closer" to the actual model than TransformerLens, operating directly on the Jax PyTree and supporting jitting and sharding. One drawback is that when you `unflaxify` a model, you lose all the methods that run on it e.g. `model.generate` so the `unflaxified` model is really cemented as a duplicate for inspecting. Equally, treescope doesn't work as well on `unflaxified` models compared to ones natively built in Penzai, wrapping every layer in `InterceptedFlaxModuleMethod`.

| Feature | Penzai | TransformerLens|
|--------------------------|--------------------------|--------------------------|
| Supported frameworks? | `flax.linen` and `penzai` | `torch` inside `transformer-lens` |
|Visualizations | Treescope - really nice interactivity especially with named axes. Missed the input-sequence visualizations but I can see the functionality is built into treescope, would need custom integration. | CircuitsVis - really love the token-level visualizations where hovering over a word in the input sequence highlights the attention patterns on the other ones.|
|Loading Models | Either build the model in Penzai or "unflaxify" it | Only comes with small pre-loaded `HookedTransformer` models, difficult to define a new one|
|Parallelism | Shard & jit model like normal | No support |
|Caching activations | No caching, only saved if hooked. More practical for larger models. | `model.run_with_cache` to save activations, less practical for larger models |
|Hooks | Use `Selector` to assign hooks directly on the JAX PyTree then run forward pass as normal | `model.run_with_hooks` has to be explictly called|
|Generate functionality | Lost when `unflaxify` is called. | `model.generate` for basic text generation|
|Bonus Features | `NamedArray` and the `copy and paste` functionality of treescope` | Easy cache all activations|
