In [1]:
import jax
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns 
import haiku as hk

# The default of float16 can lead to discrepancies between outputs of
# the compiled model and the RASP program.
jax.config.update('jax_default_matmul_precision', 'float32')

from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp

from scipy.optimize import linear_sum_assignment

In [6]:
#@title Define RASP programs
def get_program(program_name, max_seq_len):
  """Returns RASP program and corresponding token vocabulary."""
  if program_name == "length":
    vocab = {"a", "b", "c", "d"}
    program = lib.make_length()
  elif program_name == "frac_prevs":
    vocab = {"a", "b", "c", "x"}
    program = lib.make_frac_prevs((rasp.tokens == "x").named("is_x"))
  elif program_name == "dyck-2":
    vocab = {"(", ")", "{", "}"}
    program = lib.make_shuffle_dyck(pairs=["()", "{}"])
  elif program_name == "dyck-3":
    vocab = {"(", ")", "{", "}", "[", "]"}
    program = lib.make_shuffle_dyck(pairs=["()", "{}", "[]"])
  elif program_name == "sort":
    vocab = {i for i in range(1, max_seq_len + 1)}
    program = lib.make_sort(
        rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)
  elif program_name == "sort_unique":
    vocab = {i for i in range(1, max_seq_len + 1)}
    program = lib.make_sort_unique(rasp.tokens, rasp.tokens)
  elif program_name == "hist":
    vocab = {"a", "b", "c", "d"}
    program = lib.make_hist()
  elif program_name == "sort_freq":
    vocab = {"a", "b", "c", "d"}
    program = lib.make_sort_freq(max_seq_len=max_seq_len)
  elif program_name == "pair_balance":
    vocab = {"(", ")"}
    program = lib.make_pair_balance(
        sop=rasp.tokens, open_token="(", close_token=")")
  else:
    raise NotImplementedError(f"Program {program_name} not implemented.")
  return program, vocab

In [7]:
#@title: Assemble model
program_name = "sort_unique"  #@param ["length", "frac_prevs", "dyck-2", "dyck-3", "sort", "sort_unique", "hist", "sort_freq", "pair_balance"]
max_seq_len = 10 #@param {label: "Test", type: "integer"}

program, vocab = get_program(program_name=program_name,
                             max_seq_len=max_seq_len)

print(f"Compiling...")
print(f"   Program: {program_name}")
print(f"   Input vocabulary: {vocab}")
print(f"   Context size: {max_seq_len}")

from tracr.datasets.generated_lib import program_1
assembled_model = compiling.compile_rasp_to_model(
      program=program,
      vocab=vocab,
      max_seq_len=max_seq_len,
      causal=False,
      use_dropout=False, 
      embedding_size=30,
      unembed_at_every_layer=True,
      compiler_bos="bos",
      compiler_pad="pad",
      mlp_exactness=100)

print("Done.")

# dict_keys(['token_embed', 'pos_embed', 'transformer/layer_0/attn/query', 'transformer/layer_0/attn/key', 'transformer/layer_0/attn/value', 'transformer/layer_0/attn/linear', 'transformer/layer_0/mlp/linear_1', 'transformer/layer_0/mlp/linear_2', 'transformer/layer_1/attn/query', 'transformer/layer_1/attn/key', 'transformer/layer_1/attn/value', 'transformer/layer_1/attn/linear', 'transformer/layer_1/mlp/linear_1', 'transformer/layer_1/mlp/linear_2'])
# dict_keys(['token_embed', 'pos_embed', 'compressed_transformer/layer_0/attn/query', 'compressed_transformer/layer_0/attn/key', 'compressed_transformer/layer_0/attn/value', 'compressed_transformer/layer_0/attn/linear', 'compressed_transformer/layer_0/mlp/linear_1', 'compressed_transformer/layer_0/mlp/linear_2', 'compressed_transformer/layer_1/attn/query', 'compressed_transformer/layer_1/attn/key', 'compressed_transformer/layer_1/attn/value', 'compressed_transformer/layer_1/attn/linear', 'compressed_transformer/layer_1/mlp/linear_1', 'compressed_transformer/layer_1/mlp/linear_2'])


Compiling...
   Program: sort_unique
   Input vocabulary: {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
   Context size: 10
Done.


In [8]:
#@title Forward pass
assembled_model.apply(["bos", 3, 4, 1, 7, 2, 6, 5, 8, 9, 10]).decoded
# assembled_model.apply(['bos', 'a', 'b', 'c', 'x']).decoded

['bos', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [9]:
import jax.numpy as jnp

@hk.transform
def forward_fn(inputs): 
    compiled_model = assembled_model.get_compiled_model()
    return compiled_model(inputs)

dummy = jnp.zeros((1, 10), dtype=jnp.int32)
rng = jax.random.PRNGKey(0)
params = forward_fn.init(rng, dummy)
params.keys()
jax.tree_map(lambda x: x.shape, params)

{'compressed_transformer': {'w_emb': (30, 45)},
 'compressed_transformer/layer_0/attn/key': {'b': (12,), 'w': (45, 12)},
 'compressed_transformer/layer_0/attn/linear': {'b': (45,), 'w': (12, 45)},
 'compressed_transformer/layer_0/attn/query': {'b': (12,), 'w': (45, 12)},
 'compressed_transformer/layer_0/attn/value': {'b': (12,), 'w': (45, 12)},
 'compressed_transformer/layer_0/mlp/linear_1': {'b': (22,), 'w': (45, 22)},
 'compressed_transformer/layer_0/mlp/linear_2': {'b': (45,), 'w': (22, 45)},
 'compressed_transformer/layer_1/attn/key': {'b': (12,), 'w': (45, 12)},
 'compressed_transformer/layer_1/attn/linear': {'b': (45,), 'w': (12, 45)},
 'compressed_transformer/layer_1/attn/query': {'b': (12,), 'w': (45, 12)},
 'compressed_transformer/layer_1/attn/value': {'b': (12,), 'w': (45, 12)},
 'compressed_transformer/layer_1/mlp/linear_1': {'b': (22,), 'w': (45, 22)},
 'compressed_transformer/layer_1/mlp/linear_2': {'b': (45,), 'w': (22, 45)},
 'pos_embed': {'embeddings': (11, 45)},
 'toke