In [1]:
# --- 1. FRAMEWORK SETUP (MUST BE FIRST) ---
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import jax
import jax.numpy as jnp
import os
import pickle
from flax import serialization
from transformers import PreTrainedTokenizerFast
import numpy as np

# Import custom modules
from quantum_transformers.transformers import Transformer
from quantum_transformers.quantum_layer import get_circuit

print("JAX devices:", jax.devices())

2025-11-09 00:04:29.221818: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762621469.232576  721966 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762621469.235955  721966 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1762621469.245500  721966 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762621469.245511  721966 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1762621469.245512  721966 computation_placer.cc:177] computation placer alr

JAX devices: [gpu(id=0)]


In [2]:
def load_model(model_path):
    """Loads a trained model state and its tokenizer."""
    print(f"Loading model from: {model_path}")
    
    # 1. Load the tokenizer
    tokenizer_file = os.path.join(model_path, 'tinystories_tokenizer_directory', 'tokenizer.json')
    if not os.path.exists(tokenizer_file):
        # Fallback for old save structure
        tokenizer_file = os.path.join(model_path, 'tokenizer.json')
        if not os.path.exists(tokenizer_file):
            raise FileNotFoundError(f"Could not find tokenizer.json in {model_path} or its subdirectories.")
            
    tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
    
    # 2. Determine if it's a quantum model by checking path name
    is_quantum = 'quantum' in model_path.lower()
    
    # 3. Instantiate the correct model architecture
    # --- THIS MUST MATCH THE NEW mlm_training_2.py ---
    
    vocab_size = tokenizer.vocab_size
    mlp_size = 8 
    # --- NEW: Set parameters to match new training script ---
    num_blocks = 8
    vqc_shape = (4,)
    
    print(f"Instantiating model: Quantum={is_quantum}, Vocab={vocab_size}, MLP_Size={mlp_size}, Blocks={num_blocks}")

    if is_quantum:
        print(f"Using VQC Shape: {vqc_shape}")
        model_instance = Transformer(
            num_tokens=vocab_size,
            max_seq_len=128,
            task='mlm',
            hidden_size=8,
            num_heads=2,
            num_transformer_blocks=num_blocks, # <-- Updated
            mlp_hidden_size=mlp_size, 
            dropout=0.0,
            quantum_w_shape=vqc_shape, # <-- Updated
            quantum_attn_circuit=get_circuit(),
            quantum_mlp_circuit=get_circuit()
        )
    else:
        model_instance = Transformer(
            num_tokens=vocab_size,
            max_seq_len=128,
            task='mlm',
            hidden_size=8,
            num_heads=2,
            num_transformer_blocks=num_blocks, # <-- Updated
            mlp_hidden_size=mlp_size, 
            dropout=0.0
        )
    
    # 4. Load the trained parameters
    params_file = os.path.join(model_path, 'model_params.msgpack')
    with open(params_file, 'rb') as f:
        params_bytes = f.read()
    params = serialization.from_bytes(target=None, encoded_bytes=params_bytes)['params']
    
    print("Model and tokenizer loaded successfully.")
    return model_instance, params, tokenizer

def predict_masked_batch(texts, model_instance, params, tokenizer, top_k=5):
    """Tokenizes a batch of texts, predicts the [MASK] token, and decodes."""
    
    # Tokenize the batch
    inputs = tokenizer(texts, return_tensors='jax', padding=True, truncation=True, max_length=128)
    input_ids = inputs['input_ids']
    
    # Get model predictions (logits)
    logits = model_instance.apply({'params': params}, input_ids, train=False)
    
    # Find the [MASK] token's position in each text
    mask_token_id = tokenizer.mask_token_id
    mask_indices = jnp.where(input_ids == mask_token_id)
    
    results = []
    for i in range(len(texts)):
        # Find the row index (batch item) and column index (sequence position)
        # for the mask token in this specific text
        
        # --- FIX for multiple masks in one batch ---
        # Find the first mask token corresponding to this batch item 'i'
        mask_for_this_item = jnp.where(mask_indices[0] == i)[0]
        if len(mask_for_this_item) == 0:
            print(f"Warning: No [MASK] token found in text: {texts[i]}")
            continue
            
        seq_idx = mask_indices[1][mask_for_this_item[0]]
        batch_idx = i
        # --- END FIX ---
        
        # Get the logits for just that one token
        mask_logits = logits[batch_idx, seq_idx, :]
        
        # Apply softmax to get probabilities
        probs = jax.nn.softmax(mask_logits)
        
        # Get the top_k predictions
        top_probs, top_indices = jax.lax.top_k(probs, k=top_k)
        
        # Decode the tokens
        predicted_tokens = tokenizer.convert_ids_to_tokens(top_indices)
        results.append({
            "text": texts[i],
            "predictions": list(zip(predicted_tokens, np.array(top_probs)))
        })
        
    return results

def evaluate_on_list(texts, model_instance, params, tokenizer, top_k=5):
    """Helper to run prediction and print nicely."""
    results = predict_masked_batch(texts, model_instance, params, tokenizer, top_k)
    for item in results:
        print(f"Input: '{item['text']}'")
        print("Predictions:")
        for token, score in item['predictions']:
            print(f"  - {token} (Score: {score:.4f})")
        print("---")

In [3]:
# --- Paths from your training script ---
CLASSICAL_MODEL_PATH = '../../models/mlm_classical'
QUANTUM_MODEL_PATH = '../../models/mlm_quantum'

# --- Load both models ---
classical_model_instance, classical_params, classical_tokenizer = load_model(CLASSICAL_MODEL_PATH)
quantum_model_instance, quantum_params, quantum_tokenizer = load_model(QUANTUM_MODEL_PATH)

Loading model from: ../../models/mlm_classical
Instantiating model: Quantum=False, Vocab=1000, MLP_Size=8, Blocks=8


KeyError: 'params'

In [None]:
# --- Dataset for Inference ---
inference_dataset = [
    "One day, a [MASK] girl named Lily found a needle in her room",
    "Once upon a [MASK], there was a little car named Beep.",
    "One day, a little fish named Fin [MASK] swimming near the shore.",
    "Once upon a time, in a small yard, there was a small daisy. The daisy had a name. [MASK] name was Daisy. Daisy was very small, but she was also very happy.",
    "Tom kicked the ball high in the sky. The [MASK] went far, far away."
]

print("\n" + "="*30)
print("--- Classical Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model_instance=classical_model_instance, 
    params=classical_params, 
    tokenizer=classical_tokenizer,
    top_k=5  # Show top 5 predictions
)

print("\n" + "="*30)
print("--- Quantum Model Batch Prediction ---")
evaluate_on_list(
    texts=inference_dataset, 
    model_instance=quantum_model_instance, 
    params=quantum_params, 
    tokenizer=quantum_tokenizer,
    top_k=5  # Show top 5 predictions
)

print("\n--- Inference Complete ---")