In [6]:
# ======================================================
# 🔹 1. Setup & Imports
# ======================================================

import sys, os
# sys.path.append("../src")  # adjust if notebook is elsewhere

import jax
import jax.numpy as jnp
from flax.serialization import from_bytes
from transformer import (
    text_to_token_ids,
    positional_embeddings,
    normalize_text,
    forward,
    load_dataset_and_vocab,
    init_params,
)

# Hyperparameters
MAX_STEPS = 500
D_MODEL = 128
D_FF = D_MODEL * 4
DROPOUT_RATE = 0.1
N_LAYERS = 2
N_HEADS = 8
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
MAX_LEN = 32

# ======================================================
# 🔹 2. Load model weights & vocab
# ======================================================

# create empty (same-shaped) param structure
dummy_key = jax.random.PRNGKey(0)
param_template = init_params(dummy_key, vocab_size=VOCAB_SIZE, d_model=D_MODEL, d_ff=D_FF, n_heads=N_HEADS, n_layers=N_LAYERS)

# Load vocab the same way it was built during training
vocab, en_sents, ru_sents, vocab_size = load_dataset_and_vocab(max_vocab_size=20000)

# Load saved model parameters
with open("../transformer_weights.msgpack", "rb") as f:
    byte_data = f.read()
params = from_bytes(param_template, byte_data)

print(f"✅ Model and vocab loaded. Vocab size: {vocab_size}")

# Build id → token lookup for decoding
id2tok = {v: k for k, v in vocab.items()}

# ======================================================
# 🔹 3. Define Greedy Decoding
# ======================================================

def greedy_decode(params, enc_input, vocab, model_args, max_len=32):
    sos_id, eos_id = vocab['<SOS>'], vocab['<EOS>']
    dec_input = jnp.array([[sos_id]], dtype=jnp.int32)

    for _ in range(max_len - 1):
        logits, preds = forward(
            params,
            enc_input,
            dec_input,
            training=False,
            dropout_rate=0.0,
            key=None,
            vocab=vocab,
            **model_args
        )

        next_id = int(preds[0, -1])
        dec_input = jnp.concatenate([dec_input, jnp.array([[next_id]])], axis=1)
        
        # Stop if <EOS> is reached
        if next_id == eos_id:
            break

    return dec_input[0, 1:]

def decode(ids, id2tok, show_ids=False):
    ids = list(map(int, jnp.array(ids).tolist()))
    words = [id2tok[i] for i in ids if id2tok[i] not in ("<PAD>", "<SOS>", "<EOS>")]
    decoded = " ".join(words)
    if show_ids:
        return f"{decoded}  ({ids})"
    return decoded

def greedy_decode_debug(params, enc_input, vocab, vocab_inv, model_args, max_len=32):
    sos_id, eos_id = vocab['<SOS>'], vocab['<EOS>']
    dec_input = jnp.array([[sos_id]], dtype=jnp.int32)

    for t in range(max_len - 1):
        preds = inspect_decoder_step(params, enc_input, dec_input, vocab, model_args, f"Decode step {t+1}")
        next_id = int(preds[0, -1])
        print(f"→ predicted: {vocab_inv[next_id]} ({next_id})")

        dec_input = jnp.concatenate([dec_input, jnp.array([[next_id]])], axis=1)
        if next_id == eos_id:
            break

    return dec_input[0, 1:]

def inspect_decoder_step(params, enc_input, dec_input, vocab, model_args, step_desc="Step"):
    logits, preds = forward(
        params,
        enc_input,
        dec_input,
        training=False,
        dropout_rate=0.0,
        key=None,
        vocab=vocab,
        **model_args
    )

    step_logits = logits[0, -1]  # last time-step
    probs = jax.nn.softmax(step_logits)
    top_k = jnp.argsort(probs)[-10:][::-1]  # top 10 predictions
    top_probs = probs[top_k]

    print(f"\n🔍 {step_desc}")
    for i, (tid, p) in enumerate(zip(top_k, top_probs)):
        print(f"{i+1:2d}. {vocab_inv[int(tid)] if int(tid) in vocab_inv else tid}: {float(p):.4f}")

    return preds


# ======================================================
# 🔹 4. Run Inference
# ======================================================

model_args = {
    "vocab_size": vocab_size,
    "d_model": 128,
    "n_layers": 2,
    "n_heads": 8,
    "d_ff": 512,
}

# Example sentences
test_sentences = [
    "i am happy",
    "she went home",
    "he loves the city",
]

# for sent in test_sentences:
#     enc = text_to_token_ids([sent], vocab, max_len=32)
#     out_ids = greedy_decode(params, enc, vocab, model_args)
#     print(f"{sent} → {decode(out_ids, id2tok, show_ids=False)}")

sent = "i am happy"
enc = text_to_token_ids([sent], vocab, max_len=32)
vocab_inv = {v: k for k, v in vocab.items()}
out_ids = greedy_decode_debug(params, enc, vocab, vocab_inv, model_args)
print(f"{sent} → {decode(out_ids, id2tok)}")


Loaded 17496 samples for split: train
Vocabulary built successfully. Final size: 20000 tokens.
✅ Model and vocab loaded. Vocab size: 20000

🔍 Decode step 1
 1. v: 0.0087
 2. byvshuju: 0.0074
 3. zhizni: 0.0073
 4. on: 0.0073
 5. luchshe: 0.0068
 6. darmshtadte: 0.0059
 7. zheny: 0.0056
 8. skryt': 0.0056
 9. utra: 0.0052
10. prosnulsja: 0.0049
→ predicted: v (15)

🔍 Decode step 2
 1. zheny: 0.0083
 2. kakaja: 0.0075
 3. skryt': 0.0070
 4. byvshuju: 0.0065
 5. luchshe: 0.0063
 6. krome: 0.0059
 7. on: 0.0057
 8. v: 0.0054
 9. darmshtadte: 0.0054
10. zhizni: 0.0052
→ predicted: zheny (845)

🔍 Decode step 3
 1. zheny: 0.0094
 2. kakaja: 0.0075
 3. skryt': 0.0068
 4. luchshe: 0.0058
 5. krome: 0.0057
 6. on: 0.0056
 7. v: 0.0052
 8. ne: 0.0052
 9. ot: 0.0051
10. byvshuju: 0.0047
→ predicted: zheny (845)

🔍 Decode step 4
 1. zheny: 0.0089
 2. skryt': 0.0072
 3. on: 0.0063
 4. luchshe: 0.0061
 5. ne: 0.0060
 6. kakaja: 0.0056
 7. krome: 0.0054
 8. zhalkaja: 0.0051
 9. ot: 0.0048
10. v: 0.004