In this notebook we will iterate on the bigram model and build up to attention mechanism. 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Mathematical trick for self attention:

We now go on a tangent to explore different implementations of this mathematical trick:
1. manual calculation
2. using tril
3. using -inf and softmax
4. 

In [3]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch_size, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

__How do we introduce interaction such that say 5th time character only sees 1:4 chars as context?__

One (__poor__) way of capturing this is taking an average along C dimension for chars 1:4 and using that as input to predict char 5 as output. <br>
We are losing a lot of information about spatial arrangement of chars 1:4 but for a start its ok!

In [8]:
# we want: x[b,t] = mean {i <=t} x[b,i]

xbow = torch.zeros((B,T,C))

# v1 - manual
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1] # (t,C)
        # print(x_prev, x_prev.shape)
        # print(x_prev.mean(dim = 0))
        xbow[b,t] = x_prev.mean(dim = 0) # dim 0 is along 't'

In [11]:
print(x[0], xbow[0])
# first row matches, every kth subsequent row in xbow is a mean of :k+1 rows of x

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]]) tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


(Super) Clutch trick to parallelize accumulation using matrix multiplication with lower triangular matrix:

Think of matrix multiplication from first principles!

In [5]:
torch.manual_seed(1667)
L = torch.tril(torch.ones(3,3))
# L = L.mean(dim = 1, keepdim=True)
U = torch.tril(torch.ones(3,3)).T

a = torch.randint(1,10,(3,2)).float()

print(L)
print('------')
print(a)
print('------')
print(L @ a)
print('------')
print(U)
print('------')
print(U@a) # accumulation along columns


tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
------
tensor([[9., 5.],
        [4., 5.],
        [7., 4.]])
------
tensor([[ 9.,  5.],
        [13., 10.],
        [20., 14.]])
------
tensor([[1., 1., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])
------
tensor([[20., 14.],
        [11.,  9.],
        [ 7.,  4.]])


See how the accumulation of rows of `a` happens across the rows of `L@a`? Further if we normalize `a` along the columns we get the mean accumulation in `L@a`. <br>
Now lets implement this same in out $(B,T,C)$ dimensional matrix:

In [6]:
# v2 - using lower tril matrix

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(dim = 1, keepdim= True) # normalize 

# weighted aggregation through matrix multi seen above
xbow2 = wei @ x # (T,T) @ (B,T,C) => pytorch adds batch dimension 

# verify
torch.allclose(xbow[0], xbow2[0]) #--  same

True

In [7]:
xbow2.shape

torch.Size([4, 8, 2])

`(T,T) @ (B,T,C)` $\implies$ pytorch adds batch dimension $\implies$ `(B,T,T) @ (B,T,C)` $\implies$ `(B) + (T,T @ T,C)` $\implies$ `(B,T,C)

<img src="images/transformer_visual.jpg" style="width:70%;">`

Parallelization is achieved along `B` and context accumulation along `T` dimension is also satisfied. 

In [8]:
# v3 - using softmax

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
print(wei)
wei = wei.softmax(dim =1)

xbow3 = wei @ x # braodcast along batch dimension: (B,T,T) @ (B,T,C)
torch.allclose(xbow3[0], xbow[0]) # verify similarity

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


True

Basically, we first replace all positions where `tril == 0` with `-inf` and then take softmax, which end up having a normalizing effect. 

This is interesting because: 

- `wei = torch.zeros((T,T))`: can be interpreted as an 'interaction strength' at initialization
- `wei = wei.masked_fill(tril == 0, float('-inf'))`: info from future tokens is masked <br>
Then other steps are accumulation. 

<span style="color:#FF0000; font-family: 'Bebas Neue'; font-size: 01em;">Summary of the above section:</span><br>
You can do weighthed aggregation of your past elements by using matrix multiplication with a lower triangular matrix. 

In [None]:
# self attention!

B,T,C = 4,8,32
x = torch.randn((B,T,C))

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros_like(tril) # loses information, not learnable
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = wei.softmax(dim = -1) # need softmax along 'C' dimension

out = wei @ x

out.shape

torch.Size([4, 8, 32])


The `wei` matrix doesnt have to be zeros. It can be learnt, based on how much weight each token assigns to the each of the previous token!

## Crux of attention mechanism:

__Problem that self attention solves:__ How to gather information from the past in a data dependent way?

Every single token has 2 attributes:
- `query` vector: What is being looked for?
- `key` vector: What information does it contain?

Crux: The key of a token dot products with query of other tokens and that is sotred in `wei`!!

In [None]:
# self attention!
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn((B,T,C))

# lets introduce a single head with self attention 
head_size = 16 # hyperparam
key = nn.Linear(C,head_size, bias=False)
query = nn.Linear(C,head_size, bias=False)

# forward x
k = key(x) #(B,T,16)
q = query(x) # (B,T,16)

#find wei
wei = q @ k.transpose(-2,-1) # (B,T,16) @ (B,16,T) --> (B,T,T)

print(f'Raw wei outputs:\n {wei[0]}')

tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros_like(tril)
wei = wei.masked_fill(tril == 0, float('-inf')) # to hide future token
wei = wei.softmax(dim = -1)

# we accumulate value(x), not x directly!
value = nn.Linear(C,head_size, bias=False)
out = wei @ value(x)

out.shape # (B,T, head_size = 16)

Raw wei outputs:
 tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)


torch.Size([4, 8, 16])

In [25]:
print((wei[0]), '\n',"Sum along columns:\n", wei[0].sum(dim = 1))

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>) 
 Sum along columns:
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)


### Interpretation of `wei`

The wei matrix is the attention weights - it's a $(B,T,T)$ tensor where `wei[b,i,j]` tells us __"how much should token i pay attention to token j" in batch b__
This dot product computes compatibility between queries and keys:

- If `q[i]` (what token i is looking for) is similar to `k[j]` (what token j offers), their dot product is high
- After softmax, high compatibility becomes high attention weight

Example:<br>
The dog chased the cat, and _it_ ran away $\implies$ `wei[position_of_"it", position_of_"cat"]` would be high

Now `wei` doesnt have a uniform structure i.e. we dont assign equal weight to each preceeding token. Now it is __variable and more importantly, can be learnt.__ Great explanation from the [master himself](https://youtu.be/kCc8FmEb1nY?si=DVwpogFwnz4z-9YY&t=4024)

<span style="color:#FF0000; font-family: 'Bebas Neue'; font-size: 01em;">Notes:</span><br>

- Attention is a `communication` mechanism. It can be seen in the below graph, as a data dependent (learnable) aggregation of weights from different tokens. 

<img src="images/attention_as_graph.png" style="width:40%;">` <br>
- There is no notion of space. Attention simply acts over space. Which is why positional encoding is important. (unlike convolution where space in built in)
- There is no communication across Batch dimension - allowing parallel ops


### Encoding vs Decoding

- Encoder blocks: allow all the tokens to communicate (future can also influence) $\implies$ masking step is removed, rest all remains the same.
- Decoder block (what we have implemented above): future tokens cant communicate with past ones. 

### Self vs cross attention

Key, query and value are __all__ applied on `x`; sometimes key may be outsourced to a different tensor before dot product is taken. This called cross product. 

### Scaling self-attention 

<img src="images/scaling_vaswani_et_al.png" style="width:60%;">` <br>


In the original paper, the authors propose scaling the the dot product of Q (query) and K (key) by $\frac{1}{\sqrt(d_k)}$, where $d_k = $ head size. This allows to preserve variance to 1 after the dot product, as illustrated below:

In [50]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)

wt1 = q @ k.transpose(-2,-1) # ~ head_size
wt2 = q @ k.transpose(-2,-1) * head_size**-0.5 # ~ 1

print(wt1.var(), wt2.var())

tensor(15.6638) tensor(0.9790)


Preserving variance close to 1 is necessary or else softmax will converge to a one hot encoding vector and weights for a particular node will be unduly high at initialization. 

## A transformer from first priciples:

A transformer is a neural network architecture that processes sequences of data (like text) __by learning relationships between all elements simultaneously, rather than one at a time__. Let me build this up from the ground up.
The Core Problem
__Traditional neural networks for sequences (like RNNs) process data sequentially - word by word, left to right__. This creates two major limitations: it's slow because you can't parallelize the computation, and it struggles to connect distant elements in long sequences due to the _vanishing gradient_ problem.
The Key Insight: Attention
The transformer's breakthrough is the attention mechanism. Instead of processing sequentially, attention lets the model look at all positions in the sequence simultaneously and decide which ones are most relevant to each other.
Think of it like this: when you read the sentence "The cat sat on the mat because it was comfortable," you instantly know "it" refers to "the cat" by considering the whole sentence context. __Attention mechanizes this intuition.__
How Attention Works
Self-attention computes three vectors for each word:

- Query (Q): "What am I looking for?"
- Key (K): "What do I represent?"
- Value (V): "What information do I contain?"

For each word, you compute similarity scores between its query and all other words' keys. These scores determine how much to "attend to" each word. You then take a weighted sum of all the value vectors based on these attention scores.
Mathematically: <br> 


Attention $(Q,K,V)$ = softmax $(QK^T/\sqrt(d))V$ <br>


The softmax ensures attention weights sum to 1, and âˆšd prevents the dot products from getting too large.

### Multi-Head Attention
Rather than having just one attention mechanism, transformers use multiple "heads" in parallel. __Each head learns different types of relationships - one might focus on grammatical dependencies, another on semantic similarity, etc.__ The outputs are concatenated and projected back to the original dimension.
### The Complete Architecture
A transformer has two main components: <br>
1. Encoder: Processes the input sequence and builds rich representations. Each encoder layer contains:

2. Multi-head self-attention (words attend to other words in the same sequence)
Feed-forward neural network
Residual connections and layer normalization around both

3. Decoder: Generates the output sequence. Each decoder layer has:

4. Masked self-attention (prevents looking at future words during training)
5. Cross-attention (attends to the encoder's output)
6. Feed-forward network
7. Residual connections and layer normalization

### Why This Works So Well
- Parallelization: Since attention looks at all positions simultaneously, you can compute everything in parallel rather than sequentially.
Long-range dependencies: Direct connections between any two positions mean the model can easily relate distant elements.
- Flexibility: The same architecture works for many tasks - translation, text generation, question answering - just by changing the training objective.
- Scalability: Transformers scale remarkably well with more data and parameters, leading to models like GPT and BERT.
The transformer essentially turned sequence modeling from a sequential, limited process into a massively parallel, globally-aware one. This architectural shift enabled the current era of large language models and has revolutionized not just NLP, but computer vision and other domains as well.

In [None]:
# Dimensional Flow Through Transformer
#Starting point <br>
input_ids: (B, T)                    # Raw token indices

#Step 1: Embeddings
token_emb: (B, T, n_embd)           # Token embeddings
pos_emb: (B, T, n_embd)             # Positional embeddings  
x = token_emb + pos_emb: (B, T, n_embd)

#Step 2: Multi-Head Attention
#Each head produces: (B, T, head_size)
#where head_size = n_embd // n_heads

head_1: (B, T, head_size)
head_2: (B, T, head_size)
...
head_n: (B, T, head_size)

#Concatenate all heads
concat_heads: (B, T, n_heads * head_size)
#Note: n_heads * head_size = n_embd (by design)

#Linear projection after concatenation
attn_output: (B, T, n_embd)

#Step 3: Feedforward Block  
ff_input: (B, T, n_embd)
ff_hidden: (B, T, d_ff)              # Typically d_ff = 4 * n_embd
ff_output: (B, T, n_embd)            # Back to original dimension

#Step 4: Final Output (after all layers)
final_hidden: (B, T, n_embd)
logits: (B, T, vocab_size)           # Linear projection to vocabulary