### Explain Attention

In the previous notebooks we familiries ourselves with  tokens, embeddings, vocabulary, masking, and positional encoding. In this notebook we will get to know "Attention".

### Attention is Abstraction
The following is very important: From an application perspective, it allows higher layers in the architecture to operate on relations, grammar, and semantics rather than on raw words. Like how convolution allows the higher layers to operate on visual concepts rather than on raw pixels. Language is understood in a context. Words by themselves are symbols they get meaning when they are grouped. That is why we can take a paragraph and say "This paragraph talks about immigration" even if the word is not mentioned in the paragraph. This is the role of "attention" I would rather call it "abstraction" but well I didn't coin the term.

### The Simple Math Behind Attention
In previous articles, we see how words can be expressed as vectors in an n-dimensional space. But how can I express a sentence? A sentence is more than a collection of words. It includes grammar, time, action, and meaning. A sentence is a collection of words related to each other. If words are vectors in a space, then a sentence is the similarity matrix between those words. How to compute the similarity or "distance" between two vectors. By computing their dot product.
Consider the sentence S="The cat is on the mat". The visual concepts that pop into our minds are based on the sentence as a whole and how different parts relate to each other. Now each one of these words is represented as a vector that holds information about the word and its position (Remember we added positional encoding).

Let V be the raw  representation of the sentence S. V is a matrix of size (n,d). Where n is the number of tokens 6. And d the dimensionality of the word embedding. V is a collection of words that don't contain the interaction of words in it. To get the attended representation of S, we multiply V itself with the interaction between words or the similarity matrix.

$$
V_{attended} = (V.V^T)V
$$

In [13]:
import numpy as np

# S = ["Being", "Strong", "Is", "All", "What", "Matters"]
# (n) tokens = 6
# (d) embedding dim = 10 
# Each token is represented by a vector of 10 values 

V = np.random.rand(6, 10)
print(f"V shape: {V.shape}")

VVT = np.dot(V, V.T)
print(f"VVT shape: {VVT.shape}")

Vattended = np.dot(VVT, V)
print(f"V attended shape: {Vattended.shape}")

V shape: (6, 10)
VVT shape: (6, 6)
V attended shape: (6, 10)


### Trainable Attention[s]

Now the attended matrix that we get assumes the representation or the word embedding is perfect. The truth is the word embedding isn't we are training it to be perfect. Similarly, the attention must be trained. You see we will end eventually by multiple attentions because each attention learns something about the sentence. Some attentions learn grammatical rules, others learn semantical rules. we have a lot of rules. 

For some rules, the words should be represented differently. Some words should get priority. Other words should be dimmed. And so on. Hence  for attention, we introduce three weights W1, W2, and W3. We multiply the weights with the Vs to get V1, V2, and V3 to get the trainable attention. 

You see by adding weights and using multiple attentions (Multihead Attention) we allow our model to learn different sets of rules. By adding more attention layers, we allow our model to even learn more and more complex rules.

Note the V1, V2, and V3 happens to be called Q, K, and V.\
Also in practice we scale the similarity matrix QK^T and we apply softmax.

$$

A = Softmax(\dfrac{Q.K^T}{\sqrt{d_k}})V

$$

In [46]:
import torch
from torch import nn

class Attention(nn.Module):

    def __init__(self, sequence_size, input_dim, output_dim):
        super(Attention, self).__init__()
        self.sequence_size = sequence_size
        self.input_dim = input_dim
        self.wq = nn.Parameter(torch.rand(input_dim, output_dim, requires_grad=True))
        self.wk = nn.Parameter(torch.rand(input_dim, output_dim, requires_grad=True))
        self.wv = nn.Parameter(torch.rand(input_dim, output_dim, requires_grad=True))
        self.softmax = nn.Softmax(1)

    def forward(self, x):

        K = torch.matmul(x, self.wk)
        Q = torch.matmul(x, self.wq)
        V = torch.matmul(x, self.wv)

        similarity_matrix = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(torch.as_tensor(self.input_dim))
        similarity_matrix = self.softmax(similarity_matrix)

        return torch.matmul(similarity_matrix, V)

In [47]:
import numpy as np

# S = ["Being", "Strong", "Is", "All", "What", "Matters"]
# (n) tokens = 6
# (d) embedding dim = 10 
# Each token is represented by a vector of 10 values 

x = torch.rand(3, 6, 10)
print(f"x shape: {V.shape}")

attention = Attention(6, 10, 32)
y = attention(x)
print(f"y shape: {y.shape}")

x shape: (6, 10)
y shape: torch.Size([3, 6, 32])
