In [1]:
import jax
import tensorflow as tf
import os
import pickle

# Import custom modules
from quantum_transformers.datasets import get_mlm_dataloaders
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit
from quantum_transformers.inference import load_model, evaluate_on_list

2025-11-01 22:03:25.336455: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762009405.346565 2038996 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762009405.349927 2038996 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1762009405.358939 2038996 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762009405.358946 2038996 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762009405.358948 2038996 computation_placer.cc:177] computation placer alr

In [4]:
# --- 1. SETUP ---
print("Setting up environment for inference...")

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

# Define directories
CLASSICAL_MODEL_PATH = '../../models/mlm_classical'
QUANTUM_MODEL_PATH = '../../models/mlm_quantum'

# Check if model paths exist
if not os.path.exists(CLASSICAL_MODEL_PATH) or not os.path.exists(QUANTUM_MODEL_PATH):
    print(f"Error: Model directories not found.")
    print(f"Please run 'mlm_training.py' first to train and save the models.")
    exit()
else:
    print(f"Both models found, in {CLASSICAL_MODEL_PATH} and {QUANTUM_MODEL_PATH}")

# Define model hyperparameters (MUST match training script)
block_size = 128
batch_size = 16

Setting up environment for inference...
Both models found, in ../../models/mlm_classical and ../../models/mlm_quantum


In [5]:
# --- 2. GET INITIALIZATION DATA ---
# We MUST load the dataloader to get two things:
# 1. The exact 'tokenizer' used during training.
# 2. An 'init_batch_input' to create the model "scaffold".
print("\nLoading tokenizer and initialization batch...")
try:
    (train_dataloader_gen, _, _), tokenizer = get_mlm_dataloaders(
        dataset_name='Helsinki-NLP/opus_books',
        model_checkpoint='bert-base-uncased',
        block_size=block_size,
        batch_size=batch_size
    )
    init_batch_tuple = next(iter(train_dataloader_gen()))
    init_batch_input = init_batch_tuple[0]
    print(f"Tokenizer and init batch loaded successfully.")
    print(f"Tokenizer vocabulary size: {len(tokenizer.vocab)}")
except StopIteration:
    print("ERROR: Dataloader is empty. Cannot get initialization batch.")
    exit()


Loading tokenizer and initialization batch...


Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (522 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]

Map:   0%|          | 0/75710 [00:00<?, ? examples/s]

Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/9347 [00:00<?, ? examples/s]

Tokenizer and init batch loaded successfully.
Tokenizer vocabulary size: 30522


In [6]:
# --- 3. DEFINE MODEL STRUCTURES ---
# We need to create instances of the models so 'load_model' knows
# what structure to load the saved weights into.
print("\nInstantiating model structures...")

classical_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1
)

quantum_model = Transformer(
    num_tokens=len(tokenizer.vocab),
    max_seq_len=block_size,
    task='mlm',
    hidden_size=8,
    num_heads=2,
    num_transformer_blocks=4,
    mlp_hidden_size=4,
    dropout=0.1,
    quantum_attn_circuit=get_circuit(),
    quantum_mlp_circuit=get_circuit()
)


Instantiating model structures...


In [7]:
# --- 4. LOAD MODELS ---
print("\n" + "="*30)
print("--- 4. Loading Saved Models ---")
print("="*30)

print("\n--- Loading Classical Model for Inference ---")
classical_params, classical_tokenizer = load_model(
    model_path=CLASSICAL_MODEL_PATH,
    model_instance=classical_model,
    init_batch=init_batch_input
)

print("\n--- Loading Quantum Model for Inference ---")
quantum_params, quantum_tokenizer = load_model(
    model_path=QUANTUM_MODEL_PATH,
    model_instance=quantum_model,
    init_batch=init_batch_input
)


--- 4. Loading Saved Models ---

--- Loading Classical Model for Inference ---
Model and tokenizer loaded from ../../models/mlm_classical

--- Loading Quantum Model for Inference ---
Model and tokenizer loaded from ../../models/mlm_quantum


In [8]:
# --- 5. RUN INFERENCE ---
print("\n" + "="*30)
print("--- 5. Running Inference ---")
print("="*30)

# Create your test dataset
inference_dataset = [
    "He went to the [MASK] to buy some bread.",
    "The capital of France is [MASK].",
    "She put the book on the [MASK].",
    "Let's go for a [MASK] in the park.",
    "The [MASK] is barking at the mailman."
]

print("\n" + "="*30)
print("--- Classical Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model=classical_model, 
    params=classical_params, 
    tokenizer=classical_tokenizer,
    top_k=3  # Show top 3 predictions
)

print("\n" + "="*30)
print("--- Quantum Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model=quantum_model, 
    params=quantum_params, 
    tokenizer=quantum_tokenizer,
    top_k=3  # Show top 3 predictions
)

print("\n--- Inference Complete ---")



--- 5. Running Inference ---

--- Classical Model Batch Prediction ---
--- Running batch inference on 5 sentences ---

Example 1:
Input: 'He went to the [MASK] to buy some bread.'
Top predictions:
  - ,               (Logit: 6.99)
  - the             (Logit: 6.51)
  - and             (Logit: 6.07)

Example 2:
Input: 'The capital of France is [MASK].'
Top predictions:
  - ,               (Logit: 7.01)
  - the             (Logit: 6.45)
  - .               (Logit: 6.18)

Example 3:
Input: 'She put the book on the [MASK].'
Top predictions:
  - ,               (Logit: 6.85)
  - the             (Logit: 6.62)
  - and             (Logit: 6.07)

Example 4:
Input: 'Let's go for a [MASK] in the park.'
Top predictions:
  - ,               (Logit: 6.85)
  - the             (Logit: 6.62)
  - and             (Logit: 6.07)

Example 5:
Input: 'The [MASK] is barking at the mailman.'
Top predictions:
  - ,               (Logit: 6.88)
  - the             (Logit: 6.51)
  - .               (Logit: 6.23)

-