In [92]:
from tracr.rasp import rasp

def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
  """Returns vals sorted by < relation on keys.

  Only supports unique keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
  """
  smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
  target_pos = rasp.SelectorWidth(smaller).named("target_pos")
  sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
  return rasp.Aggregate(sel_new, vals).named("sort")

In [93]:
from tracr.compiler import compiling

sort = make_sort_unique(rasp.tokens, rasp.tokens)
bos = "BOS"
model = compiling.compile_rasp_to_model(
    sort,
    vocab={0, 1, 2, 3},
    max_seq_len=5,
    compiler_bos=bos,
)
print(model.model_config)

TransformerConfig(num_heads=1, num_layers=2, key_size=7, mlp_hidden_size=12, dropout_rate=0.0, activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x795ce4279a90>, layer_norm=False, causal=False)


In [94]:
import haiku as hk
import jax

params = model.params

hk_model = hk.transform(model.get_compiled_model)
hk_model = hk_model.apply(params, jax.random.PRNGKey(42))
hk_model.use_unembed_argmax = False

In [95]:
def forward(x):
    return hk_model(x)

forward = hk.transform(forward)

In [96]:
output = forward.apply(model.params, jax.random.PRNGKey(42), jax.numpy.array([[4, 3, 2, 1]]))
output.unembedded_output

Array([[[0.0000000e+00, 6.2030554e-09, 6.2030554e-09, 6.2030554e-09],
        [0.0000000e+00, 1.0000000e+00, 3.8477894e-17, 3.8477894e-17],
        [0.0000000e+00, 3.8477894e-17, 1.0000000e+00, 3.8477894e-17],
        [0.0000000e+00, 3.8477894e-17, 3.8477894e-17, 1.0000000e+00]]],      dtype=float32)

In [111]:
import haiku as hk
import jax
import jax.numpy as jnp

# TODO: ignore first position
def loss_fn(params, x, y):
  logits = forward.apply(params, jax.random.PRNGKey(42), x).unembedded_output
  labels = jax.nn.one_hot(y, logits.shape[-1])
  return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]

LEARNING_RATE = 0.0001

# @jax.jit
def update(params, x, y):
  grads = jax.grad(loss_fn)(params, x, y)
  # print(grads)
  return jax.tree_util.tree_map(
      lambda p, g: p - LEARNING_RATE * g, model.params, grads
  )

In [116]:
x_untoken = [bos, 2, 1, 3]
print(model.input_encoder.encode(x_untoken))
x = jax.numpy.array([[4, 2, 1, 3]])
y = jnp.array([[1, 3, 2, 1]])

print("BEFORE")
print(model.apply(x_untoken).decoded)

print("UPDATE")
new_params = update(model.params, x, y)
old_params = model.params
model.params = new_params

print("AFTER")
print(model.apply(x_untoken).decoded)
model.params = old_params

[4, 2, 1, 3]
BEFORE
['BOS', 1, 2, 3]
UPDATE
AFTER
['BOS', 1, 1, 3]
