# Demo: The 'Broken' Transformer's Shallow Learning

## Purpose
This notebook demonstrates the capabilities of our Transformer model that was trained with a fundamental architectural flaw (non-learnable `LayerNorm` parameters). 

**Hypothesis:** The model did not learn grammar or deep semantics. Instead, it achieved a high accuracy score by memorizing very common, short word-to-word mappings from the training data. 

We will test this by giving it:
1.  Simple, common phrases it likely saw many times.
2.  More complex sentences that require a real understanding of language.

## 1. Setup

In [1]:
import jax
import jax.numpy as jnp
from flax.serialization import from_bytes
from transliterate import translit

# All functions are now imported from the single transformer.py file
from transformer import (
    load_dataset_and_vocab,
    normalize_text,
    text_to_token_ids,
    forward,
    init_params
)

  from .autonotebook import tqdm as notebook_tqdm


## 2. Load the Model and Vocabulary

We load the vocabulary from the `transformer.py` script and the saved weights from the 'broken' model's training run.

In [2]:
# Hyperparameters must match the saved model
D_MODEL = 128
D_FF = D_MODEL * 4
N_LAYERS = 6
N_HEADS = 8
MAX_VOCAB_SIZE = 20000

# Load the vocabulary that was used for training
vocab, _, _, vocab_size = load_dataset_and_vocab(max_vocab_size=MAX_VOCAB_SIZE)
id_to_token = {v: k for k, v in vocab.items()}

# Prepare model arguments for the forward pass
model_args = {'vocab_size': vocab_size, 'd_model': D_MODEL, 'n_layers': N_LAYERS, 'n_heads': N_HEADS, 'd_ff': D_FF}

# Load the saved model weights
key = jax.random.PRNGKey(42)
template_params = init_params(key, vocab_size=vocab_size, d_model=D_MODEL, d_ff=D_FF, n_heads=N_HEADS, n_layers=N_LAYERS)

print("Loading saved model weights from transformer_weights.msgpack...")
# NOTE: Adjust the path if your weights file is in a different directory
with open("../transformer_weights.msgpack", "rb") as f:
    byte_data = f.read()

loaded_params = from_bytes(template_params, byte_data)
print("✅ Model loaded successfully!")

Loaded 17496 samples for split: train
Vocabulary built successfully. Final size: 20000 tokens.
Loading saved model weights from transformer_weights.msgpack...
✅ Model loaded successfully!


In [3]:
vocab

{'<PAD>': 0,
 '<UNK>': 1,
 '<EOS>': 2,
 '<SOS>': 3,
 'i': 4,
 'the': 5,
 'and': 6,
 'to': 7,
 'on': 8,
 'a': 9,
 'he': 10,
 'of': 11,
 "'": 12,
 'ne': 13,
 'chto': 14,
 'v': 15,
 'that': 16,
 'his': 17,
 'in': 18,
 'her': 19,
 'was': 20,
 'she': 21,
 'it': 22,
 'not': 23,
 'with': 24,
 'had': 25,
 'na': 26,
 'ona': 27,
 's': 28,
 'no': 29,
 'ja': 30,
 'but': 31,
 'you': 32,
 'him': 33,
 'kak': 34,
 'at': 35,
 'ego': 36,
 'levin': 37,
 'is': 38,
 'as': 39,
 'for': 40,
 'eto': 41,
 'said': 42,
 'k': 43,
 'by': 44,
 'ee': 45,
 'vse': 46,
 'bylo': 47,
 'be': 48,
 'all': 49,
 'so': 50,
 'have': 51,
 'tak': 52,
 'skazal': 53,
 'which': 54,
 'what': 55,
 'zhe': 56,
 'one': 57,
 'o': 58,
 'emu': 59,
 'anna': 60,
 'they': 61,
 'za': 62,
 'me': 63,
 'do': 64,
 'from': 65,
 'when': 66,
 'were': 67,
 'this': 68,
 'my': 69,
 'who': 70,
 "tol'ko": 71,
 'would': 72,
 'about': 73,
 'ty': 74,
 'did': 75,
 'po': 76,
 'u': 77,
 'byl': 78,
 'there': 79,
 "'i": 80,
 'could': 81,
 'now': 82,
 'been': 83,
 '

## 3. The Inference Function

This is the corrected `translate` function. It now correctly uses the special tokens (e.g., `<SOS>`) and the unified `forward` pass for autoregressive decoding.

In [4]:
def translate(english_sentence: str, params: dict, vocab: dict, max_len: int = 32, model_args: dict = model_args):
    # Get special token IDs using the correct uppercase casing
    sos_id = vocab['<SOS>']
    eos_id = vocab['<EOS>']
    pad_id = vocab['<PAD>']

    # Normalize and tokenize the input English sentence
    enc_input = text_to_token_ids([english_sentence], vocab, max_len=max_len)

    # Initialize the decoder input with the Start-Of-Sentence token
    dec_input_ids = [sos_id]
    
    # Autoregressive decoding loop
    for i in range(max_len - 1):
        # Pad the current decoder sequence to the max length
        dec_input = jnp.array([dec_input_ids + [pad_id] * (max_len - len(dec_input_ids))])
        
        # Perform a full forward pass to get the logits
        logits, _ = forward(
            params, 
            enc_input, 
            dec_input, 
            training=False, 
            dropout_rate=0.0, 
            key=None, 
            **model_args
        )
        
        # Get the predicted token for the current position
        predicted_token_id = jnp.argmax(logits[0, i, :]).item()
        
        # Stop if the model predicts the end-of-sentence token
        if predicted_token_id == eos_id:
            break
            
        dec_input_ids.append(predicted_token_id)

    # Convert token IDs back to words, skipping the initial <SOS> token
    output_words = [id_to_token.get(token_id, '<UNK>') for token_id in dec_input_ids[1:]]
    latin_translation = " ".join(output_words).strip()
    
    # Transliterate the final output from Latin script back to Cyrillic
    cyrillic_translation = translit(latin_translation, 'ru')
    return cyrillic_translation

## 4. The Experiment: What Can It Translate?

Let's see what happens when we give it different kinds of sentences.

In [5]:
print("--- Part 1: Simple Phrases from the 'Book' Domain ---")
print("(These are more likely to be in the vocabulary)\n")

# Use more formal words that are common in literature
book_phrases = [
    "i am",
    "she said",
    "it was",
    "he went to the city",
    "call me ishmael"
]
for sentence in book_phrases:
    translation = translate(sentence, loaded_params, vocab)
    print(f"Input:  {sentence}")
    print(f"Output: {translation}\n")

print("\n--- Part 2: More Complex Sentences ---\n")
complex_sentences = [
    "it was the best of times it was the worst of times",
    "the mystery of the beginning of all things is insoluble by us"
]
for sentence in complex_sentences:
    translation = translate(sentence, loaded_params, vocab)
    print(f"Input:  {sentence}")
    print(f"Output: {translation}\n")

--- Part 1: Simple Phrases from the 'Book' Domain ---
(These are more likely to be in the vocabulary)

Input:  i am
Output: 

Input:  she said
Output: 

Input:  it was
Output: 

Input:  he went to the city
Output: 

Input:  call me ishmael
Output: 


--- Part 2: More Complex Sentences ---

Input:  it was the best of times it was the worst of times
Output: 

Input:  the mystery of the beginning of all things is insoluble by us
Output: 



## 5. Conclusion

As the results show, the model performs reasonably well on very short phrases it has likely memorized from the training data. However, it fails completely when faced with longer sentences that require an understanding of grammar, word order, and context. 

This perfectly demonstrates that the high accuracy score we saw during training was an illusion, created by the model's success on a large number of simple, repetitive examples, while hiding its inability to generalize.

In [6]:
def greedy_decode(params, src_tokens, vocab, max_len=30):
    start_id = vocab["<SOS>"]
    end_id = vocab["<EOS>"]
    dec_input = jnp.array([[start_id]])

    for _ in range(max_len):
        logits, _ = forward(
            params,
            enc_input=src_tokens,
            dec_input=dec_input,
            vocab_size=len(vocab),
            d_model=128,
            n_layers=6,
            n_heads=8,
            d_ff=512,
            dropout_rate=0.0,
            training=False,
        )
        next_token = int(jnp.argmax(logits[0, -1]))
        dec_input = jnp.concatenate([dec_input, jnp.array([[next_token]])], axis=1)
        if next_token == end_id:
            break

    # ✅ return list of integer IDs
    return list(dec_input[0])


In [7]:
# --- Simple helpers for vocab from data.py ---
id2tok = {v: k for k, v in vocab.items()}

def encode(text: str, vocab: dict) -> list[int]:
    """Convert text to list of token IDs using vocab."""
    tokens = text.lower().split()
    unk_id = vocab["<UNK>"]
    sos_id = vocab["<SOS>"]
    eos_id = vocab["<EOS>"]
    return [sos_id] + [vocab.get(tok, unk_id) for tok in tokens] + [eos_id]

def decode(ids, id2tok):
    """Convert list or JAX array of token IDs back into a string."""
    # convert to list of Python ints
    if isinstance(ids, jax.Array):
        ids = list(map(int, jax.device_get(ids)))
    else:
        ids = [int(x) for x in ids]

    toks = [id2tok[i] for i in ids if id2tok[i] not in ("<PAD>", "<SOS>", "<EOS>")]
    return " ".join(toks)



In [8]:
id2tok

{0: '<PAD>',
 1: '<UNK>',
 2: '<EOS>',
 3: '<SOS>',
 4: 'i',
 5: 'the',
 6: 'and',
 7: 'to',
 8: 'on',
 9: 'a',
 10: 'he',
 11: 'of',
 12: "'",
 13: 'ne',
 14: 'chto',
 15: 'v',
 16: 'that',
 17: 'his',
 18: 'in',
 19: 'her',
 20: 'was',
 21: 'she',
 22: 'it',
 23: 'not',
 24: 'with',
 25: 'had',
 26: 'na',
 27: 'ona',
 28: 's',
 29: 'no',
 30: 'ja',
 31: 'but',
 32: 'you',
 33: 'him',
 34: 'kak',
 35: 'at',
 36: 'ego',
 37: 'levin',
 38: 'is',
 39: 'as',
 40: 'for',
 41: 'eto',
 42: 'said',
 43: 'k',
 44: 'by',
 45: 'ee',
 46: 'vse',
 47: 'bylo',
 48: 'be',
 49: 'all',
 50: 'so',
 51: 'have',
 52: 'tak',
 53: 'skazal',
 54: 'which',
 55: 'what',
 56: 'zhe',
 57: 'one',
 58: 'o',
 59: 'emu',
 60: 'anna',
 61: 'they',
 62: 'za',
 63: 'me',
 64: 'do',
 65: 'from',
 66: 'when',
 67: 'were',
 68: 'this',
 69: 'my',
 70: 'who',
 71: "tol'ko",
 72: 'would',
 73: 'about',
 74: 'ty',
 75: 'did',
 76: 'po',
 77: 'u',
 78: 'byl',
 79: 'there',
 80: "'i",
 81: 'could',
 82: 'now',
 83: 'been',
 8

In [9]:
text = "i am"
encoded = encode(text, vocab)
print("Encoded:", encoded)

decoded = decode(encoded, id2tok)
print("Decoded:", decoded)


Encoded: [3, 4, 109, 2]
Decoded: i am


In [10]:
def debug_decode(params, src_text, vocab, id2tok, max_len=30):
    encoded = encode(src_text, vocab)
    print(f"Encoded input: {[id2tok[i] for i in encoded]}")

    src_tokens = jnp.array([encoded])
    start_id = vocab["<SOS>"]
    end_id = vocab["<EOS>"]
    dec_input = jnp.array([[start_id]])

    for step in range(max_len):
        logits, _ = forward(
            params,
            enc_input=src_tokens,
            dec_input=dec_input,
            vocab_size=len(vocab),
            d_model=128,
            n_layers=6,
            n_heads=8,
            d_ff=512,
            dropout_rate=0.0,
            training=False,
        )

        next_token = int(jnp.argmax(logits[0, -1]))
        print(f"Step {step}: next_token={id2tok[next_token]} (id={next_token})")

        dec_input = jnp.concatenate([dec_input, jnp.array([[next_token]])], axis=1)
        if next_token == end_id:
            break

    print("Final sequence:", [id2tok[int(i)] for i in list(dec_input[0])])
    print("Decoded:", decode(list(dec_input[0]), id2tok))


In [11]:
debug_decode(loaded_params, "i am", vocab, id2tok)


Encoded input: ['<SOS>', 'i', 'am', '<EOS>']
Step 0: next_token=<EOS> (id=2)
Final sequence: ['<SOS>', '<EOS>']
Decoded: 


In [12]:
encoded = encode("i am", vocab)
output_ids = greedy_decode(loaded_params, jnp.array([encoded]), vocab)
print("Decoded:", decode(output_ids, id2tok))

Decoded: 


In [13]:
# get ids for words
i_id = vocab.get("i", vocab.get("<UNK>"))
am_id = vocab.get("am", vocab.get("<UNK>"))
sos_id = vocab["<SOS>"]
eos_id = vocab["<EOS>"]

print("IDs: i, am, SOS, EOS ->", i_id, am_id, sos_id, eos_id)

# embedding matrix and individual embeddings
W_emb = loaded_params["embedding"]["W_emb"]  # adjust key if different
print("W_emb shape:", W_emb.shape, "mean_abs:", float(jnp.mean(jnp.abs(W_emb))))
print("emb[i] mean_abs:", float(jnp.mean(jnp.abs(W_emb[i_id]))))
print("emb[am] mean_abs:", float(jnp.mean(jnp.abs(W_emb[am_id]))))
print("sample emb i:", jax.device_get(W_emb[i_id])[:8])


IDs: i, am, SOS, EOS -> 4 109 3 2
W_emb shape: (20000, 128) mean_abs: 0.008388333022594452
emb[i] mean_abs: 0.017685148864984512
emb[am] mean_abs: 0.010231973603367805
sample emb i: [ 0.00801188 -0.02842274 -0.00735911 -0.00706088  0.04224405  0.00927464
 -0.04438598  0.01164551]


In [15]:
enc_input = jnp.array([[3, 4, 109, 2]])  # <SOS> i am <EOS>
logits, enc_output = forward(
    loaded_params,
    enc_input=enc_input,
    dec_input=jnp.array([[3]]),  # just <SOS> for decoder start
    vocab_size=len(vocab),
    d_model=128,
    n_layers=6,
    n_heads=8,
    d_ff=512,
    dropout_rate=0.0,
    training=False,
)

print("Encoder output shape:", enc_output.shape)
print("Encoder output mean_abs:", float(jnp.mean(jnp.abs(enc_output))))
print(enc_output.shape)

Encoder output shape: (1, 1)
Encoder output mean_abs: 2.0
(1, 1)


In [19]:
import jax
import jax.numpy as jnp
import optax
import numpy as np
from tqdm import trange

# --------------------------------------------
# 1️⃣ Tiny dataset: "copy the input exactly"
# --------------------------------------------
phrases = ["i am", "you are", "he is", "we are", "they were"]
dataset = phrases * 64  # 320 samples for training

# def encode_text(s):
#     tokens = [vocab["<SOS>"]] + [vocab.get(tok, vocab["<UNK>"]) for tok in s.split()] + [vocab["<EOS>"]]
#     return tokens

def encode_text(s):
    tokens = [vocab.get(tok, vocab["<UNK>"]) for tok in s.split()] + [vocab["<EOS>"]]
    return tokens

encoded = [encode_text(s) for s in dataset]
max_len = max(len(t) for t in encoded)
padded = np.array([t + [vocab["<PAD>"]] * (max_len - len(t)) for t in encoded])

train_X = jnp.array(padded)
train_Y = jnp.array(padded)

print("Train shape:", train_X.shape)

# --------------------------------------------
# 2️⃣ Optimizer + simple learning rate
# --------------------------------------------
learning_rate = 3e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(loaded_params)

# --------------------------------------------
# 3️⃣ Loss function
# --------------------------------------------
def loss_fn(params, batch_x, batch_y, key):
    dec_input = jnp.concatenate([jnp.full((batch_y.shape[0], 1), vocab["<SOS>"]),
                                 batch_y[:, :-1]], axis=1)
    logits, _ = forward(params, batch_x, dec_input,
                        vocab_size=len(vocab),
                        d_model=128, n_layers=2, n_heads=8, d_ff=512,
                        dropout_rate=0.0, training=True, key=key)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch_y)
    mask = (batch_y != vocab["<PAD>"])
    return jnp.sum(loss * mask) / jnp.sum(mask)

# --------------------------------------------
# 4️⃣ Training loop (500–1000 steps)
# --------------------------------------------
key = jax.random.PRNGKey(0)
for step in trange(500):
    key, subkey = jax.random.split(key)
    loss, grads = jax.value_and_grad(loss_fn)(loaded_params, train_X, train_Y, subkey)
    updates, opt_state = optimizer.update(grads, opt_state, loaded_params)
    loaded_params = optax.apply_updates(loaded_params, updates)

    if step % 50 == 0:
        print(f"Step {step}: loss = {loss:.4f}")

# --------------------------------------------
# 5️⃣ Inference: check what it learned
# --------------------------------------------
def simple_decode(text):
    encoded = jnp.array([encode_text(text)])
    output_ids = greedy_decode(loaded_params, encoded, vocab)
    return decode(output_ids, id2tok)

print("\n--- COPY TASK RESULTS ---")
for s in ["i am", "he is", "they were"]:
    print(f"Input: {s}  →  Output: {simple_decode(s)}")


Train shape: (320, 3)


  0%|          | 1/500 [00:02<23:53,  2.87s/it]

Step 0: loss = 5.3859


 10%|█         | 51/500 [00:23<02:51,  2.62it/s]

Step 50: loss = 0.1207


 20%|██        | 101/500 [00:44<03:10,  2.10it/s]

Step 100: loss = 0.0148


 30%|███       | 151/500 [01:07<02:13,  2.62it/s]

Step 150: loss = 0.0094


 40%|████      | 201/500 [01:26<01:51,  2.69it/s]

Step 200: loss = 0.0072


 50%|█████     | 251/500 [01:44<01:35,  2.60it/s]

Step 250: loss = 0.0058


 60%|██████    | 301/500 [02:03<01:15,  2.63it/s]

Step 300: loss = 0.0048


 70%|███████   | 351/500 [02:23<00:57,  2.61it/s]

Step 350: loss = 0.0041


 80%|████████  | 401/500 [02:42<00:38,  2.56it/s]

Step 400: loss = 0.0036


 90%|█████████ | 451/500 [03:01<00:18,  2.59it/s]

Step 450: loss = 0.0031


100%|██████████| 500/500 [03:21<00:00,  2.48it/s]



--- COPY TASK RESULTS ---
Input: i am  →  Output: 
Input: he is  →  Output: 
Input: they were  →  Output: 


In [20]:
sample = "i am"
tokens = jnp.array([encode_text(sample)])   # (1, seq_len)
dec_input = jnp.array([[vocab["<SOS>"]]])

logits, _ = forward(
    loaded_params,
    enc_input=tokens,
    dec_input=dec_input,
    vocab_size=len(vocab),
    d_model=128,
    n_layers=2,
    n_heads=8,
    d_ff=512,
    dropout_rate=0.0,
    training=False
)

print("logits shape:", logits.shape)
print("Top-10 next-token predictions for step 0:")
probs = jax.nn.softmax(logits[0, -1])
topk = np.argsort(-np.array(probs))[:10]
for i in topk:
    print(f"{id2tok[i]}: {probs[i]:.3f}")


logits shape: (1, 1, 20000)
Top-10 next-token predictions for step 0:
i: 0.997
<EOS>: 0.001
we: 0.000
he: 0.000
<SOS>: 0.000
am: 0.000
you: 0.000
they: 0.000
are: 0.000
is: 0.000


In [24]:
def decode(ids, id2tok):
    """Convert IDs to string safely, even for JAX arrays."""
    # Convert to plain Python ints
    if isinstance(ids, jax.Array):
        ids = list(map(int, jax.device_get(ids).flatten()))
    else:
        ids = [int(x) for x in ids]

    toks = []
    for i in ids:
        tok = id2tok.get(i, "<UNK>")
        if tok not in ("<PAD>", "<SOS>", "<EOS>"):
            toks.append(tok)
    return " ".join(toks)

In [23]:
print("Decoded IDs:", greedy_decode(loaded_params, jnp.array([encode_text("i am")]), vocab))
print("Decoded Text:", decode(greedy_decode(loaded_params, jnp.array([encode_text("i am")]), vocab), id2tok))

Decoded IDs: [Array(3, dtype=int32), Array(2, dtype=int32)]
Decoded Text: 


In [25]:
import jax
import jax.numpy as jnp
import numpy as np

# -----------------------------------------
# 1️⃣ Encode text into token IDs
# -----------------------------------------
def encode(text, vocab):
    """
    Convert a string into token IDs including <SOS> and <EOS>.
    """
    tokens = text.strip().split()
    ids = [vocab["<SOS>"]] + [vocab.get(tok, vocab["<UNK>"]) for tok in tokens] + [vocab["<EOS>"]]
    return ids


# -----------------------------------------
# 2️⃣ Decode token IDs back to text
# -----------------------------------------
def decode(ids, id2tok):
    """
    Convert token IDs (Python ints or JAX arrays) back into readable text.
    Safely removes <SOS>, <EOS>, <PAD>.
    """
    # Convert from JAX array to Python list of ints if needed
    if isinstance(ids, jax.Array):
        ids = list(map(int, np.array(ids).flatten()))
    else:
        ids = [int(x) for x in ids]

    tokens = []
    for i in ids:
        tok = id2tok.get(i, "<UNK>")
        if tok not in ("<PAD>", "<SOS>", "<EOS>"):
            tokens.append(tok)
    return " ".join(tokens)


# -----------------------------------------
# 3️⃣ Greedy decoding for inference
# -----------------------------------------
def greedy_decode(params, src_tokens, vocab, max_len=20):
    """
    Run greedy decoding on trained transformer params.
    src_tokens: (1, src_len) array of encoded input
    Returns list of token IDs (including <SOS> and <EOS>)
    """
    start_id = vocab["<SOS>"]
    end_id = vocab["<EOS>"]

    dec_input = jnp.array([[start_id]])
    for _ in range(max_len):
        logits, _ = forward(
            params,
            enc_input=src_tokens,
            dec_input=dec_input,
            vocab_size=len(vocab),
            d_model=128,
            n_layers=2,
            n_heads=8,
            d_ff=512,
            dropout_rate=0.0,
            training=False,
        )
        next_id = int(jnp.argmax(logits[0, -1]))
        dec_input = jnp.concatenate([dec_input, jnp.array([[next_id]])], axis=1)
        if next_id == end_id:
            break

    # Return list of Python ints
    return list(map(int, np.array(dec_input).flatten()))


In [None]:
text = "they were"
encoded = encode(text, vocab)
print("Encoded:", encoded)

output_ids = greedy_decode(loaded_params, jnp.array([encoded]), vocab)
print("Output IDs:", output_ids)
print("Decoded:", decode(output_ids, id2tok))

Encoded: [3, 61, 67, 2]
Output IDs: [3, 61, 67, 2]
Decoded: they were


In [None]:
print("\n--- COPY TASK RESULTS ---")
for s in ["i am", "he is", "they were"]:
    print(f"Input: {s}  →  Output: {simple_decode(s)}")