In [None]:

import functools  # Used for creating partial functions
import tqdm.notebook  as tqdm # Used for displaying progress bars
import matplotlib.pyplot as plt  # Used for plotting graphs

In [None]:
import tensorflow as tf# We only import it for the tokenizer
import jax
from jax import numpy as jnp
import jax.random as random
import numpy as np

In [73]:
from flax import linen as nn
import optax
from flax.training.train_state import TrainState
import pandas as pd

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
def predict(params, inputs):
    """ Implementation of the MLP architecture."""
    outputs = None
    for w, b in params:
        outputs = jnp.matmul(w, inputs.transpose()) + b
        outputs = outputs.transpose()
        inputs = jax.nn.relu(outputs)
    return outputs

In [None]:
def loss(params, batch):
    """ Implementation of the square loss function."""
    inputs, targets = batch
    predictions = predict(params, inputs)
    return jnp.mean((predictions - targets) ** 2)

In [None]:
# prepare data
# here: target (y_train) is a linear function of input (x_train) plus some noise

key = random.PRNGKey(0)
num_examples = 10_000
dim = 100
x_train = random.normal(key, (num_examples, dim))
w = random.normal(key, (dim,))
y_train =jnp.dot(x_train, w) + 0.2 * random.normal(key, (num_examples,))

x_train = x_train.astype(jnp.float32)
y_train = y_train.astype(jnp.float32)

batch = (x_train, y_train)

In [None]:
# initialize model parameters
W1 = jnp.identity(dim)  # identity matrix
b1 = 0.

W2 = random.normal(key, (dim,))
b2 = 0.

params = [(W1, b1), (W2, b2)]  # two layers

In [None]:
loss(params, batch)

In [None]:
%timeit loss(params, batch)

In [None]:
# Text Tokenization

In [None]:
def sample_difference_dataset(
        dataset_size: int,
        lengths: list[int],
        k: int,
):
    key = random.PRNGKey(0)  # Initialize the key
    data_all = []
    for length in lengths:
        if length < 2:
            raise ValueError("The length of the expression must be at least two.")

        # Convert JAX array to integer using item()
        key, subkey = random.split(key)
        length_n = random.randint(subkey, (dataset_size,), minval=1, maxval=length - 1)
        length_m = length - 1 - length_n

        # Generate random integers for each length with int64 dtype
        integer_n = []
        integer_m = []
        for len_n, len_m in zip(length_n, length_m):
            key, subkey_n = random.split(key)
            maxval_n = k ** int(len_n) - 1
            integer_n.append(random.randint(subkey_n, (), minval=1, maxval=maxval_n, dtype=jnp.int64))

            key, subkey_m = random.split(key)
            maxval_m = k ** int(len_m) - 1
            integer_m.append(random.randint(subkey_m, (), minval=1, maxval=maxval_m, dtype=jnp.int64))

        # Calculate differences and signs
        diff_sign = [int(x > y) for x, y in zip(integer_n, integer_m)]

        # Convert to strings and create expressions
        integer_n_strs = [str(x.item())[::-1] for x in integer_n]
        integer_m_strs = [str(x.item())[::-1] for x in integer_m]
        expressions = [f"{n}-{m}" for n, m in zip(integer_n_strs, integer_m_strs)]

        # Combine expressions with signs
        data = list(zip(expressions, diff_sign))
        data_all.extend(data)

    # Remove duplicates and shuffle
    data_all = list(set(data_all))
    np.random.shuffle(data_all)
    return data_all

In [None]:
MAX_TRAIN_LENGTH = 10  # the maximum length allowed in the training split
MAX_TEST_LENGTH = 15  # the maximum length allowed in the test split

train_ds = sample_difference_dataset(
        dataset_size=2500,
        lengths=list(range(3, MAX_TRAIN_LENGTH + 1)),
        k=10,
)
test_ds = sample_difference_dataset(
        dataset_size=1000,
        lengths=list(range(MAX_TRAIN_LENGTH + 1, MAX_TEST_LENGTH + 1)),
        k=10,
)

print(f"Train dataset size {len(train_ds)}")
print(f"Test dataset size {len(test_ds)}")

In [None]:
it = iter(train_ds)

In [None]:
for _ in range(5):
    text, label = next(it)
    print("text: ", text)
    print("label: ", label)
    print()

In [None]:
# size of corpus to build the tokenizer
corpus_size = 5_000  #@param = 'int'

# size of the vocabulary
vocab_size = 12  #@param = 'int'

# maximum length of examples in tokens
max_len = MAX_TEST_LENGTH + 1  #@param = 'int'

# pad value
pad_value = 0  #@param = 'int'


In [None]:
corpus = [text for text, _ in train_ds[:corpus_size]]

In [None]:
corpus[:10]

In [None]:
# now, we build the tokenizer
tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=vocab_size,
        oov_token=None,
        char_level=True,
)
tokenizer.fit_on_texts(corpus)

In [None]:


# note how the tokenizer figured out it was best to tokenize each digit separately
tokenizer.index_word

In [None]:
# Example usage:
print("original text: ", text)

# tokenize text
tokens = tokenizer.texts_to_sequences([text])
print("tokens: ", tokens)
print("number of tokens: ", len(tokens[0]))

In [None]:


# we can see the actual tokens by converting each token individually to text
print(tokenizer.sequences_to_texts([token.tolist() for token in jnp.array(tokens).reshape((-1, 1))]))

In [None]:
# Let's examine the distribution of tokens in the corpus:
print("Token frequency:")
dict(list(tokenizer.word_counts.items()))

In [None]:
# preprocessing the data


In [None]:
def preprocess_function(text, label):
    # Tokenize the text
    tokens = tokenizer.texts_to_sequences([text])
    # Pad the sequences
    tokens = tf.keras.preprocessing.sequence.pad_sequences(tokens, maxlen=max_len, padding='post', value=pad_value)
    # Convert to torch tensors of a type int64
    tokens = jnp.array(tokens[0]).astype(jnp.int64)
    label = jnp.array(label).astype(jnp.int64)
    # Return the tokens and label
    return tokens, label


In [None]:
# Apply the preprocessing function to the training and test datasets
print("preprocessing training examples ... ")
x_train = []
y_train = []
for text, label in tqdm.tqdm(train_ds):
    tokens, label = preprocess_function(text, label)
    x_train.append(tokens)
    y_train.append(label)

print("preprocessing test examples ... ")
x_test = []
y_test = []
for text, label in tqdm.tqdm(test_ds):
    tokens, label = preprocess_function(text, label)
    x_test.append(tokens)
    y_test.append(label)

In [None]:
x_train = jnp.stack(x_train)
y_train = jnp.array(y_train)
x_test = jnp.stack(x_test)
y_test = jnp.array(y_test)

In [None]:

print("x_train.shape: ", x_train.shape)
print("y_train.shape: ", y_train.shape)
print("x_test.shape: ", x_test.shape)
print("y_test.shape: ", y_test.shape)


In [None]:

# let's see what it looks like
x_train[0], y_train[0]

In [None]:
max_len

In [None]:
# Transformer Architecture - Classification task

In [None]:
def train(Model, epochs=10, batch_size=32, lr=3e-4, wd=1e-5, **kwargs):
    # Initialize the model
    model = Model(**kwargs)
    params_model = model.init(random.PRNGKey(0), x_train[:128])

    # Initialize the optimizer
    optimizer = optax.adamw(learning_rate=lr, weight_decay=wd)
    opt_state = optimizer.init(params_model)

    # Create a TrainState
    state = TrainState.create(apply_fn=model.apply, params=params_model, tx=optimizer)

    # Define the loss function
    def loss_fn(params, x, y):
        logits = state.apply_fn(params, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
        return loss.mean()

    grad_fn = jax.value_and_grad(loss_fn)

    @jax.jit
    def train_step(state, x, y):
        loss, grads = grad_fn(state.params, x, y)
        new_state = state.apply_gradients(grads=grads)
        return new_state, loss

    # Report accuracy
    def report(state, x, y):
        logits = state.apply_fn(state.params, x)
        predictions = (logits[:, 1] > 0).astype(jnp.int32)
        return jnp.mean(predictions == y)

    # Training loop
    num_train_batches = len(x_train) // batch_size
    num_test_batches = len(x_test) // batch_size

    print("Training starts...")
    for epoch in range(epochs):
        # Training
        epoch_loss = 0
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i:i + batch_size]
            y_batch = y_train[i:i + batch_size]
            state, loss = train_step(state, x_batch, y_batch)
            epoch_loss += loss

        # Evaluation
        train_accuracy = 0
        for i in range(0, len(x_train), batch_size):
            x_batch = x_train[i:i + batch_size]
            y_batch = y_train[i:i + batch_size]
            train_accuracy += report(state, x_batch, y_batch)
        train_accuracy /= num_train_batches

        test_accuracy = 0
        for i in range(0, len(x_test), batch_size):
            x_batch = x_test[i:i + batch_size]
            y_batch = y_test[i:i + batch_size]
            test_accuracy += report(state, x_batch, y_batch)
        test_accuracy /= num_test_batches

        print(f"Epoch {epoch + 1}/{epochs}")
        print(f"Loss: {epoch_loss/num_train_batches:.4f}")
        print(f"Train accuracy: {train_accuracy:.4f}")
        print(f"Test accuracy: {test_accuracy:.4f}")

In [None]:
# Define the self-attention layer
class SelfAttention(nn.Module):
    embed_dim : int

    @nn.compact
    def __call__(self, x):
        # Calculate query, key, and value matrices using linear layers
        query = nn.Dense(self.embed_dim)(x)
        key = nn.Dense(self.embed_dim)(x)
        value = nn.Dense(self.embed_dim)(x)

        # Calculate attention scores (scaled dot-product attention)
        attention_scores = jnp.matmul(query, jnp.swapaxes(key, -2, -1)) / jnp.sqrt(self.embed_dim)
        attention_weights = nn.softmax(attention_scores, axis=-1)

        # Apply attention weights to values
        output = jnp.matmul(attention_weights, value)
        return output


In [None]:
# Define the model
class SimpleTransformer(nn.Module):
    vocab_size: int
    embed_dim: int =128
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [None]:
kwg = dict(
        vocab_size=vocab_size,
        embed_dim=64,
)
train(SimpleTransformer, **kwg)

In [None]:
# MLP layers

In [None]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    embed_dim: int = 128
    mlp_dim: int = 256
    num_classes: int = 2
    max_seq_length: int = None

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # MLP layers
        x = nn.Dense(self.mlp_dim)(cls_token)
        x = nn.relu(x)
        logits = nn.Dense(self.num_classes)(x)
        return logits

In [None]:

kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
)
train(SimpleTransformer, **kwg)

In [None]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    embed_dim: int = 128
    mlp_dim: int = 256
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim) # Scale positional embeddings

        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)

        # MLP layers
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)

        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits



In [None]:
kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
)
train(SimpleTransformer, **kwg)

##  normalization layer


In [None]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    embed_dim: int = 128
    mlp_dim: int = 256
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)  # Scale positional embeddings

        # self-attention layer
        x = SelfAttention(self.embed_dim)(x)

        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)

        # MLP layers
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)

        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)

        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [None]:
kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
)
train(SimpleTransformer, **kwg)

In [None]:
## deeper architectures

In [None]:
class TransformerEncoderBlock(nn.Module):
    embed_dim: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)
        # MLP layers
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(self.embed_dim)(x)
        # We add layer norm
        x = nn.LayerNorm(self.embed_dim)(x)
        return x

In [None]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    num_layers: int
    mlp_dim: int = 256
    embed_dim: int = 128
    num_classes: int = 2

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)  # Scale positional embeddings

        # Stack multiple transformer encoder blocks
        x = nn.Sequential([TransformerEncoderBlock(self.embed_dim, self.mlp_dim)
                           for _ in range(self.num_layers)])(x)
        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [None]:

kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
        num_layers=3,
)
train(SimpleTransformer, **kwg)

In [None]:
# skip connections

In [None]:
class TransformerEncoderBlock(nn.Module):
    embed_dim: int
    mlp_dim: int


    @nn.compact
    def __call__(self, x):
        # Self-attention layer
        x = SelfAttention(self.embed_dim)(x)
        x = nn.LayerNorm(self.embed_dim)(x)

        # MLP layers
        y = nn.Dense(self.mlp_dim)(x)
        y = nn.relu(y)
        y = nn.Dense(self.embed_dim)(y)

        # We introduce a skip connection
        x = x + y
        x= nn.LayerNorm(self.embed_dim)(x)
        return x

In [None]:
class SimpleTransformer(nn.Module):
    vocab_size: int
    max_seq_length: int
    num_layers: int
    mlp_dim: int = 256
    embed_dim: int = 128
    num_classes: int = 2

    def setup(self):
        # Create list of transformer blocks with unique parameters
        self.transformer_blocks = [
                TransformerEncoderBlock(self.embed_dim, self.mlp_dim)
                for _ in range(self.num_layers)
        ]

    @nn.compact
    def __call__(self, x):
        # Embedding layer
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        # Positional embeddings
        positions = jnp.arange(x.shape[1])
        pos_embeddings = nn.Embed(self.max_seq_length, self.embed_dim)(positions)
        x = x + pos_embeddings / jnp.sqrt(self.embed_dim)

        # Stack multiple transformer encoder blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Extract the CLS token (the last token)
        cls_token = x[:, -1, :]  # Assuming the last token is the CLS token

        # Linear classifier
        logits = nn.Dense(self.num_classes)(cls_token)
        return logits

In [None]:

kwg = dict(
        embed_dim=64,
        mlp_dim=64*4,
        vocab_size=vocab_size,
        max_seq_length=max_len,
        num_layers=3,
)
train(SimpleTransformer, **kwg)

In [None]:
# Transformer Architecture - Sequence generation task

In [70]:
class AdditionTask:
    def sample_batch(self, batch_size: int, length: int):
        """Returns a batch of additions and their results."""
        if length <= 2:
            raise ValueError("Length must be greater than 2.")
        # We only use `length - 1` tokens for the two values to account for the `+`.
        # Generate random lengths for the two numbers in each addition problem.
        length_n = random.randint(random.PRNGKey(0), (batch_size,), minval=1, maxval=length - 1)
        length_m = length - 1 - length_n

        # Generate random integers for each length with int64 dtype
        integer_n = []
        integer_m = []
        key = random.PRNGKey(0)
        for len_n, len_m in zip(length_n, length_m):
            key, subkey_n = random.split(key)
            maxval_n = 10 ** int(len_n) - 1
            integer_n.append(random.randint(subkey_n, (), minval=1, maxval=maxval_n, dtype=jnp.int64))

            key, subkey_m = random.split(key)
            maxval_m = 10 ** int(len_m) - 1
            integer_m.append(random.randint(subkey_m, (), minval=1, maxval=maxval_m, dtype=jnp.int64))

        # Calculate the sum of the generated integers.
        integer_sum = list(map(sum, zip(integer_n, integer_m)))

        # Convert integers to reversed strings (e.g., 123 becomes "321").
        knary_n = [str(x)[::-1] for x in integer_n]
        knary_m = [str(x)[::-1] for x in integer_m]

        # Create the addition expressions by concatenating the reversed strings with a "+".
        expressions = [f"{a}+{b}" for a, b in zip(knary_n, knary_m)]

        # Pad the expressions with zeros to reach the desired length.
        expressions = [a + "".join(["0"] * (length - len(a))) for a in expressions]

        # Convert the sums to strings.
        results = list(map(str, integer_sum))
        # Append "#" to the results and pad with zeros.
        results = [
                res + "#" + "".join(["0"] * (length - len(res))) for res in results
        ]
        return {
                "input": expressions,
                "output": results,
        }

    @property
    def input_size(self) -> int:
        """Returns the input size for the models."""
        return 12

    @property
    def output_size(self) -> int:
        """Returns the output size for the models."""
        return 12

    @property
    def vocab_size(self) -> int:
        """Returns the output size for the models."""
        return 12

    def output_length(self, input_length: int) -> int:
        return input_length + 1

In [71]:

# # Instantiate an AdditionTask object. This object will handle data generation for our addition task.
task = AdditionTask()

MAX_TRAIN_LENGTH = 10
MAX_TEST_LENGTH = 20

# Generate a sample batch of addition problems.
data = task.sample_batch(batch_size=16, length=MAX_TRAIN_LENGTH)

In [74]:
pd.DataFrame(data)

Unnamed: 0,input,output
0,2+00239732,23793202#00
1,78749186+4,68194791#00
2,976584+919,486598#0000
3,436635+705,537141#0000
4,5+72286476,67468232#00
5,33516+6677,69299#00000
6,974+906346,644088#0000
7,6202483+22,3842048#000
8,1783+11382,32182#00000
9,0589909+98,9099939#000


In [75]:
# Define and fit the tokenizer
tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=None,
        oov_token=None,
        char_level=True,
)
tokenizer.fit_on_texts(data["input"] + data["output"])
tokenizer.word_index

{'0': 1,
 '2': 2,
 '9': 3,
 '8': 4,
 '6': 5,
 '3': 6,
 '4': 7,
 '7': 8,
 '5': 9,
 '+': 10,
 '#': 11,
 '1': 12}

In [76]:
def preprocess_data(batch, tokenizer):
    """Tokenizes and pads the input and output sequences for the model.

    Args:
      batch: A dictionary containing the input and output sequences as lists of strings.
      tokenizer: A fitted Tokenizer object (e.g., from torchtext.data.utils.get_tokenizer or a custom implementation).

    Returns:
      A dictionary containing the processed input and output sequences as PyTorch tensors,
    """
    # Tokenize the input sequences using the provided tokenizer.
    # This converts each string into a sequence of integer indices.
    tokens_input = jnp.array(tokenizer.texts_to_sequences(batch["input"]), dtype=jnp.int64) - 1

    # Tokenize the output sequences.
    tokens_output = jnp.array(tokenizer.texts_to_sequences(batch["output"]), dtype=jnp.int64) - 1

    # Pad the sequences to the maximum length within the batch for consistent tensor shapes.
    tokens_input = jnp.pad(
            tokens_input,
            pad_width=((0, 0), (0, tokens_input.shape[1] - tokens_input.shape[1])),
            mode='constant',
            constant_values=0
    )
    tokens_output = jnp.pad(
            tokens_output,
            pad_width=((0, 0), (0, tokens_output.shape[1] - tokens_output.shape[1])),
            mode='constant',
            constant_values=0
    )
        # Return the processed data as a dictionary.
    return dict(input=tokens_input, output=tokens_output)

In [77]:
preprocess_data(data, tokenizer)

{'input': Array([[ 1,  9,  0,  0,  1,  5,  2,  7,  5,  1],
        [ 7,  3,  7,  6,  2, 11,  3,  4,  9,  6],
        [ 2,  7,  4,  8,  3,  6,  9,  2, 11,  2],
        [ 6,  5,  4,  4,  5,  8,  9,  7,  0,  8],
        [ 8,  9,  7,  1,  1,  3,  4,  6,  7,  4],
        [ 5,  5,  8, 11,  4,  9,  4,  4,  7,  7],
        [ 2,  7,  6,  9,  2,  0,  4,  5,  6,  4],
        [ 4,  1,  0,  1,  6,  3,  5,  9,  1,  1],
        [11,  7,  3,  5,  9, 11, 11,  5,  3,  1],
        [ 0,  8,  3,  2,  2,  0,  2,  9,  2,  3],
        [ 1,  7,  9,  0,  1,  4,  8,  8,  1,  5],
        [11,  6,  0,  2,  8,  0,  0,  3,  9,  1],
        [ 0,  1,  1,  5,  1,  9,  7,  3,  2,  6],
        [ 6,  7,  3,  1,  1,  9,  6,  1,  0,  1],
        [ 8,  1,  1,  0,  6,  5,  4,  4,  9,  5],
        [ 0,  3,  2,  9, 11,  7,  1,  8,  0,  8]], dtype=int64),
 'output': Array([[ 1,  5,  7,  2,  5,  1,  0,  1, 10,  0,  0],
        [ 4,  3, 11,  2,  6,  7,  2, 11, 10,  0,  0],
        [ 6,  3,  4,  8,  2,  3, 10,  0,  0,  0,  0],
    