
# LLM Intro: From Tokens to Transformer Forward Pass

This notebook accompanies the **Introduction to LLMs** module. Run cells top-to-bottom.
It mirrors the sections:
1. What is an LLM?
2. From Words to Tokens
3. Transformer Architecture
4. Positional Encoding
5. Putting It All Together

> **Tip:** If you're on Colab, make sure the runtime is set to Python 3.10+.



## 0. Environment & Imports

We'll use PyTorch and (optionally) Hugging Face Tokenizers/Transformers.
If you're on Colab, the first `pip` cell will fetch dependencies.


In [None]:

# If running on Colab, uncomment to install latest versions:
# !pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip -q install transformers tokenizers


In [None]:

import math, os, sys, json, random, time
import torch
import torch.nn as nn
import torch.nn.functional as F

# Matplotlib for a few simple visualizations
import matplotlib.pyplot as plt

print(f'Torch: {torch.__version__}, CUDA available: {torch.cuda.is_available()}')



## 1. What is a Large Language Model? (Quick Code Peek)

Language models learn the conditional probability of the next token given previous tokens.
Below we define a helper to apply softmax to logits and inspect the distribution.


In [None]:

import torch

def softmax(logits):
    exps = torch.exp(logits - logits.max())  # numerical stability
    return exps / exps.sum(dim=-1, keepdim=True)

logits = torch.tensor([2.0, 1.0, 0.5])
probs = softmax(logits)
print('Logits:', logits.tolist())
print('Probs:', probs.tolist(), 'sum=', float(probs.sum()))
print('Argmax token index:', int(probs.argmax()))



## 2. From Words to Tokens

We'll demonstrate tokenization using GPT-2's tokenizer from Hugging Face.


In [None]:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "Transformers are powerful models."
tokens = tokenizer.tokenize(text)
token_ids = tokenizer.encode(text)
print('Tokens:', tokens)
print('Token IDs:', token_ids)


In [None]:

# Embedding lookup demo (toy dimensions)
V, d = 50000, 8
embedding = nn.Embedding(V, d)

# Map token IDs to embeddings (note: GPT-2 vocab ~50k, so IDs fit this toy V)
ids = torch.tensor(token_ids)
X = embedding(ids)
print('Embedding shape:', tuple(X.shape))  # (T, d)



## 3. Transformer Architecture: Self-Attention & Feedforward

We'll implement one attention step by hand, then use PyTorch's built-in MultiheadAttention.


In [None]:

# Manual single-head attention on small tensors
T, d = 4, 8
X = torch.rand(T, d)
W_Q = torch.rand(d, d)
W_K = torch.rand(d, d)
W_V = torch.rand(d, d)

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

A = F.softmax(Q @ K.T / (d ** 0.5), dim=-1)  # attention weights
Z = A @ V

print('Attention weights shape:', tuple(A.shape))  # (T, T)
print('Context shape:', tuple(Z.shape))           # (T, d)


In [None]:

# Multi-head attention using nn.MultiheadAttention
mha = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
X_batched = torch.rand(1, 4, 8)  # (batch, tokens, dim)
out, weights = mha(X_batched, X_batched, X_batched)
print('MHA out shape:', tuple(out.shape))        # (1, 4, 8)
print('MHA weights shape:', tuple(weights.shape))# (1, 2, 4, 4) = (batch, heads, T, T)


In [None]:

# Feedforward network (position-wise MLP)
ffn = nn.Sequential(
    nn.Linear(8, 16),
    nn.ReLU(),
    nn.Linear(16, 8)
)
X_ffn = ffn(out)
print('FFN out shape:', tuple(X_ffn.shape))



## 4. Positional Encoding (Sinusoidal)

We'll create sinusoidal positional encodings and visualize a few dimensions across positions.


In [None]:

def positional_encoding(max_len, d_model):
    PE = torch.zeros(max_len, d_model)
    for pos in range(max_len):
        for i in range(0, d_model, 2):
            angle = pos / (10000 ** (2 * i / d_model))
            PE[pos, i] = math.sin(angle)
            if i + 1 < d_model:
                PE[pos, i + 1] = math.cos(angle)
    return PE

PE = positional_encoding(max_len=32, d_model=8)
print('PE shape:', tuple(PE.shape))


In [None]:

# Plot positional encoding patterns for a few dimensions
plt.figure(figsize=(8, 3))
plt.plot(PE[:, 0], label='dim 0 (sin)')
plt.plot(PE[:, 1], label='dim 1 (cos)')
plt.plot(PE[:, 2], label='dim 2 (sin)')
plt.xlabel('Token position')
plt.ylabel('Encoding value')
plt.title('Example Positional Encoding Patterns')
plt.legend()
plt.show()



## 5. Putting It All Together: A Minimal Transformer Layer

We now assemble a single Transformer block with MHA, FFN, and LayerNorm, then run a forward pass.


In [None]:

class MiniTransformerLayer(nn.Module):
    def __init__(self, d_model=8, nhead=2, dim_ff=16):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, X):
        attn_out, attn_w = self.mha(X, X, X)
        X = self.norm1(X + attn_out)
        ffn_out = self.ffn(X)
        X = self.norm2(X + ffn_out)
        return X, attn_w

torch.manual_seed(0)
X = torch.rand(1, 5, 8)  # (batch, tokens, dim)
layer = MiniTransformerLayer()
out, attn_w = layer(X)
print('Output shape:', tuple(out.shape))
print('Attention weights shape:', tuple(attn_w.shape))


In [None]:
# Visualize attention weights for head 0
# Note: weights shape is (batch, T, T), we'll plot for batch=0
w = attn_w[0].detach()
plt.figure(figsize=(4, 3))
plt.imshow(w, aspect='auto')
plt.colorbar()
plt.title('Attention Weights (head 0)')
plt.xlabel('Key positions')
plt.ylabel('Query positions')
plt.show()


## 6. Output Projection & Softmax (Toy Demo)

Finally, we project the hidden state at each position to vocabulary logits and apply softmax.


In [None]:

# Toy vocabulary projection
V = 1000
d_model = out.shape[-1]
Wo = nn.Linear(d_model, V)

logits = Wo(out)             # (batch, T, V)
probs = F.softmax(logits, dim=-1)
print('Logits shape:', tuple(logits.shape))
print('Probs shape:', tuple(probs.shape), 'Row sums (approx 1):', probs[0,0,:999].sum().item())
