In [26]:
import torch 
import torch.nn as nn 
from torch.nn import functional as F

In [72]:
# The mathematical trick in self-attention 

# considering the following toy example 
torch.manual_seed(1337)
B,T,C = 4,8,2 # Batch, time, channels
x = torch.randn(B,T,C) # (Batch, Time, Channel)
x.shape 
# torch.Size([4,8,2])

# If we have 8 tokens, we want the token at 5tyh location to communicate with 
# those in 1,2,3,4 location, but not 6,7,8 because they are in the future.
# The information only flows from the previous context to the current timestamp,
# We do not want to use the future tokens because we want to predict them



torch.Size([4, 8, 2])

In [28]:
# This is the very weak form of gathering the info, but for now ok
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # x Bag of words, term used when you just averaging a bunch of words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b,t] = torch.mean(xprev, 0)

In [29]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [30]:
# The first one are equal, but the folowing are the averages of all the tokens before the last i
xbow[0]
# Nevermind, this is not very efficient

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [31]:
# Lower triangular portion 
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [32]:
# Toy example of doing it using matrix multiplication
torch.manual_seed(42)
# a = torch.ones(3,3)
# Basically here we do a sum over all the rows depending on the ones we have. 
# Very efficient method
a = torch.tril(torch.ones(3,3))
# But we can also do average, so we can weigth each element
a = a / torch.sum(a, 1, keepdim=True) # a in a oneth dimention
# So we can do avergaes in the incremental way
b = torch.randint(0,10, (3,2)).float()
c = a @ b 

print('a= ')
print(a)
print('b= ')
print(b)
print('-'*20)
print('c=')
print(c)

a= 
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b= 
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--------------------
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [33]:
# This is the very weak form of gathering the info, but for now ok
# We want x[b,t] = mean_{i<=t} x[b,i]
torch.manual_seed(666)
B,T,C = 4,8,2 # Batch, time, channels
x = torch.randn(B,T,C)

xbow = torch.zeros((B,T,C)) # x Bag of words, term used when you just averaging a bunch of words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b,t] = torch.mean(xprev, 0)

        
# Version of matrices ------------------------------------------------------
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
# wei ## Obtaining the weights
xbow2 = wei @ x # (T,T) @ (B, T, C) ---> (B,T,T) @ (B, T, C) ---> (B,T,C)
print(f"Using Mean: {torch.allclose(xbow, xbow2)}")

# Version using Softmax ------------------------------------------------------
# BASIC PREVIEW FOR THE SELF ATTENTION, WEIGHTET AGGREGATION OF YOUR PAST 
# ELEMENTS BY USING MATRIX MULTIPLICATION OF ALL THE PAST TOKENS, SO IT'S TELLING US
# HOW MUCH OF EACH ELEMENT FUZZES IN THE EACH POSITION
tril = torch.tril(torch.ones(T,T))
# This is more interesting because all the weights becomes zero, like infinity
# Says us how much tokens from the past we want to aggregate in our job
# These tokens are data dependent, these tokens will start looking at each other, 
# They will find other tokens more os less interesting
wei = torch.zeros((T,T))
# By setting them to negative infinity we basically saying that we will not aggregate 
# anithing from these tokens
# Basically here we are saying that the future can not communicate with the past

wei = wei.masked_fill(tril == 0, float('-inf')) # all the 0s in the tril, becomes infinity
# Here basically we will aggregate its values on how interesting are each other to every token 
wei = F.softmax(wei, dim=1) # Softmax in all the rows
xbow3 = wei @ x 
torch.allclose(xbow, xbow3)
print(f"Using Softmax: {torch.allclose(xbow, xbow3)}")

Using Mean: True
Using Softmax: True


In [34]:
print(xbow.shape, xbow2.shape)
xbow == xbow2

torch.Size([4, 8, 2]) torch.Size([4, 8, 2])


tensor([[[ True,  True],
         [ True,  True],
         [ True, False],
         [ True,  True],
         [ True,  True],
         [False, False],
         [False, False],
         [ True,  True]],

        [[ True,  True],
         [ True,  True],
         [False, False],
         [ True,  True],
         [False,  True],
         [False,  True],
         [False, False],
         [ True, False]],

        [[ True,  True],
         [ True,  True],
         [False,  True],
         [ True,  True],
         [False,  True],
         [ True, False],
         [ True, False],
         [False,  True]],

        [[ True,  True],
         [ True,  True],
         [ True, False],
         [ True,  True],
         [ True,  True],
         [False, False],
         [ True, False],
         [ True,  True]]])

In [35]:
xbow[0]

tensor([[-0.7747,  0.7926],
        [-0.3905,  0.1775],
        [-0.1051,  0.0556],
        [-0.3032,  0.1460],
        [-0.3193,  0.1246],
        [-0.1850,  0.0793],
        [-0.2222,  0.0631],
        [-0.1917,  0.0484]])

In [40]:
# SELF ATTENTION for the singular individual head 
torch.manual_seed(666)
B,T,C = 4,8,32 # Batch, time, channels
x = torch.randn(B,T,C)

# We do not want this to be all uniform, because different tokens 
# will find other different tokens and we won't to be independent
# I want to gather information from the past but I want to be in a data 
# dependent way and this is what self attention solves

# wei = torch.zeros((T,T))

# The way we get affinities between these tokens now is by using 
# a dot product between the keys and queries, so my query dot product with 
# all the keys of all the other tokens, and this dot product now becomes wei
# So if the key and a Query are self aligned, theu will interact in very hight 
# amount and i will get to learn more about that specific token as 
# oposed to any other token in a sequence

# Let's create a single HEAD perform self-attention
head_size = 16 
# The self Attention mechanism
# This will just perform a matriz multiply with a some weights
# out = inpuyt X W^t + b
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# Just produce this staff by forwarding x to each module
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
# Now the weights will be data dependent
wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # If we delete this, it will 
# communicate with all the tokens, so it will be called ENCODER block, if we leave
# this line, it will be a DECODER block, as we are making a masked trill op

wei = F.softmax(wei, dim=1) # Softmax in all the rows
# x will be like a private info for this token, x is private to this token,
# V is what it's aggregated for the purposes of this single head, 
# values Q and K
v = value(x)
out = wei @ v # V is the vector that we aggregate except of the raw x
# out = wei @ x 

# There is no notion of space, attencion only happens in a set of vectors
# from the graph, no node knows where is he, so we need to encode him a position 
# 
out.shape

# It is called self attention because all the keys, the queries and the values all 
# come from the same source which is x, so the same source x produces keys, 
# queries and values, so this nodes are self attended

# In other cases, in encoder- decoder transformers you can have for example 
# queryes are produced in x, but then the keys and the values could be produced 
# in a all separate source, sometimes from an ENCODER blocks, that encodes
# A CONTEXT THAT WE WANT TO CONDITION them. So ein this case we are just producing
# queries but we rig the information from the side. So cross attention is used 
# when the is a separate source of nodes we like to pull information from 
# into our nodes. And in self-attention is when we have nodes that talk to 
# each other in the same block



torch.Size([4, 8, 16])

In [55]:
wei[0]

tensor([[1.1400, 1.0236, 1.1396, 0.9841, 0.6756, 0.7669, 0.8734, 0.7898],
        [1.2575, 1.3668, 1.6682, 1.3182, 1.1020, 1.3072, 1.1687, 1.3637],
        [0.6605, 0.8261, 0.8798, 0.9264, 0.5227, 0.5701, 0.7271, 0.7566],
        [0.9049, 0.8630, 1.1582, 0.9967, 0.7957, 0.8470, 0.9116, 0.9257],
        [0.9325, 1.1924, 1.3057, 1.2844, 0.8200, 1.0326, 1.1037, 1.1075],
        [0.9765, 1.1116, 1.1894, 1.1841, 0.8487, 0.9106, 0.9587, 0.9484],
        [0.8570, 1.0748, 1.2791, 1.1276, 0.9151, 0.9637, 1.0107, 1.1278],
        [1.3683, 1.3410, 1.5229, 1.3319, 1.0144, 1.2033, 1.2456, 1.2996]])

In [54]:
# We also divide the before expresion by 1/sqrt(head_size), called 
# SCALED ATTENTION
# that's because it makes so when input Q,K are unit variance, wei will be
# unit variance too and Softmax will stay diffuse and not saturate too much.
# for example 
k = torch.rand(B,T,head_size)
q = torch.rand(B,T,head_size)

# So we can see that our variance could be over our hea_size, so we need
# simply to normalize it
wei = q @k.transpose(-2, -1) * head_size**-0.5
print(k.var())
print(q.var())
print(wei.var())
# print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=1))

tensor(0.0822)
tensor(0.0914)
tensor(0.0488)


In [67]:
f1 = torch.tensor([[1,2,3]])
f2 = torch.tensor([[3,4,5]])

print(torch.cat([f1,f2], dim=1))

tensor([[1, 2, 3, 3, 4, 5]])


In [65]:
5//2

2

In [68]:
batch_size = 32 # How many parallel sequences will we process?
block_size = 8 # What is the max context prediction length? 
max_iters = 5000 
n_embd = 32 # Number of embedding directions
eval_iters = 200
eval_interval = 300
learning_rate = 1e-3 # Self attention can't tolerate very hight learning rates 
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [71]:
t1 = torch.tensor([[1,1]])
t2 = torch.tensor([
    [[1,1]],
    [[1,1]],
    [[2,2]]
])
t1 +t2

tensor([[[2, 2]],

        [[2, 2]],

        [[3, 3]]])

In [None]:
0.73