## In this note book we code  fourdifferent (from basic to complex) attention mechanisms

### A Simple self-attention mechanism without trainable weights

In [1]:
# let's consider this input sentence (embeddings of a sentence), we will implement and use a simple attention mechanism other input
import torch
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]  # step (x^6)
])

In [18]:
# we wannt to compute context vector for x2, so we will use x2 as query
query = inputs[1] # x^2
attn_scores_2 = torch.empty(inputs.shape[0]) # to store attention scores for each input word respect to x2
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) # score(x^i, x^2) = x^i . x^2

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [19]:
# then we ormalize the attentions score to get attentions weights that sum to 1
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum() # this is not the best way to do it, we will see better ways later
print("Attention weights:", attn_weights_2_tmp)
print("Sum of attention weights:", attn_weights_2_tmp.sum()) # should be 1

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum of attention weights: tensor(1.0000)


In [20]:
# In practice, it's better to use softmax function to get attention weights (avoid negative weights and best for gradient operations)
def softmax_naive(x):
    return torch.exp(x)/torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights (softmax naive):", attn_weights_2_naive)
print("Sum of attention weights (softmax naive):", attn_weights_2_naive.sum()) # should be 1

Attention weights (softmax naive): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of attention weights (softmax naive): tensor(1.)


In [21]:
#better use pytorch softmax which is suitable to encounder overflow, underflow or others like instability problems
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights (softmax):", attn_weights_2)
print("Sum of attention weights (softmax):", attn_weights_2.sum()) # should be 1

Attention weights (softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum of attention weights (softmax): tensor(1.)


In [None]:
#finallu we calculate the context vector z2 for x2 as the weighted sum of all input words
query = inputs[1] # x^2
context_vec_2 = torch.zeros(query.shape) # to store context vector for x^2
for i_, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i # z^2 = sum(score(x^i, x^2) * x^i)
print("Context vector for x2:", context_vec_2)

y = attn_weights_2 @ inputs # this is equivalent to the above for loop
print("Context vector for x2 (using @):", y)

#NB: using the matrice multiplication give more accurate result

Context vector for x2: tensor([0.4095, 0.5534, 0.5012])
Context vector for x2 (using @): tensor([0.4419, 0.6515, 0.5683])


In [None]:
#now computting attention weights and context vector for all inputs tokens

attn_scores = torch.zeros((inputs.shape[0], inputs.shape[0])) # to store attention scores for each input word respect to each input word
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j) # score(x^i, x^j) = x^i . x^j
print(attn_scores)

#or 
attn_scores = inputs@ inputs.T
print(attn_scores)


tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [15]:
#we normalize
atten_weights = torch.softmax(attn_scores, dim=-1) # normalize each row (-1 for normalize according last dimension)
print("Attention weights:", atten_weights)

#let's verifie the rows sum to 1
print("Sum of attention weights (rows):", atten_weights.sum(dim=-1)) # should be 1

Attention weights: tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Sum of attention weights (rows): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [17]:
#we compute all context vectors
all_context_vecs = atten_weights @ inputs
print("All context vectors:", all_context_vecs)

print("Previous 2nd context vector:", context_vec_2) # to check

All context vectors: tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])
Previous 2nd context vector: tensor([0.4095, 0.5534, 0.5012])


### A  self-attention mechanism with trainable weights (scaled dot-product attention)

#### Computing step by step

In [27]:
# we will start by computing context vector for the second word "journey" x^2
x_2 = inputs[1] #
d_in = inputs.shape[1] #The input embedding size, d=3
d_out = 2 #The output embedding size, d_out=2

#generally d_in= d_out but for comprehensiveness we will consider d_in != d_out

#we then initialize the weight matrices Wq, W_k and W_v

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # grad false because we arre not training here
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

query_2 = x_2@W_query # q^2 = x^2 W_q
key_2 = x_2@W_key # k^2 = x^2 W_k
value_2 = x_2@W_value # v^2 = x^2 W_v
print("Query vector for x2:", query_2)

Query vector for x2: tensor([1.0747, 1.3278])


In [29]:
# we need keys and values of all inputs to compute z2
keys = inputs@W_key # keys = X W_k
values = inputs@W_value # values = X W_v
print("Keys.shape:", keys.shape)
print("Values.shape:", values.shape)

Keys.shape: torch.Size([6, 2])
Values.shape: torch.Size([6, 2])


In [None]:
#we compute attentions scores

#first attention score w22
keys_2 = keys[1] # k^2
attn_scores_22 = query_2.dot(keys_2) # score(x^2) = q^2 . k^2
print("Attention score w22:", attn_scores_22)


#then all attention scores for x2
attn_scores_2 = query_2@ keys.T
print("Attention scores for x2:", attn_scores_2)


Attention score w22: tensor(1.6392)
Attention scores for x2: tensor([0.9369, 1.6392, 1.6356, 0.8910, 1.1108, 0.9853])


In [None]:
#now we calculate attention weights for input x2 
#the difference from previous implementation is that we scale the scores by sqrt(d_out)
attn_weights_2 = torch.softmax(attn_scores_2/ d_out**0.5, dim=-1) # scale by sqrt(d_out)
print(attn_weights_2)

#we scaled the scores to avoid very small gradients when d_out is large
# in fact the softmax function can produce very small gradients when the input values are large
# this is because the softmax function is sensitive to the scale of the input values
# by scaling the input values by sqrt(d_out), we can reduce the magnitude of the input values and thus improve the stability of the softmax function
# avoid training stagnation

tensor([0.1348, 0.2216, 0.2210, 0.1305, 0.1525, 0.1395])


In [34]:
#finally, we compute the context vector z2 # weighted sum of all value vectors by attention weights
context_vec_2 = attn_weights_2 @ values
print("Context vector for x2:", context_vec_2)

Context vector for x2: tensor([1.1735, 0.6411])


#### implementing a compact self-attention python class

In [35]:
#the compact class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in,d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x@self.W_key
        values = x@self.W_value
        queries = x@self.W_query
        attn_scores = queries@keys.T #omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights@values
        return context_vec
    

In [36]:
#testing
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [None]:
# improving the SelfAttention class to use nn.Linear which is more efficient and better for training
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)
        attn_scores = queries@keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights@values
        return context_vec
    


In [58]:
#testing
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

# we can see that the results are different from the previous bescause they use different initialisation scheme
# let's try to transfer the weights from the sa_v2 to sa_v1
sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)
sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
print(sa_v1(inputs)) # now the results should be the same

tensor([[-0.6433,  0.1075],
        [-0.6464,  0.1035],
        [-0.6464,  0.1035],
        [-0.6443,  0.1024],
        [-0.6449,  0.1029],
        [-0.6445,  0.1027]], grad_fn=<MmBackward0>)
tensor([[-0.6433,  0.1075],
        [-0.6464,  0.1035],
        [-0.6464,  0.1035],
        [-0.6443,  0.1024],
        [-0.6449,  0.1029],
        [-0.6445,  0.1027]], grad_fn=<MmBackward0>)


### Causal Attention or masked attention (hidding future word)

In [59]:
# the goal is to use only token which appears before or at current token position to compute the context vector
#we want to put to 0 all elements above the diagonales in the weight matrix and then do softmax to normalize
# 
# we will use weights from previous sections
# 
queries = sa_v2.W_query(inputs) 
keys = sa_v2.W_key(inputs)
attn_scores = queries@keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1827, 0.1681, 0.1682, 0.1577, 0.1660, 0.1573],
        [0.1675, 0.1732, 0.1730, 0.1607, 0.1623, 0.1634],
        [0.1675, 0.1731, 0.1730, 0.1607, 0.1623, 0.1634],
        [0.1650, 0.1705, 0.1703, 0.1642, 0.1640, 0.1660],
        [0.1661, 0.1713, 0.1711, 0.1630, 0.1635, 0.1650],
        [0.1657, 0.1707, 0.1706, 0.1636, 0.1638, 0.1655]],
       grad_fn=<SoftmaxBackward0>)


In [60]:
#we create the mask
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length)) # lower triangular matrix
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [61]:
# then we multiply the attention weight by the mask
masked_simple = atten_weights*mask_simple
print(masked_simple)

tensor([[0.2098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1385, 0.2379, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1390, 0.2369, 0.2326, 0.0000, 0.0000, 0.0000],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.0000, 0.0000],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.0000],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [62]:
#then we renormalize the weights so that they sum to 1 on each row
row_sums = masked_simple.sum(dim=-1, keepdim=True) # keepdim to keep the same shape for broadcasting
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3680, 0.6320, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2284, 0.3893, 0.3822, 0.0000, 0.0000, 0.0000],
        [0.2046, 0.2956, 0.2915, 0.2084, 0.0000, 0.0000],
        [0.1753, 0.2250, 0.2269, 0.1570, 0.2158, 0.0000],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [63]:
#another way to do it is to use a large negative value for the elements above the diagonal before applying softmax
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # upper triangular matrix
masked = attn_scores.masked_fill(mask.bool(), -torch.inf) # fill the elements above the diagonal with -inf
print(masked)

tensor([[0.2477,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1748, 0.2214,   -inf,   -inf,   -inf,   -inf],
        [0.1739, 0.2210, 0.2194,   -inf,   -inf,   -inf],
        [0.0730, 0.1196, 0.1184, 0.0668,   -inf,   -inf],
        [0.1086, 0.1515, 0.1502, 0.0814, 0.0858,   -inf],
        [0.0895, 0.1316, 0.1305, 0.0717, 0.0734, 0.0875]],
       grad_fn=<MaskedFillBackward0>)


In [64]:
#we apply the softmax to this matrix and we are done
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4918, 0.5082, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3261, 0.3371, 0.3368, 0.0000, 0.0000, 0.0000],
        [0.2462, 0.2544, 0.2542, 0.2451, 0.0000, 0.0000],
        [0.1990, 0.2051, 0.2049, 0.1952, 0.1958, 0.0000],
        [0.1657, 0.1707, 0.1706, 0.1636, 0.1638, 0.1655]],
       grad_fn=<SoftmaxBackward0>)
