# **SELF-ATTENTION MECHANISM**

An attention mechanism is an Encoder-Decoder kind of neural network architecture that allows the model to focus on specific sections of the input while executing a task. It dynamically assigns weights to different elements in the input, indicating their relative importance or relevance. By incorporating attention, the model can selectively attend to and process the most relevant information, capturing dependencies and relationships within the data. This mechanism is particularly valuable in tasks involving sequential or structured data, such as natural language processing or computer vision, as it enables the model to effectively handle long-range dependencies and improve performance by selectively attending to important features or contexts.

In [1]:
import numpy as np
import math


In [2]:
L, dk, dv = 4, 8, 8

q = np.random.randn(L, dk)
k = np.random.randn(L, dk)
v = np.random.randn(L, dv)

# **Self attention**

```
Q = What I want
K = What I have
V = What I contribute
```

![attention](https://media.geeksforgeeks.org/wp-content/uploads/20240110170625/Scaled-Dot-Product-and-Multi-Head-Attentions.webp)




In [3]:
mmul = np.matmul(q, k.T)

In [4]:
q.var(), v.var(), mmul.var()

(0.5302327041201682, 1.0245439008600614, 2.5938049950396618)

In [5]:
mmul /= math.sqrt(dk)
mmul.var()

0.3242256243799577

### **Attention with Masking**

**DECODER**

We need each token to communicate only to the past not the future tokens

For now, for every single batch element independently, for each T'th token in that sequence, we will calculate the average of all the vectors of this and the previous tokens

In [6]:
import torch
import torch.nn.functional as F

**Method - 1 For loops**

In [7]:
torch.manual_seed(1337)
B, T, C = 1, 4, 4
x = torch.randn(B, T, C)
x.shape

torch.Size([1, 4, 4])

In [8]:
attention = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        attention[b, t] = torch.mean(xprev, 0)

In [9]:
(x.shape, attention.shape)

(torch.Size([1, 4, 4]), torch.Size([1, 4, 4]))

In [10]:
attention

tensor([[[ 0.1808, -0.0700, -0.3596, -0.9152],
         [ 0.4033, -0.0222,  0.2974, -0.4254],
         [ 0.3892,  0.3745, -0.2517, -0.4537],
         [ 0.3509,  0.2209, -0.4190,  0.0456]]])

But the above process is computationally heavy

For that, we could use a shortcut via matrix multiplication
- tril for generating a lower triangular matrix
- normalise each row such that sum(row) = 1
- matrix multiply this tri_matrix and the x to get xbow

**Method - 2 Matrix multiplication**

In [11]:
a = torch.ones(4, 4)
b = torch.randint(0, 10, (4, 4)).float()
c = a @ b
print(f'a => {a}\n----\nb => {b}\n----\nc => {c}')

a => tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
----
b => tensor([[1., 4., 9., 5.],
        [3., 6., 2., 0.],
        [2., 1., 6., 5.],
        [9., 4., 5., 9.]])
----
c => tensor([[15., 15., 22., 19.],
        [15., 15., 22., 19.],
        [15., 15., 22., 19.],
        [15., 15., 22., 19.]])


In [12]:
# This gives us the sum of the previous context for each of the tokens
torch.manual_seed(35)

a = torch.tril(torch.ones(4, 4)) # tril gives us the lower triangular matrix
b = torch.randint(0, 10, (4, 4)).float()

c = a @ b
print(f'a => {a}\n----\nb => {b}\n----\nc => {c}')

a => tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])
----
b => tensor([[7., 9., 8., 7.],
        [1., 8., 7., 6.],
        [3., 8., 5., 3.],
        [6., 2., 7., 3.]])
----
c => tensor([[ 7.,  9.,  8.,  7.],
        [ 8., 17., 15., 13.],
        [11., 25., 20., 16.],
        [17., 27., 27., 19.]])


In [13]:
# We can even get the average of the previous context for each of the tokens by normalising
torch.manual_seed(35)

a = torch.tril(torch.ones(4, 4))
a = a / torch.sum(a, 1, keepdim = True)
b = torch.randint(0, 10, (4, 4)).float()

c = a @ b
print(f'a => {a}\n----\nb => {b}\n----\nc => {c}')

a => tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])
----
b => tensor([[7., 9., 8., 7.],
        [1., 8., 7., 6.],
        [3., 8., 5., 3.],
        [6., 2., 7., 3.]])
----
c => tensor([[7.0000, 9.0000, 8.0000, 7.0000],
        [4.0000, 8.5000, 7.5000, 6.5000],
        [3.6667, 8.3333, 6.6667, 5.3333],
        [4.2500, 6.7500, 6.7500, 4.7500]])


In [14]:
torch.manual_seed(1337)
B, T, C = 1, 4, 4
x = torch.randn(B, T, C)

ws = torch.tril(torch.ones(T, T))
ws = ws / torch.sum(ws, 1, keepdim = True)

attention = ws @ x

(x, attention)

(tensor([[[ 0.1808, -0.0700, -0.3596, -0.9152],
          [ 0.6258,  0.0255,  0.9545,  0.0643],
          [ 0.3612,  1.1679, -1.3499, -0.5102],
          [ 0.2360, -0.2398, -0.9211,  1.5433]]]),
 tensor([[[ 0.1808, -0.0700, -0.3596, -0.9152],
          [ 0.4033, -0.0222,  0.2974, -0.4254],
          [ 0.3892,  0.3745, -0.2517, -0.4537],
          [ 0.3509,  0.2209, -0.4190,  0.0456]]]))



**Method - 3 Using softmax**

In [15]:
torch.manual_seed(1337)
B, T, C = 1, 4, 4
x = torch.randn(B, T, C)

tril = torch.tril(torch.ones(T, T))
w = tril.masked_fill(tril == 0, float('-inf'))

sm = F.softmax(w, dim = -1)
attention = sm @ x

(x[0], attention[0])

(tensor([[ 0.1808, -0.0700, -0.3596, -0.9152],
         [ 0.6258,  0.0255,  0.9545,  0.0643],
         [ 0.3612,  1.1679, -1.3499, -0.5102],
         [ 0.2360, -0.2398, -0.9211,  1.5433]]),
 tensor([[ 0.1808, -0.0700, -0.3596, -0.9152],
         [ 0.4033, -0.0222,  0.2974, -0.4254],
         [ 0.3892,  0.3745, -0.2517, -0.4537],
         [ 0.3509,  0.2209, -0.4190,  0.0456]]))

In [16]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis = 1)).T

In [17]:
mask = np.tril(np.ones((4, 4)))
mask[mask == 0] = -np.inf
mask[mask == 1] = 0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

**Self attention**

In the above, we have created the tril in a normal distribution, but we want the tokens to be in a context specific way, and we find those affinities by using a dot product between the query of the current token and the keys of the previous tokens and then matrix multiply with the value matrix to get a more context aware new_V

Attention(Q,K,V) = ${\frac{Q.K.trans}{\sqrt{dk}}}.V$

In [18]:
attention = softmax(mmul + mask)
attention


array([[1.        , 0.        , 0.        , 0.        ],
       [0.68380936, 0.31619064, 0.        , 0.        ],
       [0.27508722, 0.06704027, 0.65787251, 0.        ],
       [0.34092999, 0.17502176, 0.18444345, 0.2996048 ]])

In [19]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.86513062, -1.43718951,  0.35290122,  1.08955526,  0.88640486,
         0.6905449 ,  1.04516679,  0.05863759],
       [ 1.12439035, -0.91937306,  0.42884466,  0.91799703,  0.62062161,
         1.03561574,  0.38338553, -1.03196514],
       [ 0.82566275, -0.11224609,  0.88871722, -0.01261991,  0.87309642,
         0.83774284, -0.43442011,  0.39228697],
       [ 0.78805628, -0.01149165,  0.59020333,  0.07813057,  0.71829653,
         0.99530663, -0.10141863, -0.34828239]])

In [20]:
v

array([[ 0.86513062, -1.43718951,  0.35290122,  1.08955526,  0.88640486,
         0.6905449 ,  1.04516679,  0.05863759],
       [ 1.68507811,  0.20048216,  0.59308367,  0.54697681,  0.04582575,
         1.78188282, -1.04781502, -3.39055599],
       [ 0.72158097,  0.40990617,  1.14289334, -0.53051633,  0.95183428,
         0.80308091, -0.99059668,  0.91729068],
       [ 0.21725647,  1.22760415,  0.51830606, -0.97199399,  0.77607043,
         1.00094486, -0.30589572,  0.18677643]])

### **SUMMARY**

In [21]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis = 1)).T

def scaled_attention(q, k, v, mask = None):
  dk = q.shape[-1]
  mmul = np.matmul(q, k.T) / np.sqrt(dk)

  if mask is not None:
    # Decoder
    mmul = mmul + mask

  attention = softmax(mmul)
  new_v = np.matmul(attention, v)

  return new_v

In [22]:
q = np.random.randn(L, dk)
k = np.random.randn(L, dk)
v = np.random.randn(L, dv)

mask = np.tril(np.ones((4, 4)))
mask[mask == 0] = -np.inf
mask[mask == 1] = 0

In [23]:
scaled_attention(q, k, v, mask)

array([[ 0.81560115,  0.17641806, -0.02343507,  0.81003228,  0.11657641,
        -0.7274329 ,  0.46035725, -1.25286399],
       [ 0.19743813, -0.2320181 ,  0.18984669,  1.06292392,  0.41770658,
         0.25696126,  0.81283491, -0.76352402],
       [ 0.29775435, -0.27722671,  0.16728917,  0.96637163,  0.36491109,
         0.10536374,  0.64879437, -0.79355619],
       [-0.11198217, -0.03530606,  0.02255145,  0.96962907, -0.39601461,
         1.11035952, -0.02437963, -0.60231577]])