In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

B = 4  # batch, independant sets processed in parallel for efficiency
T = 8  # time/block/context, window of tokens considered
C = 32 # in_channels, each token is "embedded" into a vector of size C
H = 16 # out_channels, seperate similar operations on same vectors

x = torch.randn(B,T,C)

## Masked Self Attention (Goal)

Attention(x) = V.σ(QK/sqrt(H)), x:(B,T,C), y: (B,T,H)

|         | Info                                                                                                                             |
| ------- |--------------------------------------------------------------------------------------------------------------------------------- |
| x       | private information of token                                                                                                     |  
| q       | each token generates a query vector , [I am a vowel in position 8 looking for consonents upto position 4]                        |  
| k       | each other token generates a key vector, what information I have [I am a consonent in position 3]                                |
| w=qk    | affinity - those two tokens find each other, affinity at the intersection will be very high (I am interested in these positions) |
| v       | value vector, what information I am willing to provide                                                                           |
| y=vσ(w) | accumulate all the information from interested positions to me                                                                   |


### Notes

- Attention is a communication mechanism
- A directed graph of T nodes, each node being a token position, and contains info as a vector of H size
- T nodes aggregate infromation as a weighted sum from all nodes
- Data dependant, the data stored in the nodes change with time
- For autogression,
  - 1-th node gets only from itself
  - 2-nd node gets from 1,2
  - T-th node gets from everyone
- Nodes have no notion of space / ordering. So we need to add postional embedding

In [2]:

key   = nn.Linear(C, H, bias=False)
query = nn.Linear(C, H, bias=False)
value = nn.Linear(C, H, bias=False)

# x          # (B,T,C)
k = key(x)   # (B,T,H)
q = query(x) # (B,T,H)
wei = q @ k.transpose(-2,-1) # (B,T,H) @ (B,H,T) = (B,T,T)

tril = torch.tril(torch.ones(T,T))               # for decoder: lower triangular matrix (including diagonal)
wei  = wei.masked_fill(tril==0, float('-inf'))   # Replace upper triangular of wei with -inf
wei /= H**0.5                                    # scaling, bring variance to 1, to prevent softmax clipping
wei  = F.softmax(wei, dim=-1)                    # -inf -> 0, rest normalized to 1

v = value(x)   # (B,T,H)
out = wei @ v  # (B,T,T) @ (B,T,T) = (B,T,H)

out.shape

torch.Size([4, 8, 16])

Encoder vs Decoder

- Encoder
  - no triangular mask (it gathers data from both directions)
  - eg. translation, sentiment analysis
- Decoder
  - triangular mask (gathers data from past & present only)
  - predicts next word
  - "autoregressive", P(next_word) = P(this_word|past_words) * P(prev_word|words_before)...

Attention = V*softmax(QK.T/sqrt(H))

- Self-Attention: Q,K,V come from X
- Cross-attention: 
  - query from x, keys & values come from different place. 
  - eg: English -> French, french (query) searches in English (key, value)