In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# consider the following example:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [8]:
# We want x[b,t] = mean_{i<=t} x[b, i]
# also, BOW => bag of words
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

Here, there's a fun function called `tril`, with is short for triangular lower, which we can wrap in a matrix of ones to get the following result!
If for whatever reason, we wanted to do an upper triangular matrix, that'd be as easy as calling a `triu`.

In [20]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [28]:
# version 2
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ------> (B, T, C)
torch.allclose(xbow, xbow2)

True

In [37]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

In [7]:
# version 4: self-attention!
# This is the big one boys, this is what's going to put AM radio back on the top
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)      # (B, T, 16)
q = query(x)    # (B, T, 16)

wei = q @ k.transpose(-2, -1) #(B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# out = wei @ x
v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

Recap on how this works:
- each token in each stage of wei knows what content it has as well as where it is located at
- all of the nodes (because it's a lower triangular matrix, that means only values that came before) emit keys; one of the keys will have the highest value & be selected to create a dot-product with a high affinity
- Due to whichever key emitted the highest value after the softmax, we'll end up aggregting a lot of its information into the current position (we'll get to learn a lot about it)
- the key says "here is where I currently am/have"
- the query says "here is what I am interested in"
- the value takes whichever matrix was found from querying the key, and says "here is what I have to communicate to you, *relevant info*"

Quick notes:
- Attention is a **communication mechanism**. It can be seen as nodes in a directed graph looking at ech other and aggregating informataion with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimensinos is of course processed completely independently and never "talk" to each other.
- In an "encoder" attention block just delete the line that does the masking with tril, allowing all tokens to communicate. Our example here is called a "decoder" attention block because it has tringular masking, and is usually used in autoregressive settings, like language modelling. 
    - Sometimes you want to be able to let a token "see the future" so that you can do things such as adjusting for the sentiment that is being written with.
    - One of the cool things about attention is that it doesn't care, attention supports arbitrary connection between all the nodes.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from `x`, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additionally divides `wei` by 1/sqrt(head_size) . This makes it so that when input `Q`, `K` are unit variance, `wei` will be unit variance too, and Softmax will stay diffuse and not saturate too much.

In [21]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * (head_size**-0.5)

In [22]:
k.var()

tensor(1.0652)

In [23]:
q.var()

tensor(0.9575)

In [24]:
wei.var()

tensor(0.9137)

In [19]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [17]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b # a multiply b with matrix multiplication
print(f'a = \n{a}\n-------')
print(f'b = \n{b}\n-------')
print(f'c = \n{c}')

a = 
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
-------
b = 
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
-------
c = 
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
