# Demo: The 'Broken' Transformer's Shallow Learning

## Purpose
This notebook demonstrates the capabilities of our Transformer model that was trained with a fundamental architectural flaw (non-learnable `LayerNorm` parameters). 

**Hypothesis:** The model did not learn grammar or deep semantics. Instead, it achieved a high accuracy score by memorizing very common, short word-to-word mappings from the training data. 

We will test this by giving it:
1.  Simple, common phrases it likely saw many times.
2.  More complex sentences that require a real understanding of language.

## 1. Setup

In [None]:
import jax
import jax.numpy as jnp
from flax.serialization import from_bytes
from transliterate import translit

# Ensure you are using your restored 'broken' versions of these files
from data import load_dataset_and_vocab, normalize_text
from transformer import (
    text_to_token_ids,
    token_embeddings,
    positional_embeddings,
    transformer_encoder,
    forward,
    init_params
)

## 2. Load the Model and Vocabulary

We load the vocabulary from the `data.py` script and the saved weights from the 'broken' model's training run.

In [None]:
# Hyperparameters must match the saved model
D_MODEL = 128
D_FF = D_MODEL * 4
N_LAYERS = 6
N_HEADS = 8
MAX_VOCAB_SIZE = 20000

# Load the vocabulary that was used to train the broken model
vocab, _, _, vocab_size = load_dataset_and_vocab(max_vocab_size=MAX_VOCAB_SIZE)
id_to_token = {v: k for k, v in vocab.items()}

# Load the saved model weights
key = jax.random.PRNGKey(42)
template_params = init_params(key, vocab_size=vocab_size, d_model=D_MODEL, d_ff=D_FF, n_heads=N_HEADS, n_layers=N_LAYERS)

print("Loading saved model weights from transformer_weights.msgpack...")
with open("transformer_weights.msgpack", "rb") as f:
    byte_data = f.read()

loaded_params = from_bytes(template_params, byte_data)
print("✅ 'Broken' model loaded successfully!")

## 3. The Inference Function

This is the same `translate` function we developed, which is compatible with the quirks of this specific model (text normalization, special token casing, and transliteration).

In [None]:
def translate(english_sentence: str, params: dict, vocab: dict, max_output_len: int = 32, d_model: int = 128):
    # Get special tokens with the exact casing from data.py
    sos_id = vocab['<SOS>']
    pad_id = vocab['<pad>']
    eos_id = vocab.get('<eos>', -1)

    # Normalize and tokenize input
    normalized_sentence = normalize_text(english_sentence)
    enc_input = text_to_token_ids([normalized_sentence], vocab, max_len=32)

    # Encoder pass
    inf_key = jax.random.PRNGKey(0)
    keys = jax.random.split(inf_key, 12)
    enc_emb = token_embeddings(enc_input, params['embedding']['W_emb'], d_model)
    enc_emb += positional_embeddings(max_len=32, d_model=d_model)
    enc_output = transformer_encoder(enc_emb, params['encoder'], keys[:6], d_model=d_model, training=False)

    # Autoregressive decoding
    dec_input_ids = [sos_id]
    for i in range(max_output_len):
        dec_input = jnp.array([dec_input_ids + [pad_id] * (32 - len(dec_input_ids))])
        logits, _ = forward(params, enc_input, dec_input, vocab_size=len(vocab), d_model=d_model, training=False, key=inf_key)
        predicted_token_id = jnp.argmax(logits[0, i, :]).item()
        if predicted_token_id == pad_id or predicted_token_id == eos_id:
            break
        dec_input_ids.append(predicted_token_id)

    # Detokenize and transliterate back to Cyrillic
    output_words = [id_to_token.get(token_id, '') for token_id in dec_input_ids[1:]]
    latin_translation = " ".join(output_words).strip()
    cyrillic_translation = translit(latin_translation, 'ru')
    return cyrillic_translation

## 4. The Experiment: What Can It Translate?

Let's see what happens when we give it different kinds of sentences.

In [None]:
print("--- Part 1: Simple, Common Phrases ---")
print("(These are likely to work due to memorization)\n")
simple_phrases = [
    "hi",
    "go",
    "i see",
    "my cat",
    "who is he"
]
for sentence in simple_phrases:
    translation = translate(sentence, loaded_params, vocab, d_model=D_MODEL)
    print(f"Input:  {sentence}")
    print(f"Output: {translation}\n")

print("\n--- Part 2: More Complex Sentences ---")
print("(These are likely to fail or produce nonsense)\n")
complex_sentences = [
    "where are you going tomorrow?",
    "the black cat is sleeping on the green mat",
    "i love to build neural networks from scratch"
]
for sentence in complex_sentences:
    translation = translate(sentence, loaded_params, vocab, d_model=D_MODEL)
    print(f"Input:  {sentence}")
    print(f"Output: {translation}\n")

## 5. Conclusion

As the results show, the model performs reasonably well on very short phrases it has likely memorized from the training data. However, it fails completely when faced with longer sentences that require an understanding of grammar, word order, and context. 

This perfectly demonstrates that the high accuracy score we saw during training was an illusion, created by the model's success on a large number of simple, repetitive examples, while hiding its inability to generalize.