# Transformer Inference Demo

This notebook demonstrates how to use a trained Transformer model for English-to-Russian translation by **loading pre-trained weights** from a file.

We will:
1.  Load all the necessary functions from our `transformer.py` script.
2.  Load the vocabulary that maps words to tokens.
3.  Load the trained model parameters from the `transformer_weights.msgpack` file.
4.  Define an `translate` function that performs autoregressive inference.
5.  Try out the model on new sentences!

## 1. Setup and Imports

In [None]:
import jax
import jax.numpy as jnp
from flax.serialization import from_bytes

# Import all the building blocks from our script
from transformer import (
    text_to_token_ids,
    token_embeddings,
    positional_embeddings,
    transformer_encoder,
    forward,
    init_params
)

# We also need the data loading functions
from data import load_dataset_and_vocab

## 2. Load Vocabulary

In [None]:
print("Loading vocabulary...")
vocab, _, _, vocab_size = load_dataset_and_vocab(split="train", max_vocab_size=20000)
id_to_token = {v: k for k, v in vocab.items()} # For decoding the output
print(f"Vocabulary loaded. Size: {vocab_size} tokens.")

## 3. Load Trained Model Weights

Here, we load the `params` object that was saved by our training script. Note that we first need to initialize a model with the same architecture to act as a template, and then we load the saved weights into it.

In [None]:
# Define the hyperparameters to match the saved model
D_MODEL = 128
D_FF = D_MODEL * 4
N_LAYERS = 6
N_HEADS = 8

# 1. Initialize a model with the correct structure but random weights
key = jax.random.PRNGKey(42)
template_params = init_params(key, vocab_size=vocab_size, d_model=D_MODEL, d_ff=D_FF, n_heads=N_HEADS, n_layers=N_LAYERS)

# 2. Load the saved byte data from the file
print("Loading saved model weights from transformer_weights.msgpack...")
with open("transformer_weights.msgpack", "rb") as f:
    byte_data = f.read()

# 3. Restore the trained parameters into our template model
loaded_params = from_bytes(template_params, byte_data)

print("✅ Model weights loaded successfully!")

## 4. The Inference Function

In [None]:
def translate(english_sentence: str, params: dict, vocab: dict, max_output_len: int = 32, d_model: int = 128):
    """
    Translates an English sentence to Russian using the trained Transformer model.
    """
    # Get special token IDs
    sos_id = vocab['<SOS>']
    eos_id = vocab['<EOS>']
    
    # 1. Tokenize the input English sentence
    enc_input = text_to_token_ids([english_sentence], vocab, max_len=32)
    
    # 2. Run the encoder pass (this is done only once)
    inf_key = jax.random.PRNGKey(0)
    keys = jax.random.split(inf_key, 12)
    enc_keys = keys[:6]
    enc_emb = token_embeddings(enc_input, params['embedding']['W_emb'], d_model)
    enc_emb += positional_embeddings(max_len=32, d_model=d_model)
    enc_output = transformer_encoder(enc_emb, params['encoder'], enc_keys, d_model=d_model, training=False)

    # 3. Autoregressive decoding loop
    dec_input_ids = [sos_id]
    
    for _ in range(max_output_len):
        current_len = len(dec_input_ids)
        dec_input = jnp.array([dec_input_ids + [0] * (32 - current_len)])
        
        logits, _ = forward(params, enc_input, dec_input, vocab_size=len(vocab), d_model=d_model, training=False, key=inf_key)
        
        last_token_logits = logits[0, current_len - 1, :]
        predicted_token_id = jnp.argmax(last_token_logits).item()
        
        dec_input_ids.append(predicted_token_id)
        
        if predicted_token_id == eos_id:
            break
            
    # 4. Convert token IDs back to words
    output_words = []
    for token_id in dec_input_ids[1:]:
        if token_id == eos_id:
            break
        output_words.append(id_to_token.get(token_id, '<UNK>'))
        
    return " ".join(output_words)

## 5. Try it Out!

In [None]:
test_sentences = [
    "I love to code.",
    "What is your name?",
    "This is a small transformer.",
    "Let's go to the park."
]

for sentence in test_sentences:
    translation = translate(sentence, loaded_params, vocab, d_model=D_MODEL)
    print(f"Input:  {sentence}")
    print(f"Output: {translation}\n")