**Attention Mechanisms**

Simplified Self Attention

In [2]:
import torch

# example embedding vector for each word
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

Computing Attention scores

Dot product between query embedding and every other input embedding

In [7]:
# query token is the token we want to find the most similar token to
query = inputs[1] # journey
print(query)
attention_scores_2 = torch.empty(inputs.shape[0]) # creates an empty tensor of size 6
print(attention_scores_2)
print(inputs.shape[0]) # number of words in the sentence 6 (rows)

for i, x_i in enumerate(inputs): # i = index, x_i = word embedding
  print(f"Word embedding of {i+1}th word: {x_i}")
  attention_scores_2[i] = torch.dot(x_i, query) # dot product between query and each word embedding
  print(f"Similarity score between journey and {i+1}th word: {attention_scores_2[i]:.4f}")
  print(f"Attention scores: {attention_scores_2}")

tensor([0.5500, 0.8700, 0.6600])
tensor([0., 0., 0., 0., 0., 0.])
6
Word embedding of 1th word: tensor([0.4300, 0.1500, 0.8900])
Similarity score between journey and 1th word: 0.9544
Attention scores: tensor([0.9544, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
Word embedding of 2th word: tensor([0.5500, 0.8700, 0.6600])
Similarity score between journey and 2th word: 1.4950
Attention scores: tensor([0.9544, 1.4950, 0.0000, 0.0000, 0.0000, 0.0000])
Word embedding of 3th word: tensor([0.5700, 0.8500, 0.6400])
Similarity score between journey and 3th word: 1.4754
Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.0000, 0.0000, 0.0000])
Word embedding of 4th word: tensor([0.2200, 0.5800, 0.3300])
Similarity score between journey and 4th word: 0.8434
Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.0000, 0.0000])
Word embedding of 5th word: tensor([0.7700, 0.2500, 0.1000])
Similarity score between journey and 5th word: 0.7070
Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434,

Computing Attention Weigts

In [14]:
# Attention scores are normalized to get the attention weights x/sum(x)

attention_weights_2_tmp = attention_scores_2/attention_scores_2.sum() # normalization
print(attention_weights_2_tmp)
print(attention_weights_2_tmp.sum()) # should sum to 1

tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
tensor(1.0000)


In [15]:
# using softmax function e^x/sum(e^x)
# softmax ensures weights are positive and sum to 1

def softmax_naive(x):
    print("Attention scores: ",x)
    exp_x = torch.exp(x)
    print("Exponential of attention scores: ",exp_x)
    return exp_x / exp_x.sum(dim=0)

attention_weights_2_naive = softmax_naive(attention_scores_2)
print(attention_weights_2_naive)
print(attention_weights_2_naive.sum()) # should sum to 1

print(softmax_naive(torch.tensor([34,14,23])).sum())


Attention scores:  tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Exponential of attention scores:  tensor([2.5971, 4.4593, 4.3728, 2.3243, 2.0279, 2.9639])
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)
Attention scores:  tensor([34, 14, 23])
Exponential of attention scores:  tensor([5.8346e+14, 1.2026e+06, 9.7448e+09])
tensor(1.)


In [16]:
# pytorch softmax function

attention_weights_2 = torch.nn.functional.softmax(attention_scores_2, dim=0)
print(attention_weights_2)
print(attention_weights_2.sum()) # should sum to 1

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


Computing Context Vector

In [22]:
# multiplying attention weights with word embeddings to get the context vector

query = inputs[1] # journey
context_vector_2 = torch.zeros(query.shape) # creates a tensor of zeros of size 3 (no. of input columns)
print(context_vector_2)
for i, x_i in enumerate(inputs): # i = index, x_i = word embedding
  print(f"Word embedding of {i+1}th word: {x_i}")
  print(f"Attention weight of {i+1}th word: {attention_weights_2[i]}")
  context_vector_2 += attention_weights_2[i] * x_i # attention weight * correspinding word embedding
  print(f"Context vector after adding {i+1}th word, multiplying {attention_weights_2[i]} and word {x_i}: {context_vector_2}")
  print()
print(context_vector_2)

# this context vector is related to the query token journey

tensor([0., 0., 0.])
Word embedding of 1th word: tensor([0.4300, 0.1500, 0.8900])
Attention weight of 1th word: 0.13854756951332092
Context vector after adding 1th word, multiplying 0.13854756951332092 and word tensor([0.4300, 0.1500, 0.8900]): tensor([0.0596, 0.0208, 0.1233])

Word embedding of 2th word: tensor([0.5500, 0.8700, 0.6600])
Attention weight of 2th word: 0.2378913015127182
Context vector after adding 2th word, multiplying 0.2378913015127182 and word tensor([0.5500, 0.8700, 0.6600]): tensor([0.1904, 0.2277, 0.2803])

Word embedding of 3th word: tensor([0.5700, 0.8500, 0.6400])
Attention weight of 3th word: 0.23327402770519257
Context vector after adding 3th word, multiplying 0.23327402770519257 and word tensor([0.5700, 0.8500, 0.6400]): tensor([0.3234, 0.4260, 0.4296])

Word embedding of 4th word: tensor([0.2200, 0.5800, 0.3300])
Attention weight of 4th word: 0.12399158626794815
Context vector after adding 4th word, multiplying 0.12399158626794815 and word tensor([0.2200, 0

Computing Attention Scores for all Inputs

In [23]:
# computing attention scores for all words in the sentence (all queries)

attention_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
  for j, x_j in enumerate(inputs):
    attention_scores[i,j] = torch.dot(x_i, x_j)
    # print(f"Similarity score between {i+1}th and {j+1}th word: {attention_scores[i,j]:.4f}")
print(f"Attention scores: {attention_scores}")
# each row represents the similarity between the query token and all other tokens

Attention scores: tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [26]:
# or matmul
print(inputs)

attention_scores = inputs @ inputs.T # 6*3 @ 3*6 = 6*6
print(f"Attention scores: {attention_scores}")

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])
Attention scores: tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Computing Attention Weights for all inputs

In [28]:
# attention weights for all queries
# dim=-1 means the last dimension columns - normalize across columns
# so that rows sum to 1

attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)
print(f"Attention weights: {attention_weights}")

row2_sum = attention_weights[1].sum()
print(f"Sum of attention weights for the 2nd query (2nd row): {row2_sum}")

Attention weights: tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
Sum of attention weights for the 2nd query (2nd row): 1.0


Computing Context Vectors for all inputs

In [29]:
print(attention_weights.shape)
inputs.shape

torch.Size([6, 6])


torch.Size([6, 3])

In [30]:
context_vectors = attention_weights @ inputs
print(f"Context vectors: {context_vectors}")

Context vectors: tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


**Self Attention with Trainable Weights**

In [36]:
x_2 = inputs[1]
d_in = inputs.shape[1] # embedding dimension
d_out = 2
print(d_in, d_out)

3 2


In [38]:
# initializing weights for all tokens

torch.manual_seed(12)
W_query = torch.nn.Parameter(torch.randn(d_in, d_out, requires_grad=False)) # set to true while training
W_key = torch.nn.Parameter(torch.randn(d_in, d_out, requires_grad=False))
W_value = torch.nn.Parameter(torch.randn(d_in, d_out, requires_grad=False))
print(W_query, W_key, W_value)

Parameter containing:
tensor([[-0.2138, -1.3780],
        [-0.0546,  0.4515],
        [ 0.7858, -1.0884]], requires_grad=True) Parameter containing:
tensor([[-0.5599, -0.9336],
        [ 0.0479, -0.0844],
        [-0.1471,  0.7590]], requires_grad=True) Parameter containing:
tensor([[ 0.1466, -1.0041],
        [-0.7882, -0.8074],
        [-0.2957, -0.1462]], requires_grad=True)


In [41]:
# computing query, key, and value vectors for the input query token

query2 = x_2 @ W_query
key2 = x_2 @ W_key
value2 = x_2 @ W_value
print(f"Query vector: {query2}")

Query vector: tensor([ 0.3535, -1.0834], grad_fn=<SqueezeBackward4>)


In [42]:
# all other key and value vectors for all other input token
keys = inputs @ W_key
values = inputs @ W_value
print(f"Keys: {keys}")
print(f"Values: {values}")

Keys: tensor([[-0.3644,  0.2614],
        [-0.3633, -0.0859],
        [-0.3725, -0.1181],
        [-0.1439, -0.0038],
        [-0.4338, -0.6641],
        [-0.0705,  0.3033]], grad_fn=<MmBackward0>)
Values: tensor([[-0.3183, -0.6830],
        [-0.8002, -1.3512],
        [-0.7756, -1.3522],
        [-0.5225, -0.7374],
        [-0.1137, -0.9896],
        [-0.7858, -0.7766]], grad_fn=<MmBackward0>)


Computing attention score using weights

In [43]:
# computing unscaled attention score for query input
# attn score is a dot product between query vector and key vector

keys2 = keys[1] #query token's key vector
attention_scores_22 = query2.dot(keys2)
print(attention_scores_22)


tensor(-0.0354, grad_fn=<DotBackward0>)


In [44]:
# unscaled attention score for all inputs - matrix multiplication

attention_scores_2 = query2 @ keys.T
attention_scores_2


tensor([-0.4121, -0.0354, -0.0038, -0.0467,  0.5661, -0.3535],
       grad_fn=<SqueezeBackward4>)

Computing attention weights

In [45]:
# normalize using softmax
# to scale attention scores - divide it by sqrt of dim_k
# scaling is done to avoid small gradients

dim_k = keys.shape[-1] #2 - embedding dimension of keys
attention_weights_2 = torch.softmax(attention_scores_2 / dim_k**0.5, dim=-1)
attention_weights_2

tensor([0.1254, 0.1637, 0.1674, 0.1624, 0.2504, 0.1307],
       grad_fn=<SoftmaxBackward0>)

In [46]:
# computing context vector for query
context_vector_2 = values.T @ attention_weights_2
print(context_vector_2)

tensor([-0.5168, -1.0022], grad_fn=<MvBackward0>)


In [47]:
# self attention class
import torch.nn as nn

class selfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))
    
    # forward pass, X = inputs
    def forward(self, x):
        # computing query, key, and value vectors
        query = x @ self.W_query
        key = x @ self.W_key
        value = x @ self.W_value
        # computing attention score
        attention_scores = query @ key.T
        # computing attention weights
        dim_k = key.shape[-1]
        attention_weights = torch.softmax(attention_scores / dim_k**0.5, dim=-1)
        # computing context vector
        context_vector = attention_weights @ value
        return context_vector
    


In [48]:
torch.manual_seed(12)
sa_v1 = selfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[-0.5008, -1.0070],
        [-0.5168, -1.0022],
        [-0.5159, -1.0028],
        [-0.5417, -0.9892],
        [-0.5124, -1.0075],
        [-0.5471, -0.9842]], grad_fn=<MmBackward0>)


In [49]:
# more optimized self attention class using nn.linear

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias= qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias= qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias= qkv_bias)

    def forward(self, X):
        query = self.W_query(X)
        key = self.W_query(X)
        value = self.W_value(X)
        attention_scores = query @ key.T
        attention_weights = torch.softmax(attention_scores/key.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ value
        return context_vector

In [50]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.5287, -0.0976],
        [-0.5293, -0.1053],
        [-0.5293, -0.1052],
        [-0.5287, -0.1072],
        [-0.5287, -0.1038],
        [-0.5288, -0.1080]], grad_fn=<MmBackward0>)


**Causal Attention**

To mask future tokens

In [51]:
queries = sa_v2.W_query(inputs) #  query vectors
keys = sa_v2.W_key(inputs) # key vectors
attention_scores = queries @ keys.T # attention scores
attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5, dim=-1) # attention weights
attention_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [52]:
# creating a mask - values above diagonal will become 0

context_length = attention_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length)) #sets lower triangle to 1 and upper to 0
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [53]:
# applying mask to attention weights - just multiply

attention_weights_masked = attention_weights * mask_simple
print(attention_weights_masked)

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)


In [54]:
# normalize so that rows sum to 1

row_sum = attention_weights_masked.sum(dim=-1, keepdim=True)
mask_simple_normalized = attention_weights_masked / row_sum
print(mask_simple_normalized)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)


In [56]:
# masking attention scores
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) # upper triangle set to 1
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
print(mask)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)


In [57]:
# softmax 

attention_weights = torch.softmax(masked/keys.shape[-1]**0.5, dim=-1)
print(attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


Masking with Dropout

In [59]:
#dropping out attention weights

torch.manual_seed(12)
dropout = nn.Dropout(p=0.5) # 50% dropout rate
ex = torch.randn(6,6)
print(ex)
print(dropout(ex)) # randomly sets 50% of the values to 0

# the rest of the values are scaled by 1/(1-p) = 1/0.5 = 2

tensor([[-0.1320, -0.1254,  0.3443, -0.4519, -0.8888, -0.3526],
        [-1.3373,  0.5223, -1.1118, -0.7171,  1.0426, -1.2510],
        [-0.5107, -0.3843, -0.4899,  0.5306, -0.4929, -0.2625],
        [-0.1610, -0.8372, -1.0828, -0.6006,  0.0555,  0.7082],
        [-0.8102,  0.5724,  0.6928,  0.5124, -0.9402,  0.1808],
        [-0.5538,  1.5044,  1.3942, -0.1758, -0.5595, -0.0454]])
tensor([[-0.2640, -0.0000,  0.6886, -0.9038, -0.0000, -0.0000],
        [-2.6745,  1.0446, -2.2236, -1.4341,  0.0000, -0.0000],
        [-0.0000, -0.7686, -0.9797,  1.0611, -0.0000, -0.5251],
        [-0.0000, -1.6744, -0.0000, -1.2012,  0.1110,  1.4164],
        [-0.0000,  1.1449,  0.0000,  1.0248, -0.0000,  0.3615],
        [-0.0000,  0.0000,  0.0000, -0.0000, -1.1189, -0.0000]])


In [60]:
torch.manual_seed(12)
dropout = nn.Dropout(p=0.5)
print(dropout(attention_weights))

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0335, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6816, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4889, 0.5090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4120, 0.4116, 0.0000, 0.3906, 0.0000],
        [0.0000, 0.0000, 0.3413, 0.3308, 0.3249, 0.3363]],
       grad_fn=<MulBackward0>)


In [63]:
# creating a batch of inputs by stacking 

inputs_batch = torch.stack([inputs, inputs])
inputs_batch.shape #3D tensor

torch.Size([2, 6, 3])

In [64]:
# creating a causal attention class - adding dropout and masking

class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, qkv_bias=False, dropout=0.1):
        super().__init__()
        # initializing query, key, and value weights
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)

        # creating buffer for mask
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    
    def forward(self, X):
        batch, num_tokens, d_in = X.shape
        # computing query, key, and value vectors
        query = self.W_query(X)
        key = self.W_key(X)
        value = self.W_value(X)
        attention_scores = query @ key.transpose(1,2) # only transposing the last two dimensions
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -float('inf'))
        attention_weights = torch.softmax(attention_scores/key.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights @ value
        return context_vector

In [65]:
torch.manual_seed(12)
context_length = inputs_batch.shape[1] #6
csa = CausalSelfAttention(d_in, d_out, context_length)
context_vecs = csa(inputs_batch)
print(context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


**Multi-Head Attention**

In [66]:
# stacking multiple single head self attention layer

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_in, d_out, context_length, qkv_bias=False, dropout=0.0):
        super().__init__()
        # creating a list of single head self attention layers
        self.heads = nn.ModuleList([CausalSelfAttention(d_in, d_out, context_length, qkv_bias, dropout) for _ in range(num_heads)])
    

    def forward(self, X):
        # concatenating context vectors from all heads
        # combining the different views of the input
        return torch.cat([head(X) for head in self.heads], dim=-1)
        

In [67]:
torch.manual_seed(12)
context_length = inputs_batch.shape[1] #6
d_in, d_out = 3, 2 # embedding dimension, output dimension
num_heads = 2
mha = MultiHeadAttention(num_heads, d_in, d_out, context_length)
context_vecs = mha(inputs_batch)

print(context_vecs)
print(context_vecs.shape)

tensor([[[ 0.1215,  0.2512, -0.1501, -0.4825],
         [ 0.1778,  0.3864,  0.0274, -0.5587],
         [ 0.1942,  0.4181,  0.0845, -0.5817],
         [ 0.1771,  0.4007,  0.1062, -0.5214],
         [ 0.1649,  0.2884,  0.0923, -0.4825],
         [ 0.1652,  0.3542,  0.1147, -0.4734]],

        [[ 0.1215,  0.2512, -0.1501, -0.4825],
         [ 0.1778,  0.3864,  0.0274, -0.5587],
         [ 0.1942,  0.4181,  0.0845, -0.5817],
         [ 0.1771,  0.4007,  0.1062, -0.5214],
         [ 0.1649,  0.2884,  0.0923, -0.4825],
         [ 0.1652,  0.3542,  0.1147, -0.4734]]], grad_fn=<CatBackward0>)
torch.Size([2, 6, 4])


Weight Splits

In [69]:
# implementing one class for both single and multi head self attention

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_in, d_out, context_length, qkv_bias=False, dropout=0.0):
        super().__init__()
        # creating a list of single head self attention layers
        assert (d_out % num_heads) == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # dimension of each head

        # setting weights for query, key, and value for all heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # combining head outputs
        self.out_projection = nn.Linear(d_out, d_out)

        # adding dropout layer
        self.dropout = nn.Dropout(dropout)

        # creating buffer for mask
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))


    def forward(self, X):
        batch, num_tokens, d_in = X.shape #3D vector

        # computing query, key, and value vectors
        query = self.W_query(X)
        key = self.W_key(X)
        value = self.W_value(X)

        # splitting query, key, and value vectors for all heads - changing dimensions
        query = query.view(batch, num_tokens, self.num_heads, self.head_dim)
        key = key.view(batch, num_tokens, self.num_heads, self.head_dim)
        value = value.view(batch, num_tokens, self.num_heads, self.head_dim)

        # transposing to get the right dimensions
        keys = key.transpose(1,2)
        queries = query.transpose(1,2)
        values = value.transpose(1,2)


        # computing attention scores
        attention_scores = queries @ keys.transpose(2,3)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        # computing attention weights
        attention_weights = torch.softmax(attention_scores/keys.shape[-1]**0.5, dim=-1)

        # applying dropout to attention weights
        attention_weights = self.dropout(attention_weights)
        

        # computing context vectors
        context_vec = (attention_weights @ values).transpose(1,2)
        context_vec = context_vec.contiguous().view(batch, num_tokens, self.d_out)
        

        # combining head outputs
        context_vec = self.out_projection(context_vec)
        
        return context_vec

In [70]:
torch.manual_seed(12)
batch_size, context_length, d_in = inputs_batch.shape
d_out = 2
num_heads = 2
mha = MultiHeadAttention(num_heads, d_in, d_out, context_length)

context_vecs = mha(inputs_batch)
print(context_vecs)
print(context_vecs.shape)

tensor([[[-0.5683,  0.3132],
         [-0.6078,  0.2420],
         [-0.6191,  0.2256],
         [-0.6091,  0.2322],
         [-0.5941,  0.3001],
         [-0.5991,  0.2586]],

        [[-0.5683,  0.3132],
         [-0.6078,  0.2420],
         [-0.6191,  0.2256],
         [-0.6091,  0.2322],
         [-0.5941,  0.3001],
         [-0.5991,  0.2586]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 2])
