# Decoder-only Transformer

## Baseline

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-12-15 02:55:34--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-12-15 02:55:35 (30.9 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
text = open('input.txt', 'r', encoding='utf-8').read()
print (len(text))
print (text[:400])

1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it 


In [None]:
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

In [4]:
c2i = {c:i for i, c in enumerate(chars)}
i2c = {i:c for i, c in enumerate(chars)}
encode = lambda s: [c2i[c] for c in s] # take string -> list of int
decode = lambda li: ''.join([i2c[i] for i in li]) # take list of int -> string

In [5]:
import torch
data = torch.tensor(encode(text), dtype=torch.int64)
print (data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [6]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
BLOCK_SIZE = 8
BATCH_SIZE = 32
torch.manual_seed(1337)

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+1+BLOCK_SIZE] for i in ix])
    return x, y

xb, yb = get_batch('train')
print (xb.shape, yb.shape)

torch.Size([32, 8]) torch.Size([32, 8])


In [7]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        logits = self.emb(idx) # [B,T,C] [Batch, Time, Channel]
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate (self, idx, max_new_tokens):
        # idx is [B, T] array of indices in the current context
        for _ in range(max_new_tokens):
            logits, loss = self(idx) # [B,T,C]
            logits = logits[:,-1,:] # last time step
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =-1)
        return idx
        
    
bi = BigramLM(VOCAB_SIZE)
logits, loss = bi(xb, yb)
print(logits.shape, loss)

idx = torch.zeros((1,1), dtype=torch.int64)
for s in bi.generate(idx, 100).tolist():
    print (decode(s))

torch.Size([256, 65]) tensor(4.7313, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [8]:
optimizer = torch.optim.AdamW(bi.parameters(), lr = 1e-2)
for step in range(10000):
    xb, yb = get_batch('train')
    logits, loss = bi(xb, yb)
    optimizer.zero_grad(True)
    loss.backward()
    optimizer.step()
print (loss.item())

2.367182493209839


In [9]:
idx = torch.zeros((1,1), dtype=torch.int64)
for s in bi.generate(idx, 100).tolist():
    print (decode(s))


lso br. ave aviu urf my, y MP t ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulsee


## The mathematical trick in self-attention

In [10]:
torch.manual_seed(1337)
B, T, C = 4,8,2
x = torch.rand(B, T, C)

In [None]:
# the trick 2
wei = torch.tril(torch.ones(T,T))
wei = wei / torch.sum(wei, 1, keepdim=True)
xbow2 = wei @ x # (T, T) @ (B, T, C) -> (B, T, T) @ (B, T, C) -> for each B cal (T, T) @ (T, C)

In [13]:
# the trick 2 explain
# cal the avg of the present token plus previous token
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print (a)
print (b)
print (c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [None]:
# the trick 3
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x

### Self-attention

- Each token will find other tokens less or more interesting -> data dependent
- Every single token at each position will emit 2 independent vectors (query and key)
    - Key: what do I contain
    - Query: what am I lookin for
    - Get affinities between 2 token is dot product Key and Query
    - We dont aggregate tokens x exactly like wei @ x, instead use another vector Value to communicate with Key and Query

In [44]:
# self-att
torch.manual_seed(1337)
B, T, C = 4,8,32
x = torch.rand(B, T, C)

# single head perform attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, head_size)
q = query(x)
v = value(x)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # scale attention # (B, T, T)

tril = torch.tril(torch.ones(T,T))
#wei = torch.ones(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [45]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4852, 0.5148, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3242, 0.3345, 0.3413, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2428, 0.2597, 0.2434, 0.2541, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1960, 0.2043, 0.1992, 0.2105, 0.1899, 0.0000, 0.0000, 0.0000],
        [0.1586, 0.1804, 0.1706, 0.1790, 0.1493, 0.1622, 0.0000, 0.0000],
        [0.1395, 0.1417, 0.1422, 0.1456, 0.1337, 0.1410, 0.1563, 0.0000],
        [0.1203, 0.1272, 0.1256, 0.1299, 0.1166, 0.1240, 0.1248, 0.1316]],
       grad_fn=<SelectBackward0>)

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

Residual connections essentially say to the model:

"Stick to the original input unless you are absolutely sure you have a valuable correction to make."

This conservative approach stabilizes training and generally leads to better generalization, not worse.

In [None]:
class LayerNorm:
    def __init__(self, n_hidden, eps = 1e-5):
        self.eps = eps
        # params trained with backprop
        self.gamma = torch.ones(n_hidden) # gain, scale
        self.beta = torch.zeros(n_hidden) # shift
        # no need buffers running_mean and running_var
    
    def __call__(self, x):
        # calculate forward pass
        dim = 0
        if x.ndim == 2:
            dim = 1
        elif x.ndim == 3:
            dim = (1,2)
        xmean = x.mean(dim, keepdim=True) # batch mean
        xvar = x.var(dim, keepdim=True) # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        # Stabilize the activations by scaling and shifting each individual feature channel
        self.out = self.gamma * xhat + self.beta
        return self.out
    def params (self):
        return [self.gamma, self.beta]
    
torch.manual_seed(1337)
bn = LayerNorm(100)
x = torch.rand(32, 100)
x = bn(x)

**BatchNorm assumes that "Channel 5" means the same thing for every example in the batch. In NLP, that assumption is false.**

Here is the deep dive into why BatchNorm fails for Transformers and why LayerNorm is the hero.

### 1. The Axis of Normalization (The Mental Model)
Let's look at our input tensor shape: $(B, T, C)$ or (Batch, Time, Channels).

* **Batch Normalization (Vertical Slice):**
    * It calculates the Mean/Variance across the **Batch Dimension ($B$)**.
    * It asks: *"What is the average value of Feature #1 across ALL sentences in this batch?"*
* **Layer Normalization (Horizontal Slice):**
    * It calculates the Mean/Variance across the **Channel Dimension ($C$)**.
    * It asks: *"What is the average value of the features within THIS single token?"*



### 2. Reason 1: The "Apples and Oranges" Problem
**BatchNorm works great for Images (CNNs).**
* In a picture, "Pixel (0,0)" is always the top-left corner.
* "Channel 1" might always be the "Red" channel.
* It makes sense to calculate the "Average Redness" across 32 images.

**BatchNorm fails for Text (Transformers).**
* In Sentence 1, Token #5 might be the word "King".
* In Sentence 2, Token #5 might be the word "and".
* Averaging the features of "King" with the features of "and" creates a meaningless statistic.
* **LayerNorm** doesn't care about other sentences. It just looks at "King" and normalizes the vector for "King" so it is numerically stable.

### 3. Reason 2: Variable Sequence Lengths (The Padding Nightmare)
Sentences have different lengths.
* Sentence A: 5 words.
* Sentence B: 20 words.
* To stack them in a batch, we pad Sentence A with zeros (or special tokens) to make it length 20.

**If you use BatchNorm:**
* You have to compute statistics across the Time dimension ($T$).
* The vast number of **Padding Zeros** will destroy your Mean and Variance calculations. You would be averaging valid data with garbage padding data, skewing the normalization.

**If you use LayerNorm:**
* It calculates statistics for *each token individually*.
* The normalization for the word "cat" is calculated using *only* the numbers inside the "cat" vector. The padding zeros 10 tokens away do not affect it.

### 4. Reason 3: Batch Size Independence
Transformers (especially large ones like GPT-3 or Llama) are massive.
* Sometimes, they are so big you can only fit **Batch Size = 1** or **2** on a GPU.
* **BatchNorm breaks with small batches:** You cannot calculate a reliable "Population Mean" if your population is just 1 or 2 samples. The noise is too high.
* **LayerNorm works perfectly with Batch Size = 1:** Since it normalizes *within* the example, it doesn't care how many examples are in the batch.

### Summary
| Feature | BatchNorm | LayerNorm |
| :--- | :--- | :--- |
| **Direction** | Across Batch ($B$) | Across Features ($C$) |
| **Assumption** | Features are consistent across samples. | Features are consistent within one sample. |
| **Dependency** | Depends on other samples in the batch. | **Independent** (Self-contained). |
| **Constraint** | Requires fixed sequence length (or complex masking). | Works with any sequence length. |
| **Verdict** | Great for Images (CNNs). | **Essential for Text (RNNs/Transformers).** |