In [4]:
import torch
from torch.nn import functional as F
import torch.nn as nn

In [6]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

# Now we can look closer to scaled attention, we have query, key and value
* We multiply query and key, next we softmax it and aggregate the values
* To make scaled attention we use one more thing, we need to devide by one over √ of the head size (dk)
* We do this, because

In [7]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)

# The problem comes when we have unit gaussian inputs
* So k and q are unit gaussian, when we do wei without scaling, then we see that the wei variance will be on the order of head size (16)

In [8]:
k.var()

tensor(1.0449)

In [9]:
q.var()

tensor(1.0700)

In [10]:
wei.var()

tensor(17.4690)

# But when we do the scaling, variance will be 1, so will be preserved

In [11]:
wei = q @ k.transpose(-2 , -1) * head_size**-0.5

In [12]:
wei.var()

tensor(1.0918)

# This is important, because wei will feed into softmax
# Its important especially at initialization, that the wei be fairy defused
* But if wei takes very positive and negative numbers inside it, the softmax will converge towards one-hot vectors
* Once we are applying softmax to values that are very close to zero, then we are gonna diffuse thing out of softmax

In [13]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

# But once we do exact same thing, but start sharpening it, making it bigger by multiplying them by 8
* It will sharpen all numbers towards the max (the highest number)
* We don't want this values to be too extreme, especially at initialization, otherwise the softmax will be way to peaky
* So we don't want to aggregate the informations from a single node, because every node then aggregate informations from this single other node

# So scaling is used to control the variance at initialization

In [14]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

# Now we implemented scaled Dot-Product Attention