# Testing inference

In [1]:
import jax
import tensorflow as tf
import os
import pickle
import jax.numpy as jnp
from flax import serialization

# Import custom modules
from quantum_transformers.datasets import get_mlm_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit
from quantum_transformers.inference import save_model, load_model, predict_masked_token, evaluate_on_list

2025-11-01 21:56:29.544426: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762008989.555151 2087615 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762008989.558590 2087615 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1762008989.568420 2087615 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762008989.568429 2087615 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762008989.568430 2087615 computation_placer.cc:177] computation placer alr

In [2]:
# 1. SETUP 
print("Setting up environment...")

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

# Define directories
data_dir = './data'
CLASSICAL_MODEL_PATH = './models/mlm_classical'
QUANTUM_MODEL_PATH = './models/mlm_quantum'
os.makedirs(CLASSICAL_MODEL_PATH, exist_ok=True)
os.makedirs(QUANTUM_MODEL_PATH, exist_ok=True)

# Print JAX devices
print("Available JAX devices:")
for d in jax.devices():
    print(f"- {d} ({d.device_kind})")

Setting up environment...
Available JAX devices:
- gpu:0 (NVIDIA GeForce RTX 4090)


In [3]:
#2. LOAD DATA 
print("\nLoading and preparing dataset...")

# Set data loading parameters
block_size = 128  # The size of our text chunks
batch_size = 16   # How many chunks to process at once

# Get the dataloaders and the tokenizer
(train_dataloader_gen, val_dataloader_gen, test_dataloader_gen), tokenizer = get_mlm_dataloaders(
    dataset_name='Helsinki-NLP/opus_books',
    model_checkpoint='bert-base-uncased',
    block_size=block_size,
    batch_size=batch_size
)

print(f"\nDataset loading complete.")
print(f"Tokenizer vocabulary size: {len(tokenizer.vocab)}")

# Get one batch for model initialization
try:
    init_batch_tuple = next(iter(train_dataloader_gen()))
    init_batch_input = init_batch_tuple[0]
    print(f"Initialization batch shape: {init_batch_input.shape}")
except StopIteration:
    print("Error: Training dataloader is empty. Cannot initialize models.")
    # In a notebook, you might want to raise an error or just stop
    # return


Loading and preparing dataset...


Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (580 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]

Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]


Dataset loading complete.
Tokenizer vocabulary size: 30522
Initialization batch shape: (16, 128)


In [4]:
#3. TRAIN CLASSICAL MODEL 
print("\nStarting Classical Transformer Training")

classical_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1
)



Starting Classical Transformer Training


In [5]:
(classical_test_loss, classical_test_ppl), classical_best_state = train_and_evaluate(
    model=classical_model,
    train_dataloader=train_dataloader_gen,
    val_dataloader=val_dataloader_gen,
    test_dataloader=test_dataloader_gen,
    task='mlm',
    num_epochs=20  # Change this for different training epoch numbers
)

print("\n--- Classical Transformer Training Finished ---")
print(f"Final Test Perplexity: {classical_test_ppl:.4f}")

# Save the classical model
save_model(classical_best_state, tokenizer, CLASSICAL_MODEL_PATH)

Number of parameters = 521498


Epoch 1/20: 1192it [00:17, 68.62it/s, Loss=10.4501, PPL=34547.27] 
                                     

Epoch 1: Train Loss = 10.5670, Val Loss = 10.2990, Val PPL = 29702.53


Epoch 2/20: 1192it [00:11, 103.14it/s, Loss=9.8817, PPL=19568.73] 
                                     

Epoch 2: Train Loss = 10.1359, Val Loss = 9.9129, Val PPL = 20189.85


Epoch 3/20: 1192it [00:11, 101.39it/s, Loss=9.5651, PPL=14258.14]
                                     

Epoch 3: Train Loss = 9.7721, Val Loss = 9.5709, Val PPL = 14341.00


Epoch 4/20: 1192it [00:11, 100.28it/s, Loss=9.3605, PPL=11620.23]
                                     

Epoch 4: Train Loss = 9.4268, Val Loss = 9.2310, Val PPL = 10208.59


Epoch 5/20: 1192it [00:12, 98.98it/s, Loss=8.8217, PPL=6780.01] 
                                     

Epoch 5: Train Loss = 9.0626, Val Loss = 8.8602, Val PPL = 7045.69


Epoch 6/20: 1192it [00:12, 96.66it/s, Loss=8.2523, PPL=3836.51]
                                     

Epoch 6: Train Loss = 8.6824, Val Loss = 8.4643, Val PPL = 4742.62


Epoch 7/20: 1192it [00:12, 95.95it/s, Loss=8.0463, PPL=3122.25]
                                     

Epoch 7: Train Loss = 8.3153, Val Loss = 8.1079, Val PPL = 3320.69


Epoch 8/20: 1192it [00:12, 94.92it/s, Loss=7.7956, PPL=2429.83]
                                     

Epoch 8: Train Loss = 7.9525, Val Loss = 7.7378, Val PPL = 2293.34


Epoch 9/20: 1192it [00:12, 93.05it/s, Loss=7.5697, PPL=1938.57]
                                     

Epoch 9: Train Loss = 7.6459, Val Loss = 7.4699, Val PPL = 1754.50


Epoch 10/20: 1192it [00:12, 97.25it/s, Loss=7.2749, PPL=1443.56]
                                     

Epoch 10: Train Loss = 7.3824, Val Loss = 7.2455, Val PPL = 1401.71


Epoch 11/20: 1192it [00:12, 98.04it/s, Loss=6.9125, PPL=1004.75]
                                     

Epoch 11: Train Loss = 7.1713, Val Loss = 7.0709, Val PPL = 1177.19


Epoch 12/20: 1192it [00:12, 96.62it/s, Loss=7.2897, PPL=1465.14]
                                     

Epoch 12: Train Loss = 7.0278, Val Loss = 6.9523, Val PPL = 1045.53


Epoch 13/20: 1192it [00:12, 96.66it/s, Loss=6.7979, PPL=895.98] 
                                     

Epoch 13: Train Loss = 6.9326, Val Loss = 6.8951, Val PPL = 987.47


Epoch 14/20: 1192it [00:12, 97.19it/s, Loss=6.6591, PPL=779.84] 
                                     

Epoch 14: Train Loss = 6.8748, Val Loss = 6.8371, Val PPL = 931.78


Epoch 15/20: 1192it [00:12, 97.89it/s, Loss=6.8049, PPL=902.21] 
                                     

Epoch 15: Train Loss = 6.8133, Val Loss = 6.7908, Val PPL = 889.62


Epoch 16/20: 1192it [00:12, 97.63it/s, Loss=6.4027, PPL=603.46] 
                                     

Epoch 16: Train Loss = 6.7814, Val Loss = 6.7666, Val PPL = 868.40


Epoch 17/20: 1192it [00:12, 98.25it/s, Loss=6.8051, PPL=902.43]  
                                     

Epoch 17: Train Loss = 6.7534, Val Loss = 6.7339, Val PPL = 840.42


Epoch 18/20: 1192it [00:12, 98.52it/s, Loss=6.4398, PPL=626.25] 
                                     

Epoch 18: Train Loss = 6.7163, Val Loss = 6.7051, Val PPL = 816.59


Epoch 19/20: 1192it [00:11, 99.40it/s, Loss=6.6392, PPL=764.46]  
                                     

Epoch 19: Train Loss = 6.7127, Val Loss = 6.6586, Val PPL = 779.47


Epoch 20/20: 1192it [00:11, 99.80it/s, Loss=6.3387, PPL=566.04]  
                                     

Epoch 20: Train Loss = 6.6977, Val Loss = 6.6971, Val PPL = 810.07
Total training time = 264.35s, best validation loss = 6.6586 at epoch 19


Testing: 147it [00:01, 104.86it/s]

Test Loss = 6.6782, Test PPL = 794.87

--- Classical Transformer Training Finished ---
Final Test Perplexity: 794.8747
Model and tokenizer saved to ./models/mlm_classical





In [6]:
#4. TRAIN QUANTUM MODEL 
print("\nStarting Quantum Transformer Training")

quantum_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1,
    quantum_attn_circuit=get_circuit(),  # Activate the quantum attention
    quantum_mlp_circuit=get_circuit()    # Activate the quantum MLP
)


Starting Quantum Transformer Training


In [None]:
(quantum_test_loss, quantum_test_ppl), quantum_best_state = train_and_evaluate(
    model=quantum_model,
    train_dataloader=train_dataloader_gen,
    val_dataloader=val_dataloader_gen,
    test_dataloader=test_dataloader_gen,
    task='mlm',
    num_epochs=20  # Change this for different training epoch numbers
)

print("\n--- Quantum Transformer Training Finished ---")
print(f"Final Test Perplexity: {quantum_test_ppl:.4f}")

# Save the quantum model
save_model(quantum_best_state, tokenizer, QUANTUM_MODEL_PATH)


--- 4. Starting Quantum Transformer Training ---
Number of parameters = 520490


Epoch 1/2: 1189it [09:24,  2.10it/s, Loss=10.4944, PPL=36111.71]
                                    

Epoch 1: Train Loss = 10.5890, Val Loss = 10.3688, Val PPL = 31849.36


Epoch 2/2: 1189it [06:24,  3.09it/s, Loss=10.1799, PPL=26367.33]
                                    

Epoch 2: Train Loss = 10.2563, Val Loss = 10.0297, Val PPL = 22691.01
Total training time = 1010.13s, best validation loss = 10.0297 at epoch 2


Testing: 150it [00:34,  4.40it/s]


Test Loss = 10.0258, Test PPL = 22603.11

--- Quantum Transformer Training Finished ---
Final Test Perplexity: 22603.1074
Model and tokenizer saved to ./models/mlm_quantum


In [10]:
#5. RUN INFERENCE 
print("\n" + "="*30)
print("=== Running Inference ===")
print("="*30)

print("\n--- Loading Classical Model for Inference ---")
classical_params, classical_tokenizer = load_model(
    model_path=CLASSICAL_MODEL_PATH,
    model_instance=classical_model,
    init_batch=init_batch_input
)

print("\n--- Loading Quantum Model for Inference ---")
quantum_params, quantum_tokenizer = load_model(
    model_path=QUANTUM_MODEL_PATH,
    model_instance=quantum_model,
    init_batch=init_batch_input
)

# --- NEW: Short test dataset ---
inference_dataset = [
    "He went to the [MASK] to buy some bread.",
    "The capital of France is [MASK].",
    "She put the book on the [MASK].",
    "Let's go for a [MASK] in the park.",
    "The [MASK] is barking at the mailman."
]

print("\n" + "="*30)
print("--- Classical Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model=classical_model, 
    params=classical_params, 
    tokenizer=classical_tokenizer,
    top_k=3  # Show top 3 predictions
)

print("\n" + "="*30)
print("--- Quantum Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model=quantum_model, 
    params=quantum_params, 
    tokenizer=quantum_tokenizer,
    top_k=3  # Show top 3 predictions
)

print("\n--- Experiment Complete ---")


=== Running Inference ===

--- Loading Classical Model for Inference ---
Model and tokenizer loaded from ./models/mlm_classical

--- Loading Quantum Model for Inference ---
Model and tokenizer loaded from ./models/mlm_quantum

--- Classical Model Batch Prediction ---
--- Running batch inference on 5 sentences ---

Example 1:
Input: 'He went to the [MASK] to buy some bread.'
Top predictions:
  - ,               (Logit: 7.08)
  - the             (Logit: 6.65)
  - "               (Logit: 6.18)

Example 2:
Input: 'The capital of France is [MASK].'
Top predictions:
  - ,               (Logit: 7.14)
  - the             (Logit: 6.57)
  - to              (Logit: 6.21)

Example 3:
Input: 'She put the book on the [MASK].'
Top predictions:
  - ,               (Logit: 6.95)
  - the             (Logit: 6.77)
  - "               (Logit: 6.26)

Example 4:
Input: 'Let's go for a [MASK] in the park.'
Top predictions:
  - ,               (Logit: 6.96)
  - the             (Logit: 6.77)
  - "            