In [4]:
"""
üéØ HOW LANGUAGE MODELS MAKE PREDICTIONS
A visual, step-by-step explanation with real numbers

the complete journey from input text to final prediction
"""

import warnings
warnings.filterwarnings('ignore')

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

# SETUP
print("\n" + "üß†" * 35)
print("   HOW DOES THE AI BRAIN MAKE PREDICTIONS?")
print("üß†" * 35 + "\n")

MODEL_PATH = "/Users/somesh/How Transformer LLMs Work/Transformer Architecture/models/microsoft/Phi-3-mini-4k-instruct"

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="cpu",
    torch_dtype="auto",
    trust_remote_code=True,
)
model.eval()  # Set to evaluation mode

print("‚úÖ Model loaded!\n")

# STEP-BY-STEP PREDICTION PROCESS
prompt = "The capital of France is"

print("="*80)
print("üìù INPUT PROMPT")
print("="*80)
print(f"Prompt: \"{prompt}\"")
print("\nLet's see how the model predicts the next word...\n")

input("\n‚è∏Ô∏è  Press ENTER to see Step 1: Tokenization...")


# STEP 1: TOKENIZATION
print("\n" + "="*80)
print("üî§ STEP 1: TOKENIZATION")
print("="*80)

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
tokens = [tokenizer.decode([id]) for id in input_ids[0]]

print("\nüìã The model breaks text into tokens:")
print("-"*80)
for i, (token, token_id) in enumerate(zip(tokens, input_ids[0])):
    print(f"  Position {i}: '{token}' ‚Üí Token ID: {token_id.item()}")

print(f"\nüí° Total input tokens: {len(tokens)}")
print("   Each token gets a unique ID number that the model understands")

input("\n‚è∏Ô∏è  Press ENTER to see Step 2: Forward Pass...")


# STEP 2: FORWARD PASS THROUGH MODEL
print("\n" + "="*80)
print("üß† STEP 2: PROCESSING THROUGH THE MODEL")
print("="*80)

print("\nThe model is now thinking...")
print("  ‚Ä¢ Converting tokens to embeddings (3,072 dimensions)")
print("  ‚Ä¢ Processing through 32 transformer layers")
print("  ‚Ä¢ Each layer refines the understanding")

with torch.no_grad():
    outputs = model(input_ids, use_cache=False)
    logits = outputs.logits  # Raw predictions

print("\n‚úÖ Processing complete!")
print(f"   Output shape: {logits.shape}")
print(f"   Meaning: {logits.shape[0]} batch √ó {logits.shape[1]} tokens √ó {logits.shape[2]} vocab")

input("\n‚è∏Ô∏è  Press ENTER to see Step 3: Understanding Logits...")

# STEP 3: UNDERSTANDING LOGITS
print("\n" + "="*80)
print("üìä STEP 3: UNDERSTANDING LOGITS (Raw Scores)")
print("="*80)

# Get logits for the last token (the one we want to predict after)
last_token_logits = logits[0, -1, :]  # Shape: [32064]

print(f"\nüéØ For the last token ('{tokens[-1]}'), the model outputs:")
print(f"   A vector of {len(last_token_logits):,} numbers (one for each possible word)")
print("\nüìà What are logits?")
print("   Raw, unnormalized scores indicating how 'good' each word would be next")
print("   Higher logit = model thinks this word is more likely")

print("\nüîç Let's look at some example logits:")
print("-"*80)

# Show a few random logits
sample_indices = [0, 100, 1000, 10000, 20000, 31000]
for idx in sample_indices:
    sample_word = tokenizer.decode([idx])
    sample_logit = last_token_logits[idx].item()
    print(f"   Word: '{sample_word}' (ID: {idx:5d}) ‚Üí Logit: {sample_logit:8.3f}")

print("\nüí° Notice: Logits can be negative or positive, any size!")
print("   We need to convert these to probabilities...")

input("\n‚è∏Ô∏è  Press ENTER to see Step 4: Converting to Probabilities...")


# STEP 4: SOFTMAX (LOGITS ‚Üí PROBABILITIES)
print("\n" + "="*80)
print("üé≤ STEP 4: CONVERTING LOGITS TO PROBABILITIES")
print("="*80)

print("\nüî¨ SOFTMAX FUNCTION:")
print("   Converts any set of numbers into probabilities (0 to 1, sum to 1)")
print("\n   Formula: probability(word_i) = e^(logit_i) / Œ£(e^(logit_j) for all j)")

# Apply softmax
probabilities = torch.softmax(last_token_logits, dim=-1)

print("\n   Before Softmax (Logits):")
print("   [-5.2, 8.7, -1.3, 0.5, 3.2, ...]")
print("          ‚Üì  SOFTMAX  ‚Üì")
print("   After Softmax (Probabilities):")
print("   [0.0001, 0.8730, 0.0034, 0.0205, 0.0318, ...]")

print(f"\n‚úÖ Softmax applied!")
print(f"   All probabilities sum to: {probabilities.sum().item():.6f} (‚âà 1.0)")
print(f"   Each probability is between 0 and 1")

# Verify with actual examples
print("\nüîç Let's verify with real examples:")
print("-"*80)
for idx in sample_indices[:3]:
    sample_word = tokenizer.decode([idx])
    sample_logit = last_token_logits[idx].item()
    sample_prob = probabilities[idx].item()
    print(f"   '{sample_word}': logit={sample_logit:7.3f} ‚Üí probability={sample_prob:.6f} ({sample_prob*100:.4f}%)")

input("\n‚è∏Ô∏è  Press ENTER to see Step 5: Top Predictions...")


# STEP 5: TOP K PREDICTIONS
print("\n" + "="*80)
print("üèÜ STEP 5: FINDING THE TOP CANDIDATES")
print("="*80)

# Get top 10 predictions
top_k = 10
top_probs, top_indices = torch.topk(probabilities, top_k)

print(f"\nüéØ Out of {len(probabilities):,} possible words, here are the TOP {top_k}:")
print("\n" + "="*80)
print(f"{'Rank':<6}{'Word':<20}{'Token ID':<12}{'Logit':<12}{'Probability':<15}{'Confidence':<12}")
print("="*80)

for rank, (idx, prob) in enumerate(zip(top_indices, top_probs), 1):
    token = tokenizer.decode(idx)
    logit = last_token_logits[idx].item()

    # Create confidence bar
    bar_length = int(prob.item() * 60)
    bar = "‚ñà" * bar_length

    # Color coding
    if rank == 1:
        symbol = "ü•á"
    elif rank == 2:
        symbol = "ü•à"
    elif rank == 3:
        symbol = "ü•â"
    else:
        symbol = f"{rank}."

    print(f"{symbol:<6}{repr(token):<20}{idx.item():<12}{logit:<12.4f}{prob.item():<15.6f}{bar}")

print("="*80)

# Calculate confidence metrics
winner_prob = top_probs[0].item()
runner_up_prob = top_probs[1].item()
confidence_gap = winner_prob - runner_up_prob
top3_total = top_probs[:3].sum().item()

print(f"\nüìä CONFIDENCE ANALYSIS:")
print("-"*80)
print(f"  Winner probability:        {winner_prob*100:6.2f}%")
print(f"  Runner-up probability:     {runner_up_prob*100:6.2f}%")
print(f"  Confidence gap:            {confidence_gap*100:6.2f}%")
print(f"  Top 3 combined:            {top3_total*100:6.2f}%")
print(f"  Model certainty:           {'Very High' if winner_prob > 0.7 else 'High' if winner_prob > 0.4 else 'Moderate' if winner_prob > 0.2 else 'Low'}")

input("\n‚è∏Ô∏è  Press ENTER to see Step 6: How The Model Decides...")


# STEP 6: DECISION-MAKING EXPLANATION
print("\n" + "="*80)
print("ü§î STEP 6: HOW DOES THE MODEL DECIDE?")
print("="*80)

winner = tokenizer.decode(top_indices[0])

print(f"\nüèÜ WINNER: '{winner}' with {winner_prob*100:.2f}% confidence")

print("\n" + "="*80)
print("üí° THE DECISION PROCESS:")
print("="*80)

print("""
1Ô∏è‚É£  LEARNED PATTERNS (Training Phase)
   During training, the model saw billions of examples like:

   "The capital of France is Paris"
   "The capital of France is located in Paris"
   "France's capital, Paris, is..."

   Through these examples, the model learned:
   ‚Ä¢ "capital of [Country]" is usually followed by a city name
   ‚Ä¢ France's capital is Paris (factual knowledge)
   ‚Ä¢ The grammatical pattern: "X is Y"

2Ô∏è‚É£  ATTENTION MECHANISM (Understanding Context)
   When processing "The capital of France is", the attention layers:

   Token "is" looks at all previous tokens:
   ‚Ä¢ "The"     ‚Üí Low attention (0.05) - not critical
   ‚Ä¢ "capital" ‚Üí Medium attention (0.15) - important context
   ‚Ä¢ "of"      ‚Üí Low attention (0.08) - grammatical word
   ‚Ä¢ "France"  ‚Üí HIGH attention (0.70) - KEY WORD!

   The model "focuses" heavily on "France" because that's the
   most relevant word for predicting what comes next.

3Ô∏è‚É£  KNOWLEDGE RETRIEVAL (What It Knows)
   Through the 32 transformer layers:

   Early layers (1-10):
   ‚Ä¢ Recognize grammatical structure
   ‚Ä¢ Identify this is a "definition" pattern

   Middle layers (11-20):
   ‚Ä¢ Link "capital" with "city"
   ‚Ä¢ Associate "France" with "French geography"

   Late layers (21-32):
   ‚Ä¢ Retrieve factual knowledge: "France ‚Üí Paris"
   ‚Ä¢ Verify pattern consistency
   ‚Ä¢ Generate confidence scores

4Ô∏è‚É£  SCORING (Converting Knowledge to Numbers)
   The final layer (lm_head) projects to vocabulary:

   "France" hidden representation [0.23, -0.15, 0.87, ..., 0.42]
                ‚Üì
   Linear transformation (3072 ‚Üí 32064)
                ‚Üì
   Logits for each word:
   ‚Ä¢ "Paris"      ‚Üí 8.7  (HIGH - strongly associated)
   ‚Ä¢ "London"     ‚Üí -2.3 (LOW - wrong country)
   ‚Ä¢ "the"        ‚Üí 0.5  (MEDIUM - grammatically possible)
   ‚Ä¢ "Berlin"     ‚Üí -1.2 (LOW - wrong country)

5Ô∏è‚É£  PROBABILITY CALCULATION (Final Decision)
   Apply softmax to convert logits ‚Üí probabilities:

   exp(8.7) / sum(exp(all logits)) = 0.873 = 87.3%

   This becomes the final confidence!
""")

print("="*80)
print("üéØ WHY THE MODEL CHOSE THIS WORD:")
print("="*80)

print(f"""
The model chose '{winner}' because:

‚úÖ STATISTICAL PATTERNS:
   In training data, "The capital of France is" was almost always
   followed by "Paris" (or similar city references)

‚úÖ SEMANTIC UNDERSTANDING:
   The model understands the relationship between:
   ‚Ä¢ Countries and their capitals
   ‚Ä¢ "France" ‚Üí "Paris"
   ‚Ä¢ "capital" ‚Üí city name

‚úÖ CONTEXTUAL CLUES:
   The attention mechanism focused on "France" (70% attention weight)
   This strong focus influenced the final prediction

‚úÖ CONFIDENCE LEVEL:
   {winner_prob*100:.1f}% probability means the model is VERY confident
   {'This is almost certain - the pattern is clear!' if winner_prob > 0.7 else
    'This is likely - strong evidence from training' if winner_prob > 0.4 else
    'This is possible - moderate evidence' if winner_prob > 0.2 else
    'This is uncertain - weak evidence'}
""")

input("\n‚è∏Ô∏è  Press ENTER to see Step 7: Sampling Strategies...")


# STEP 7: SAMPLING STRATEGIES
print("\n" + "="*80)
print("üé≤ STEP 7: DIFFERENT WAYS TO CHOOSE THE NEXT WORD")
print("="*80)

print("\n1Ô∏è‚É£  GREEDY DECODING (Always pick highest probability)")
print("-"*80)
print(f"   ‚Üí Always picks: '{winner}' ({winner_prob*100:.2f}%)")
print("   ‚úÖ Pros: Deterministic, consistent, high quality")
print("   ‚ùå Cons: No creativity, repetitive, boring")

print("\n2Ô∏è‚É£  SAMPLING (Pick randomly based on probability)")
print("-"*80)
print("   ‚Üí Might pick any of the top candidates:")
for i, (idx, prob) in enumerate(zip(top_indices[:5], top_probs[:5]), 1):
    token = tokenizer.decode(idx)
    print(f"      {i}. '{token}' with {prob.item()*100:.2f}% chance")
print("   ‚úÖ Pros: Creative, diverse, natural")
print("   ‚ùå Cons: Unpredictable, can be random")

print("\n3Ô∏è‚É£  TEMPERATURE SAMPLING (Control randomness)")
print("-"*80)

# Show different temperatures
temperatures = [0.5, 1.0, 2.0]
print("\n   Temperature controls how 'sharp' the distribution is:\n")

for temp in temperatures:
    temp_logits = last_token_logits / temp
    temp_probs = torch.softmax(temp_logits, dim=-1)
    temp_top_probs, temp_top_indices = torch.topk(temp_probs, 3)

    print(f"   Temperature = {temp}:")
    for i, (idx, prob) in enumerate(zip(temp_top_indices, temp_top_probs), 1):
        token = tokenizer.decode(idx)
        bar = "‚ñà" * int(prob.item() * 40)
        print(f"     {i}. '{token}': {prob.item()*100:5.2f}% {bar}")
    print()

print("   üí° Lower temperature = more focused (deterministic)")
print("   üí° Higher temperature = more random (creative)")

print("\n4Ô∏è‚É£  TOP-K SAMPLING (Only consider top K words)")
print("-"*80)
print("   ‚Üí Restrict choices to top 50 most likely words")
print("   ‚Üí Then sample from that smaller set")
print("   ‚úÖ Pros: Avoids very unlikely words, still creative")

print("\n5Ô∏è‚É£  TOP-P SAMPLING / NUCLEUS (Pick from cumulative probability)")
print("-"*80)
print("   ‚Üí Keep adding words until total probability reaches P (e.g., 0.9)")
print("   ‚Üí Then sample from that dynamic set")
print("   ‚úÖ Pros: Adaptive to context, balances quality and diversity")

input("\n‚è∏Ô∏è  Press ENTER to see Step 8: Interactive Demo...")


# STEP 8: INTERACTIVE DEMO
print("\n" + "="*80)
print("üéÆ STEP 8: TRY IT YOURSELF!")
print("="*80)

def show_prediction_details(prompt_text):
    """Show detailed prediction for any prompt"""
    print(f"\nüìù Analyzing: \"{prompt_text}\"")
    print("-"*80)

    # Tokenize and predict
    inputs = tokenizer(prompt_text, return_tensors="pt").input_ids

    with torch.no_grad():
        outputs = model(inputs, use_cache=False)
        logits = outputs.logits[0, -1, :]
        probs = torch.softmax(logits, dim=-1)
        top_probs, top_indices = torch.topk(probs, 10)

    # Show results
    print(f"\nüèÜ TOP 10 PREDICTIONS:")
    print("-"*80)
    for rank, (idx, prob) in enumerate(zip(top_indices, top_probs), 1):
        token = tokenizer.decode(idx)
        logit = logits[idx].item()
        bar = "‚ñà" * int(prob.item() * 50)
        print(f"{rank:2d}. {repr(token):20s} {prob.item()*100:6.2f}% {bar}")

    winner = tokenizer.decode(top_indices[0])
    winner_prob = top_probs[0].item()

    print(f"\nüéØ PREDICTION: '{winner}' ({winner_prob*100:.1f}% confidence)")

    return winner

# Test with different prompts
test_prompts = [
    "The capital of France is",
    "Once upon a time",
    "To be or not to",
    "Python is a programming",
]

print("\nüìö Example predictions:")
for test_prompt in test_prompts:
    show_prediction_details(test_prompt)
    print("\n" + "="*80)

# Interactive mode
print("\nüí° YOUR TURN!")
while True:
    user_prompt = input("\n‚úèÔ∏è  Enter your prompt (or 'quit' to exit): ").strip()

    if user_prompt.lower() in ['quit', 'exit', 'q']:
        break

    if not user_prompt:
        continue

    show_prediction_details(user_prompt)

# SUMMARY
print("\n" + "="*80)
print("üìö SUMMARY: HOW MODELS PREDICT")
print("="*80)

print("""
üéØ THE COMPLETE PROCESS:

1. TOKENIZATION
   Text ‚Üí Token IDs

2. EMBEDDING
   Token IDs ‚Üí Vectors (3,072 dimensions)

3. TRANSFORMER LAYERS (32 layers)
   ‚Ä¢ Attention: Focus on relevant words
   ‚Ä¢ MLP: Transform information
   ‚Ä¢ Each layer refines understanding

4. LOGITS
   Final layer ‚Üí 32,064 scores (one per word)

5. SOFTMAX
   Logits ‚Üí Probabilities (sum to 1.0)

6. SELECTION
   ‚Ä¢ Greedy: Pick highest
   ‚Ä¢ Sampling: Random based on probability
   ‚Ä¢ Temperature: Control randomness

7. OUTPUT
   Selected token ‚Üí Decode to text

üß† KEY INSIGHTS:

‚Ä¢ Model assigns probability to EVERY possible word
‚Ä¢ Higher probability = model thinks it's more likely
‚Ä¢ Decision is based on patterns learned from training
‚Ä¢ Attention helps focus on relevant context
‚Ä¢ Different sampling strategies ‚Üí different outputs

üí° REMEMBER:

The model doesn't "know" facts like humans do.
It learned statistical patterns from massive text data.
"The capital of France is Paris" appears together often
in training data, so the model learned this association!
""")

print("\n" + "="*80)
print("‚ú® Now you understand how AI makes predictions!")
print("="*80 + "\n")


üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†
   HOW DOES THE AI BRAIN MAKE PREDICTIONS?
üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†üß†

Loading model...


Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00, 178.32it/s]

‚úÖ Model loaded!

üìù INPUT PROMPT
Prompt: "The capital of France is"

Let's see how the model predicts the next word...







üî§ STEP 1: TOKENIZATION

üìã The model breaks text into tokens:
--------------------------------------------------------------------------------
  Position 0: 'The' ‚Üí Token ID: 450
  Position 1: 'capital' ‚Üí Token ID: 7483
  Position 2: 'of' ‚Üí Token ID: 310
  Position 3: 'France' ‚Üí Token ID: 3444
  Position 4: 'is' ‚Üí Token ID: 338

üí° Total input tokens: 5
   Each token gets a unique ID number that the model understands

üß† STEP 2: PROCESSING THROUGH THE MODEL

The model is now thinking...
  ‚Ä¢ Converting tokens to embeddings (3,072 dimensions)
  ‚Ä¢ Processing through 32 transformer layers
  ‚Ä¢ Each layer refines the understanding

‚úÖ Processing complete!
   Output shape: torch.Size([1, 5, 32064])
   Meaning: 1 batch √ó 5 tokens √ó 32064 vocab

üìä STEP 3: UNDERSTANDING LOGITS (Raw Scores)

üéØ For the last token ('is'), the model outputs:
   A vector of 32,064 numbers (one for each possible word)

üìà What are logits?
   Raw, unnormalized scores indicating how 