# Testing inference

In [1]:
import jax
import tensorflow as tf
import os
import pickle
import jax.numpy as jnp
from flax import serialization

# Import custom modules
from quantum_transformers.datasets import get_mlm_dataloaders
from quantum_transformers.training import train_and_evaluate
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit
from quantum_transformers.inference import save_model, load_model, predict_masked_token, evaluate_on_list

2025-11-01 14:11:45.285221: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-11-01 14:11:45.285357: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-11-01 14:11:45.365667: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Please first ``pip install -U qiskit`` to enable related functionality in translation module
Please first ``pip install -U cirq`` to enable related functionality in translation module


In [None]:
# 1. SETUP 
print("Setting up environment...")

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

# Define directories
data_dir = './data'
CLASSICAL_MODEL_PATH = './models/mlm_classical'
QUANTUM_MODEL_PATH = './models/mlm_quantum'
os.makedirs(CLASSICAL_MODEL_PATH, exist_ok=True)
os.makedirs(QUANTUM_MODEL_PATH, exist_ok=True)

# Print JAX devices
print("Available JAX devices:")
for d in jax.devices():
    print(f"- {d} ({d.device_kind})")

--- 1. Setting up environment ---
Available JAX devices:
- gpu:0 (NVIDIA GeForce RTX 3050 6GB Laptop GPU)


In [None]:
#2. LOAD DATA 
print("\nLoading and preparing dataset...")

# Set data loading parameters
block_size = 128  # The size of our text chunks
batch_size = 16   # How many chunks to process at once

# Get the dataloaders and the tokenizer
(train_dataloader_gen, val_dataloader_gen, test_dataloader_gen), tokenizer = get_mlm_dataloaders(
    dataset_name='Helsinki-NLP/opus_books',
    model_checkpoint='bert-base-uncased',
    block_size=block_size,
    batch_size=batch_size
)

print(f"\nDataset loading complete.")
print(f"Tokenizer vocabulary size: {len(tokenizer.vocab)}")

# Get one batch for model initialization
try:
    init_batch_tuple = next(iter(train_dataloader_gen()))
    init_batch_input = init_batch_tuple[0]
    print(f"Initialization batch shape: {init_batch_input.shape}")
except StopIteration:
    print("Error: Training dataloader is empty. Cannot initialize models.")
    # In a notebook, you might want to raise an error or just stop
    # return


--- 2. Loading and preparing dataset ---


Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (874 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]

Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]


Dataset loading complete.
Tokenizer vocabulary size: 30522
Initialization batch shape: (16, 128)


In [None]:
#3. TRAIN CLASSICAL MODEL 
print("\nStarting Classical Transformer Training")

classical_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1
)



--- 3. Starting Classical Transformer Training ---


In [None]:
(classical_test_loss, classical_test_ppl), classical_best_state = train_and_evaluate(
    model=classical_model,
    train_dataloader=train_dataloader_gen,
    val_dataloader=val_dataloader_gen,
    test_dataloader=test_dataloader_gen,
    task='mlm',
    num_epochs=2  # A shorter run for demonstration
)

print("\n--- Classical Transformer Training Finished ---")
print(f"Final Test Perplexity: {classical_test_ppl:.4f}")

# Save the classical model
save_model(classical_best_state, tokenizer, CLASSICAL_MODEL_PATH)


--- 3. Starting Classical Transformer Training ---
Number of parameters = 521498


Epoch 1/2: 1189it [00:59, 19.96it/s, Loss=10.3391, PPL=30917.71]
                                    

Epoch 1: Train Loss = 10.5682, Val Loss = 10.3013, Val PPL = 29772.06


Epoch 2/2: 1189it [00:58, 20.32it/s, Loss=9.8803, PPL=19542.40] 
                                   

Epoch 2: Train Loss = 10.1327, Val Loss = 9.9135, Val PPL = 20201.25
Total training time = 115.72s, best validation loss = 9.9135 at epoch 2


Testing: 150it [00:03, 38.88it/s]


Test Loss = 9.9059, Test PPL = 20047.90

--- Classical Transformer Training Finished ---
Final Test Perplexity: 20047.9043
Model and tokenizer saved to ./models/mlm_classical


In [None]:
#4. TRAIN QUANTUM MODEL 
print("\nStarting Quantum Transformer Training")

quantum_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1,
    quantum_attn_circuit=get_circuit(),  # Activate the quantum attention
    quantum_mlp_circuit=get_circuit()    # Activate the quantum MLP
)


--- 4. Starting Quantum Transformer Training ---


In [None]:
(quantum_test_loss, quantum_test_ppl), quantum_best_state = train_and_evaluate(
    model=quantum_model,
    train_dataloader=train_dataloader_gen,
    val_dataloader=val_dataloader_gen,
    test_dataloader=test_dataloader_gen,
    task='mlm',
    num_epochs=2  # A shorter run for demonstration
)

print("\n--- Quantum Transformer Training Finished ---")
print(f"Final Test Perplexity: {quantum_test_ppl:.4f}")

# Save the quantum model
save_model(quantum_best_state, tokenizer, QUANTUM_MODEL_PATH)


--- 4. Starting Quantum Transformer Training ---
Number of parameters = 520490


Epoch 1/2: 1189it [09:24,  2.10it/s, Loss=10.4944, PPL=36111.71]
                                    

Epoch 1: Train Loss = 10.5890, Val Loss = 10.3688, Val PPL = 31849.36


Epoch 2/2: 1189it [06:24,  3.09it/s, Loss=10.1799, PPL=26367.33]
                                    

Epoch 2: Train Loss = 10.2563, Val Loss = 10.0297, Val PPL = 22691.01
Total training time = 1010.13s, best validation loss = 10.0297 at epoch 2


Testing: 150it [00:34,  4.40it/s]


Test Loss = 10.0258, Test PPL = 22603.11

--- Quantum Transformer Training Finished ---
Final Test Perplexity: 22603.1074
Model and tokenizer saved to ./models/mlm_quantum


In [None]:
#5. RUN INFERENCE 
print("\n" + "="*30)
print("=== Running Inference ===")
print("="*30)

print("\n--- Loading Classical Model for Inference ---")
classical_params, classical_tokenizer = load_model(
    model_path=CLASSICAL_MODEL_PATH,
    model_instance=classical_model,
    init_batch=init_batch_input
)

print("\n--- Loading Quantum Model for Inference ---")
quantum_params, quantum_tokenizer = load_model(
    model_path=QUANTUM_MODEL_PATH,
    model_instance=quantum_model,
    init_batch=init_batch_input
)

# --- NEW: Short test dataset ---
inference_dataset = [
    "He went to the [MASK] to buy some bread.",
    "The capital of France is [MASK].",
    "She put the book on the [MASK].",
    "Let's go for a [MASK] in the park.",
    "The [MASK] is barking at the mailman."
]

print("\n" + "="*30)
print("--- Classical Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model=classical_model, 
    params=classical_params, 
    tokenizer=classical_tokenizer,
    top_k=3  # Show top 3 predictions
)

print("\n" + "="*30)
print("--- Quantum Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model=quantum_model, 
    params=quantum_params, 
    tokenizer=quantum_tokenizer,
    top_k=3  # Show top 3 predictions
)

print("\n--- Experiment Complete ---")


--- 5. Running Inference ---

--- Loading Classical Model for Inference ---
Model and tokenizer loaded from ./models/mlm_classical

--- Loading Quantum Model for Inference ---
Model and tokenizer loaded from ./models/mlm_quantum

--- Classical Model Batch Prediction ---
--- Running batch inference on 5 sentences ---

Example 1:
Input: 'He went to the [MASK] to buy some bread.'
Top predictions:
  - authentication  (Logit: 3.32)
  - ##gar           (Logit: 3.27)
  - rec             (Logit: 3.19)

Example 2:
Input: 'The capital of France is [MASK].'
Top predictions:
  - ##dah           (Logit: 3.44)
  - listened        (Logit: 3.33)
  - mermaid         (Logit: 3.16)

Example 3:
Input: 'She put the book on the [MASK].'
Top predictions:
  - gland           (Logit: 3.35)
  - ##thic          (Logit: 3.22)
  - ##dah           (Logit: 3.14)

Example 4:
Input: 'Let's go for a [MASK] in the park.'
Top predictions:
  - gland           (Logit: 3.37)
  - ##thic          (Logit: 3.20)
  - ##par     