In [1]:
import sys
from os.path import dirname, abspath
parent = dirname(dirname(abspath("__file__")))
sys.path.insert(0, str(parent))

In [2]:
import os
import pickle
import math
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

2025-07-15 10:27:11.992053: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-15 10:27:11.992077: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-15 10:27:11.993010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-15 10:27:11.997851: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
import jax
import jax.numpy as jnp

In [4]:
main_rng = jax.random.PRNGKey(421)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [5]:
import flax
from flax import linen as nn
from flax.training import train_state
from flax.training import checkpoints
import optax

In [6]:
from utils_display import nice_colorbar
from utils_display import pc
from jax_transformer_display_helpers import display_scaled_dot_product
from jax_transformer_display_helpers import display_positional_encoding
from jax_transformer_display_helpers import display_positional_encoding_profiles
from jax_transformer_display_helpers import display_lr_scheduler

from trainer import Trainer
from transformer_predictor import TransformerPredictor

from transformer_helpers import MultiheadAttention
from transformer_helpers import EncoderBlock
from transformer_helpers import TransformerEncoder
from transformer_helpers import scaled_dot_product
from transformer_helpers import expand_mask
from transformer_helpers import PositionalEncoding

In [7]:
CHECKPOINT_PATH = "/media/guillaume/DATA/NERD/GitHub/nlp/jax_transformer_mlm/jax_checkpoints/"

# Scaled dot product attention

In [8]:
sequence_length = 5
embedding_dimensionality = 3

_, rand1 = jax.random.split(main_rng)

qkv = jax.random.normal(rand1, (3, sequence_length, embedding_dimensionality))
q, k, v = qkv[0], qkv[1], qkv[2]

mask = jnp.zeros((sequence_length, sequence_length))
mask = mask.at[3,:].set(1)

weighted_sum_of_values, attention_weights = scaled_dot_product(q, k, v, mask)
                 
## display_scaled_dot_product(q, k, v, mask, weighted_sum_of_values, attention_weights)

# Multi-head attention

In [9]:
# Test
batch_size = 2
sequence_length = 13
embedding_dimensionality = 32
number_of_heads = 4

main_rng, x_rng = jax.random.split(main_rng)

x = jax.random.normal(x_rng, (batch_size, sequence_length, embedding_dimensionality))

mha = MultiheadAttention(embedding_dimensionality=embedding_dimensionality, number_of_heads=number_of_heads)

main_rng, init_rng = jax.random.split(main_rng)

params = mha.init(init_rng, x)['params']
w_o, attention_weights = mha.apply({'params': params}, x)

print('Out', w_o.shape, 'Attention', attention_weights.shape)

del w_o, attention_weights

Out (2, 13, 32) Attention (2, 4, 13, 13)


# Transformer encoder

In [10]:
# Test
main_rng, x_rng = jax.random.split(main_rng)
x = jax.random.normal(x_rng, (3, 16, 128))

encblock = EncoderBlock(input_dimensionality=128, number_of_heads=4, feedforward_dimensionality=512, dropout_probability=0.1)

main_rng, init_rng, dropout_init_rng = jax.random.split(main_rng, 3)
params = encblock.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']

main_rng, dropout_apply_rng = jax.random.split(main_rng)
out = encblock.apply({'params': params}, x, train=True, rngs={'dropout': dropout_apply_rng})
print('Out', out.shape)

del encblock, params

Out (3, 16, 128)


In [11]:
# Test
main_rng, x_rng = jax.random.split(main_rng)
x = jax.random.normal(x_rng, (3, 16, 128))

transenc = TransformerEncoder(
    number_of_layers=5,
    input_dimensionality=128,
    number_of_heads=4,
    feedforward_dimensionality=256,
    dropout_probability=0.15)

main_rng, init_rng, dropout_init_rng = jax.random.split(main_rng, 3)
params = transenc.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']

# Since dropout is stochastic, we need to pass a rng to the forward
main_rng, dropout_apply_rng = jax.random.split(main_rng)

# Instead of passing params and rngs every time to a function call, we can bind them to the module
binded_mod = transenc.bind({'params': params}, rngs={'dropout': dropout_apply_rng})
out = binded_mod(x, train=True)
print('Out', out.shape)

attn_maps = binded_mod.get_attention_maps(x, train=True)
print('Attention maps', len(attn_maps), attn_maps[0].shape)

del transenc, binded_mod, params

Out (3, 16, 128)
Attention maps 5 (3, 4, 16, 16)


# Positional encoding

In [12]:
encod_block = PositionalEncoding(hidden_dimensionality=48, maximum_sequence_length=96).bind({})
positional_encoding = jax.device_get(encod_block.positional_encoding.squeeze().T)

In [13]:
## display_positional_encoding(positional_encoding)

In [14]:
## display_positional_encoding_profiles(positional_encoding)

In [15]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

In [16]:
def cosine_warmup_schedule(base_lr: float, warmup: int, max_iters: int):
    assert warmup > 0 and max_iters > 0
    # Create function to return lr based on iteration count
    def get_lr(train_iter):
        lr_factor = 0.5 * (1 + np.cos(np.pi * train_iter / max_iters))
        if train_iter <= warmup:
            lr_factor *= train_iter * 1.0 / warmup
        return lr_factor * base_lr
    return get_lr



In [17]:
lr_scheduler = cosine_warmup_schedule(base_lr=1.0, warmup=100, max_iters=2000)

## display_lr_scheduler(lr_scheduler)

# Full transformer model

In [18]:
main_rng, x_rng = jax.random.split(main_rng)

x = jax.random.normal(x_rng, (3, 17, 1))

transpre = TransformerPredictor(
    num_layers=5,
    model_dim=128,
    num_classes=17,
    num_heads=64,
    dropout_prob=0.15,
    input_dropout_prob=0.05)

# Initialize parameters of transformer predictor with random key and inputs
main_rng, init_rng, dropout_init_rng = jax.random.split(main_rng, 3)
params = transpre.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']

print('[initialization finished]')

# Apply transformer predictor with parameters on the inputs
# Since dropout is stochastic, we need to pass a rng to the forward
main_rng, dropout_apply_rng = jax.random.split(main_rng)

# Instead of passing params and rngs every time to a function call, we can bind them to the module
model = transpre.bind({'params': params}, rngs={'dropout': dropout_apply_rng})


out = model(x, mask=None, add_positional_encoding=True, train=True)
print('Out', out.shape)


attn_maps = model.get_attention_maps(x, train=True)
print('Attention maps', len(attn_maps), attn_maps[0].shape)

del transpre, model, params

[initialization finished]
Out (3, 17, 17)
Attention maps 5 (3, 64, 17, 17)


# Trainer

In [19]:
class MLMDataset(data.Dataset):

    def __init__(self, np_rng):
        super().__init__()

        with open(os.path.join("..", "local_datasets", "wikipedia_man_o_war.pkl"), "rb") as fid:                                       
            self.dico_word2index, self.dico_index2word, self.dataset = pickle.load(fid) 
        
        self.num_categories = len(self.dico_word2index)
        self.maximum_sequence_length = len(self.dataset[0]["mask"])
        self.size = len(self.dataset)
        self.np_rng = np_rng

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        input_indices = self.dataset[index]["input_indices"]
        mask = self.dataset[index]["mask"]
        masked_indices = self.dataset[index]["masked_indices"]
        labels = self.dataset[index]["labels"]
        return input_indices, mask, masked_indices, labels

In [20]:
def numpy_collate(batch):
    
    """
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    """
    u = np.array(batch)
    return u[:, 0, :, None], u[:, 1, :, None], u[:, 2, :, None], u[:, 3, :]

In [21]:
np_rng = np.random.default_rng(421)

dataset = MLMDataset(np_rng=np_rng)

data_loader_train = data.DataLoader(
    dataset,
    batch_size=9,
    shuffle=True,
    drop_last=True,
    collate_fn=numpy_collate)

data_loader_validation = data.DataLoader(
    dataset,
    batch_size=128,
    shuffle=False,
    drop_last=False,
    collate_fn=numpy_collate)

data_loader_test = data.DataLoader(
    dataset,
    batch_size=128,
    shuffle=False,
    drop_last=False,
    collate_fn=numpy_collate)

number_of_categories = dataset.num_categories

pc("Number of categories", number_of_categories)
pc("Maximum sequence length", dataset.maximum_sequence_length)
pc("Dataset size", dataset.size)

index = 2
input_indices, mask, masked_indices, labels = data_loader_train.dataset[index]

pc("Sequence length", len(input_indices))
pc("Input indices", input_indices)
pc("Mask", mask)
pc("Masked indices", masked_indices)
pc("Labels", labels)

[34mNumber of categories[0m: 1099
[34mMaximum sequence length[0m: 50
[34mDataset size[0m: 132
[34mSequence length[0m: 50
[34mInput indices[0m: [ 59  30  14 152 265   3 266  39  10  55 153   4  26  27  48   4   7 189
   4  26  27 154 155   4 267 190 268 456  49  60 156   5   1   1   1   1
   1   1   1   1   1   1   1   1   1   1   1   1   1   1]
[34mMask[0m: [0 1 0 1 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1]
[34mMasked indices[0m: [ 59   0  14   0 265   0 266  39  10  55 153   4  26  27   0   4   7 189
   4  26   0 154 155   0 267 190 268 456  49  60 156   5   1   1   1   1
   1   1   1   1   1   1   1   1   1   1   1   1   1   1]
[34mLabels[0m: [-100   30 -100  152 -100    3 -100 -100 -100 -100 -100 -100 -100 -100
   48 -100 -100 -100 -100 -100   27 -100 -100    4 -100 -100 -100 -100
 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
 -100 -100 -100 -100 -100 -100 -100 -100]


In [22]:
u = next(iter(data_loader_train))
print(len(u))
print(u[0].shape)

4
(9, 50, 1)


In [23]:
def do_training(max_epochs=10, **model_args):
    number_of_iterations = len(data_loader_train) * max_epochs

    exmp_batch,  _,  _, _ = next(iter(data_loader_train))
    trainer = Trainer(
        model_name="Trainer",
        exmp_batch=exmp_batch,
        max_iters=number_of_iterations,
        checkpoint_path=CHECKPOINT_PATH,
        **model_args)

    if not trainer.checkpoint_exists():
        trainer.train_model(data_loader_train, data_loader_validation, num_epochs=max_epochs)
        trainer.load_model()
    else:
        trainer.load_model(pretrained=True)
        
    val_acc = trainer.eval_model(data_loader_validation)
    test_acc = trainer.eval_model(data_loader_test)
    
    # Bind parameters to model for easier inference
    trainer.model_bd = trainer.model.bind({"params": trainer.state.params})
    return trainer, {"val_acc": val_acc, "test_acc": test_acc}

In [24]:
reverse_trainer, reverse_result = do_training(
    model_dim=128,
    num_classes=number_of_categories,
    num_heads=2,                                                
    num_layers=3,
    dropout_prob=0.0,
    lr=5e-4,
    warmup=51)

  0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/14 [00:00<?, ?it/s]

