In [18]:
# ======================================================
# 🔹 1. Setup & Imports
# ======================================================

import sys, os
# sys.path.append("../src")  # adjust if notebook is elsewhere

import jax
import jax.numpy as jnp
from flax.serialization import from_bytes
from transformer import (
    text_to_token_ids,
    positional_embeddings,
    normalize_text,
    forward,
    load_dataset_and_vocab,
    init_params,
)

# Hyperparameters
MAX_STEPS = 500
D_MODEL = 128
D_FF = D_MODEL * 4
DROPOUT_RATE = 0.1
N_LAYERS = 2
N_HEADS = 8
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
MAX_LEN = 32
VOCAB_SIZE = 20000

# ======================================================
# 🔹 2. Load model weights & vocab
# ======================================================

# create empty (same-shaped) param structure
dummy_key = jax.random.PRNGKey(0)
param_template = init_params(dummy_key, vocab_size=VOCAB_SIZE, d_model=D_MODEL, d_ff=D_FF, n_heads=N_HEADS, n_layers=N_LAYERS)

# Load vocab the same way it was built during training
vocab, en_sents, ru_sents, vocab_size = load_dataset_and_vocab(max_vocab_size=20000)

# Load saved model parameters
with open("../transformer_weights.msgpack", "rb") as f:
    byte_data = f.read()
params = from_bytes(param_template, byte_data)

print(f"✅ Model and vocab loaded. Vocab size: {vocab_size}")

# Build id → token lookup for decoding
id2tok = {v: k for k, v in vocab.items()}

# ======================================================
# 🔹 3. Define Greedy Decoding
# ======================================================

def greedy_decode(params, enc_input, vocab, model_args, max_len=32):
    sos_id, eos_id = vocab['<SOS>'], vocab['<EOS>']
    dec_input = jnp.array([[sos_id]], dtype=jnp.int32)

    for _ in range(max_len - 1):
        logits, preds = forward(
            params,
            enc_input,
            dec_input,
            training=False,
            dropout_rate=0.0,
            key=None,
            vocab=vocab,
            **model_args
        )

        next_id = int(preds[0, -1])
        dec_input = jnp.concatenate([dec_input, jnp.array([[next_id]])], axis=1)
        
        # Stop if <EOS> is reached
        if next_id == eos_id:
            break

    return dec_input[0, 1:]

def decode(ids, id2tok, show_ids=False):
    ids = list(map(int, jnp.array(ids).tolist()))
    words = [id2tok[i] for i in ids if id2tok[i] not in ("<PAD>", "<SOS>", "<EOS>")]
    decoded = " ".join(words)
    if show_ids:
        return f"{decoded}  ({ids})"
    return decoded

def greedy_decode_debug(params, enc_input, vocab, vocab_inv, model_args, max_len=32):
    sos_id, eos_id = vocab['<SOS>'], vocab['<EOS>']
    dec_input = jnp.array([[sos_id]], dtype=jnp.int32)

    for t in range(max_len - 1):
        preds = inspect_decoder_step(params, enc_input, dec_input, vocab, model_args, f"Decode step {t+1}")
        next_id = int(preds[0, -1])
        print(f"→ predicted: {vocab_inv[next_id]} ({next_id})")

        dec_input = jnp.concatenate([dec_input, jnp.array([[next_id]])], axis=1)
        if next_id == eos_id:
            break

    return dec_input[0, 1:]

def inspect_decoder_step(params, enc_input, dec_input, vocab, model_args, step_desc="Step"):
    logits, preds = forward(
        params,
        enc_input,
        dec_input,
        training=False,
        dropout_rate=0.0,
        key=None,
        vocab=vocab,
        **model_args
    )

    step_logits = logits[0, -1]  # last time-step
    probs = jax.nn.softmax(step_logits)
    top_k = jnp.argsort(probs)[-10:][::-1]  # top 10 predictions
    top_probs = probs[top_k]

    print(f"\n🔍 {step_desc}")
    for i, (tid, p) in enumerate(zip(top_k, top_probs)):
        print(f"{i+1:2d}. {vocab_inv[int(tid)] if int(tid) in vocab_inv else tid}: {float(p):.4f}")

    return preds


# ======================================================
# 🔹 4. Run Inference
# ======================================================

model_args = {
    "vocab_size": vocab_size,
    "d_model": 128,
    "n_layers": 2,
    "n_heads": 8,
    "d_ff": 512,
}

# Example sentences
test_sentences = [
    "i am happy",
    "she went home",
    "he loves the city",
]

# for sent in test_sentences:
#     enc = text_to_token_ids([sent], vocab, max_len=32)
#     out_ids = greedy_decode(params, enc, vocab, model_args)
#     print(f"{sent} → {decode(out_ids, id2tok, show_ids=False)}")

sent = "i am happy"
enc = text_to_token_ids([sent], vocab, max_len=32)
vocab_inv = {v: k for k, v in vocab.items()}
out_ids = greedy_decode_debug(params, enc, vocab, vocab_inv, model_args)
print(f"{sent} → {decode(out_ids, id2tok)}")


Loaded 14870 samples for split: train[:1%]
Vocabulary built successfully. Final size: 19358 tokens.
✅ Model and vocab loaded. Vocab size: 19358

🔍 Decode step 1
 1. you: 0.0356
 2. have: 0.0069
 3. by: 0.0029
 4. lyrics: 0.0025
 5. icecat: 0.0024
 6. can: 0.0024
 7. distribution: 0.0021
 8. above: 0.0016
 9. vista: 0.0014
10. etc: 0.0014
→ predicted: you (7)

🔍 Decode step 2
 1. you: 0.0154
 2. have: 0.0092
 3. can: 0.0044
 4. by: 0.0023
 5. lyrics: 0.0020
 6. download: 0.0019
 7. icecat: 0.0017
 8. vista: 0.0017
 9. distribution: 0.0013
10. etc: 0.0011
→ predicted: you (7)

🔍 Decode step 3
 1. you: 0.0112
 2. have: 0.0085
 3. can: 0.0046
 4. download: 0.0019
 5. vista: 0.0017
 6. lyrics: 0.0017
 7. by: 0.0013
 8. icecat: 0.0011
 9. distribution: 0.0011
10. etc: 0.0010
→ predicted: you (7)

🔍 Decode step 4
 1. you: 0.0110
 2. have: 0.0076
 3. can: 0.0056
 4. download: 0.0017
 5. lyrics: 0.0015
 6. vista: 0.0015
 7. distribution: 0.0012
 8. etc: 0.0011
 9. and: 0.0011
10. as: 0.0009
→ p

In [2]:
def greedy_decode(params, enc_input, vocab, model_args, max_len=32):
    sos_id = vocab["<SOS>"]
    eos_id = vocab["<EOS>"]

    dec_input = jnp.array([[sos_id]], dtype=jnp.int32)

    for _ in range(max_len - 1):
        logits, preds = forward(
            params,
            enc_input,
            dec_input,
            training=False,
            dropout_rate=0.0,
            key=None,
            vocab=vocab,
            **model_args
        )

        next_token = int(preds[0, -1])
        dec_input = jnp.concatenate([dec_input, jnp.array([[next_token]])], axis=1)

        if next_token == eos_id:
            break

    return dec_input[0]


In [3]:
import jax

def sample_decode(params, enc_input, vocab, model_args, id2tok, max_len=32, top_k=5, temperature=0.7):
    """
    Performs autoregressive decoding with top-k temperature sampling.
    """
    sos_id, eos_id = vocab["<SOS>"], vocab["<EOS>"]
    dec_input = jnp.full((1, 1), sos_id, dtype=jnp.int32)
    key = jax.random.PRNGKey(42)
    decoded_ids = []

    for step in range(max_len):
        # Run one forward pass for current decoder input
        logits, _ = forward(params, enc_input, dec_input,
                            training=False, dropout_rate=0.0,
                            key=None, vocab=vocab, **model_args)

        # Take logits for the last timestep
        next_logits = logits[0, -1] / temperature

        # Select top-k
        topk_vals, topk_ids = jax.lax.top_k(next_logits, k=top_k)

        # Softmax over top-k
        probs = jax.nn.softmax(topk_vals)

        # Sample one token
        key, subkey = jax.random.split(key)
        sampled_idx = int(jax.random.choice(subkey, topk_ids, p=probs))

        # Append to output
        decoded_ids.append(sampled_idx)

        # Stop on EOS
        if sampled_idx == eos_id and step > 5:
            break

        # Add to decoder input for next iteration
        dec_input = jnp.concatenate([dec_input, jnp.array([[sampled_idx]])], axis=1)

    # Convert to readable text
    decoded_words = []
    for tid in decoded_ids:
        word = id2tok.get(int(tid), "<UNK>")
        if word == "<EOS>":
            break
        if word not in ("<PAD>", "<SOS>"):
            decoded_words.append(word)

    return " ".join(decoded_words)


In [4]:
test_sentences = [
    "we tried to make lyrics as correct as possible",
    "click here to download",
    "english is a west germanic language",
]

test_tokens = text_to_token_ids(test_sentences, vocab, max_len=MAX_LEN)

for sent, toks in zip(test_sentences, test_tokens):
    decoded = sample_decode(params, toks[None, :], vocab, model_args, id2tok,
                            max_len=MAX_LEN, top_k=5, temperature=0.8)
    print(f"💬 {sent:<35} →  {decoded}")


💬 we tried to make lyrics as correct as possible →  
💬 click here to download              →  to
💬 english is a west germanic language →  to permission to


In [5]:
print(vocab["<SOS>"], vocab["<EOS>"])

3 2


In [6]:
enc_input = toks[None, :]
dec_input = jnp.array([[vocab["<SOS>"]]])
logits, preds = forward(params, enc_input, dec_input,
                        training=False, dropout_rate=0.0,
                        key=None, vocab=vocab, **model_args)
print("Top 5 next tokens:", [id2tok[int(i)] for i in jnp.argsort(logits[0, -1])[-5:][::-1]])


Top 5 next tokens: ['to', '<EOS>', 'pt', 'permission', 'link']


In [7]:
# Single test example
sent = "click here to download"
enc = text_to_token_ids([sent], vocab, max_len=MAX_LEN)
dec_input = jnp.full((1, 1), vocab["<SOS>"], dtype=jnp.int32)

# Run with debug=True
logits, preds = forward(
    params,
    enc,
    dec_input,
    training=False,
    dropout_rate=0.0,
    key=None,
    vocab=vocab,
    debug=True,
    **model_args,
)


Decoder layer cross-attn mean=0.031250 max=0.312364
Decoder layer cross-attn mean=0.031250 max=0.341948
[DEBUG] Top-5 next tokens: [('to', 0.0014790651621297002), ('<EOS>', 0.0008908223244361579), ('link', 0.0007451464189216495), ('are', 0.0006858575507067144), ('pt', 0.000624298641923815)]


Let’s confirm that attention is input-dependent

In [11]:
def contrast_cross_attention_test(params, vocab, model_args):
    """Compare decoder outputs for two different encoder inputs."""
    sentences = [
        "click here to download",
        "english is a west germanic language",
    ]
    dec_input = jnp.full((1, 1), vocab["<SOS>"], dtype=jnp.int32)

    logits_list = []
    for sent in sentences:
        enc = text_to_token_ids([sent], vocab, max_len=MAX_LEN)
        logits, _ = forward(
            params,
            enc,
            dec_input,
            training=False,
            dropout_rate=0.0,
            key=None,
            vocab=vocab,
            debug=True,  # show per-layer attention
            **model_args,
        )
        logits_list.append(logits[0, -1])

    # Compare differences
    diff = jnp.mean(jnp.abs(logits_list[0] - logits_list[1]))
    print("\n⚖️  Mean abs diff between encoder1 vs encoder2 logits:", float(diff))

contrast_cross_attention_test(params, vocab, model_args)


Decoder layer cross-attn mean=0.031250 max=0.312364
Decoder layer cross-attn mean=0.031250 max=0.341948
[DEBUG] Top-5 next tokens: [('to', 0.0014790651621297002), ('<EOS>', 0.0008908223244361579), ('link', 0.0007451464189216495), ('are', 0.0006858575507067144), ('pt', 0.000624298641923815)]
Decoder layer cross-attn mean=0.031250 max=0.237591
Decoder layer cross-attn mean=0.031250 max=0.229700
[DEBUG] Top-5 next tokens: [('to', 0.0013154017506167293), ('<EOS>', 0.001099822111427784), ('pt', 0.0008092112257145345), ('permission', 0.0007868147804401815), ('link', 0.000756389694288373)]

⚖️  Mean abs diff between encoder1 vs encoder2 logits: 0.06868710368871689


In [21]:
def greedy_decode(params, enc_input, vocab, id2tok, model_args, max_len=32):
    sos_id = vocab["<SOS>"]
    eos_id = vocab["<EOS>"]
    
    dec_input = jnp.full((1, 1), sos_id, dtype=jnp.int32)
    output_tokens = []

    for _ in range(max_len):
        logits, preds = forward(
            params, enc_input, dec_input,
            training=False, dropout_rate=0.0, key=None, vocab=vocab, **model_args
        )

        next_token = int(preds[0, -1])
        word = id2tok.get(next_token, "<UNK>")
        if word == "<EOS>":
            break
        output_tokens.append(word)

        dec_input = jnp.concatenate(
            [dec_input, jnp.array([[next_token]])], axis=1
        )

    return output_tokens


In [19]:
params

{'embedding': {'W_emb': array([[-5.2683325e-03,  1.3703517e-02, -1.1342160e-03, ...,
          -6.6670706e-03, -7.9261297e-03, -1.2728910e-02],
         [-1.6557986e-02,  9.3093811e-05,  7.6027070e-03, ...,
          -1.8042251e-02,  1.8943299e-02, -3.9828317e-03],
         [-8.0038067e-03, -3.7590373e-02, -7.0448980e-02, ...,
           2.5867159e-02, -1.9109381e-02,  7.3585503e-02],
         ...,
         [-1.6328590e-02, -5.8025643e-03,  4.2159412e-05, ...,
          -4.2169085e-03, -8.0032398e-05,  1.4439098e-02],
         [-1.0075579e-02, -5.0954935e-03,  6.0562673e-03, ...,
           2.4970188e-03, -4.2785779e-03,  5.0235838e-03],
         [-7.2666607e-03,  4.2803269e-03,  2.6518304e-04, ...,
          -4.0223850e-03,  1.5031750e-02,  1.6771419e-02]],
        shape=(19358, 128), dtype=float32)},
 'encoder': {'layers': [{'self_attention': {'W_q': array([[-0.07309957, -0.06353994, -0.00316792, ..., -0.05789318,
              0.06156326, -0.05878427],
            [-0.04064056, -0.1

In [22]:
test_sentences = [
    "we tried to make lyrics as correct as possible",
    "click here to download",
    "english is a west germanic language",
    "he went to the city",
    "they are coming home"
]

for sent in test_sentences:
    enc = text_to_token_ids([sent], vocab, max_len=MAX_LEN)
    out = greedy_decode(params, enc, vocab, id2tok, model_args)
    print(f"💬 {sent:<40} →  {' '.join(out)}")


💬 we tried to make lyrics as correct as possible →  you you have you you you you you you you you you you you you you you you you you you you you you you you you you you you you you
💬 click here to download                   →  you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you
💬 english is a west germanic language      →  you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you
💬 he went to the city                      →  you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you
💬 they are coming home                     →  you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you you
