In [1]:
import string
import re
import numpy as np
import matplotlib.pyplot as plt

import jax.numpy as jnp
import optax

from flax import nnx

import tiktoken
import grain.python as grain
import tqdm

from datasets import load_dataset

## 2. Data Loading and Preparation

In [2]:
# Load the French-Fon dataset
ds = load_dataset("jonathansuru/fr_fon")

In [3]:
def data_to_pairs(data):
    text_pairs = []
    for line in data:
        fon = line["fon"]
        french = line["french"]
        french = "[start] " + french + " [end]"
        text_pairs.append((fon, french))
    return text_pairs

In [4]:
# Create train/validation/test pairs
train_pairs = data_to_pairs(ds["train"])
val_pairs = data_to_pairs(ds["validation"])
test_pairs = data_to_pairs(ds["test"])

## 3. Tokenization and Preprocessing

In [5]:
# Initialize tokenizer
tokenizer = tiktoken.get_encoding("o200k_base")

In [6]:
# Define characters to strip from a text
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

# Set vocabulary size based on tokenizer
vocab_size = tokenizer.n_vocab
sequence_length = 512

In [7]:
def custom_standardization(input_string):
    lowercase = input_string.lower()
    return re.sub(f"[{re.escape(strip_chars)}]", "", lowercase)


def tokenize_and_pad(text, tokenizer, max_length):
    tokens = tokenizer.encode(text)[:max_length]
    padded = tokens + [0] * (max_length - len(tokens)) if len(tokens) < max_length else tokens ##assumes list-like - (https://github.com/openai/tiktoken/blob/main/tiktoken/core.py#L81 current tiktoken out)
    return padded

def format_dataset(fon, french, tokenizer, sequence_length):
    fon = custom_standardization(fon)
    french = custom_standardization(french)
    fon = tokenize_and_pad(fon, tokenizer, sequence_length)
    french = tokenize_and_pad(french, tokenizer, sequence_length)
    return {
            "encoder_inputs": fon,
            "decoder_inputs": french[:-1],
            "target_output": french[1:],
    }

In [8]:
# Prepare datasets
train_data = [format_dataset(fon, french, tokenizer, sequence_length) for fon, french in train_pairs]
val_data = [format_dataset(fon, french, tokenizer, sequence_length) for fon, french in val_pairs]
test_data = [format_dataset(fon, french, tokenizer, sequence_length) for fon, french in test_pairs]

## 4. Data Loaders

In [20]:
# Set batch size
batch_size = 2 #set here for the loader and model train later on

class CustomPreprocessing(grain.MapTransform):
    def __init__(self):
        pass

    def map(self, data):
        return {
                "encoder_inputs": np.array(data["encoder_inputs"]),
                "decoder_inputs": np.array(data["decoder_inputs"]),
                "target_output": np.array(data["target_output"]),
        }


In [21]:
# Create data loaders with grain
train_sampler = grain.IndexSampler(
        len(train_data) ,
        shuffle=True ,
        seed=12 ,  # Seed for reproducibility
        shard_options=grain.NoSharding( ) ,  # No sharding since it's a single-device setup
        num_epochs=1 ,  # Iterate over the dataset for one epoch
)

val_sampler = grain.IndexSampler(
        len(val_data) ,
        shuffle=False ,
        seed=12 ,
        shard_options=grain.NoSharding( ) ,
        num_epochs=1 ,
)

train_loader = grain.DataLoader(
        data_source=train_data ,
        sampler=train_sampler ,  # Sampler to determine how to access the data
        worker_count=4 ,  # Number of child processes launched to parallelize the transformations
        worker_buffer_size=2 ,  # Count of output batches to produce in advance per worker
        operations=[
                CustomPreprocessing( ) ,
                grain.Batch(batch_size=batch_size , drop_remainder=True) ,
        ]
)

val_loader = grain.DataLoader(
        data_source=val_data ,
        sampler=val_sampler ,
        worker_count=4 ,
        worker_buffer_size=2 ,
        operations=[
                CustomPreprocessing( ) ,
                grain.Batch(batch_size=batch_size) ,
        ]
)



## 5. Model Architecture

In [78]:
class LuongAttention(nnx.Module):
    def __init__(self, hidden_size, src_vocab_size, tgt_vocab_size, rngs=nnx.Rngs):
        self.source_embedding = nnx.Embed(src_vocab_size, hidden_size, rngs=rngs)
        self.target_embedding = nnx.Embed(tgt_vocab_size, hidden_size, rngs=rngs)

        # Initialize RNNs for encoder and decoder
        self.encoder = nnx.RNN(
                nnx.GRUCell(hidden_size, hidden_size, rngs=rngs),
                return_carry=True
        )
        self.decoder = nnx.RNN(
                nnx.GRUCell(hidden_size, hidden_size, rngs=rngs),
                return_carry=True
        )

        self.W_c = nnx.Linear(hidden_size * 2, hidden_size, rngs=rngs)
        self.W_y = nnx.Linear(hidden_size, tgt_vocab_size, rngs=rngs)

    def __call__(self, source, target, h_init):
        # Compute embeddings
        source_seq = self.source_embedding(source)  # (batch, src_seq_len, hidden)
        target_seq = self.target_embedding(target)  # (batch, tgt_seq_len, hidden)

        # Encoder pass
        h_final, h_t = self.encoder(source_seq, initial_carry=h_init)

        # Decoder pass
        s_final, s_t = self.decoder(target_seq, initial_carry=h_final)

        # Reshape hidden states for attention calculation
        # s_t shape: (batch, tgt_seq_len, hidden)
        # h_t shape: (batch, src_seq_len, hidden)

        # Compute attention scores using dot product
        # (batch, tgt_seq_len, hidden) @ (batch, hidden, src_seq_len) = (batch, tgt_seq_len, src_seq_len)
        e_t_i = jnp.matmul(s_t, jnp.transpose(h_t, (0, 2, 1)))

        # Apply softmax to get alignment weights (along src_seq_len dimension)
        alignment_scores = nnx.softmax(e_t_i, axis=-1)

        # Compute context vectors
        # (batch, tgt_seq_len, src_seq_len) @ (batch, src_seq_len, hidden) = (batch, tgt_seq_len, hidden)
        c_t = jnp.matmul(alignment_scores, h_t)

        # Concatenate and process for final output
        # Concatenate decoder output with context vector
        s_hat_t = jnp.concatenate([s_t, c_t], axis=-1)  # (batch, tgt_seq_len, hidden*2)
        s_hat_t = nnx.tanh(self.W_c(s_hat_t))  # (batch, tgt_seq_len, hidden)

        # Project to vocabulary space
        y_t = self.W_y(s_hat_t)  # (batch, tgt_seq_len, vocab_size)

        return y_t

In [79]:
# Inicializar red
net = LuongAttention(
        hidden_size=256,
        src_vocab_size=3371,
        tgt_vocab_size=2810,
        rngs=nnx.Rngs(42)
)

# Crear entradas de prueba
source = jnp.array([[10, 23, 5]])  # (batch=1, src_seq_len=3)
target = jnp.array([[4, 9]])       # (batch=1, tgt_seq_len=2)

# Create initial hidden state with proper dimensions
batch_size = source.shape[0]
h_init = jnp.zeros((batch_size, 256))  # (batch=1, hidden_size)

# Forward pass
output = net(source, target, h_init)
print(output.shape)

(1, 2, 2810)


## 6. Training Functions

In [80]:
def compute_loss(logits, labels):
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
    return jnp.mean(loss)

In [81]:
@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model, train_encoder_input, train_decoder_input, train_target_input):
        h = jnp.zeros((train_encoder_input.shape[0],embed_dim))
        logits = model(train_encoder_input, train_decoder_input, h)
        loss = compute_loss(logits, train_target_input)
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model, jnp.array(batch["encoder_inputs"]), jnp.array(batch["decoder_inputs"]), jnp.array(batch["target_output"]))
    optimizer.update(grads)
    return loss

@nnx.jit
def eval_step(model, batch, eval_metrics):
    h = jnp.zeros(( jnp.array(batch["encoder_inputs"]).shape[0], embed_dim))
    logits = model(jnp.array(batch["encoder_inputs"]), jnp.array(batch["decoder_inputs"]), h)
    loss = compute_loss(logits, jnp.array(batch["target_output"]))
    labels = jnp.array(batch["target_output"])

    eval_metrics.update(
            loss=loss,
            logits=logits,
            labels=labels,
    )

In [82]:
# Initialize metrics tracking
eval_metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average('loss'),
        accuracy=nnx.metrics.Accuracy(),
)

train_metrics_history = {
        "train_loss": [],
}

eval_metrics_history = {
        "test_loss": [],
        "test_accuracy": [],
}

 ## 7. Hyperparameters and Model Setup

In [83]:
## Hyperparameters
rng = nnx.Rngs(0)
embed_dim = 256
latent_dim = 2048
vocab_size = tokenizer.n_vocab
sequence_length = 512
learning_rate = 1.5e-3
num_epochs = 10

In [84]:
# Initialize model
model = LuongAttention( hidden_size=embed_dim, src_vocab_size=vocab_size, tgt_vocab_size=vocab_size, rngs=rng)

In [85]:
# Initialize optimizer
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))

## 8. Training Loop

In [86]:
# Training utilities
bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"
train_total_steps = len(train_data) // batch_size

def train_one_epoch(epoch):
    model.train()  # Set model to the training mode: e.g. update batch statistics
    with tqdm.tqdm(
            desc=f"[train] epoch: {epoch}/{num_epochs}, ",
            total=train_total_steps,
            bar_format=bar_format,
            leave=True,
    ) as pbar:
        for batch in train_loader:
            loss = train_step(model, optimizer, batch)
            train_metrics_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)


def evaluate_model(epoch):
    # Compute the metrics on the train and val sets after each training epoch.
    model.eval()  # Set model to evaluation model: e.g. use stored batch statistics

    eval_metrics.reset()  # Reset the eval metrics
    for val_batch in val_loader:
        eval_step(model, val_batch, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'test_{metric}'].append(value)

    print(f"[test] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}")

In [None]:
# Run training loop
for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)