**Attention Mechanisms**

Simplified Self Attention

In [1]:
import torch

# example embedding vector for each word
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

Computing Attention scores

In [2]:
# query token is the token we want to find the most similar token to
query = inputs[1] # journey
attention_scores_2 = torch.empty(inputs.shape[0]) # creates an empty tensor of size 6
inputs.shape[0] # number of words in the sentence 6 (rows)

for i, x_i in enumerate(inputs): # i = index, x_i = word embedding
  attention_scores_2[i] = torch.dot(x_i, query) # dot product between query and each word embedding
  print(f"Similarity score between journey and {i+1}th word: {attention_scores_2[i]:.4f}")
print(f"Attention scores: {attention_scores_2}")

Similarity score between journey and 1th word: 0.9544
Similarity score between journey and 2th word: 1.4950
Similarity score between journey and 3th word: 1.4754
Similarity score between journey and 4th word: 0.8434
Similarity score between journey and 5th word: 0.7070
Similarity score between journey and 6th word: 1.0865
Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


Computing Attention Weigts

In [3]:
# Attention scores are normalized to get the attention weights x/sum(x)

attention_weights_2_tmp = attention_scores_2/attention_scores_2.sum() # normalization
print(attention_weights_2_tmp)
print(attention_weights_2_tmp.sum()) # should sum to 1

tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
tensor(1.0000)


In [4]:
# using softmax function e^x/sum(e^x)
# softmax ensures weights are positive and sum to 1

def softmax_naive(x):
    exp_x = torch.exp(x)
    return exp_x / exp_x.sum(dim=0)

attention_weights_2_naive = softmax_naive(attention_scores_2)
print(attention_weights_2_naive)
print(attention_weights_2_naive.sum()) # should sum to 1

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


In [5]:
# pytorch softmax function

attention_weights_2 = torch.nn.functional.softmax(attention_scores_2, dim=0)
print(attention_weights_2)
print(attention_weights_2.sum()) # should sum to 1

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


Computing Context Vector

In [6]:
# multiplying attention weights with word embeddings to get the context vector

query = inputs[1] # journey
context_vector_2 = torch.zeros(query.shape) # creates a tensor of zeros of size 3 (no. of input columns)
for i, x_i in enumerate(inputs): # i = index, x_i = word embedding
  context_vector_2 += attention_weights_2[i] * x_i # attention weight * correspinding word embedding
  print(f"Context vector after adding {i+1}th word, multiplying {attention_weights_2[i]} and word {x_i}: {context_vector_2}")

print(context_vector_2)

# this context vector is related to the query token journey

Context vector after adding 1th word, multiplying 0.13854756951332092 and word tensor([0.4300, 0.1500, 0.8900]): tensor([0.0596, 0.0208, 0.1233])
Context vector after adding 2th word, multiplying 0.2378913015127182 and word tensor([0.5500, 0.8700, 0.6600]): tensor([0.1904, 0.2277, 0.2803])
Context vector after adding 3th word, multiplying 0.23327402770519257 and word tensor([0.5700, 0.8500, 0.6400]): tensor([0.3234, 0.4260, 0.4296])
Context vector after adding 4th word, multiplying 0.12399158626794815 and word tensor([0.2200, 0.5800, 0.3300]): tensor([0.3507, 0.4979, 0.4705])
Context vector after adding 5th word, multiplying 0.10818186402320862 and word tensor([0.7700, 0.2500, 0.1000]): tensor([0.4340, 0.5250, 0.4813])
Context vector after adding 6th word, multiplying 0.15811361372470856 and word tensor([0.0500, 0.8000, 0.5500]): tensor([0.4419, 0.6515, 0.5683])
tensor([0.4419, 0.6515, 0.5683])


Computing Attention Scores for all Inputs

In [7]:
# computing attention scores for all words in the sentence (all queries)

attention_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
  for j, x_j in enumerate(inputs):
    attention_scores[i,j] = torch.dot(x_i, x_j)
    # print(f"Similarity score between {i+1}th and {j+1}th word: {attention_scores[i,j]:.4f}")
print(f"Attention scores: {attention_scores}")
# each row represents the similarity between the query token and all other tokens

Attention scores: tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [8]:
# or matmul

attention_scores = inputs @ inputs.T # 6*3 @ 3*6 = 6*6
print(f"Attention scores: {attention_scores}")

Attention scores: tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Computing Attention Weights for all inputs

In [9]:
# attention weights for all queries
# dim=-1 means the last dimension columns - normalize across columns
# so that rows sum to 1

attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)
print(f"Attention weights: {attention_weights}")

row2_sum = attention_weights[1].sum()
print(f"Sum of attention weights for the 2nd query (2nd row): {row2_sum}")

Attention weights: tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Sum of attention weights for the 2nd query (2nd row): 1.0


Computing Context Vectors for all inputs

In [10]:
print(attention_weights.shape)
inputs.shape

torch.Size([6, 6])


torch.Size([6, 3])

In [11]:
context_vectors = attention_weights @ inputs
print(f"Context vectors: {context_vectors}")

Context vectors: tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


**Self Attention with Trainable Weights**

In [12]:
x_2 = inputs[1]
d_in = inputs.shape[1] # embedding dimension
d_out = 2

In [13]:
# initializing weights for all tokens

torch.manual_seed(12)
W_query = torch.nn.Parameter(torch.randn(d_in, d_out, requires_grad=False)) # set to true while training
W_key = torch.nn.Parameter(torch.randn(d_in, d_out, requires_grad=False))
W_value = torch.nn.Parameter(torch.randn(d_in, d_out, requires_grad=False))

In [14]:
# computing query, key, and value vectors for the input query token

query2 = x_2 @ W_query
key2 = x_2 @ W_key
value2 = x_2 @ W_value
print(f"Query vector: {query2}")

Query vector: tensor([ 0.3535, -1.0834], grad_fn=<SqueezeBackward4>)


In [15]:
# all other key and value vectors for all other input token
keys = inputs @ W_key
values = inputs @ W_value
print(f"Keys: {keys}")
print(f"Values: {values}")

Keys: tensor([[-0.3644,  0.2614],
        [-0.3633, -0.0859],
        [-0.3725, -0.1181],
        [-0.1439, -0.0038],
        [-0.4338, -0.6641],
        [-0.0705,  0.3033]], grad_fn=<MmBackward0>)
Values: tensor([[-0.3183, -0.6830],
        [-0.8002, -1.3512],
        [-0.7756, -1.3522],
        [-0.5225, -0.7374],
        [-0.1137, -0.9896],
        [-0.7858, -0.7766]], grad_fn=<MmBackward0>)


Computing attention score using weights

In [16]:
# computing unscaled attention score for query input
# attn score is a dot product between query vector and key vector

keys2 = keys[1] #query token's key vector
attention_scores_22 = query2.dot(keys2)
print(attention_scores_22)


tensor(-0.0354, grad_fn=<DotBackward0>)


In [17]:
# unscaled attention score for all inputs - matrix multiplication

attention_scores_2 = query2 @ keys.T
attention_scores_2


tensor([-0.4121, -0.0354, -0.0038, -0.0467,  0.5661, -0.3535],
       grad_fn=<SqueezeBackward4>)

Computing attention weights

In [18]:
# normalize using softmax
# to scale attention scores - divide it by sqrt of dim_k
# scaling is done to avoid small gradients

dim_k = keys.shape[-1] #2 - embedding dimension of keys
attention_weights_2 = torch.softmax(attention_scores_2 / dim_k**0.5, dim=-1)
attention_weights_2

tensor([0.1254, 0.1637, 0.1674, 0.1624, 0.2504, 0.1307],
       grad_fn=<SoftmaxBackward0>)

In [19]:
# computing context vector for query
context_vector_2 = values.T @ attention_weights_2
print(context_vector_2)

tensor([-0.5168, -1.0022], grad_fn=<MvBackward0>)


In [20]:
# self attention class
import torch.nn as nn

class selfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))
    
    # forward pass, X = inputs
    def forward(self, x):
        # computing query, key, and value vectors
        query = x @ self.W_query
        key = x @ self.W_key
        value = x @ self.W_value
        # computing attention score
        attention_scores = query @ key.T
        # computing attention weights
        dim_k = key.shape[-1]
        attention_weights = torch.softmax(attention_scores / dim_k**0.5, dim=-1)
        # computing context vector
        context_vector = attention_weights @ value
        return context_vector
    


In [21]:
torch.manual_seed(12)
sa_v1 = selfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[-0.5008, -1.0070],
        [-0.5168, -1.0022],
        [-0.5159, -1.0028],
        [-0.5417, -0.9892],
        [-0.5124, -1.0075],
        [-0.5471, -0.9842]], grad_fn=<MmBackward0>)


In [22]:
# more optimized self attention class using nn.linear

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias= qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias= qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias= qkv_bias)

    def forward(self, X):
        query = self.W_query(X)
        key = self.W_query(X)
        value = self.W_value(X)
        attention_scores = query @ key.T
        attention_weights = torch.softmax(attention_scores/key.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ value
        return context_vector

In [23]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.5287, -0.0976],
        [-0.5293, -0.1053],
        [-0.5293, -0.1052],
        [-0.5287, -0.1072],
        [-0.5287, -0.1038],
        [-0.5288, -0.1080]], grad_fn=<MmBackward0>)


**Causal Attention**

To mask future tokens

In [25]:
queries = sa_v2.W_query(inputs) #  query vectors
keys = sa_v2.W_key(inputs) # key vectors
attention_scores = queries @ keys.T # attention scores
attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5, dim=-1) # attention weights
attention_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [26]:
# creating a mask - values above diagonal will become 0

context_length = attention_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length)) #sets lower triangle to 1 and upper to 0
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [27]:
# applying mask to attention weights - just multiply

attention_weights_masked = attention_weights * mask_simple
print(attention_weights_masked)

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)


In [28]:
# normalize so that rows sum to 1

row_sum = attention_weights_masked.sum(dim=-1, keepdim=True)
mask_simple_normalized = attention_weights_masked / row_sum
print(mask_simple_normalized)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)


In [32]:
# masking attention scores
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # upper triangle set to 1
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
print(mask)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)


In [33]:
# softmax 

attention_weights = torch.softmax(masked/keys.shape[-1]**0.5, dim=-1)
print(attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


Masking with Dropout