# Self Attention

Self attention is a mechanism that enhances the information content of an input embedding by including information about the input's context.


In [2]:
import torch 

In [1]:
sentence = "I love cakes, specially chocolate cakes. Lately, I want to try vanilla cake."

dict_words = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
dict_words 

{'I': 1,
 'Lately': 2,
 'cake.': 3,
 'cakes': 4,
 'cakes.': 5,
 'chocolate': 6,
 'love': 7,
 'specially': 8,
 'to': 9,
 'try': 10,
 'vanilla': 11,
 'want': 12}

In [3]:
# assign integer index to each word in the sentence
sentence_indices = torch.tensor([dict_words[word] for word in sentence.replace(',', '').split()])
sentence_indices

tensor([ 1,  7,  4,  8,  6,  5,  2,  1, 12,  9, 10, 11,  3])

In [7]:
embed = torch.nn.Embedding(13,16) 
embedded_sentence = embed(sentence_indices)
embedded_sentence.shape 

torch.Size([13, 16])

### Self attention (Scaled Dot Product Attention)

$W_q, W_k, W_v$ are weight matrices that are adjusted during model training.
Key, Query, Value sequences are obtained by matrix multiplication between the weight matrices **W** and input embeddings **x**.

In [8]:
embedded_sentence.shape 

torch.Size([13, 16])

In [10]:
# dimension of word, and the vectors
d_word = embedded_sentence.shape[1]
d_q = 16
d_k = 16
d_v = 20 

# initialize the weight matrices
W_q = torch.nn.Parameter(torch.randn(d_q, d_word))
W_k = torch.nn.Parameter(torch.randn(d_k, d_word))
W_v = torch.nn.Parameter(torch.randn(d_v, d_word))

W_q.shape, W_k.shape, W_v.shape

(torch.Size([16, 16]), torch.Size([16, 16]), torch.Size([20, 16]))

## Getting attention scores and final weighted values for second element of the sequence

In [13]:
# getting attention-vector for second input element
x_2 = embedded_sentence[1]
query_2 = torch.matmul(W_q, x_2)
key_2 = torch.matmul(W_k, x_2)
value_2 = torch.matmul(W_v, x_2)

query_2.shape, key_2.shape, value_2.shape

(torch.Size([16]), torch.Size([16]), torch.Size([20]))

In [16]:
W_k.shape, embedded_sentence.T.shape

(torch.Size([16, 16]), torch.Size([16, 13]))

In [14]:
# getting keys and values for the remaining sequence too
keys = W_k.matmul(embedded_sentence.T).T 
values = W_v.matmul(embedded_sentence.T).T
keys.shape, values.shape

(torch.Size([13, 16]), torch.Size([13, 20]))

In [None]:
# getting attention scores for the second element
attention_score_2 = query_2.matmul(key_2.T)

# the second element in keys refer to the keys corresponding to the second element 
attention_score_2, query_2.matmul(keys[1].T)

(tensor(27.5843, grad_fn=<DotBackward0>),
 tensor(27.5843, grad_fn=<DotBackward0>))

In [20]:
# attention score for second element with respect to all words 
attention_scores_2 = query_2.matmul(keys.T) 
attention_scores_2

tensor([ 60.7147,  27.5843, -27.5929, -17.2011,  -9.4265, -39.2114,  34.1458,
         60.7147, -38.7424,  22.7840, -62.3283, -22.1738, -63.6789],
       grad_fn=<SqueezeBackward4>)

In [None]:
# normalize with softmax, and scale with sqrt(d_k)
# change attention scores to probabilities
attention_weights_2 = torch.nn.functional.softmax(attention_scores_2/torch.sqrt(torch.tensor(d_k))) 
attention_weights_2

  attention_weights_2 = torch.nn.functional.softmax(attention_scores_2/torch.sqrt(torch.tensor(d_k)))


tensor([4.9959e-01, 1.2634e-04, 1.2904e-10, 1.7338e-09, 1.2110e-08, 7.0677e-12,
        6.5152e-04, 4.9959e-01, 7.9469e-12, 3.8048e-05, 2.1847e-14, 5.0016e-10,
        1.5587e-14], grad_fn=<SoftmaxBackward0>)

In [28]:
torch.nn.functional.softmax(attention_scores_2/torch.sqrt(torch.tensor(d_k)), dim=0) 

tensor([4.9959e-01, 1.2634e-04, 1.2904e-10, 1.7338e-09, 1.2110e-08, 7.0677e-12,
        6.5152e-04, 4.9959e-01, 7.9469e-12, 3.8048e-05, 2.1847e-14, 5.0016e-10,
        1.5587e-14], grad_fn=<SoftmaxBackward0>)

In [29]:
attention_weights_2.sum() 

tensor(1.0000, grad_fn=<SumBackward0>)

In [30]:
# weighted values
# input sequence with acquired attention
weighted_values_2 = attention_weights_2.matmul(values)
weighted_values_2.shape 

torch.Size([20])

In [None]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_word, d_q, d_k, d_v):
        super(SelfAttention, self).__init__()
        self.W_q = torch.nn.Parameter(torch.randn(d_q, d_word))
        self.W_k = torch.nn.Parameter(torch.randn(d_k, d_word))
        self.W_v = torch.nn.Parameter(torch.randn(d_v, d_word))
        
    def forward(self, x):
        query = torch.matmul(self.W_q, x)
        key = torch.matmul(self.W_k, x)
        value = torch.matmul(self.W_v, x)
        
        attention_scores = query.matmul(key.T)
        attention_weights = torch.nn.functional.softmax(attention_scores/torch.sqrt(torch.tensor(d_k)), dim=0)
        weighted_values = attention_weights.matmul(value.T)
        return weighted_values