In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class SimpleLanguageModel(nn.Module):
    """
    Simplified LLM to show exactly how logits are created
    """
    def __init__(self, vocab_size=10, embedding_dim=8, hidden_dim=16):
        super().__init__()
        
        # This is what creates the logits in real LLMs
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, 
            nhead=2, 
            dim_feedforward=hidden_dim,
            batch_first=True
        )
        
        # THE MAGIC LAYER: This outputs the logits!
        self.output_projection = nn.Linear(embedding_dim, vocab_size)
        
        self.vocab = ["the", "cat", "dog", "is", "on", "mat", "running", "sleeping", "big", "small"]
    
    def forward(self, input_ids):
        # 1. Convert tokens to embeddings
        embeddings = self.embedding(input_ids)  # [batch, seq_len, embedding_dim]
        
        # 2. Process through transformer
        hidden_states = self.transformer_layer(embeddings)  # [batch, seq_len, embedding_dim]
        
        # 3. PROJECT TO VOCABULARY SIZE - THIS CREATES LOGITS!
        logits = self.output_projection(hidden_states)  # [batch, seq_len, vocab_size]
        
        return logits

def explain_logit_creation():
    """
    Show step-by-step how logits are actually created
    """
    print("🔍 HOW LOGITS ARE BORN")
    print("=" * 40)
    
    model = SimpleLanguageModel()
    
    # Input: "the cat is"
    input_tokens = torch.tensor([[0, 1, 3]])  # token IDs
    
    print("Step 1: Input tokens ->", [model.vocab[i] for i in input_tokens[0]])
    
    # Get embeddings
    embeddings = model.embedding(input_tokens)
    print(f"Step 2: Token embeddings shape: {embeddings.shape}")
    print(f"Each token becomes a vector of {embeddings.shape[-1]} numbers")
    
    # Process through transformer
    hidden_states = model.transformer_layer(embeddings)
    print(f"Step 3: After transformer: {hidden_states.shape}")
    print("These are rich representations of each token in context")
    
    # THE KEY STEP: Project to vocabulary
    logits = model.output_projection(hidden_states)
    print(f"Step 4: Final logits shape: {logits.shape}")
    print("Each token position now has a score for EVERY possible next word")
    
    # Look at logits for predicting the next word after "the cat is"
    next_word_logits = logits[0, -1, :]  # Last position
    
    print(f"\n🎯 LOGITS FOR NEXT WORD AFTER 'the cat is':")
    for i, (word, logit) in enumerate(zip(model.vocab, next_word_logits)):
        print(f"{word:>10}: {logit.item():6.2f}")
    
    return next_word_logits

def why_these_numbers():
    """
    Explain WHY logits have the values they do
    """
    print("\n🤔 WHY DO LOGITS LOOK LIKE RANDOM NUMBERS?")
    print("=" * 50)
    
    # Create some example scenarios
    scenarios = {
        "Very confident": torch.tensor([8.5, 0.1, -0.3, -1.2, -2.1]),
        "Uncertain": torch.tensor([1.1, 1.0, 0.9, 0.8, 0.7]),
        "One clear winner": torch.tensor([5.0, -3.0, -4.0, -5.0, -6.0]),
        "Two good options": torch.tensor([3.0, 2.8, -1.0, -2.0, -3.0])
    }
    
    print("SCENARIO ANALYSIS:")
    print("-" * 30)
    
    for scenario_name, logits in scenarios.items():
        probs = F.softmax(logits, dim=0)
        
        print(f"\n{scenario_name}:")
        print(f"  Logits: {logits.tolist()}")
        print(f"  Probabilities: {[f'{p:.2f}' for p in probs.tolist()]}")
        print(f"  Top choice: {probs.max().item():.2f} ({probs.max().item()*100:.0f}%)")

def logit_intuition():
    """
    Build intuition about what logits represent
    """
    print("\n💡 BUILDING INTUITION ABOUT LOGITS")
    print("=" * 40)
    
    print("Think of logits as 'raw votes' from the neural network:")
    print()
    print("🗳️  VOTING ANALOGY:")
    print("   Logit = 5.0  →  'STRONG YES! Pick this word!'")
    print("   Logit = 0.0  →  'Meh, maybe this word'")
    print("   Logit = -3.0 →  'NO! Don't pick this word'")
    print()
    print("🎯 The neural network learned these 'voting weights' during training")
    print("   by seeing millions of examples of good next words.")
    print()
    print("📊 Softmax converts these raw votes into proper probabilities")
    print("   that sum to 1.0, but the logits are the 'pure opinion'")

def the_linear_layer_secret():
    """
    Show exactly how the final linear layer creates logits
    """
    print("\n🔬 THE SECRET: IT'S JUST MATRIX MULTIPLICATION!")
    print("=" * 50)
    
    # Simulate the final layer of an LLM
    hidden_size = 4  # Simplified
    vocab_size = 5
    
    # This is what the model learned during training
    weight_matrix = torch.tensor([
        [ 2.1, -0.5,  1.3,  0.8],  # Weights for word "cat" 
        [-1.2,  3.0, -0.2,  1.5],  # Weights for word "dog"
        [ 0.3,  0.1,  2.5, -1.0],  # Weights for word "run"
        [-0.8,  1.8, -1.5,  2.2],  # Weights for word "sleep"
        [ 1.0, -2.0,  0.5,  0.3],  # Weights for word "big"
    ])
    
    bias = torch.tensor([0.1, -0.2, 0.3, 0.0, 0.5])
    
    # Hidden state from transformer (what the model "knows" about context)
    hidden_state = torch.tensor([1.5, -0.8, 2.0, 0.5])
    
    print("Weight matrix (learned during training):")
    print(weight_matrix)
    print(f"\nHidden state (context understanding): {hidden_state}")
    
    # The actual logit calculation: logits = W @ h + b
    logits = torch.matmul(weight_matrix, hidden_state) + bias
    
    print(f"\nLogits = Weight_matrix @ Hidden_state + Bias")
    print(f"Logits = {logits.tolist()}")
    
    vocab = ["cat", "dog", "run", "sleep", "big"]
    probs = F.softmax(logits, dim=0)
    
    print(f"\nFinal predictions:")
    for word, logit, prob in zip(vocab, logits, probs):
        print(f"{word:>6}: logit={logit:5.2f} → probability={prob:.3f}")

# Run all explanations


In [3]:
explain_logit_creation()


🔍 HOW LOGITS ARE BORN
Step 1: Input tokens -> ['the', 'cat', 'is']
Step 2: Token embeddings shape: torch.Size([1, 3, 8])
Each token becomes a vector of 8 numbers
Step 3: After transformer: torch.Size([1, 3, 8])
These are rich representations of each token in context
Step 4: Final logits shape: torch.Size([1, 3, 10])
Each token position now has a score for EVERY possible next word

🎯 LOGITS FOR NEXT WORD AFTER 'the cat is':
       the:   1.28
       cat:  -0.33
       dog:  -1.32
        is:   0.94
        on:   0.69
       mat:  -0.22
   running:  -0.31
  sleeping:   0.18
       big:  -0.11
     small:   0.16


tensor([ 1.2782, -0.3274, -1.3195,  0.9365,  0.6894, -0.2176, -0.3136,  0.1835,
        -0.1090,  0.1588], grad_fn=<SelectBackward0>)

In [4]:
why_these_numbers()  



🤔 WHY DO LOGITS LOOK LIKE RANDOM NUMBERS?
SCENARIO ANALYSIS:
------------------------------

Very confident:
  Logits: [8.5, 0.10000000149011612, -0.30000001192092896, -1.2000000476837158, -2.0999999046325684]
  Probabilities: ['1.00', '0.00', '0.00', '0.00', '0.00']
  Top choice: 1.00 (100%)

Uncertain:
  Logits: [1.100000023841858, 1.0, 0.8999999761581421, 0.800000011920929, 0.699999988079071]
  Probabilities: ['0.24', '0.22', '0.20', '0.18', '0.16']
  Top choice: 0.24 (24%)

One clear winner:
  Logits: [5.0, -3.0, -4.0, -5.0, -6.0]
  Probabilities: ['1.00', '0.00', '0.00', '0.00', '0.00']
  Top choice: 1.00 (100%)

Two good options:
  Logits: [3.0, 2.799999952316284, -1.0, -2.0, -3.0]
  Probabilities: ['0.54', '0.44', '0.01', '0.00', '0.00']
  Top choice: 0.54 (54%)


In [5]:
logit_intuition()


💡 BUILDING INTUITION ABOUT LOGITS
Think of logits as 'raw votes' from the neural network:

🗳️  VOTING ANALOGY:
   Logit = 5.0  →  'STRONG YES! Pick this word!'
   Logit = 0.0  →  'Meh, maybe this word'
   Logit = -3.0 →  'NO! Don't pick this word'

🎯 The neural network learned these 'voting weights' during training
   by seeing millions of examples of good next words.

📊 Softmax converts these raw votes into proper probabilities
   that sum to 1.0, but the logits are the 'pure opinion'


In [7]:
the_linear_layer_secret()



🔬 THE SECRET: IT'S JUST MATRIX MULTIPLICATION!
Weight matrix (learned during training):
tensor([[ 2.1000, -0.5000,  1.3000,  0.8000],
        [-1.2000,  3.0000, -0.2000,  1.5000],
        [ 0.3000,  0.1000,  2.5000, -1.0000],
        [-0.8000,  1.8000, -1.5000,  2.2000],
        [ 1.0000, -2.0000,  0.5000,  0.3000]])

Hidden state (context understanding): tensor([ 1.5000, -0.8000,  2.0000,  0.5000])

Logits = Weight_matrix @ Hidden_state + Bias
Logits = [6.649999618530273, -4.050000190734863, 5.170000076293945, -4.539999961853027, 4.75]

Final predictions:
   cat: logit= 6.65 → probability=0.726
   dog: logit=-4.05 → probability=0.000
   run: logit= 5.17 → probability=0.165
 sleep: logit=-4.54 → probability=0.000
   big: logit= 4.75 → probability=0.109


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ExplainableLanguageModel:
    """
    Shows the difference between weights (parameters) and logits (outputs)
    """
    def __init__(self):
        # Create a tiny vocabulary for easy tracking
        self.vocab = {
            0: "the", 1: "cat", 2: "dog", 3: "is", 4: "running", 
            5: "sleeping", 6: "big", 7: "small", 8: "on", 9: "mat"
        }
        self.word_to_id = {word: id for id, word in self.vocab.items()}
        
        # THESE ARE THE PARAMETERS/WEIGHTS (fixed after training)
        # Shape: [vocab_size, hidden_size] 
        self.final_layer_weights = torch.tensor([
            # Each row = weights for predicting one word
            #    [dim0, dim1, dim2, dim3]  ← hidden dimensions
            [  2.1, -0.5,  1.3,  0.8],  # weights for predicting "the" (id=0)
            [ -1.2,  3.0, -0.2,  1.5],  # weights for predicting "cat" (id=1) 
            [  0.3,  0.1,  2.5, -1.0],  # weights for predicting "dog" (id=2)
            [ -0.8,  1.8, -1.5,  2.2],  # weights for predicting "is" (id=3)
            [  1.5,  0.8,  2.0,  1.2],  # weights for predicting "running" (id=4)
            [  0.9, -0.3,  1.8,  0.5],  # weights for predicting "sleeping" (id=5)
            [  1.0,  1.5, -0.2,  0.8],  # weights for predicting "big" (id=6)
            [ -0.5,  2.0,  0.3, -0.8],  # weights for predicting "small" (id=7)
            [  0.7, -1.0,  1.5,  1.8],  # weights for predicting "on" (id=8)
            [  1.8,  0.2, -0.5,  1.0]   # weights for predicting "mat" (id=9)
        ])
        
        self.bias = torch.tensor([0.1, -0.2, 0.3, 0.0, 0.5, 0.2, -0.1, 0.4, 0.3, 0.1])
    
    def show_parameters(self):
        """Show the actual model parameters (weights)"""
        print("🔧 MODEL PARAMETERS (WEIGHTS) - These are FIXED after training")
        print("=" * 70)
        print("Final layer weight matrix:")
        print("Shape:", self.final_layer_weights.shape, "→ [vocab_size, hidden_dimensions]")
        print()
        
        for word_id, word in self.vocab.items():
            weights = self.final_layer_weights[word_id]
            print(f"Weights for '{word}' (id={word_id}): {weights.tolist()}")
        
        print(f"\nBias terms: {self.bias.tolist()}")
        print("\n💡 These weights were learned during training and DON'T change!")
    
    def compute_logits(self, context_vector, show_computation=True):
        """
        Compute logits from context vector - THIS IS WHERE LOGITS COME FROM
        """
        if show_computation:
            print(f"\n🧮 COMPUTING LOGITS (this happens every inference)")
            print("=" * 60)
            print(f"Input context vector: {context_vector.tolist()}")
            print("↓")
            print("For each possible next word, compute: weights • context + bias")
            print()
        
        # LOGITS = WEIGHTS @ CONTEXT + BIAS
        logits = torch.matmul(self.final_layer_weights, context_vector) + self.bias
        
        if show_computation:
            print("LOGIT COMPUTATION FOR EACH WORD:")
            print("-" * 40)
            for word_id, word in self.vocab.items():
                weights = self.final_layer_weights[word_id]
                dot_product = torch.dot(weights, context_vector).item()
                bias_val = self.bias[word_id].item()
                final_logit = logits[word_id].item()
                
                print(f"'{word}' (id={word_id}):")
                print(f"  {weights.tolist()} • {context_vector.tolist()} + {bias_val:.1f}")
                print(f"  = {dot_product:.2f} + {bias_val:.1f} = {final_logit:.2f}")
                print()
        
        return logits
    
    def demonstrate_inference(self, sentence):
        """Show complete inference process"""
        print(f"\n🎯 INFERENCE EXAMPLE: Predicting next word after '{sentence}'")
        print("=" * 70)
        
        # Convert sentence to token IDs
        words = sentence.split()
        token_ids = [self.word_to_id[word] for word in words]
        print(f"Input tokens: {words} → IDs: {token_ids}")
        
        # Simulate context vector (in real LLM, this comes from transformer layers)
        # Different contexts produce different vectors
        if "cat" in words:
            context_vector = torch.tensor([1.5, -0.8, 2.0, 0.5])  # "animal context"
        elif "big" in words or "small" in words:
            context_vector = torch.tensor([0.2, 1.8, -0.5, 1.2])  # "size context" 
        else:
            context_vector = torch.tensor([1.0, 0.0, 1.0, 0.0])   # "neutral context"
        
        print(f"Context vector (from transformer): {context_vector.tolist()}")
        
        # Compute logits
        logits = self.compute_logits(context_vector)
        
        # Convert to probabilities
        probs = F.softmax(logits, dim=0)
        
        # Show top predictions
        top_values, top_indices = torch.topk(logits, 5)
        
        print("🏆 TOP 5 PREDICTIONS:")
        print("-" * 25)
        for i, (logit_val, word_id) in enumerate(zip(top_values, top_indices)):
            word = self.vocab[word_id.item()]
            prob = probs[word_id].item()
            print(f"{i+1}. '{word}' (id={word_id.item()}): logit={logit_val:.2f}, prob={prob:.3f}")

def key_differences():
    """Highlight the key differences"""
    print("\n" + "="*80)
    print("🔑 KEY DIFFERENCES: WEIGHTS vs LOGITS")
    print("="*80)
    
    print("WEIGHTS/PARAMETERS:")
    print("✓ Fixed numbers learned during training")
    print("✓ Stored in model files (.bin, .safetensors)")
    print("✓ Same for every inference")
    print("✓ Shape: [vocab_size, hidden_size]")
    print("✓ Example: tensor([2.1, -0.5, 1.3, 0.8]) for word 'cat'")
    
    print("\nLOGITS:")
    print("✓ Computed fresh for each query")  
    print("✓ Result of: weights @ context_vector + bias")
    print("✓ Different for every different input")
    print("✓ Shape: [vocab_size] - one score per word")
    print("✓ Example: tensor([4.2, -1.8, 3.1, 0.5, 2.3, ...]) for some context")
    
    print("\n🎯 THE RELATIONSHIP:")
    print("Weights are like a 'recipe' → Logits are the 'dish' you cook")
    print("Same recipe + different ingredients = different dish")
    print("Same weights + different context = different logits")

# Run the demonstration
if __name__ == "__main__":
    model = ExplainableLanguageModel()
    
    # Show the fixed parameters
    model.show_parameters()
    
    # Show different inferences with same weights but different contexts
    model.demonstrate_inference("the cat")
    model.demonstrate_inference("the big")
    model.demonstrate_inference("the dog is")
    
    # Explain the key differences
    key_differences()
    
    print("\n" + "="*80)
    print("💡 FOR YOUR PROJECT:")
    print("="*80)
    print("• You DON'T train new weights - you use pre-trained LLM weights")
    print("• You DO extract logits from each inference")  
    print("• Uncertainty comes from analyzing the logit patterns")
    print("• Your router learns to interpret logit uncertainty patterns")

🔧 MODEL PARAMETERS (WEIGHTS) - These are FIXED after training
Final layer weight matrix:
Shape: torch.Size([10, 4]) → [vocab_size, hidden_dimensions]

Weights for 'the' (id=0): [2.0999999046325684, -0.5, 1.2999999523162842, 0.800000011920929]
Weights for 'cat' (id=1): [-1.2000000476837158, 3.0, -0.20000000298023224, 1.5]
Weights for 'dog' (id=2): [0.30000001192092896, 0.10000000149011612, 2.5, -1.0]
Weights for 'is' (id=3): [-0.800000011920929, 1.7999999523162842, -1.5, 2.200000047683716]
Weights for 'running' (id=4): [1.5, 0.800000011920929, 2.0, 1.2000000476837158]
Weights for 'sleeping' (id=5): [0.8999999761581421, -0.30000001192092896, 1.7999999523162842, 0.5]
Weights for 'big' (id=6): [1.0, 1.5, -0.20000000298023224, 0.800000011920929]
Weights for 'small' (id=7): [-0.5, 2.0, 0.30000001192092896, -0.800000011920929]
Weights for 'on' (id=8): [0.699999988079071, -1.0, 1.5, 1.7999999523162842]
Weights for 'mat' (id=9): [1.7999999523162842, 0.20000000298023224, -0.5, 1.0]

Bias terms: 