In [1]:
from tracr.rasp import rasp

def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
  """Returns vals sorted by < relation on keys.

  Only supports unique keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
  """
  smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
  target_pos = rasp.SelectorWidth(smaller).named("target_pos")
  sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
  return rasp.Aggregate(sel_new, vals).named("sort")

In [2]:
from tracr.compiler import compiling

sort = make_sort_unique(rasp.tokens, rasp.tokens)
bos = "BOS"
model = compiling.compile_rasp_to_model(
    sort,
    vocab={0, 1, 2, 3},
    max_seq_len=5,
    compiler_bos=bos,
)
print(model.model_config)

TransformerConfig(num_heads=1, num_layers=2, key_size=7, mlp_hidden_size=12, dropout_rate=0.0, activation_function=<jax._src.custom_derivatives.custom_jvp object at 0x7bec213b7350>, layer_norm=False, causal=False)


In [3]:
import haiku as hk
import jax

params = model.params

hk_model = hk.transform(model.get_compiled_model)
hk_model = hk_model.apply(params, jax.random.PRNGKey(42))
hk_model.use_unembed_argmax = False

In [4]:
def forward(x):
    return hk_model(x)

forward = hk.transform(forward)

In [5]:
output = forward.apply(model.params, jax.random.PRNGKey(42), jax.numpy.array([[4, 3, 2, 1]]))
output.unembedded_output

Array([[[0.0000000e+00, 6.2030554e-09, 6.2030554e-09, 6.2030554e-09],
        [0.0000000e+00, 1.0000000e+00, 3.8477894e-17, 3.8477894e-17],
        [0.0000000e+00, 3.8477894e-17, 1.0000000e+00, 3.8477894e-17],
        [0.0000000e+00, 3.8477894e-17, 3.8477894e-17, 1.0000000e+00]]],      dtype=float32)

In [6]:
import haiku as hk
import jax
import jax.numpy as jnp

# TODO: ignore first position
def loss_fn(params, x, y):
  logits = forward.apply(params, jax.random.PRNGKey(42), x).unembedded_output
  labels = jax.nn.one_hot(y, logits.shape[-1])
  return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]

@jax.jit
def update(params, x, y, lr=0.0001):
  grads = jax.grad(loss_fn)(params, x, y)
  return jax.tree_util.tree_map(
      lambda p, g: p - lr * g, model.params, grads
  )

In [7]:
import tqdm

def train(params, X, Y, n_epochs=1, batch_size=8, lr=0.0001):
    for _ in tqdm.trange(n_epochs):
        for i in range(0, len(X), batch_size):
            x = X[i:i + batch_size]
            y = Y[i:i + batch_size]
            params = update(params, x, y, lr)
    return params

In [8]:
import numpy as np
jax.config.update('jax_default_matmul_precision', 'float32')

acceptedTokens = [1,2,3]
maxSeqLength = 5
size = 1000
X = []
Y = []

for i in range(size):
    # TODO: implement padding for training
    # inputLength = np.random.randint(2, maxSeqLength+1)  #Uniformly distributed between 2 and max length
    inputLength = maxSeqLength

    inputSeq = []
    outputSeq = []
    for t in np.random.choice(acceptedTokens, inputLength):
        inputSeq.append(t)
        outputSeq.append(t)

    inputSeq.insert(0,"BOS")
    outputSeq.sort()
    outputSeq.insert(0,0)

    inputSeq = jax.numpy.array(model.input_encoder.encode(inputSeq))
    outputSeq = jax.numpy.array(model.output_encoder.encode(outputSeq))

    X.append(inputSeq)
    Y.append(outputSeq)

X = jax.numpy.array(X)
Y = jax.numpy.array(Y)

X, Y
print(X.shape)

# Remove duplicates from X, and the corresponding Y
X, indices = np.unique(X, return_index=True, axis=0)
Y = Y[indices]
print(X.shape)

# Split test and validation
split = int(X.shape[0] * 0.8)
X_train, X_test = X[:split], X[split:]
Y_train, Y_test = Y[:split], Y[split:]
print(X_train.shape, X_test.shape)

(1000, 6)
(242, 6)
(193, 6) (49, 6)


In [9]:
new_params = train(model.params, X_train, Y_train, n_epochs=1000, batch_size=8, lr=0.0001)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:10<00:00, 96.60it/s]


In [10]:
forward.apply(new_params, jax.random.PRNGKey(42), jax.numpy.array([[4, 1, 3, 2, 3, 1]])).unembedded_output.argmax(axis=-1)

Array([[0, 0, 2, 2, 3, 0]], dtype=int32)