<a href="https://colab.research.google.com/github/AndisDraguns/rasp-lab/blob/main/3_SAT_Verifier_Transformer_in_Tracr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3-SAT Verifier Transformer in Tracr

I implement a [3-SAT](https://en.wikipedia.org/wiki/Boolean_satisfiability_problem) verifier as a Transformer circuit by using [RASP](https://arxiv.org/pdf/2106.06981.pdf) programs in [Tracr](https://github.com/google-deepmind/tracr/tree/main/tracr/rasp). This can serve as a toy example for a Transformer that is hard to reverse, and showcase theoretical cases that would be difficult for algorithms such as [Greedy Coordinate Descent](https://arxiv.org/pdf/2307.15043.pdf) (GCG). Any problem in NP can be reduced to 3-SAT, including ones that seem to be hard in the average case, e.g. integer factoring, (which is believed highly unlikely to be NP-complete).

RASP is a programming language that allows implementing handcrafted Transformer circuits. It can be used to to make Transformers perform simple algorithms such as reversing a string in a way that achieves length-generalization. For a great introduction on RASP, see this blog: https://srush.github.io/raspy/.

This notebook is made for developing and experimenting with Tracr models. It adapts Neel Nanda's [Tracr to TransformerLens Converter notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Tracr_to_Transformer_Lens_Demo.ipynb) and Tracr's [example visualisation notebook](https://github.com/google-deepmind/tracr/blob/main/tracr/examples/Visualize_Tracr_Models.ipynb). A TransformerLens model can be accessed at the end of the 'Build the TransformerLens model' section. In this version I've added some helper functions and bug fixes.

If you have any suggested improvements or bug fixes, feel free to message or email me!

# Install dependencies

This might take a few minutes:

In [None]:
%pip install transformer_lens
%pip install git+https://github.com/deepmind/tracr

Collecting git+https://github.com/deepmind/tracr
  Cloning https://github.com/deepmind/tracr to /tmp/pip-req-build-90tt6jet
  Running command git clone --filter=blob:none --quiet https://github.com/deepmind/tracr /tmp/pip-req-build-90tt6jet
  Resolved https://github.com/deepmind/tracr to commit 9ce2b8c82b6ba10e62e86cf6f390e7536d4fd2cd
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
#@title Imports

from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
import numpy as np
import math

from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler import lib
import matplotlib.pyplot as plt

import jax


# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

# originally developed for these versions:
  # Python 3.10.12
  # transformer_lens==1.14.0
  # git+https://github.com/google-deepmind/tracr/tree/9ce2b8c82b6ba10e62e86cf6f390e7536d4fd2cd

# Define helper functions

In [None]:
#@title Helper functions: Interfacing with Tracr
from typing import Iterable

def create_input_list(input: Iterable) -> list:
  return [compiler_bos] + list(input)

def create_tracr_runner_fn(model) -> callable:
  def run_tracr_model(input: Iterable)-> list:
    input_list = create_input_list(input)
    tracr_logits = model.apply(input_list)
    tracr_decoded_output = tracr_logits.decoded
    return tracr_decoded_output
  return run_tracr_model

compiler_bos = "BOS"
compiler_pad = "PAD"
causal = False
mlp_exactness = 100

In [None]:
#@title Helper functions: RASP programming

def make_all_true() -> rasp.SOp:
  return rasp.Map(lambda i: True, rasp.indices).named("all_trues")

def make_length() -> rasp.SOp:
  all_true_selector = rasp.Select(all_true, all_true, rasp.Comparison.TRUE).named("all_true_selector")
  return rasp.SelectorWidth(all_true_selector).named("length")

def hardcode_to_SOp(s: Iterable) -> rasp.SOp:
  return rasp.Map(lambda i: s[i] if i<len(s) else "_", rasp.indices).named("hardcoded_value")

def elementwise_EQ(a: rasp.SOp, b: rasp.SOp) -> rasp.SOp:
  return rasp.SequenceMap(lambda x, y: x == y, a, b).named("elementwise_EQ")

from tracr.compiler.lib import make_count
def count_true_vals(bools: rasp.SOp) -> rasp.SOp:
  return make_count(bools, True)

def elementwise_OR(a: rasp.SOp, b: rasp.SOp) -> rasp.SOp:
  return rasp.SequenceMap(lambda x, y: x or y, a, b).named("elementwise_OR")

def elementwise_AND(a: rasp.SOp, b: rasp.SOp) -> rasp.SOp:
  return rasp.SequenceMap(lambda x, y: x and y, a, b).named("elementwise_AND")

def check_if_all_true(bools: rasp.SOp) -> rasp.SOp:
  """check if all bools are True (numerically stable)"""
  inverted_bools = rasp.Map(lambda x: not x, bools).named("inverted_bools")
  any_missing  = rasp.numerical(rasp.Map(lambda x: x, inverted_bools).named("any_missing"))
  select_all = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.TRUE).named("select_all")
  has_missing = rasp.numerical(rasp.Aggregate(select_all, any_missing, default=0)).named("has_missing")
  not_has_missing = (~has_missing).named("not_has_missing")
  is_all_true = rasp.categorical(not_has_missing).named("is_all_true")
  return is_all_true

def elementwise_AND_for_n_SOps(sop_list)->rasp.SOp:
  if len(sop_list)==1: return sop_list[0]
  return elementwise_AND(sop_list[0], elementwise_AND_for_n_SOps(sop_list[1:]))

# Useful constants
all_true = make_all_true()
length = make_length()

# Password Checker

The simplest backdoor trigger. Brute force black-box testing would not work, but GCG would probably succeed.

A Tracr model inputs are lists that include a beginning-of-sentence token, e.g. ["BOS", 1, 2, 3]. These are tokenized and then the Tracr transformer is applied. Then an argmax is taken over the output, which is then detokenized by calling the .decoded attribute.

In [None]:
def check_if_all_true_and_len_is_n(bools: rasp.SOp, n: int) -> rasp.SOp:
  true_counts = count_true_vals(bools)
  is_all_true = true_counts == n
  return is_all_true.named("check_if_all_true_and_len_is_n")

def check_password(hardcoded_password: str) -> rasp.SOp:
  """outputs all True if the password is correct, else all False"""
  pw = hardcode_to_SOp(hardcoded_password)
  pw_bools = elementwise_EQ(rasp.tokens, pw)
  output = check_if_all_true_and_len_is_n(pw_bools, len(hardcoded_password))
  return output.named("pwcheck")

hardcoded_password = "1101"
test_input = "1101"

program = check_password(hardcoded_password)
max_seq_len = len(hardcoded_password)
# vocab = set(list("abcdefghijklmnopqrstuvwxyz")+["_"])
vocab = set(list("10"))
tracr_model = compiling.compile_rasp_to_model(program, vocab, max_seq_len, False, compiler_bos, compiler_pad)
run_tracr_model = create_tracr_runner_fn(tracr_model)

decoded_output = run_tracr_model(test_input)
print(decoded_output)

# 3-SAT verifier

In [None]:
def ev(rasp_object: rasp.RASPExpr, test_input=test_input) -> None:
  """
  evaluate a rasp object for debugging purposes
  e.g. write ev(is_shorter_than_n) instead of print(is_shorter_than_n)
  not passing test_input assumes that test_input is in scope
  """
  inp = create_input_list(test_input)[1:]
  print(f"evaluated {rasp_object.name} on test_input:", rasp.evaluate(rasp_object, inp), flush=True)

def var_to_pair(var: int)->str:
  """one-hot encodings for boolean inputs"""
  if int(var)==1: return "01"
  if int(var)==0: return "10"
  assert(False)

def format_vars(vars):
  """e.g. "1101" -> [0,1,0,1,1,0,0,1] """
  pairs = [var_to_pair(i) for i in vars]
  var_string = "".join(pairs)
  int_list = [int(x) for x in var_string]
  return int_list

def get_key_value(i):
    """maps [0,1,2,3,4,...] -> [-1,1,-2,2,-3,...]"""
    sign = -1 if i%2==0 else 1
    value = i//2 + 1
    return sign * value

def swap_indices_fn(i: int) -> int:
  """[0, 1, 2, ...] -> [2, 1, 4, 3, 6, 5 ...]"""
  is_even = i%2==0
  if is_even: return i+1
  else: return i-1

def check_if_len_smaller_than_n(n: int) -> rasp.SOp:
  inverted_one_hot_n = hardcode_to_SOp([True]*(n-1)+[False]).named("inverted_one_hot_n")
  inverted_one_hot_n = rasp.categorical(rasp.Map(lambda x: bool(x), inverted_one_hot_n)).named("inverted_one_hot_n")
  is_shorter_than_n = check_if_all_true(inverted_one_hot_n).named("is_shorter_than_n")
  return is_shorter_than_n

def verify_input_correctness(n_variables: int)-> rasp.SOp:
  """chech that input len >= n_variables * 2 and input bools are one-hot-encoded"""
  swapped_pairs_indices = rasp.Map(lambda x: swap_indices_fn(x), rasp.indices).named("swapped_pairs_indices")
  other_var_selector = rasp.Select(rasp.indices, swapped_pairs_indices, rasp.Comparison.EQ).named("other_var_selector")
  other_var = rasp.Aggregate(other_var_selector, rasp.tokens, default=None).named("other_var")
  var_1_xor_var_2 = rasp.SequenceMap(lambda x, y: bool(x) != bool(y), rasp.tokens, other_var).named("var_1_xor_var_2")
  is_correct_length = (~check_if_len_smaller_than_n(n_variables*2)).named("is_correct_length")
  input_pos_is_valid = elementwise_AND(var_1_xor_var_2, is_correct_length).named("input_pos_is_valid")
  input_is_valid = check_if_all_true(input_pos_is_valid).named("input_is_valid")
  return input_is_valid

def fetch_var_i_for_clauses(clauses, i) -> rasp.SOp:
  """takes a batch of clauses and variable position i (1,2 or 3), returns"""
  queries = hardcode_to_SOp(clauses[:,i]).named(f"clauses_var_{i}")
  clauses_var_i_selector = rasp.Select(var_keys, queries, rasp.Comparison.EQ).named(f"clause_var_{i}_selector")
  clauses_var_i_averaged = rasp.Aggregate(clauses_var_i_selector, rasp.tokens).named(f"clause_var_{i}_averaged")
  clauses_var_i_detected = rasp.categorical(rasp.Map(lambda x: x==1, clauses_var_i_averaged)).named(f"clauses_var_{i}_detected")
  return clauses_var_i_detected

def verify_2n_clause_batch(clauses: np.array, n_variables: int) -> rasp.SOp:
  """evaluate a batch of clauses (batching because can fit only 2n clauses in one attention head)"""
  c_vars = [fetch_var_i_for_clauses(clauses, i) for i in range(3)]
  clauses_var_OR = elementwise_OR(elementwise_OR(c_vars[0],c_vars[1]),c_vars[2]).named("clauses_batch_OR_result")
  output = check_if_all_true(clauses_var_OR).named("clauses_batch_satisfied")
  return output

def verify_3sat(clauses: np.array, n_variables: int) -> rasp.SOp:
  """outputs all True if the 3-SAT instance is satisfied, else all False"""
  n_clauses = len(clauses)
  clause_batch_size = 2*n_variables
  n_clause_batches = math.ceil(n_clauses/(clause_batch_size))
  n_padded_clauses = n_clause_batches*(clause_batch_size) - n_clauses
  if n_padded_clauses > 0:
    padded_clauses = np.tile(clauses[-1], (n_padded_clauses, 1))
    clauses = np.concatenate((clauses, padded_clauses), axis=0)
  clause_batch_verifications = []

  for i in range(n_clause_batches):
    clause_batch = clauses[i*clause_batch_size:(i+1)*clause_batch_size]
    batch_i = verify_2n_clause_batch(clause_batch, n_variables).named(f"verify_2n_clause_batch_{i}")
    clause_batch_verifications.append(batch_i)

  output = elementwise_AND_for_n_SOps(clause_batch_verifications).named("unverified_output")
  verified_input = verify_input_correctness(n_variables).named("verified_input")
  verified_output = elementwise_AND(output, verified_input).named("verified_output")
  return verified_output.named("3SAT_instance_satisfied")


# useful constants:
var_keys = rasp.Map(get_key_value, rasp.indices).named(f"variable_keys")  # -1 1 -2 2 -3 3 ...
positive_keys = rasp.Map(lambda x: abs(x), var_keys).named("positive_keys_correct_inp") # 1 1 2 2 3 3 ...


n_variables = 5
clauses = np.array([
    [1, 2, -3],  # clause 1 - x1 or x2 or NOT x3
    [-1, -4, 5], # clause ...
    [1, -2, 4],  # ...
    [-1, 3, -5],
    [2, -3, 4],
    [-2, -4, 5],
    [-1, -3, -5],
    [1, 4, -5],
    [1, -2, -4],
    [-1, 2, -5],
    [3, -4, 5],
    [-2, 3, -4],
    [-1, -3, 4],
    [2, 4, -5],
    [1, 2, 3],
    [-4, -5, 1],
    [2, 3, -1],
])

test_input_unformatted = "11111"  # 11111 = all variables True
test_input = format_vars(test_input_unformatted)  # one-hot-encoded bools


program = verify_3sat(clauses, n_variables)
max_seq_len = n_variables*2
vocab = set([1,0])
tracr_model = compiling.compile_rasp_to_model(program, vocab, max_seq_len, False, compiler_bos, compiler_pad)
run_tracr_model = create_tracr_runner_fn(tracr_model)

print("test input:", test_input)
decoded_output = run_tracr_model(test_input)
print("decoded_output:", decoded_output)

Now let's test the correctness of the Tracr program by comparing its outputs to a 3-SAT checker in Python.

In [None]:
from itertools import product

def sat_tester(clauses, variables):
  for c in clauses:
    found_satisfying_variable = False
    for i in range(3):
      c_needs_True = c[i] > 0
      c_ith_variable_index = int(abs(c[i])-1)
      v_is_True = int(variables[c_ith_variable_index]) > 0.5
      # print(c,i,c_needs_True == v_is_True)
      if c_needs_True == v_is_True:
        found_satisfying_variable = True
        break
    if found_satisfying_variable == False:
      return False
  return True

def generate_binary_strings(n):
    return [''.join(bits) for bits in product('01', repeat=n)]

def test_tracr_3sat_correctness():
  binary_strings = generate_binary_strings(n_variables)
  n_sat_assignments = 0
  for s in binary_strings:
    formatted_s = format_vars(s)
    tracr_result = run_tracr_model(formatted_s)[-1]
    sat_tester_result = sat_tester(clauses, s)
    if sat_tester_result!=tracr_result:
      print(f"mismatch: {s}, tester:{sat_tester_result}, tracr:{tracr_result}")
    if sat_tester_result:
      # print(f"Satisfying: {s}")
      n_sat_assignments+=1
    # assert sat_tester_result==tracr_result
  print(f"All {len(binary_strings)} binary strings passed the test")
  print(f"Nr of satisfying assignments: {n_sat_assignments}/{len(binary_strings)}")

sat_tester_result = sat_tester(clauses, test_input_unformatted)
print("sat_tester on test_input:", sat_tester_result)
test_tracr_3sat_correctness()

It works!

To create a different RASP program, you can overwrite the relevant variables (program, max_seq_len,vocab, tracr_model, run_tracr_model) and the rest of the notebook should work.

# Build the TransformerLens model

Now we extract Tracr/Craft parameters. This solves a bug with manually setting the TransformerLens unembed matrix, where it stopped working if the input space was different from the output space.

In [None]:
#@title Get Tracr parameters
from tracr.compiler import rasp_to_graph, nodes, craft_graph_to_model, basis_inference, expr_to_craft_graph
from tracr.craft import bases, vectorspace_fns

def get_unembed_params():
  mlp_exactness = 100  # Controls the approximation of the MLP layers.
  extracted = rasp_to_graph.extract_rasp_graph(program)
  graph, sources, sink = extracted.graph, extracted.sources, extracted.sink
  basis_inference.infer_bases(graph, sink, vocab, max_seq_len,)
  expr_to_craft_graph.add_craft_components_to_rasp_graph(graph, bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos), mlp_exactness=mlp_exactness,)
  craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources)
  tokens_value_set = (graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union({compiler_bos, compiler_pad}))
  tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label, tokens_value_set)
  indices_space = bases.VectorSpaceWithBasis.from_values(rasp.indices.label, range(max_seq_len))
  categorical_output = rasp.is_categorical(sink[nodes.EXPR])
  output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS])
  residual_space = bases.join_vector_spaces(craft_model.residual_space, tokens_space, indices_space, output_space)
  res_to_out = vectorspace_fns.project(residual_space, output_space)
  return res_to_out.matrix, len(output_space.basis)

tracr_unembed_matrix, tracr_d_vocab_out = get_unembed_params()

Extract the model config from the Tracr model, and create a blank HookedTransformer object

In [None]:
#@title Define TransformerLens hyperparameters

n_heads = tracr_model.model_config.num_heads
n_layers = tracr_model.model_config.num_layers
d_head = tracr_model.model_config.key_size
d_mlp = tracr_model.model_config.mlp_hidden_size
act_fn = "relu"
normalization_type = "LN"  if tracr_model.model_config.layer_norm else None
attention_type = "causal"  if tracr_model.model_config.causal else "bidirectional"

n_ctx = tracr_model.params["pos_embed"]['embeddings'].shape[0]
# Equivalent to length of vocab, with BOS and PAD at the end
d_vocab = tracr_model.params["token_embed"]['embeddings'].shape[0]
# Residual stream width, I don't know of an easy way to infer it from the above config.
d_model = tracr_model.params["token_embed"]['embeddings'].shape[1]

# Equivalent to length of vocab, WITHOUT BOS and PAD at the end because we never care about these outputs
# In practice, we always feed the logits into an argmax
# d_vocab_out = tracr_model.params["token_embed"]['embeddings'].shape[0] - 2  # incorrect! fixed below
d_vocab_out = tracr_d_vocab_out

cfg = HookedTransformerConfig(
    n_layers=n_layers,
    d_model=d_model,
    d_head=d_head,
    n_ctx=n_ctx,
    d_vocab=d_vocab,
    d_vocab_out=d_vocab_out,
    d_mlp=d_mlp,
    n_heads=n_heads,
    act_fn=act_fn,
    attention_dir=attention_type,
    normalization_type=normalization_type,
)
tl_model = HookedTransformer(cfg)

Extract the state dict, and do some reshaping so that everything has a n_heads dimension

In [None]:
#@title Construct TransformerLens model
# %%
sd = {}
sd["pos_embed.W_pos"] = tracr_model.params["pos_embed"]['embeddings']
sd["embed.W_E"] = tracr_model.params["token_embed"]['embeddings']
# Equivalent to max_seq_len plus one, for the BOS

# The unembed is just a projection onto the first few elements of the residual stream, these store output tokens
# This is a NumPy array, the rest are Jax Arrays, but w/e it's fine.
# sd["unembed.W_U"] = np.eye(d_model, d_vocab_out)  # incorrect! fixed below
sd["unembed.W_U"] = tracr_unembed_matrix


for l in range(n_layers):
    sd[f"blocks.{l}.attn.W_K"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/key"]["w"],
        "d_model (n_heads d_head) -> n_heads d_model d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_K"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/key"]["b"],
        "(n_heads d_head) -> n_heads d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/query"]["w"],
        "d_model (n_heads d_head) -> n_heads d_model d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/query"]["b"],
        "(n_heads d_head) -> n_heads d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.W_V"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/value"]["w"],
        "d_model (n_heads d_head) -> n_heads d_model d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_V"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/value"]["b"],
        "(n_heads d_head) -> n_heads d_head",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.W_O"] = einops.rearrange(
        tracr_model.params[f"transformer/layer_{l}/attn/linear"]["w"],
        "(n_heads d_head) d_model -> n_heads d_head d_model",
        d_head = d_head,
        n_heads = n_heads
    )
    sd[f"blocks.{l}.attn.b_O"] = tracr_model.params[f"transformer/layer_{l}/attn/linear"]["b"]

    sd[f"blocks.{l}.mlp.W_in"] = tracr_model.params[f"transformer/layer_{l}/mlp/linear_1"]["w"]
    sd[f"blocks.{l}.mlp.b_in"] = tracr_model.params[f"transformer/layer_{l}/mlp/linear_1"]["b"]
    sd[f"blocks.{l}.mlp.W_out"] = tracr_model.params[f"transformer/layer_{l}/mlp/linear_2"]["w"]
    sd[f"blocks.{l}.mlp.b_out"] = tracr_model.params[f"transformer/layer_{l}/mlp/linear_2"]["b"]
print(sd.keys())

In [None]:
#@title Convert weights to tensors and load into the tl_model
# %%capture
for k, v in sd.items():
    sd[k] = torch.tensor(np.array(v))

tl_model.load_state_dict(sd, strict=False)

In [None]:
#@title Create helper functions to do the tokenization and de-tokenization
# %%
INPUT_ENCODER = tracr_model.input_encoder
OUTPUT_ENCODER = tracr_model.output_encoder

def create_model_input(input, input_encoder=INPUT_ENCODER):
    encoding = input_encoder.encode(input)
    return torch.tensor(encoding).unsqueeze(dim=0)

def decode_model_output(logits, output_encoder=OUTPUT_ENCODER, bos_token=INPUT_ENCODER.bos_token):
    max_output_indices = logits.squeeze(dim=0).argmax(dim=-1)
    decoded_output = output_encoder.decode(max_output_indices.tolist())
    decoded_output_with_bos = [bos_token] + decoded_output[1:]
    return decoded_output_with_bos

In [None]:
#@title We can now run the model!
def get_tl_input_tokens_tensor(input: Iterable)-> torch.Tensor:
  input_list = create_input_list(input)
  return create_model_input(input_list)

def run_tl_model(input: Iterable)-> list:
  tl_input_tokens_tensor = get_tl_input_tokens_tensor(input)
  tl_logits = tl_model(tl_input_tokens_tensor)
  tl_decoded_output = decode_model_output(tl_logits)
  return tl_decoded_output

tracr_decoded_output = run_tracr_model(test_input)
tl_decoded_output = run_tl_model(test_input)
print("Original Tracr Decoding: ", tracr_decoded_output)
print("TransformerLens Decoding:", tl_decoded_output)

In [None]:
#@title Check whether the activations and outputs match
tracr_out = tracr_model.apply(create_input_list(test_input))
tl_input_tokens_tensor = get_tl_input_tokens_tensor(test_input)
logits, cache = tl_model.run_with_cache(tl_input_tokens_tensor)
# (cached all intermediate activations in the model)

for layer in range(tl_model.cfg.n_layers):
    attention_equal = np.isclose(cache["attn_out", layer].detach().cpu().numpy(), np.array(tracr_out.layer_outputs[2*layer])).all()
    mlp_equal = np.isclose(cache["mlp_out", layer].detach().cpu().numpy(), np.array(tracr_out.layer_outputs[2*layer+1])).all()
    assert(attention_equal and mlp_equal)
    print(f"Layer {layer} Attn Out Equality Check:", attention_equal)
    print(f"Layer {layer} MLP Out Equality Check: ", mlp_equal)

outputs_match = tracr_decoded_output == tl_decoded_output
print("Outputs match:", outputs_match)
assert(outputs_match)

In [None]:
# test_activations = cache["mlp_out", -1].detach().cpu().numpy()[0]
# plt.imshow(test_activations, interpolation='none')
# plt.show()

In [None]:
#@title TransformerLens Model
tl_model

# Run tests with TransformerLens and Tracr models

In [None]:
#@title Forward pass
# test_input = "password"
# test_input = "1101"
tracr_output = run_tracr_model(test_input)
tl_output = run_tl_model(test_input)
print(tracr_output)
print(tl_output)

The positions on the vertical axis are all tokens from the output vocab (which is automatically inferred and might be different from input vocab). The horizontal axis is the dimensions of the final residual stream.

In [None]:
#@title Plot the final residual stream

import plotly.express as px
final_residual_stream = cache["resid_post", -1].detach().cpu().numpy()[0]
# px.imshow(final_residual_stream,
# color_continuous_scale="Blues", labels={"x":"Residual Stream", "y":"Position"}, y=[str(i) for i in create_input_list(test_input)]).show("colab" if IN_COLAB else "")

# Setting up visualisations

This part is adapted from https://github.com/google-deepmind/tracr/blob/main/tracr/examples/Visualize_Tracr_Models.ipynb

In [None]:
#@title Plotting functions

plot_height = final_residual_stream.shape[1]//3
plot_width = final_residual_stream.shape[0]

def tidy_label(label, value_width=5):
  if ':' in label:
    label, value = label.split(':')
  else:
    value = ''
  return label + f":{value:>{value_width}}"


def add_residual_ticks(model, value_width=5, x=False, y=True):
  if y:
    plt.yticks(
            np.arange(len(model.residual_labels))+0.5,
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels],
            family='monospace',
            fontsize=20,
    )
  if x:
    plt.xticks(
            np.arange(len(model.residual_labels))+0.5,
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels],
            family='monospace',
            rotation=90,
            fontsize=20,
    )


def plot_computation_trace(model,
                           input_labels,
                           residuals_or_outputs,
                           add_input_layer=False,
                           figsize=(plot_width, plot_height),
                           show_last_n_layers=None,
                           residual_original_len=None):
  fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)
  value_width = max(map(len, map(str, input_labels))) + 1

  for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):
    ax.grid(True, lw=0.5)
    plt.sca(ax)
    plt.pcolormesh(layer[0].T, vmin=0, vmax=1)
    if i == 0:
      add_residual_ticks(model, value_width=value_width)
    if show_last_n_layers is not None:
      plt.xticks(
          np.arange(len(input_labels))+0.5,
          input_labels,
          rotation=90,
          fontsize=20,
      )
    if add_input_layer and i == 0:
      if show_last_n_layers is not None:
        title = 'Input'
      else:
        title = 'in'
    else:
      layer_no = i - 1 if add_input_layer else i
      if show_last_n_layers is not None:
        layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'
        title = f'{layer_type}{(residual_original_len-show_last_n_layers)//2 + layer_no // 2 + 1}'
      else:
        layer_type = 'A' if layer_no % 2 == 0 else 'M'
        title = f'{layer_type}{layer_no // 2 + 1}'
    plt.title(title, fontsize=20)


def plot_residuals_and_input(model, inputs, figsize=(plot_width, plot_height), show_last_n_layers=None):
  """Applies model to inputs, and plots the residual stream at each layer."""
  model_out = model.apply(inputs)
  residuals = np.concatenate([model_out.input_embeddings[None, ...],
                              model_out.residuals], axis=0)
  if show_last_n_layers is not None:
    residual_original_len = len(residuals)
    residuals = residuals[-show_last_n_layers:]
    add_input_layer = False
  else:
    residual_original_len = None
    add_input_layer = True
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=residuals,
      add_input_layer=add_input_layer,
      figsize=figsize,
      show_last_n_layers=show_last_n_layers,
      residual_original_len=residual_original_len)


def plot_layer_outputs(model, inputs, figsize=(plot_width, plot_height), show_last_n_layers=None):
  """Applies model to inputs, and plots the outputs of each layer."""
  model_out = model.apply(inputs)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=model_out.layer_outputs,
      add_input_layer=False,
      figsize=figsize)


In [None]:
#@title Forward pass
# test_input_list = ["bos", 3, 4, 1]
test_input_2 = test_input[:]
# test_input_2 = "1000"
test_input_list = create_input_list(test_input_2)
print(test_input_list, "-inputs")
tracr_model.apply(test_input_list).decoded

# Visualizing Tracr Models

In [None]:
#@title Plot final residual stream
plot_residuals_and_input(model=tracr_model, inputs=test_input_list, show_last_n_layers=2)

In [None]:
#@title Plot residual stream
# plots the entire stream - can be slow
plot_residuals_and_input(model=tracr_model, inputs=test_input_list)

In [None]:
#@title Plot layer outputs
# plot_layer_outputs(model=tracr_model, inputs = test_input_list)