# Showcasing different results when encoded vs. decoded

This notebook showcases a problem where, for the same input, the encoded and decoded outputs are different. The problem happens after we train the compiled tranformer model.

This is a minimal example to showcase the issue.

We start by loading an example program.

In [11]:
from tracr.compiler import lib
from tracr.rasp import rasp
from tracr.compiler import compiling

tokens = {1,2,3,4,5,6}
max_size = 5
sort = lib.make_sort(rasp.tokens, rasp.tokens, max_seq_len=max_size, min_key=min(tokens))
model = compiling.compile_rasp_to_model(sort, tokens, max_size, compiler_bos="BOS")

In [12]:
model.apply(["BOS", 4, 3, 5, 6, 1]).decoded

['BOS', 1, 3, 4, 5, 6]

We generate some data to train and evaluate the model.

In [20]:
import numpy as np

np.random.seed(42)

data = []
dataset = set()
data_size = 1000
for i in range(data_size):
    inputLength = np.random.randint(2, max_size+1)

    inputSeq = []
    outputSeq = []
    for t in np.random.choice(list(tokens), inputLength):
        inputSeq.append(t)
        outputSeq.append(t)

    inputSeq.insert(0,"BOS")
    outputSeq.sort()
    outputSeq.insert(0,"BOS")

    if tuple(inputSeq) not in dataset:
        dataset.add(tuple(inputSeq))
        data.append((inputSeq, outputSeq))

print(len(data))
print(data[0])

655
(['BOS', np.int64(4), np.int64(5), np.int64(3), np.int64(5)], ['BOS', np.int64(3), np.int64(4), np.int64(5), np.int64(5)])


In [21]:
import jax.numpy as jnp

inputEncoder = model.input_encoder
outputEncoder = model.output_encoder

print(inputEncoder.encoding_map)
print(outputEncoder.encoding_map)

# outputEncoder does not support pad tokens, so we use a filler token that is ignored by the loss function
fillerToken = next(iter(outputEncoder.encoding_map))

X = []
Y = []

for inputSeq, outputSeq in data:
    x = []
    y = []
    for i in range(max_size+1):
        if i < len(inputSeq):   #Assumes that input is same size as output
            x.append(inputSeq[i])
            y.append(outputSeq[i])
        else:
            x.append("compiler_pad")
            y.append(fillerToken)

    y[0] = fillerToken

    X.append(inputEncoder.encode(x))
    Y.append(outputEncoder.encode(y))

   
X = jnp.array(X)
Y = jnp.array(Y)

split = int(X.shape[0] * 0.90)
X_train, X_test = X[:split], X[split:]
Y_train, Y_test = Y[:split], Y[split:]
split = int(X_train.shape[0] * 0.85)
X_train, X_val = X_train[:split], X_train[split:]
Y_train, Y_val = Y_train[:split], Y_train[split:]

print(X_train.shape, X_val.shape, X_test.shape)
print(Y_train.shape, Y_val.shape, Y_test.shape)

print(X_train[0])
print(Y_train[0])

{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 'BOS': 6, 'compiler_pad': 7}
{1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5}
(500, 6) (89, 6) (66, 6)
(500, 6) (89, 6) (66, 6)
[6 3 4 2 4 7]
[0 2 3 4 4 0]


Training setup

In [22]:
import jax
from typing import NamedTuple
import haiku as hk
import optax

class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState
    step: jax.Array

def optimiser(lr) -> optax.GradientTransformation:
    return optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(lr),
    )

def forward(x):
    compiled_model = model.get_compiled_model()
    compiled_model.use_unembed_argmax = False
    return compiled_model(x, use_dropout=False)

@hk.without_apply_rng
@hk.transform
def loss_fn(x, y, padToken):
    # Loss is the average negative log-likelihood per token (excluding the first token and padding tokens)
    logits = forward(x).unembedded_output
    log_probs = jax.nn.log_softmax(logits)
    one_hot_targets = jax.nn.one_hot(y, logits.shape[-1])
    log_likelihood = jnp.sum(one_hot_targets * log_probs, axis=-1)
    # Mask the first token (BOS)
    mask = jnp.ones_like(log_likelihood)
    mask = mask.at[:, 0].set(0.0)
    # Mask the padding tokens
    padMask = jnp.where(x!=padToken, mask, 0.0)
    # Return the average negative log-likelihood per token
    return -jnp.mean(log_likelihood * padMask) / jnp.sum(padMask)

@jax.jit
def update(state: TrainingState, x, y, lr: float, padToken) -> TrainingState:
    loss_and_grads_fn = jax.value_and_grad(loss_fn.apply)
    loss, grads = loss_and_grads_fn(state.params, x, y, padToken)
    updates, opt_state = optimiser(lr).update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    metrics = {"step": state.step, "loss": loss}
    return TrainingState(params, opt_state, step=state.step+1), metrics

@jax.jit
def init(initial_params: hk.Params, lr: float) -> TrainingState:
    initial_opt_state = optimiser(lr).init(initial_params)
    return TrainingState(
        params=initial_params,
        opt_state=initial_opt_state,
        step=jnp.array(0),
    )

In [16]:
! pip install tqdm




[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [23]:
#Randomize starting weights
PRNGSeq = hk.PRNGSequence(42)
randomParams = jax.tree_util.tree_map(
    lambda p: jax.random.normal(next(PRNGSeq), p.shape), model.params
)
model.params = randomParams

In [24]:
import tqdm

padToken = model.input_encoder.encoding_map["compiler_pad"]
num_epochs = 1500
batch_size = 256
lr = 1e-4
state = init(model.params, lr)

for epoch in tqdm.trange(num_epochs):
    for i in range(0, len(X_train), batch_size):
        x = X_train[i:i + batch_size]
        y = Y_train[i:i + batch_size]
        state, metric = update(state, x, y, lr, padToken)
        model.params = state.params

100%|██████████| 1000/1000 [01:11<00:00, 14.03it/s]


Now, we run the model with both encoded and decoded inputs/outputs to showcase the issue.

In [28]:
forward_fn = hk.without_apply_rng(hk.transform(forward))

different_samples = []
accuracy_encoded = 0
accuracy_decoded = 0
ind = 0                                         #
for sample in tqdm.tqdm(data):
    # Decoded version
    decoded_output = model.apply(sample[0]).decoded

    # Encoded version
#    encoded_input = model.input_encoder.encode(sample[0])
    encoded_input = X[ind]                      #
    logits = forward_fn.apply(model.params, jax.numpy.array([encoded_input])).unembedded_output
    encoded_output = jnp.argmax(logits, axis=-1)[0]
#    encoded_output = ["BOS"] + model.output_encoder.decode(encoded_output.tolist()[1:])
    mask = jnp.ones_like(x)                     #
    mask = mask.at[0].set(0)                    #
    padMask = jnp.where(x!=padToken, mask, 0)   #


    if decoded_output == sample[1]:
        accuracy_decoded += 1
    #if encoded_output == sample[1]:
    if jnp.all(encoded_output*padMask == Y[ind]*padMask):
        accuracy_encoded += 1

    ind += 1                                    #

    #if encoded_output != decoded_output:
    #    different_samples.append((sample[0], encoded_output, decoded_output))

print(f"Accuracy encoded: {accuracy_encoded / len(data)}")
print(f"Accuracy decoded: {accuracy_decoded / len(data)}")

print(f"Different samples: {len(different_samples)}")
if len(different_samples) > 0:
    print(different_samples[0])


100%|██████████| 655/655 [00:45<00:00, 14.50it/s]

Accuracy encoded: 0.004580152671755725
Accuracy decoded: 0.0030534351145038168
Different samples: 0



