This notebook pre-trains a Transformer model on the Helsinki-NLP/opus_books dataset using the Masked Language Modeling (MLM) objective. We will train both a classical and a quantum-hybrid version of the model to compare their performance. The key evaluation metric for this task is Perplexity (PPL), where a lower value is better.

In [1]:
import jax
import tensorflow as tf

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')  

# Import our custom modules
from quantum_transformers.datasets import get_mlm_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit

# Define the directory where datasets are stored/cached
data_dir = './data'

2025-09-09 17:30:50.536304: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-09 17:30:50.536382: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-09 17:30:50.618842: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-09-09 17:33:37.427917: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skip

The models are trained using the following devices:

In [2]:
print("Available JAX devices:")
for d in jax.devices():
    print(f"- {d} ({d.device_kind})")

Available JAX devices:
- gpu:0 (NVIDIA GeForce RTX 3050 6GB Laptop GPU)


Let's check how big is the vocabulary, and see an example of one example review (both in tokenized and raw form).

In [3]:
# Set data loading parameters
block_size = 128  # The size of our text chunks
batch_size = 16   # How many chunks to process at once

print("Loading and preparing the dataset for MLM...")

# Get the dataloaders and the tokenizer
(train_dataloader, val_dataloader, test_dataloader), tokenizer = get_mlm_dataloaders(
    dataset_name='Helsinki-NLP/opus_books',
    model_checkpoint='bert-base-uncased',
    # data_dir=data_dir,
    block_size=block_size,
    batch_size=batch_size
)

print("\nDataset loading complete.")
print(f"Tokenizer vocabulary size: {len(tokenizer.vocab)}")

## fourth cell. This is a crucial code cell that calls our new dataloader function. The first time you run this, it will download the Helsinki-NLP/opus_books dataset and process it, which might take a few minutes depending on your internet connection.

Loading and preparing the dataset for MLM...


Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (580 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]

Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]


Dataset loading complete.
Tokenizer vocabulary size: 30522


In [4]:
# --- Train a Classical Transformer as a Baseline ---

print("--- Starting Classical Transformer Training ---")

# Define the classical model's hyperparameters
classical_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,              # Small hidden size for comparability with quantum models
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1
)

# Train and evaluate the classical model
# Unpack the returned tuple into two variables
classical_test_loss, classical_test_ppl = train_and_evaluate(
    model=classical_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    task='mlm',
    num_epochs=2 
)

print("\n--- Classical Transformer Training Finished ---")
# Use the new variable to print the result
print(f"Final Test Perplexity: {classical_test_ppl:.4f}")


--- Starting Classical Transformer Training ---


2025-09-09 17:43:55.849664: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [None]:
# --- Train the Quantum Transformer ---

print("--- Starting Quantum Transformer Training ---")

# Define the quantum model's hyperparameters (identical to classical for fair comparison)
quantum_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout_rate=0.1,
    quantum_attn_circuit=get_circuit(),  # Activate the quantum attention
    quantum_mlp_circuit=get_circuit()    # Activate the quantum MLP
)

# Train and evaluate the quantum model
quantum_results = train_and_evaluate(
    model=quantum_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    task='mlm',
    num_epochs=2  # A shorter run for demonstration; increase for better results
)

print("\n--- Quantum Transformer Training Finished ---")
print(f"Final Test Perplexity: {quantum_results['test_perplexity']:.4f}")

