This notebook pre-trains a Transformer model on the Helsinki-NLP/opus_books dataset using the Masked Language Modeling (MLM) objective. We will train both a classical and a quantum-hybrid version of the model to compare their performance. The key evaluation metric for this task is Perplexity (PPL), where a lower value is better.

In [1]:
import jax
import tensorflow as tf

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')  

# Import our custom modules
from quantum_transformers.datasets import get_mlm_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit

# Define the directory where datasets are stored/cached
data_dir = './data'

2025-09-15 14:34:06.128381: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-15 14:34:06.128519: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-15 14:34:06.165393: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Please first ``pip install -U qiskit`` to enable related functionality in translation module
Please first ``pip install -U cirq`` to enable related functionality in translation module


The models are trained using the following devices:

In [2]:
print("Available JAX devices:")
for d in jax.devices():
    print(f"- {d} ({d.device_kind})")

Available JAX devices:
- gpu:0 (NVIDIA GeForce RTX 3050 6GB Laptop GPU)


Let's check how big is the vocabulary, and see an example of one example review (both in tokenized and raw form).

In [3]:
# Set data loading parameters
block_size = 128  # The size of our text chunks
batch_size = 16   # How many chunks to process at once

print("Loading and preparing the dataset for MLM...")

# Get the dataloaders and the tokenizer
(train_dataloader, val_dataloader, test_dataloader), tokenizer = get_mlm_dataloaders(
    dataset_name='Helsinki-NLP/opus_books',
    model_checkpoint='bert-base-uncased',
    # data_dir=data_dir,
    block_size=block_size,
    batch_size=batch_size
)

print("\nDataset loading complete.")
print(f"Tokenizer vocabulary size: {len(tokenizer.vocab)}")

## fourth cell. This is a crucial code cell that calls our new dataloader function. The first time you run this, it will download the Helsinki-NLP/opus_books dataset and process it, which might take a few minutes depending on your internet connection.

Loading and preparing the dataset for MLM...


Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (874 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]

Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]


Dataset loading complete.
Tokenizer vocabulary size: 30522


In [4]:
# --- Train a Classical Transformer as a Baseline ---

print("--- Starting Classical Transformer Training ---")

# Define the classical model's hyperparameters
classical_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,              # Small hidden size for comparability with quantum models
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1
)

# Train and evaluate the classical model
# Unpack the returned tuple into two variables
classical_test_loss, classical_test_ppl = train_and_evaluate(
    model=classical_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    task='mlm',
    num_epochs=2 
)

print("\n--- Classical Transformer Training Finished ---")
# Use the new variable to print the result
print(f"Final Test Perplexity: {classical_test_ppl:.4f}")


--- Starting Classical Transformer Training ---
Number of parameters = 521498


Epoch 1/2: 1191it [00:54, 21.83it/s, Loss=10.2358, PPL=27883.92]
                                   

Epoch 1: Train Loss = 10.5655, Val Loss = 10.2883, Val PPL = 29386.40


Epoch 2/2: 1191it [01:05, 18.23it/s, Loss=10.0146, PPL=22351.32]
                                    

Epoch 2: Train Loss = 10.1303, Val Loss = 9.9069, Val PPL = 20068.01
Total training time = 126.44s, best validation loss = 9.9069 at epoch 2


Testing: 146it [00:03, 36.74it/s]

Test Loss = 9.9067, Test PPL = 20064.53

--- Classical Transformer Training Finished ---
Final Test Perplexity: 20064.5273





In [7]:
# --- Train the Quantum Transformer ---

print("--- Starting Quantum Transformer Training ---")

# Define the quantum model's hyperparameters (identical to classical for fair comparison)
quantum_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1,
    quantum_attn_circuit=get_circuit(),  # Activate the quantum attention
    quantum_mlp_circuit=get_circuit()    # Activate the quantum MLP
)

# Train and evaluate the quantum model
# Unpack the returned tuple into two variables
quantum_test_loss, quantum_test_ppl = train_and_evaluate(
    model=quantum_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    task='mlm',
    num_epochs=2  # A shorter run for demonstration; increase for better results
)

print("\n--- Quantum Transformer Training Finished ---")
print(f"Final Test Perplexity: {quantum_test_ppl:.4f}")

--- Starting Quantum Transformer Training ---
Number of parameters = 520490
Number of parameters = 520490


Epoch 1/2: 1191it [07:25,  2.68it/s, Loss=10.3625, PPL=31648.71]
Epoch 1/2: 1191it [07:25,  2.68it/s, Loss=10.3625, PPL=31648.71]
                                    

Epoch 1: Train Loss = 10.5900, Val Loss = 10.3707, Val PPL = 31911.87


Epoch 2/2: 1191it [04:59,  3.97it/s, Loss=9.9607, PPL=21178.38] 
Epoch 2/2: 1191it [04:59,  3.97it/s, Loss=9.9607, PPL=21178.38] 
                                   

Epoch 2: Train Loss = 10.2565, Val Loss = 10.0239, Val PPL = 22558.38
Total training time = 796.59s, best validation loss = 10.0239 at epoch 2


Testing: 146it [00:28,  5.19it/s]



Test Loss = 10.0137, Test PPL = 22329.61

--- Quantum Transformer Training Finished ---
Final Test Perplexity: 22329.6133


### Add this the next time you work on this

In [None]:

# # After classical model definition, add:
# print("\nClassical Model:")
# params = classical_model.init(jax.random.PRNGKey(0), jnp.ones((1, block_size)), train=False)['params']
# num_params = sum(p.size for p in jax.tree_util.tree_leaves(params))
# print(f"Number of parameters: {num_params:,}")

# # And after quantum model definition, add:
# print("\nQuantum Model:")
# params = quantum_model.init(jax.random.PRNGKey(0), jnp.ones((1, block_size)), train=False)['params']
# num_params = sum(p.size for p in jax.tree_util.tree_leaves(params))
# print(f"Number of parameters: {num_params:,}")