In [1]:
import string
import re
import numpy as np

import jax.numpy as jnp
import optax

from flax import nnx

import tiktoken
import grain.python as grain
import tqdm

from datasets import load_dataset
import matplotlib.pyplot as plt
from semhash import SemHash

## 1. Data Preparation
# Load, preprocess, and tokenize the dataset.

In [3]:
# Load the dataset
ds = load_dataset("jonathansuru/fr_fon")

In [4]:
# Function to convert data to pairs
def data_to_pairs(data):
    text_pairs = []
    for line in data:
        fon = line["fon"]
        french = line["french"]
        french = "[start] " + french + " [end]"  # Add start and end tokens
        text_pairs.append((fon, french))
    return text_pairs


In [None]:
# Convert the dataset to a list of dictionaries
records = [dict(row) for row in ds["train"]]

# Initialize SemHash with the columns to deduplicate
semhash = SemHash.from_records(records=records, columns=["fon", "french"])
# Deduplicate the test data against the training data
deduplicated_test_texts = semhash.deduplicate(records=ds["test"]).deduplicated

In [5]:
# Create train, validation, and test pairs
train_pairs = data_to_pairs(ds["train"])
val_pairs = data_to_pairs(ds["validation"])
test_pairs = data_to_pairs(deduplicated_test_texts)

In [6]:
# Display the number of pairs in each set
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

35039 training pairs
8760 validation pairs
10950 test pairs


In [7]:
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("o200k_base")

In [8]:
# Define parameters
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

vocab_size = tokenizer.n_vocab
sequence_length = 512

In [9]:
# Custom standardization function
def custom_standardization(input_string):
    lowercase = input_string.lower()
    return re.sub(f"[{re.escape(strip_chars)}]", "", lowercase)

In [10]:
# Tokenize and pad function
def tokenize_and_pad(text, tokenizer, max_length):
    tokens = tokenizer.encode(text)[:max_length]
    padded = tokens + [0] * (max_length - len(tokens)) if len(tokens) < max_length else tokens
    return padded

In [11]:
# Format dataset function
def format_dataset(fon, french, tokenizer, sequence_length):
    fon = custom_standardization(fon)
    french = custom_standardization(french)
    fon = tokenize_and_pad(fon, tokenizer, sequence_length)
    french = tokenize_and_pad(french, tokenizer, sequence_length)
    return {
            "encoder_inputs": fon,
            "decoder_inputs": french[:-1],
            "target_output": french[1:],
    }


In [12]:
# Create train, validation, and test data
train_data = [format_dataset(fon, french, tokenizer, sequence_length) for fon, french in train_pairs]
val_data = [format_dataset(fon, french, tokenizer, sequence_length) for fon, french in val_pairs]
test_data = [format_dataset(fon, french, tokenizer, sequence_length) for fon, french in test_pairs]


## 2. Model Definition
# Define the Transformer model components.

Encoder

In [17]:
class TransformerEncoder(nnx.Module):
    def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads

        self.attention = nnx.MultiHeadAttention(num_heads=num_heads,
                                                in_features=embed_dim,
                                                decode=False,
                                                rngs=rngs)
        self.dense_proj = nnx.Sequential(
                nnx.Linear(embed_dim, dense_dim, rngs=rngs),
                nnx.relu,
                nnx.Linear(dense_dim, embed_dim, rngs=rngs),
        )

        self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)

    def __call__(self, inputs, mask=None):
        if mask is not None:
            padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)
        else:
            padding_mask = None

        attention_output = self.attention(
                inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

Positional Embed

In [18]:
class PositionalEmbedding(nnx.Module):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs):
        self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs)
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def __call__(self, inputs):
        length = inputs.shape[1]
        positions = jnp.arange(0, length)[None, :]
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        else:
            return jnp.not_equal(inputs, 0)

Decoder

In [19]:
class TransformerDecoder(nnx.Module):
    def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = nnx.MultiHeadAttention(num_heads=num_heads,
                                                  in_features=embed_dim,
                                                  decode=False,
                                                  rngs=rngs)
        self.attention_2 = nnx.MultiHeadAttention(num_heads=num_heads,
                                                  in_features=embed_dim,
                                                  decode=False,
                                                  rngs=rngs)

        self.dense_proj = nnx.Sequential(
                nnx.Linear(embed_dim, latent_dim, rngs=rngs),
                nnx.relu,
                nnx.Linear(latent_dim, embed_dim, rngs=rngs),
        )
        self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.layernorm_3 = nnx.LayerNorm(embed_dim, rngs=rngs)

    def __call__(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs.shape[1])
        if mask is not None:
            padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)
            padding_mask = jnp.minimum(padding_mask, causal_mask)
        else:
            padding_mask = None
        attention_output_1 = self.attention_1(
                inputs_q=inputs, inputs_v=inputs, inputs_k=inputs,  mask=causal_mask
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2( ## https://github.com/google/flax/blob/main/flax/nnx/nn/attention.py#L403-L405
                inputs_q=out_1,
                inputs_v=encoder_outputs,
                inputs_k=encoder_outputs,
                mask=padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, sequence_length):
        i = jnp.arange(sequence_length)[:, None]
        j = jnp.arange(sequence_length)
        mask = (i >= j).astype(jnp.int32)
        mask = jnp.reshape(mask, (1, 1, sequence_length, sequence_length))
        return mask

Transformer Model

In [20]:
class TransformerModel(nnx.Module):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, latent_dim: int, num_heads: int, dropout_rate: float, rngs: nnx.Rngs):
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        self.encoder = TransformerEncoder(embed_dim, latent_dim, num_heads, rngs=rngs)
        self.positional_embedding = PositionalEmbedding(sequence_length, vocab_size, embed_dim, rngs=rngs)
        self.decoder = TransformerDecoder(embed_dim, latent_dim, num_heads, rngs=rngs)
        self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
        self.dense = nnx.Linear(embed_dim, vocab_size, rngs=rngs)

    def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):
        x = self.positional_embedding(encoder_inputs)
        encoder_outputs = self.encoder(x, mask=mask)

        x = self.positional_embedding(decoder_inputs)
        decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)
        # per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training
        decoder_outputs = self.dropout(decoder_outputs, deterministic=deterministic)

        logits = self.dense(decoder_outputs)
        return logits

## 3. Data Loader
# Create data loaders for training and validation.

In [13]:
# Define batch size
batch_size = 512

In [14]:
# Custom preprocessing class
class CustomPreprocessing(grain.MapTransform):
    def __init__(self):
        pass

    def map(self, data):
        return {
                "encoder_inputs": np.array(data["encoder_inputs"]),
                "decoder_inputs": np.array(data["decoder_inputs"]),
                "target_output": np.array(data["target_output"]),
        }


In [15]:
# Create train sampler
train_sampler = grain.IndexSampler(
        len(train_data),
        shuffle=True,
        seed=12,                        # Seed for reproducibility
        shard_options=grain.NoSharding(), # No sharding since it's a single-device setup
        num_epochs=1,                    # Iterate over the dataset for one epoch
)

# Create validation sampler
val_sampler = grain.IndexSampler(
        len(val_data),
        shuffle=False,
        seed=12,
        shard_options=grain.NoSharding(),
        num_epochs=1,
)

# Create train data loader
train_loader = grain.DataLoader(
        data_source=train_data,
        sampler=train_sampler,                 # Sampler to determine how to access the data
        worker_count=4,                        # Number of child processes launched to parallelize the transformations
        worker_buffer_size=2,                  # Count of output batches to produce in advance per worker
        operations=[
                CustomPreprocessing(),
                grain.Batch(batch_size=batch_size, drop_remainder=True),
        ]
)

# Create validation data loader
val_loader = grain.DataLoader(
        data_source=val_data,
        sampler=val_sampler,
        worker_count=4,
        worker_buffer_size=2,
        operations=[
                CustomPreprocessing(),
                grain.Batch(batch_size=batch_size),
        ]
)


## 4. Training and Evaluation
# Define the training and evaluation steps.

In [16]:
# Compute loss function
def compute_loss(logits, labels):
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
    return jnp.mean(loss)

In [24]:
# Training step
@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model, train_encoder_input, train_decoder_input, train_target_input):
        logits = model(train_encoder_input, train_decoder_input)
        loss = compute_loss(logits, train_target_input)
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model, jnp.array(batch["encoder_inputs"]), jnp.array(batch["decoder_inputs"]), jnp.array(batch["target_output"]))
    optimizer.update(grads)
    return loss


In [None]:
# Evaluation step
@nnx.jit
def eval_step(model, batch, eval_metrics):
    logits = model(jnp.array(batch["encoder_inputs"]), jnp.array(batch["decoder_inputs"]))
    loss = compute_loss(logits, jnp.array(batch["target_output"]))
    labels = jnp.array(batch["target_output"])

    eval_metrics.update(
            loss=loss,
            logits=logits,
            labels=labels,
    )

In [25]:
# Initialize evaluation metrics
eval_metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average('loss'),
        accuracy=nnx.metrics.Accuracy(),
)

# Initialize metric history
train_metrics_history = {
        "train_loss": [],
}

eval_metrics_history = {
        "test_loss": [],
        "test_accuracy": [],
}

In [27]:
## Hyperparameters
rng = nnx.Rngs(0)
embed_dim = 256
latent_dim = 2048
num_heads = 8
dropout_rate = 0.5
vocab_size = tokenizer.n_vocab
sequence_length = 512
learning_rate = 1.5e-3
num_epochs = 10

In [28]:
# Training loop
bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"
train_total_steps = len(train_data) // batch_size

def train_one_epoch(epoch):
    model.train()  # Set the model to the training mode: e.g., update batch statistics
    with tqdm.tqdm(
            desc=f"[train] epoch: {epoch}/{num_epochs}, ",
            total=train_total_steps,
            bar_format=bar_format,
            leave=True,
    ) as pbar:
        for batch in train_loader:
            loss = train_step(model, optimizer, batch)
            train_metrics_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)


def evaluate_model(epoch):
    # Compute the metrics on the train and val sets after each training epoch.
    model.eval()  # Set the model to an evaluation model: e.g., use stored batch statistics

    eval_metrics.reset()  # Reset the eval metrics
    for val_batch in val_loader:
        eval_step(model, val_batch, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'test_{metric}'].append(value)

    print(f"[test] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}")

In [29]:
model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))

In [None]:
for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)

[train] epoch: 0/10, [0/31] [00:00<?]

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
plt.yscale('log')
plt.legend()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].set_title("Loss value on eval set")
axs[0].plot(eval_metrics_history["test_loss"])
axs[1].set_title("Accuracy on eval set")
axs[1].plot(eval_metrics_history["test_accuracy"])