In [4]:
from tracr.compiler import compiling
from tracr.compiler import lib
from tracr.rasp import rasp
     
import torch as t
from torch import Tensor, nn , optim
from torch.utils.data import DataLoader
import plotly.express as px
from tqdm.notebook import tqdm
from dataclasses import dataclass
import pprint
import plotly.express as px
import einops


from utils.data import *
from utils.model import *

device = t.device('cuda' if t.cuda.is_available() else 'cpu')

# Tracr sorting model
It translates the inputs into unique keys and sort them using unique sort. which is implemented like this:
```python

```

In [2]:

def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
  """Returns vals sorted by < relation on keys.

  Only supports unique keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
  """
  smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
  target_pos = rasp.SelectorWidth(smaller).named("target_pos")
  sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
  return rasp.Aggregate(sel_new, vals).named("sort")

def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
              min_key: float) -> rasp.SOp:
    keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
                            rasp.indices)

    return lib.make_sort_unique(vals, keys)

In [2]:
program_name = 'sort'
input_size = 10
vocab = {*range(input_size*10)}
program = lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=input_size, min_key=0)


assembled_model = compiling.compile_rasp_to_model(
      program=program,
      vocab=vocab,
      max_seq_len=input_size,
      causal=False,
      compiler_bos="bos",
      compiler_pad="pad",
      mlp_exactness=100)

CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Tracr model Config
After compling we see that the model has 2 layers, which is less that the maximum sequence length, meaning that the model is able to sort sequences longer than it's model length

In [3]:
pprint.pprint(assembled_model.model_config)

TransformerConfig(num_heads=1,
                  num_layers=3,
                  key_size=102,
                  mlp_hidden_size=1000,
                  dropout_rate=0.0,
                  activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x7fd51cd19810>,
                  layer_norm=False,
                  causal=False)


In [11]:
model, cfg = empty_model_from_tracr(assembled_model)
tr_model,cfg = empty_model_from_tracr(assembled_model)
tr_model = load_tracr_weights(tr_model, assembled_model, cfg)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
losses = []

TRAIN = False

if TRAIN:
    epochs = 100
    batch_size = 1024

    for epoch in tqdm(range(epochs)):
        for input_seq, target_seq in train_loader(batch_size*10, len(vocab), input_size, batch_size):
            optimizer.zero_grad()
            output = model(input_seq)
            loss = criterion(output.view(-1, output.shape[-1]), target_seq.view(-1))  # Flatten output and target for the loss function
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
else:
    save_path = 'models/sort.pth'

    model.load_state_dict(t.load(save_path))

I might have lost the data for the original training loss curve

In [32]:
x,y = generate_data(1, len(vocab), input_size)
pred = model(x).argmax(-1).squeeze(0)
pred_assembled = tr_model(x).argmax(-1).squeeze(0)

print('Data in:\t', x.squeeze().tolist())
print('Data out:\t', y.squeeze().tolist())
print('Model out:\t', pred.tolist())
print('Tracr out:\t', pred_assembled.tolist())

Data in:	 [24, 88, 68, 53, 97, 16, 41, 32, 49, 51]
Data out:	 [16, 24, 32, 41, 49, 51, 53, 68, 88, 97]
Model out:	 [16, 24, 32, 41, 49, 51, 53, 68, 88, 97]
Tracr out:	 [10, 0, 1, 2, 3, 4, 5, 6, 7, 8]


In [35]:
x, y = generate_data(1, len(vocab), input_size)
logits, cache = model.run_with_cache(x)
logits_tr, cache_tr = tr_model.run_with_cache(x)

print('Running models with cache')
print('Data in:\t', x.squeeze().tolist())
print('Data out:\t', y.squeeze().tolist())
print('Model out:\t', logits.argmax(-1).squeeze().tolist())
print('Tracr out:\t', logits_tr.argmax(-1).squeeze().tolist())

Original in:	 [85, 63, 28, 48, 73, 82, 33, 94, 70, 90]
Original out:	 [28, 33, 48, 63, 70, 73, 82, 85, 90, 94]
Model out:	 [28, 33, 48, 63, 70, 73, 82, 85, 90, 94]
Tracr out:	 [10, 0, 1, 2, 3, 4, 5, 6, 7, 8]


In [25]:
print('Model cache: ', cache)
print('Tracr cache: ', cache_tr)

Model cache:  ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.attn.hook_q', 'blocks.2.attn.hook_k', 'blocks.2.attn.hook_v', 'blocks.2.attn.hook_attn_scores', 'blocks.2.attn.hook_pattern', 'blocks.2.attn.hook_z', 'blocks.2.hook_attn_out', 'blocks.2.hook_resid_mid',

In [26]:
final_layer_logits = logits[-1].tolist()
final_layer_logits_tr = logits_tr[-1].tolist()


px.imshow(final_layer_logits, title="Unified Transformer logits")


In [27]:

px.imshow(final_layer_logits_tr, title="Tracr logits")