## A mathematical trick in attention

Given a tensor of shape (B, T, C)

For example:
- B (batch_size) = 2
- T (sequence_length) = 3  
- C (channels/features) = 4

Example tensor:
```
x = [[  [1, 2, 3, 4],     # First sequence in batch 1
        [5, 6, 7, 8],     
        [9, 10, 11, 12]], # Last sequence in batch 1
      
      [ [13, 14, 15, 16], # First sequence in batch 2  
        [17, 18, 19, 20],
        [21, 22, 23, 24]]] # Last sequence in batch 2
```



In [2]:
import torch

In [3]:
# Create a random (B, T, C) tensor
B, T, C = 2, 3, 4
x = torch.randn(B, T, C)
print(x.shape)

torch.Size([2, 3, 4])


In [4]:
# We want mean[b, t] = mean(all i<=t)
mean = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        mean[b, t] = torch.mean(x[b, :t+1, :], dim=0)

In [5]:
# What we can do instead is matrix multiplication
ones_vec = torch.tril(torch.ones(T, T))
ones_vec = ones_vec/ones_vec.sum(dim=1, keepdim=True)

In [6]:
ones_vec

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [7]:
mean_matrix = ones_vec @ x

In [8]:
assert mean_matrix.shape == mean.shape
assert mean_matrix.allclose(mean)

In [9]:
# What we can do next is using softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=-1)
mean_softmax = wei @ x

In [10]:
assert mean_softmax.shape == mean.shape
assert mean_softmax.allclose(mean)

Why?
- wei is zero for now, but it can be weights learned to say how much each words attends to another
- we used masked_fill to add the constraint that words in the furture are no use to us, so the learned weights are no use to us in the respective rows, that's why we use a low traingular matrix, and set the places where it is 0 to -inf
- Then we softmax over the rest of the weights and it kind of tells us what we learnt, i.e how much attention to give to each word while giving it's contect to the nth word

## Self attention - Simple

In [14]:
import torch.nn as nn

In [12]:
torch.manual_seed(123)
B, T, C = 4, 8, 32

x = torch.randn(B, T, C)

In [13]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=-1)

out = wei @ x
out.shape

torch.Size([4, 8, 32])

In [20]:
# simple attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
v = value(x) # (B, T, head_size)

# (B, T, 16) @ (B, 16, T) -> (B, T, T)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.softmax(wei, dim=-1)

out = wei @ v
out.shape



torch.Size([4, 8, 16])