In [1]:

import functools  # Used for creating partial functions
import tqdm.notebook  as tqdm # Used for displaying progress bars
import matplotlib.pyplot as plt  # Used for plotting graphs

In [2]:
import tensorflow as tf# We only import it for the tokenizer
import jax
from jax import numpy as jnp
import jax.random as random
import numpy as np

In [3]:
from flax import linen as nn
import optax
from flax.training.train_state import TrainState
import pandas as pd

In [4]:
jax.config.update("jax_enable_x64", True)

In [5]:
def predict(params, inputs):
    """ Implementation of the MLP architecture."""
    outputs = None
    for w, b in params:
        outputs = jnp.matmul(w, inputs.transpose()) + b
        outputs = outputs.transpose()
        inputs = jax.nn.relu(outputs)
    return outputs

In [6]:
def loss(params, batch):
    """ Implementation of the square loss function."""
    inputs, targets = batch
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

In [7]:
# prepare data
# here: target (y_train) is a linear function of input (x_train) plus some noise

key = random.PRNGKey(0)
num_examples = 10_000
dim = 100
x_train = random.normal(key, (num_examples, dim))
w = random.normal(key, (dim,))
y_train =jnp.dot(x_train, w) + 0.2 * random.normal(key, (num_examples,))

x_train = x_train.astype(jnp.float32)
y_train = y_train.astype(jnp.float32)

batch = (x_train, y_train)

In [8]:
# initialize model parameters
W1 = jnp.identity(dim)  # identity matrix
b1 = 0.

W2 = random.normal(key, (dim,))
b2 = 0.

params = [(W1, b1), (W2, b2)]  # two layers

In [9]:
loss(params, batch)

Array(48.71164087, dtype=float64)

In [10]:
%timeit loss(params, batch)

6.92 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
# Text Tokenization

In [12]:
def sample_difference_dataset(
        dataset_size: int,
        lengths: list[int],
        k: int,
):
    key = random.PRNGKey(0)  # Initialize the key
    data_all = []
    for length in lengths:
        if length < 2:
            raise ValueError("The length of the expression must be at least two.")

        # Convert JAX array to integer using item()
        key, subkey = random.split(key)
        length_n = random.randint(subkey, (dataset_size,), minval=1, maxval=length - 1)
        length_m = length - 1 - length_n

        # Generate random integers for each length with int64 dtype
        integer_n = []
        integer_m = []
        for len_n, len_m in zip(length_n, length_m):
            key, subkey_n = random.split(key)
            maxval_n = k ** int(len_n) - 1
            integer_n.append(random.randint(subkey_n, (), minval=1, maxval=maxval_n, dtype=jnp.int64))

            key, subkey_m = random.split(key)
            maxval_m = k ** int(len_m) - 1
            integer_m.append(random.randint(subkey_m, (), minval=1, maxval=maxval_m, dtype=jnp.int64))

        # Calculate differences and signs
        diff_sign = [int(x > y) for x, y in zip(integer_n, integer_m)]

        # Convert to strings and create expressions
        integer_n_strs = [str(x.item())[::-1] for x in integer_n]
        integer_m_strs = [str(x.item())[::-1] for x in integer_m]
        expressions = [f"{n}-{m}" for n, m in zip(integer_n_strs, integer_m_strs)]

        # Combine expressions with signs
        data = list(zip(expressions, diff_sign))
        data_all.extend(data)

    # Remove duplicates and shuffle
    data_all = list(set(data_all))
    np.random.shuffle(data_all)
    return data_all

In [13]:
MAX_TRAIN_LENGTH = 10  # the maximum length allowed in the training split
MAX_TEST_LENGTH = 15  # the maximum length allowed in the test split

train_ds = sample_difference_dataset(
        dataset_size=2500,
        lengths=list(range(3, MAX_TRAIN_LENGTH + 1)),
        k=10,
)
test_ds = sample_difference_dataset(
        dataset_size=1000,
        lengths=list(range(MAX_TRAIN_LENGTH + 1, MAX_TEST_LENGTH + 1)),
        k=10,
)

print(f"Train dataset size {len(train_ds)}")
print(f"Test dataset size {len(test_ds)}")

Train dataset size 15769
Test dataset size 5000


In [14]:
it = iter(train_ds)

In [15]:
for _ in range(5):
    text, label = next(it)
    print("text: ", text)
    print("label: ", label)
    print()

text:  57-995
label:  0

text:  816-09
label:  1

text:  0257-73
label:  1

text:  7-689
label:  0

text:  9978-9345
label:  1



In [16]:
# size of corpus to build the tokenizer
corpus_size = 5_000  #@param = 'int'

# size of the vocabulary
vocab_size = 12  #@param = 'int'

# maximum length of examples in tokens
max_len = MAX_TEST_LENGTH + 1  #@param = 'int'

# pad value
pad_value = 0  #@param = 'int'


In [17]:
corpus = [text for text, _ in train_ds[:corpus_size]]

In [18]:
corpus[:10]

['57-995',
 '816-09',
 '0257-73',
 '7-689',
 '9978-9345',
 '06-8',
 '8-90912',
 '5114457-07',
 '1530812-4',
 '933-6']

In [19]:
# now, we build the tokenizer
tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=vocab_size,
        oov_token=None,
        char_level=True,
)
tokenizer.fit_on_texts(corpus)

In [20]:


# note how the tokenizer figured out it was best to tokenize each digit separately
tokenizer.index_word

{1: '-',
 2: '5',
 3: '2',
 4: '3',
 5: '6',
 6: '7',
 7: '8',
 8: '1',
 9: '4',
 10: '9',
 11: '0'}

In [21]:
# Example usage:
print("original text: ", text)

# tokenize text
tokens = tokenizer.texts_to_sequences([text])
print("tokens: ", tokens)
print("number of tokens: ", len(tokens[0]))

original text:  9978-9345
tokens:  [[10, 10, 6, 7, 1, 10, 4, 9, 2]]
number of tokens:  9


In [22]:


# we can see the actual tokens by converting each token individually to text
print(tokenizer.sequences_to_texts([token.tolist() for token in jnp.array(tokens).reshape((-1, 1))]))

['9', '9', '7', '8', '-', '9', '3', '4', '5']


In [23]:
# Let's examine the distribution of tokens in the corpus:
print("Token frequency:")
dict(list(tokenizer.word_counts.items()))

Token frequency:


{'5': 3296,
 '7': 3235,
 '-': 5000,
 '9': 2812,
 '8': 3187,
 '1': 3169,
 '6': 3237,
 '0': 2079,
 '2': 3266,
 '3': 3243,
 '4': 3136}

In [24]:
# preprocessing the data


In [25]:
def preprocess_function(text, label):
    # Tokenize the text
    tokens = tokenizer.texts_to_sequences([text])
    # Pad the sequences
    tokens = tf.keras.preprocessing.sequence.pad_sequences(tokens, maxlen=max_len, padding='post', value=pad_value)
    # Convert to torch tensors of a type int64
    tokens = jnp.array(tokens[0]).astype(jnp.int64)
    label = jnp.array(label).astype(jnp.int64)
    # Return the tokens and label
    return tokens, label


In [26]:
# Apply the preprocessing function to the training and test datasets
print("preprocessing training examples ... ")
x_train = []
y_train = []
for text, label in tqdm.tqdm(train_ds):
    tokens, label = preprocess_function(text, label)
    x_train.append(tokens)
    y_train.append(label)

print("preprocessing test examples ... ")
x_test = []
y_test = []
for text, label in tqdm.tqdm(test_ds):
    tokens, label = preprocess_function(text, label)
    x_test.append(tokens)
    y_test.append(label)

preprocessing training examples ... 


  0%|          | 0/15769 [00:00<?, ?it/s]

preprocessing test examples ... 


  0%|          | 0/5000 [00:00<?, ?it/s]

In [27]:
x_train = jnp.stack(x_train)
y_train = jnp.array(y_train)
x_test = jnp.stack(x_test)
y_test = jnp.array(y_test)

In [28]:

print("x_train.shape: ", x_train.shape)
print("y_train.shape: ", y_train.shape)
print("x_test.shape: ", x_test.shape)
print("y_test.shape: ", y_test.shape)


x_train.shape:  (15769, 16)
y_train.shape:  (15769,)
x_test.shape:  (5000, 16)
y_test.shape:  (5000,)


In [29]:

# let's see what it looks like
x_train[0], y_train[0]

(Array([ 2,  6,  1, 10, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],      dtype=int64),
 Array(0, dtype=int64))

In [30]:
max_len

16

In [31]:
# Transformer Architecture - Classification task

In [32]:
def train(Model, epochs=10, batch_size=32, lr=3e-4, wd=1e-5, **kwargs):
    # Initialize the model
    model = Model(**kwargs)
    params_model = model.init(random.PRNGKey(0), x_train[:128])

    # Initialize the optimizer
    optimizer = optax.adamw(learning_rate=lr, weight_decay=wd)
    opt_state = optimizer.init(params_model)

    # Create a TrainState
    state = TrainState.create(apply_fn=model.apply, params=params_model, tx=optimizer)

    # Define the loss function
    def loss_fn(params, x, y):
        logits = state.apply_fn(params, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
        return loss.mean()

    grad_fn = jax.value_and_grad(loss_fn)

    @jax.jit
    def train_step(state, x, y):
        loss, grads = grad_fn(state.params, x, y)
        new_state = state.apply_gradients(grads=grads)
        return new_state, loss

    # Report accuracy
    def report(state, x, y):
        logits = state.apply_fn(state.params, x)
        predictions = (logits[:, 1] > 0).astype(jnp.int32)
        return jnp.mean(predictions == y)

    # Training loop
    num_train_batches = len(x_train) // batch_size
    num_test_batches = len(x_test) // batch_size

    print("Training starts...")
    for epoch in range(epochs):
        # Training
        epoch_loss = 0
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i:i + batch_size]
            y_batch = y_train[i:i + batch_size]
            state, loss = train_step(state, x_batch, y_batch)
            epoch_loss += loss

        # Evaluation
        train_accuracy = 0
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i:i + batch_size]
            y_batch = y_train[i:i + batch_size]
            train_accuracy += report(state, x_batch, y_batch)
        train_accuracy /= num_train_batches

        test_accuracy = 0
        for i in range(0, len(x_test), batch_size):
            x_batch = x_test[i:i + batch_size]
            y_batch = y_test[i:i + batch_size]
            test_accuracy += report(state, x_batch, y_batch)
        test_accuracy /= num_test_batches

        print(f"Epoch {epoch + 1}/{epochs}")
        print(f"Loss: {epoch_loss/num_train_batches:.4f}")
        print(f"Train accuracy: {train_accuracy:.4f}")
        print(f"Test accuracy: {test_accuracy:.4f}")

In [33]:
# Define the self-attention layer
class SelfAttention(nn.Module):
    embed_dim : int

    @nn.compact
    def __call__(self, x):
        # Calculate query, key, and value matrices using linear layers
        query = nn.Dense(self.embed_dim)(x)
        key = nn.Dense(self.embed_dim)(x)
        value = nn.Dense(self.embed_dim)(x)

        # Calculate attention scores (scaled dot-product attention)
        attention_scores = jnp.matmul(query, jnp.swapaxes(key, -2, -1)) / jnp.sqrt(self.embed_dim)
        attention_weights = nn.softmax(attention_scores, axis=-1)

        # Apply attention weights to values
        output = jnp.matmul(attention_weights, value)
        return output


In [34]:
# Define the model
class SimpleTransformer(nn.Module):
    vocab_size: int
    embed_dim: int =128
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [35]:
kwg = dict(
        vocab_size=vocab_size,
        embed_dim=64,
)
train(SimpleTransformer, **kwg)

Training starts...
Epoch 1/10
Loss: 0.6948
Train accuracy: 0.4961
Test accuracy: 0.4958
Epoch 2/10
Loss: 0.6947
Train accuracy: 0.4996
Test accuracy: 0.5022
Epoch 3/10
Loss: 0.6946
Train accuracy: 0.5029
Test accuracy: 0.4940
Epoch 4/10
Loss: 0.6946
Train accuracy: 0.5037
Test accuracy: 0.4950
Epoch 5/10
Loss: 0.6946
Train accuracy: 0.5048
Test accuracy: 0.4970
Epoch 6/10
Loss: 0.6946
Train accuracy: 0.5053
Test accuracy: 0.4958
Epoch 7/10
Loss: 0.6946
Train accuracy: 0.5059
Test accuracy: 0.4960
Epoch 8/10
Loss: 0.6946
Train accuracy: 0.5056
Test accuracy: 0.4956
Epoch 9/10
Loss: 0.6945
Train accuracy: 0.5053
Test accuracy: 0.4956
Epoch 10/10
Loss: 0.6945
Train accuracy: 0.5053
Test accuracy: 0.4944


In [36]:
# MLP layers

In [37]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    embed_dim: int = 128
    mlp_dim: int = 256
    num_classes: int = 2
    max_seq_length: int = None

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # MLP layers
        x = nn.Dense(self.mlp_dim)(cls_token)
        x = nn.relu(x)
        logits = nn.Dense(self.num_classes)(x)
        return logits

In [38]:

kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
)
train(SimpleTransformer, **kwg)

Training starts...
Epoch 1/10
Loss: 0.6949
Train accuracy: 0.5099
Test accuracy: 0.4916
Epoch 2/10
Loss: 0.6947
Train accuracy: 0.5107
Test accuracy: 0.4886
Epoch 3/10
Loss: 0.6946
Train accuracy: 0.5071
Test accuracy: 0.4988
Epoch 4/10
Loss: 0.6945
Train accuracy: 0.5048
Test accuracy: 0.4906
Epoch 5/10
Loss: 0.6944
Train accuracy: 0.5097
Test accuracy: 0.4990
Epoch 6/10
Loss: 0.6942
Train accuracy: 0.5127
Test accuracy: 0.4980
Epoch 7/10
Loss: 0.6940
Train accuracy: 0.5154
Test accuracy: 0.4938
Epoch 8/10
Loss: 0.6939
Train accuracy: 0.5166
Test accuracy: 0.4918
Epoch 9/10
Loss: 0.6938
Train accuracy: 0.5185
Test accuracy: 0.4930
Epoch 10/10
Loss: 0.6937
Train accuracy: 0.5204
Test accuracy: 0.4904


In [39]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    embed_dim: int = 128
    mlp_dim: int = 256
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim) # Scale positional embeddings

        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)

        # MLP layers
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)

        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits



In [40]:
kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
)
train(SimpleTransformer, **kwg)

Training starts...
Epoch 1/10
Loss: 0.6948
Train accuracy: 0.5123
Test accuracy: 0.5032
Epoch 2/10
Loss: 0.6945
Train accuracy: 0.5147
Test accuracy: 0.5016
Epoch 3/10
Loss: 0.6584
Train accuracy: 0.7629
Test accuracy: 0.5831
Epoch 4/10
Loss: 0.3379
Train accuracy: 0.9078
Test accuracy: 0.5509
Epoch 5/10
Loss: 0.2022
Train accuracy: 0.9385
Test accuracy: 0.6747
Epoch 6/10
Loss: 0.1685
Train accuracy: 0.9464
Test accuracy: 0.7123
Epoch 7/10
Loss: 0.1526
Train accuracy: 0.9484
Test accuracy: 0.7264
Epoch 8/10
Loss: 0.1423
Train accuracy: 0.9490
Test accuracy: 0.7332
Epoch 9/10
Loss: 0.1349
Train accuracy: 0.9514
Test accuracy: 0.7364
Epoch 10/10
Loss: 0.1293
Train accuracy: 0.9533
Test accuracy: 0.7352


##  normalization layer


In [41]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    embed_dim: int = 128
    mlp_dim: int = 256
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)  # Scale positional embeddings

        # self-attention layer
        x = SelfAttention(self.embed_dim)(x)

        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)

        # MLP layers
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)

        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)

        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [42]:
kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
)
train(SimpleTransformer, **kwg)

Training starts...
Epoch 1/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 2/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 3/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 4/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 5/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 6/10
Loss: 0.5774
Train accuracy: 0.9336
Test accuracy: 0.5379
Epoch 7/10
Loss: 0.1889
Train accuracy: 0.9624
Test accuracy: 0.7075
Epoch 8/10
Loss: 0.1195
Train accuracy: 0.9596
Test accuracy: 0.7125
Epoch 9/10
Loss: 0.1017
Train accuracy: 0.9577
Test accuracy: 0.7151
Epoch 10/10
Loss: 0.0933
Train accuracy: 0.9590
Test accuracy: 0.7169


In [43]:
## deeper architectures

In [44]:
class TransformerEncoderBlock(nn.Module):
    embed_dim: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)
        # MLP layers
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.embed_dim)(x)
        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)
        return x

In [45]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    num_layers: int
    mlp_dim: int = 256
    embed_dim: int = 128
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)  # Scale positional embeddings

        # Stack multiple transformer encoder blocks
        x = nn.Sequential([TransformerEncoderBlock(self.embed_dim, self.mlp_dim)
                           for _ in range(self.num_layers)])(x)
        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [46]:

kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
        num_layers=3,
)
train(SimpleTransformer, **kwg)

Training starts...
Epoch 1/10
Loss: 0.6946
Train accuracy: 0.4942
Test accuracy: 0.5106
Epoch 2/10
Loss: 0.6946
Train accuracy: 0.4942
Test accuracy: 0.5106
Epoch 3/10
Loss: 0.6946
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 4/10
Loss: 0.6946
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 5/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 6/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 7/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 8/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 9/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 10/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958


In [47]:
# skip connections

In [48]:
class TransformerEncoderBlock(nn.Module):
    embed_dim: int
    mlp_dim: int


    @nn.compact
    def __call__(self, x):
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        x = nn.LayerNorm(self.embed_dim)(x)

        # MLP layers
        y = nn.Dense(self.mlp_dim)(x)
        y = nn.relu(y)
        y = nn.Dense(self.embed_dim)(y)

        # We introduce a skip connection
        x = x + y
        x= nn.LayerNorm(self.embed_dim)(x)
        return x

In [49]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    num_layers: int
    mlp_dim: int = 256
    embed_dim: int = 128
    num_classes: int = 2

    def setup(self):
        # Create list of transformer blocks with unique parameters
        self.transformer_blocks = [
                TransformerEncoderBlock(self.embed_dim, self.mlp_dim)
                for _ in range(self.num_layers)
        ]

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)

        # Stack multiple transformer encoder blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [50]:

kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
        num_layers=3,
)
train(SimpleTransformer, **kwg)

Training starts...
Epoch 1/10
Loss: 0.6946
Train accuracy: 0.4942
Test accuracy: 0.5106
Epoch 2/10
Loss: 0.6946
Train accuracy: 0.4942
Test accuracy: 0.5106
Epoch 3/10
Loss: 0.6946
Train accuracy: 0.4942
Test accuracy: 0.5106
Epoch 4/10
Loss: 0.6946
Train accuracy: 0.4942
Test accuracy: 0.5106
Epoch 5/10
Loss: 0.6946
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 6/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 7/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 8/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 9/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958
Epoch 10/10
Loss: 0.6945
Train accuracy: 0.5078
Test accuracy: 0.4958


In [51]:
# Transformer Architecture - Sequence generation task

In [52]:
class AdditionTask:
    def sample_batch(self, batch_size: int, length: int):
        """Returns a batch of additions and their results."""
        if length <= 2:
            raise ValueError("Length must be greater than 2.")
        # We only use `length - 1` tokens for the two values to account for the `+`.
        # Generate random lengths for the two numbers in each addition problem.
        length_n = random.randint(random.PRNGKey(0), (batch_size,), minval=1, maxval=length - 1)
        length_m = length - 1 - length_n

        # Generate random integers for each length with int64 dtype
        integer_n = []
        integer_m = []
        key = random.PRNGKey(0)
        for len_n, len_m in zip(length_n, length_m):
            key, subkey_n = random.split(key)
            maxval_n = 10 ** int(len_n) - 1
            integer_n.append(random.randint(subkey_n, (), minval=1, maxval=maxval_n, dtype=jnp.int64))

            key, subkey_m = random.split(key)
            maxval_m = 10 ** int(len_m) - 1
            integer_m.append(random.randint(subkey_m, (), minval=1, maxval=maxval_m, dtype=jnp.int64))

        # Calculate the sum of the generated integers.
        integer_sum = list(map(sum, zip(integer_n, integer_m)))

        # Convert integers to reversed strings (e.g., 123 becomes "321").
        knary_n = [str(x)[::-1] for x in integer_n]
        knary_m = [str(x)[::-1] for x in integer_m]

        # Create the addition expressions by concatenating the reversed strings with a "+".
        expressions = [f"{a}+{b}" for a, b in zip(knary_n, knary_m)]

        # Pad the expressions with zeros to reach the desired length.
        expressions = [a + "".join(["0"] * (length - len(a))) for a in expressions]

        # Convert the sums to strings.
        results = list(map(str, integer_sum))
        # Append "#" to the results and pad with zeros.
        results = [
                res + "#" + "".join(["0"] * (length - len(res))) for res in results
        ]
        return {
                "input": expressions,
                "output": results,
        }

    @property
    def input_size(self) -> int:
        """Returns the input size for the models."""
        return 12

    @property
    def output_size(self) -> int:
        """Returns the output size for the models."""
        return 12

    @property
    def vocab_size(self) -> int:
        """Returns the output size for the models."""
        return 12

    def output_length(self, input_length: int) -> int:
        return input_length + 1

In [53]:

# # Instantiate an AdditionTask object. This object will handle data generation for our addition task.
task = AdditionTask()

MAX_TRAIN_LENGTH = 10
MAX_TEST_LENGTH = 20

# Generate a sample batch of addition problems.
data = task.sample_batch(batch_size=16, length=MAX_TRAIN_LENGTH)

In [54]:
pd.DataFrame(data)

Unnamed: 0,input,output
0,2+00239732,23793202#00
1,78749186+4,68194791#00
2,976584+919,486598#0000
3,436635+705,537141#0000
4,5+72286476,67468232#00
5,33516+6677,69299#00000
6,974+906346,644088#0000
7,6202483+22,3842048#000
8,1783+11382,32182#00000
9,0589909+98,9099939#000


In [55]:
# Define and fit the tokenizer
tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=None,
        oov_token=None,
        char_level=True,
)
tokenizer.fit_on_texts(data["input"] + data["output"])
tokenizer.word_index

{'0': 1,
 '2': 2,
 '9': 3,
 '8': 4,
 '6': 5,
 '3': 6,
 '4': 7,
 '7': 8,
 '5': 9,
 '+': 10,
 '#': 11,
 '1': 12}

In [211]:
def preprocess_data(batch, tokenizer):
    """Tokenizes and pads the input and output sequences for the model."""
    # Tokenize the input sequences using the provided tokenizer.
    tokens_input = jnp.array(tokenizer.texts_to_sequences(batch["input"]), dtype=jnp.float32) - 1  # Changed to int32
    tokens_output = jnp.array(tokenizer.texts_to_sequences(batch["output"]), dtype=jnp.float32) - 1  # Changed to int32

    # Pad the sequences to the maximum length within the batch for consistent tensor shapes.
    tokens_input = jnp.pad(
            tokens_input,
            pad_width=((0, 0), (0, tokens_input.shape[1] - tokens_input.shape[1])),
            mode='constant',
            constant_values=0
    )
    tokens_output = jnp.pad(
            tokens_output,
            pad_width=((0, 0), (0, tokens_output.shape[1] - tokens_output.shape[1])),
            mode='constant',
            constant_values=0
    )
    return dict(input=tokens_input, output=tokens_output)

In [212]:
preprocess_data(data, tokenizer)

{'input': Array([[ 1.,  9.,  0.,  0.,  1.,  5.,  2.,  7.,  5.,  1.],
        [ 7.,  3.,  7.,  6.,  2., 11.,  3.,  4.,  9.,  6.],
        [ 2.,  7.,  4.,  8.,  3.,  6.,  9.,  2., 11.,  2.],
        [ 6.,  5.,  4.,  4.,  5.,  8.,  9.,  7.,  0.,  8.],
        [ 8.,  9.,  7.,  1.,  1.,  3.,  4.,  6.,  7.,  4.],
        [ 5.,  5.,  8., 11.,  4.,  9.,  4.,  4.,  7.,  7.],
        [ 2.,  7.,  6.,  9.,  2.,  0.,  4.,  5.,  6.,  4.],
        [ 4.,  1.,  0.,  1.,  6.,  3.,  5.,  9.,  1.,  1.],
        [11.,  7.,  3.,  5.,  9., 11., 11.,  5.,  3.,  1.],
        [ 0.,  8.,  3.,  2.,  2.,  0.,  2.,  9.,  2.,  3.],
        [ 1.,  7.,  9.,  0.,  1.,  4.,  8.,  8.,  1.,  5.],
        [11.,  6.,  0.,  2.,  8.,  0.,  0.,  3.,  9.,  1.],
        [ 0.,  1.,  1.,  5.,  1.,  9.,  7.,  3.,  2.,  6.],
        [ 6.,  7.,  3.,  1.,  1.,  9.,  6.,  1.,  0.,  1.],
        [ 8.,  1.,  1.,  0.,  6.,  5.,  4.,  4.,  9.,  5.],
        [ 0.,  3.,  2.,  9., 11.,  7.,  1.,  8.,  0.,  8.]], dtype=float32),
 'output': Arr

In [213]:
def _pointwise_loss_fn(output: jnp.array, target: jnp.array) -> jnp.array:
    """Calculates the pointwise cross-entropy loss between predicted probabilities and the true target values.

  This function computes the loss for each token in the sequence individually.
  """
    target = target.astype(jnp.int32)
    target_one_hot = jax.nn.one_hot(target, num_classes=output.shape[-1]).astype(jnp.float32)
    return target_one_hot * jax.nn.log_softmax(output, axis=-1)

In [214]:
def loss_fn(output, target):
    return -jnp.mean(jnp.sum(_pointwise_loss_fn(output, target), axis=-1))

In [215]:

batch_size = 2
sequence_length = 3
num_classes = 4  # Let's say we have 4 possible tokens

output = jnp.array([
        [[1.0, 2.0, 3.0, 0.5], [0.1, 0.5, 1.5, 2.0], [2.5, 1.0, 0.2, 0.1]],  # Sequence 1
        [[0.2, 0.3, 0.5, 0.1], [1.0, 1.5, 2.0, 0.5], [0.1, 0.2, 0.5, 2.5]],  # Sequence 2
], dtype=jnp.float32)

target = jnp.array(
        [[0,  3,  2], [1, 0, 0],]
)

pointwise_loss = _pointwise_loss_fn(output, target)
print("Pointwise Loss:\n", pointwise_loss)

Pointwise Loss:
 [[[-2.4607735 -0.        -0.        -0.       ]
  [-0.        -0.        -0.        -0.6827076]
  [-0.        -0.        -2.6464982 -0.       ]]

 [[-0.        -1.3724415 -0.        -0.       ]
  [-1.7873387 -0.        -0.        -0.       ]
  [-2.6824024 -0.        -0.        -0.       ]]]


In [216]:
def _accuracy_fn(output: jnp.array, target: jnp.array) -> jnp.array:
    return jnp.mean(jnp.argmax(output, axis=-1) == target).astype(jnp.float32)

def accuracy_fn(output, target):
    acc = _accuracy_fn(output, target)
    return jnp.mean(acc)

In [217]:
def _apply_loss_and_metrics_fn(
        state,
        batch: dict[str, jnp.array],

):
    logits = state.apply_fn(batch["inputs"], batch["output"])
    loss = loss_fn(logits, batch["output"])
    accuracy = accuracy_fn(logits, batch["output"])
    return loss, (accuracy)

grad_fn = jax.value_and_grad(_apply_loss_and_metrics_fn, argnums=1, has_aux=True)

In [218]:
def _update_fn( batch, state):
    (loss, (accuracy)), grads = grad_fn(state,batch)
    new_state = state.apply_gradients(grads=grads)
    return new_state, (loss, accuracy)

In [219]:
def run_training(
        *,
        task,
        model,
        max_sequence_length: int,
        train_steps: int = 10_000,
        seed: int = 0,  # Used to sample during forward pass (e.g. from final logits).
        model_init_seed: int = 0,  # Used to initialize model parameters.
        log_frequency: int = 50,
        batch_size: int = 128,
        learning_rate: float = 1e-3,
        max_grad_norm: float = 1.0,
):

    # Sample a batch to fit the tokenizer
    dummy_batch = task.sample_batch(
            length=max_sequence_length,
            batch_size=256,
    )

    # Define and fit the tokenizer
    tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=None, oov_token=None, char_level=True,
    )
    tokenizer.fit_on_texts(dummy_batch["input"] + dummy_batch["output"])

    print(f"The tokenizer index is: {tokenizer.word_index}")


    # Initialize the model
    choices = jnp.array(range(3 , max_sequence_length + 1))
    length = random.choice(random.PRNGKey(model_init_seed) , choices).item()
    batch_init = task.sample_batch(length=length, batch_size=batch_size)
    batch_init = preprocess_data(batch_init, tokenizer)
    params_model = model.init(random.PRNGKey(0), batch_init["input"], batch_init["output"])

    # Initialize the optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    opt_state = optimizer.init(params_model)

    # Create a TrainState
    state = TrainState.create(apply_fn=model.apply, params=params_model, tx=optimizer)

    results = []
    for step in tqdm.tqdm(range(train_steps + 1)):
        choices = jnp.array(range(3 , max_sequence_length + 1))
        length = random.choice(random.PRNGKey(seed) , choices).item()

        # Randomness handled by either torch, python.random or numpy.
        train_batch = task.sample_batch(length=length, batch_size=batch_size)
        train_batch = preprocess_data(train_batch, tokenizer)
        # Update the parameters.
        state, (train_loss, train_accuracy) = _update_fn(train_batch, state)

        # Log the training metrics
        if (log_frequency > 0) and (step % log_frequency == 0):
            log_data = {
                    "step": step,
                    "train_loss": float(train_loss),
                    "train_accuracy": float(train_accuracy),
            }
            print(log_data)
            results.append(log_data)

    df_results = pd.DataFrame(results)
    return df_results, params, tokenizer


In [220]:
# Eval function

In [221]:
def run_evaluation(*, state, tokenizer, task,  max_test_length: int = 20,  # The largest sequence length to evaluate on
                   total_batch_size: int = 512,
                   sub_batch_size: int = 64,  # We use this to avoid memory overflow.
                   seed: int = 1,
                   is_autoregressive: bool = False,):
    results = []
    lengths = range(3, max_test_length + 1)

    for length in tqdm.tqdm(lengths, desc="Lengths"):
        sub_accuracies = []
        # Evaluate on multiple sub-batches to avoid memory overflow.
        for _ in range(total_batch_size // sub_batch_size):
            # Generate a batch of addition problems with the current length.
            batch = task.sample_batch(sub_batch_size, length)
            batch = preprocess_data(batch, tokenizer)

            outputs = state.apply_fn(batch["input"], batch["output"])

            # Calculate the accuracy for the current sub-batch.
            sub_accuracies.append(accuracy_fn(outputs, batch["output"]))

        # Calculate the average accuracy for the current length.
        log_data = {
                "length": length,
                "accuracy": np.mean(sub_accuracies),
        }
        print(log_data)
        results.append(log_data)

    # Return the results as a pandas DataFrame.
    return pd.DataFrame(results)

In [222]:
# Architecture - Encoder

In [223]:
class BaseTransformerEncoder(nn.Module):
    max_seq_length: int
    num_layers: int
    vocab_size: int
    embed_dim: int = 128
    mlp_dim: int = 256

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)

        # Stack multiple transformer encoder blocks
        for _ in range(self.num_layers):
            x = TransformerEncoderBlock(self.embed_dim, self.mlp_dim)(x)

        return x

In [224]:
# Architecture - Cross attention
class CrossAttention(nn.Module):
    embed_dim: int

    @nn.compact
    def __call__(self, inputs_q, inputs_kv):
        # Calculate query, key, and value matrices using linear layers
        query = nn.Dense(self.embed_dim)(inputs_q)
        key = nn.Dense(self.embed_dim)(inputs_kv)
        value = nn.Dense(self.embed_dim)(inputs_kv)

        print(f"query : {inputs_q.shape}")
        print(f"key : {inputs_kv.shape}")

        # Calculate attention scores (scaled dot-product attention)
        attention_scores = jnp.matmul(query, jnp.transpose(key, (0, 2, 1))) / jnp.sqrt(self.embed_dim)
        attention_weights = nn.softmax(attention_scores, axis=-1)

        # Apply attention weights to values
        output = jnp.matmul(attention_weights, value)
        return output

In [225]:
def shift_right(x: jnp.array, vocab_size: int):
    x = x.astype(jnp.int32)
    if x.ndim == 1:
        x = x.unsqueeze(1)  # Ensure 2D shape (batch, seq)
    x_shifted = x[:, :-1]  # Remove last token
    # Pad sequence axis (axis=1) with 1 at start, 0 at end
    return jnp.pad(x_shifted, ((0, 0), (1, 0)), mode='constant', constant_values=vocab_size - 1)

In [226]:

# Define the output size (vocabulary size).
output_size = 4  # Tokens will be 0, 1, 2, 3

# Create a one-hot encoded toy input sequence.
# Represents the sequence [2, 3, 0]
x = jnp.array([[
        [0, 0, 1, 0],  # Token 2
        [0, 0, 0, 1],  # Token 3
        [1, 0, 0, 0]   # Token 0
]], dtype=jnp.float32)
print(f"input : \n{x}\n")

# Compute the argmax
x = jnp.argmax(x, axis=-1)
print(f"argmax : {x}\n")

padded = jnp.pad(x[:, :-1], (1, 0), mode='constant', constant_values=output_size)
print(f"padded : \n{padded}\n")

input : 
[[[0. 0. 1. 0.]
  [0. 0. 0. 1.]
  [1. 0. 0. 0.]]]

argmax : [[2 3 0]]

padded : 
[[4 4 4]
 [4 2 3]]



In [227]:
# Decoder architecture
class TransformerDecoderBlock(nn.Module):
    embed_dim:int
    mlp_dim: int

    @nn.compact
    def __call__(self, enc_emb, dec_emb):
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(dec_emb)
        x = nn.LayerNorm(self.embed_dim)(x)

        # MLP layers
        y =CrossAttention(self.embed_dim)(x, enc_emb)
        y = nn.LayerNorm(self.embed_dim)(y)

        # MLP layers
        y = nn.Dense(self.mlp_dim)(y)
        y = nn.relu(y)
        y = nn.Dense(self.embed_dim)(y)

        # Layer normalization
        y = nn.LayerNorm(self.embed_dim)(y)
        return y

In [228]:
class BaseTransformerDecoder(nn.Module):
    vocab_size: int
    max_seq_length: int
    num_layers: int
    mlp_dim: int = 256
    embed_dim: int = 128

    @nn.compact
    def __call__(self, encoded, targets):
        x = shift_right(targets, self.vocab_size)

        # Embedding layer
        x = nn.Embed(self.vocab_size + 1, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)
        print(f"x : {x.shape}")
        # Stack multiple transformer decoder blocks
        for _ in range(self.num_layers):
            x = TransformerDecoderBlock(self.embed_dim, self.mlp_dim)(encoded, x)
        return x

In [229]:
# Architecture Encoder-Decoder

class BaseTransformer(nn.Module):
    num_layers: int
    max_seq_length: int
    vocab_size: int
    embed_dim: int = 32

    @nn.compact
    def __call__(self, inputs, targets):
        encoded = BaseTransformerEncoder(
                max_seq_length=self.max_seq_length,
                num_layers=self.num_layers,
                vocab_size=self.vocab_size,
                embed_dim=self.embed_dim,
        )(inputs)

        decoded = BaseTransformerDecoder(
                max_seq_length=self.max_seq_length,
                num_layers=self.num_layers,
                vocab_size=self.vocab_size,
                embed_dim=self.embed_dim,
        )(encoded, targets)

        logits = nn.Dense(self.vocab_size)(decoded)
        return logits


In [230]:
# Define the model
model = BaseTransformer(
        num_layers=1,
        vocab_size=task.vocab_size,
        max_seq_length=MAX_TEST_LENGTH + 2,
)

In [231]:
# Run the training loop
df_train, params, tokenizer = run_training(
        max_sequence_length=MAX_TRAIN_LENGTH,
        task=task,
        model=model,
        batch_size=128,
        train_steps=2_500
)

The tokenizer index is: {'0': 1, '3': 2, '9': 3, '5': 4, '7': 5, '4': 6, '1': 7, '6': 8, '8': 9, '2': 10, '+': 11, '#': 12}


ValueError: Input type must be an integer or unsigned integer.