<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Tracr_to_Transformer_Lens_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Tracr to TransformerLens Converter
[Tracr](https://github.com/deepmind/tracr) is a cool new DeepMind tool that compiles a written program in RASP to transformer weights. TransformerLens is a library I've written to easily do mechanistic interpretability on a transformer and to poke around at its internals. This is a (hacky!) script to convert Tracr weights from the JAX form to a TransformerLens HookedTransformer in PyTorch.

See [the TransformerLens tutorial](https://neelnanda.io/transformer-lens-demo) to get started

Python version must be >=3.8 (my fork of Tracr is a bit more backwards compatible, original library is at least 3.9)

In [1]:
!python --version

Python 3.8.15


In [2]:
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens
    # Fork of Tracr that's backward compatible with Python 3.8
    %pip install git+https://github.com/neelnanda-io/Tracr
    
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    # from IPython import get_ipython

    # ipython = get_ipython()
    # # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    # ipython.magic("load_ext autoreload")
    # ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


In [3]:
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
import numpy as np

from tracr.rasp import rasp
from tracr.compiler import compiling

  from .autonotebook import tqdm as notebook_tqdm


Loads an example RASP program model. This program reverses lists. The model takes as input a list of pre-tokenization elements (here `["BOS", 1, 2, 3]`), these are tokenized (`[3, 0, 1, 2]`), the transformer is applied, and then an argmax is taken over the output and it is detokenized - this can be seen on the `out.decoded` attribute of the output

In [4]:

def make_length():
  all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
  return rasp.SelectorWidth(all_true_selector)


length = make_length()  # `length` is not a primitive in our implementation.
opp_index = length - rasp.indices - 1
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
reverse = rasp.Aggregate(flip, rasp.tokens)

bos = "BOS"
model = compiling.compile_rasp_to_model(
    reverse,
    vocab={1, 2, 3},
    max_seq_len=5,
    compiler_bos=bos,
)

out = model.apply([bos, 1, 2, 3])



Extract the model config from the Tracr model, and create a blank HookedTransformer object

In [5]:

# %%

n_heads = model.model_config.num_heads
n_layers = model.model_config.num_layers
d_head = model.model_config.key_size
d_mlp = model.model_config.mlp_hidden_size
act_fn = "relu"
normalization_type = "LN"  if model.model_config.layer_norm else None
attention_type = "causal"  if model.model_config.causal else "bidirectional"


n_ctx = model.params["pos_embed"]['embeddings'].shape[0]
# Equivalent to length of vocab, with BOS and PAD at the end
d_vocab = model.params["token_embed"]['embeddings'].shape[0]
# Residual stream width, I don't know of an easy way to infer it from the above config.
d_model = model.params["token_embed"]['embeddings'].shape[1]

# Equivalent to length of vocab, WITHOUT BOS and PAD at the end because we never care about these outputs
# In practice, we always feed the logits into an argmax
d_vocab_out = model.params["token_embed"]['embeddings'].shape[0] - 2

cfg = HookedTransformerConfig(
    n_layers=n_layers,
    d_model=d_model,
    d_head=d_head,
    n_ctx=n_ctx,
    d_vocab=d_vocab,
    d_vocab_out=d_vocab_out,
    d_mlp=d_mlp,
    n_heads=n_heads,
    act_fn=act_fn,
    attention_dir=attention_type,
    normalization_type=normalization_type,
)
tl_model = HookedTransformer(cfg)

Extract the state dict, and do some reshaping so that everything has a n_heads dimension

In [6]:

# %%
sd = {}
sd["pos_embed.W_pos"] = model.params["pos_embed"]['embeddings']
sd["embed.W_E"] = model.params["token_embed"]['embeddings']
# Equivalent to max_seq_len plus one, for the BOS

# The unembed is just a projection onto the first few elements of the residual stream, these store output tokens
# This is a NumPy array, the rest are Jax Arrays, but w/e it's fine.
sd["unembed.W_U"] = np.eye(d_model, d_vocab_out)

for l in range(n_layers):
    sd[f"blocks.{l}.attn.W_K"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/key"]["w"],
        "d_model (n_heads d_head) -> n_heads d_model d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_K"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/key"]["b"],
        "(n_heads d_head) -> n_heads d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/query"]["w"],
        "d_model (n_heads d_head) -> n_heads d_model d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/query"]["b"],
        "(n_heads d_head) -> n_heads d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.W_V"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/value"]["w"],
        "d_model (n_heads d_head) -> n_heads d_model d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_V"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/value"]["b"],
        "(n_heads d_head) -> n_heads d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.W_O"] = einops.rearrange(
        model.params[f"transformer/layer_{l}/attn/linear"]["w"],
        "(n_heads d_head) d_model -> n_heads d_head d_model",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_O"] = model.params[f"transformer/layer_{l}/attn/linear"]["b"]

    sd[f"blocks.{l}.mlp.W_in"] = model.params[f"transformer/layer_{l}/mlp/linear_1"]["w"]
    sd[f"blocks.{l}.mlp.b_in"] = model.params[f"transformer/layer_{l}/mlp/linear_1"]["b"]
    sd[f"blocks.{l}.mlp.W_out"] = model.params[f"transformer/layer_{l}/mlp/linear_2"]["w"]
    sd[f"blocks.{l}.mlp.b_out"] = model.params[f"transformer/layer_{l}/mlp/linear_2"]["b"]
print(sd.keys())


dict_keys(['pos_embed.W_pos', 'embed.W_E', 'unembed.W_U', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out', 'blocks.2.mlp.b_out', 'blocks.3.attn.W_K', 'blocks.3.attn.b_K', 'blocks.3.attn.W_Q', 'blocks.3.attn.b_Q', 'blocks.3.attn.W_V', 'blocks.3.attn.b_V', 'blocks.3.attn.W_O', 'blocks.3.attn.b_O', 'blocks.3.ml

Convert weights to tensors and load into the tl_model

In [7]:

for k, v in sd.items():
    # I cannot figure out a neater way to go from a Jax array to a numpy array lol
    sd[k] = torch.tensor(np.array(v))

tl_model.load_state_dict(sd, strict=False)


_IncompatibleKeys(missing_keys=['blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.2.attn.mask', 'blocks.2.attn.IGNORE', 'blocks.3.attn.mask', 'blocks.3.attn.IGNORE', 'unembed.b_U'], unexpected_keys=[])

Create helper functions to do the tokenization and de-tokenization

In [8]:

# %%
INPUT_ENCODER = model.input_encoder
OUTPUT_ENCODER = model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos


We can now run the model!

In [9]:

input = [bos, 1, 2, 3]
out = model.apply(input)
print("Original Decoding:", out.decoded)

input_tokens_tensor = create_model_input(input)
logits = tl_model(input_tokens_tensor)
decoded_output = decode_model_output(logits)
print("TransformerLens Replicated Decoding:", decoded_output)
# %%


Original Decoding: ['BOS', 3, 2, 1]
TransformerLens Replicated Decoding: ['BOS', 3, 2, 1]


Lets cache all intermediate activations in the model, and check that they're the same:

In [10]:
logits, cache = tl_model.run_with_cache(input_tokens_tensor)

for layer in range(tl_model.cfg.n_layers):
    print(f"Layer {layer} Attn Out Equality Check:", np.isclose(cache["attn_out", layer].detach().cpu().numpy(), np.array(out.layer_outputs[2*layer])).all())
    print(f"Layer {layer} MLP Out Equality Check:", np.isclose(cache["mlp_out", layer].detach().cpu().numpy(), np.array(out.layer_outputs[2*layer+1])).all())

Layer 0 Attn Out Equality Check: True
Layer 0 MLP Out Equality Check: True
Layer 1 Attn Out Equality Check: True
Layer 1 MLP Out Equality Check: True
Layer 2 Attn Out Equality Check: True
Layer 2 MLP Out Equality Check: True
Layer 3 Attn Out Equality Check: True
Layer 3 MLP Out Equality Check: True


Look how pretty and ordered the final residual stream is!

(The logits are the first 3 dimensions of the residual stream, and we can see that they're flipped!)

In [23]:
import plotly.express as px
px.imshow(cache["resid_post", -1].detach().cpu().numpy()[0],
color_continuous_scale="Blues", labels={"x":"Residual Stream", "y":"Position"}, y=[str(i) for i in input]).show("colab" if IN_COLAB else "")