In [1]:
# --- 1. FRAMEWORK SETUP (MUST BE FIRST) ---
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
# This MUST run before JAX is imported.
tf.config.set_visible_devices([], device_type='GPU')

import jax
import jax.numpy as jnp
import os
from flax.training import train_state
from transformers import AutoTokenizer
import pickle

# --- END FRAMEWORK SETUP ---

# Import custom modules
# We need to add the project root to the path if running from 'notebooks/quantum'
import sys
sys.path.append('../..') 

# --- CORRECTED IMPORT: This is the function we will use ---
from quantum_transformers.inference import load_model 
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit

print("Available JAX devices:")
for d in jax.devices():
    print(f"- {d} ({d.device_kind})")

2025-11-07 15:43:51.809389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762505031.821205 3591764 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762505031.824933 3591764 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1762505031.837425 3591764 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762505031.837432 3591764 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762505031.837433 3591764 computation_placer.cc:177] computation placer alr

Available JAX devices:
- gpu:0 (NVIDIA GeForce RTX 4090)


In [2]:
# --- 2. DEFINE PATHS ---
# These paths must match the ones used in mlm_training.py

CLASSICAL_MODEL_PATH = '../../models/mlm_classical'
QUANTUM_MODEL_PATH = '../../models/mlm_quantum'

print(f"Classical model path: {os.path.abspath(CLASSICAL_MODEL_PATH)}")
print(f"Quantum model path: {os.path.abspath(QUANTUM_MODEL_PATH)}")

Classical model path: /dafriz/QuantumTransformers/models/mlm_classical
Quantum model path: /dafriz/QuantumTransformers/models/mlm_quantum


In [3]:
# --- 3. LOAD MODELS AND TOKENIZERS (CORRECTED) ---

# Define model hyperparameters (must match training)
VOCAB_SIZE = 1000
MAX_SEQ_LEN = 128
HIDDEN_SIZE = 8
NUM_HEADS = 2
NUM_BLOCKS = 4
MLP_HIDDEN_SIZE = 8

# Create a dummy batch to initialize the model state (required by load_model)
# Shape is (batch_size, max_seq_len)
init_batch = jnp.ones((1, MAX_SEQ_LEN), dtype=jnp.int32)

# --- Load Classical Model ---
print("Loading classical model...")

# 1. Instantiate the model structure
classical_model_instance = Transformer(
    num_tokens=VOCAB_SIZE,
    max_seq_len=MAX_SEQ_LEN,
    task='mlm',
    hidden_size=HIDDEN_SIZE,
    num_heads=NUM_HEADS,
    num_transformer_blocks=NUM_BLOCKS,
    mlp_hidden_size=MLP_HIDDEN_SIZE
)

# 2. Call your load_model function with the correct arguments
classical_params, classical_tokenizer = load_model(
    model_path=CLASSICAL_MODEL_PATH,
    model_instance=classical_model_instance,
    init_batch=init_batch
)
print(f"Classical tokenizer vocabulary size: {len(classical_tokenizer.vocab)}\n")

# --- Load Quantum Model ---
print("Loading quantum model...")

# 1. Instantiate the model structure
quantum_model_instance = Transformer(
    num_tokens=VOCAB_SIZE,
    max_seq_len=MAX_SEQ_LEN,
    task='mlm',
    hidden_size=HIDDEN_SIZE,
    num_heads=NUM_HEADS,
    num_transformer_blocks=NUM_BLOCKS,
    mlp_hidden_size=MLP_HIDDEN_SIZE,
    quantum_attn_circuit=get_circuit(),
    quantum_mlp_circuit=get_circuit()
)

# 2. Call your load_model function
quantum_params, quantum_tokenizer = load_model(
    model_path=QUANTUM_MODEL_PATH,
    model_instance=quantum_model_instance,
    init_batch=init_batch
)
print(f"Quantum tokenizer vocabulary size: {len(quantum_tokenizer.vocab)}")

Loading classical model...
Model and tokenizer loaded from ../../models/mlm_classical
Classical tokenizer vocabulary size: 1000

Loading quantum model...
Model and tokenizer loaded from ../../models/mlm_quantum
Quantum tokenizer vocabulary size: 1000


In [4]:
# --- 4. PREDICTION FUNCTION (CORRECTED) ---

from functools import partial # <-- ADD THIS IMPORT

# JIT-compile the prediction step for speed
# We tell JIT that the 2nd argument (arg 1) is a 'static' function
@partial(jax.jit, static_argnums=(1,))
def predict(params, model_apply_fn, inputs):
    logits = model_apply_fn({'params': params}, inputs, train=False)
    return logits

def predict_masked_input(text, model_instance, params, tokenizer, top_k=5):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="jax", padding="max_length", max_length=MAX_SEQ_LEN)
    input_ids = inputs.input_ids
    
    # Find the position of the [MASK] token
    mask_token_id = tokenizer.mask_token_id
    mask_position = jnp.where(input_ids == mask_token_id, 1, 0).argmax(axis=-1)[0]
    
    if mask_position == 0:
        print(f"Warning: Could not find [MASK] token in '{text}'")
        return
    
    # Get model predictions
    # We pass the model's .apply function as the static argument
    logits = predict(params, model_instance.apply, input_ids)
    
    # Get the logits for the [MASK] token's position
    mask_logits = logits[0, mask_position, :]
    
    # Find the top K predicted token IDs
    top_k_indices = jnp.argsort(mask_logits)[-top_k:][::-1]
    top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
    top_k_scores = jax.nn.softmax(mask_logits)[top_k_indices]
    
    print(f"Input: '{text}'")
    print("Predictions:")
    for token, score in zip(top_k_tokens, top_k_scores):
        print(f"  - {token} (Score: {score:.4f})")

def evaluate_on_list(texts, model_instance, params, tokenizer, top_k=5):
    for text in texts:
        predict_masked_input(text, model_instance, params, tokenizer, top_k)
        print("---")

In [5]:
# --- 5. RUN INFERENCE ---
# We use simple sentences similar to the TinyStories dataset

inference_dataset = [
    "One day, a [MASK] girl named Lily found a needle in her room",
    "Once upon a [MASK], there was a little car named Beep.",
    "One day, a little fish named Fin [MASK] swimming near the shore.",
    "Once upon a time, in a small yard, there was a small daisy. The daisy had a name. [MASK] name was Daisy. Daisy was very small, but she was also very happy.",
    "Tom kicked the ball high in the sky. The [MASK] went far, far away."
]

print("\n" + "="*30)
print("--- Classical Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model_instance=classical_model_instance, 
    params=classical_params, 
    tokenizer=classical_tokenizer,
    top_k=5  # Show top 5 predictions
)

print("\n" + "="*30)
print("--- Quantum Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model_instance=quantum_model_instance, 
    params=quantum_params, 
    tokenizer=quantum_tokenizer,
    top_k=5  # Show top 5 predictions
)

print("\n--- Inference Complete ---")

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.



--- Classical Model Batch Prediction ---
Input: 'One day, a [MASK] girl named Lily found a needle in her room'
Predictions:
  - s (Score: 0.0361)
  - ##e (Score: 0.0312)
  - t (Score: 0.0292)
  - . (Score: 0.0270)
  - ##a (Score: 0.0265)
---
Input: 'Once upon a [MASK], there was a little car named Beep.'
Predictions:
  - ##e (Score: 0.0324)
  - s (Score: 0.0322)
  - t (Score: 0.0274)
  - ##o (Score: 0.0258)
  - ##he (Score: 0.0256)
---
Input: 'One day, a little fish named Fin [MASK] swimming near the shore.'
Predictions:
  - s (Score: 0.0331)
  - ##e (Score: 0.0326)
  - t (Score: 0.0276)
  - ##a (Score: 0.0268)
  - ##he (Score: 0.0255)
---
Input: 'Once upon a time, in a small yard, there was a small daisy. The daisy had a name. [MASK] name was Daisy. Daisy was very small, but she was also very happy.'
Predictions:
  - t (Score: 0.0485)
  - a (Score: 0.0405)
  - ##a (Score: 0.0368)
  - . (Score: 0.0365)
  - ##e (Score: 0.0354)
---
Input: 'Tom kicked the ball high in the sky. The [MASK]