# Self Attention in Transformers

## Generate Data

In [1]:
# Imports
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.50598333  0.59287364 -0.64212808 -0.17667384  1.78683356 -0.03810015
  -2.79906259 -0.87436271]
 [-1.1635331   0.39761587  0.7502223  -0.19290089 -0.27755457 -0.7598107
   0.88142972 -0.44349095]
 [-0.28506127 -2.0400416  -1.96026798 -0.63174484  0.49923543 -0.57036319
   1.49095731  0.89803188]
 [-1.62443852  1.453917   -1.49314297  0.95953202 -1.60815407  0.1632695
  -0.98389182 -1.1222311 ]]
K
 [[-1.58265754  0.76202182  2.32828504  0.50024378 -0.38254118  2.16603257
  -0.53065346 -0.19447538]
 [ 0.59827253  0.12909976  0.10836157 -0.15652026 -0.33419576  0.25269797
   0.76630028  0.37939898]
 [-2.36337659 -0.03495578  1.25475903  0.10300347  1.11828814  0.16100326
   0.49260405  0.29632038]
 [-0.05618466 -0.58816675  0.46506258 -0.04566082  1.33489973 -0.43446506
   0.37468388 -0.33398102]]
V
 [[-2.33028699e+00  4.18506262e-01  1.57360937e+00  1.37308738e-01
   8.52094837e-01  6.58167848e-01 -4.12112209e-01 -2.52690046e-03]
 [-2.40477900e+00 -8.15979439e-01 -8.21925214e-01 

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [3]:
np.matmul(q, k.T)

array([[ 0.5584545 , -3.35153984,  0.70533024,  1.03421098],
       [ 1.87361595, -0.12535484,  3.52750914,  0.62719536],
       [-8.37572099,  0.62481134, -0.31270199,  1.5060436 ],
       [ 2.39156799, -1.69717588, -0.57566407, -3.71359891]])

The line `np.matmul(q, k.T)` is a **crucial step** in the **self-attention mechanism** of Transformers, specifically in the computation of the **attention scores**. 

### **What is Self-Attention?**
Self-attention helps a model focus on the most relevant parts of the input sequence for each token when processing sequences. It allows the model to compute relationships between tokens and dynamically weigh their importance.

Key components:
- **Query (q):** What the current token is "looking for."
- **Key (k):** Represents the "content" of all tokens in the sequence.
- **Value (v):** Holds the information to be extracted and passed along.

`np.matmul(q, k.T)`:
- Calculates similarity scores between query (`q`) and key (`k`) vectors.
- Encodes relationships between tokens in the input sequence.
- Forms the foundation of the self-attention mechanism. 

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.1902732818147692, 0.7955736609453732, 7.7092653269714475)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.1902732818147692, 0.7955736609453732, 0.9636581658714307)

Notice the reduction in variance of the product

In [6]:
scaled

array([[ 0.19744348, -1.18494828,  0.2493719 ,  0.3656488 ],
       [ 0.66242327, -0.04431963,  1.24716282,  0.22174705],
       [-2.96126456,  0.22090417, -0.11055685,  0.53246682],
       [ 0.84554697, -0.60004229, -0.20352798, -1.31295549]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [7]:
mask = np.tril(np.ones( (L, L) ))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [8]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0

In [9]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [10]:
scaled + mask

array([[ 0.19744348,        -inf,        -inf,        -inf],
       [ 0.66242327, -0.04431963,        -inf,        -inf],
       [-2.96126456,  0.22090417, -0.11055685,        -inf],
       [ 0.84554697, -0.60004229, -0.20352798, -1.31295549]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [11]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [12]:
attention = softmax(scaled + mask)

In [13]:
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.66968106, 0.33031894, 0.        , 0.        ],
       [0.02358547, 0.56838537, 0.40802916, 0.        ],
       [0.58776275, 0.13848114, 0.20587072, 0.06788539]])

In [14]:
new_v = np.matmul(attention, v)
new_v

array([[-2.33028699,  0.41850626,  1.57360937,  0.13730874,  0.85209484,
         0.65816785, -0.41211221, -0.0025269 ],
       [-2.35489311,  0.01073226,  0.78231893,  0.13865475,  0.42967217,
        -0.09690355, -0.09211864, -0.31554853],
       [-1.18941663,  0.02206074, -0.73501088, -0.16634756, -0.02023404,
        -0.85551669,  0.03142011, -0.69390214],
       [-1.52017524,  0.38293583,  0.67771783,  0.19193084,  0.50039236,
         0.14997879, -0.33905113, -0.30067545]])

In [15]:
v

array([[-2.33028699e+00,  4.18506262e-01,  1.57360937e+00,
         1.37308738e-01,  8.52094837e-01,  6.58167848e-01,
        -4.12112209e-01, -2.52690046e-03],
       [-2.40477900e+00, -8.15979439e-01, -8.21925214e-01,
         1.41383638e-01, -4.26737879e-01, -1.62771803e+00,
         5.56628997e-01, -9.50161423e-01],
       [ 5.69531565e-01,  1.16653634e+00, -7.47385138e-01,
        -6.12570059e-01,  4.95602990e-01,  1.32665094e-01,
        -6.74559168e-01, -3.76896324e-01],
       [ 9.61138903e-01,  1.44294896e-01,  3.01901776e-01,
         3.20771720e+00, -6.38903868e-01, -5.71129671e-01,
        -5.16128064e-01, -1.32604188e+00]])

# Function

In [16]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [17]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.50598333  0.59287364 -0.64212808 -0.17667384  1.78683356 -0.03810015
  -2.79906259 -0.87436271]
 [-1.1635331   0.39761587  0.7502223  -0.19290089 -0.27755457 -0.7598107
   0.88142972 -0.44349095]
 [-0.28506127 -2.0400416  -1.96026798 -0.63174484  0.49923543 -0.57036319
   1.49095731  0.89803188]
 [-1.62443852  1.453917   -1.49314297  0.95953202 -1.60815407  0.1632695
  -0.98389182 -1.1222311 ]]
K
 [[-1.58265754  0.76202182  2.32828504  0.50024378 -0.38254118  2.16603257
  -0.53065346 -0.19447538]
 [ 0.59827253  0.12909976  0.10836157 -0.15652026 -0.33419576  0.25269797
   0.76630028  0.37939898]
 [-2.36337659 -0.03495578  1.25475903  0.10300347  1.11828814  0.16100326
   0.49260405  0.29632038]
 [-0.05618466 -0.58816675  0.46506258 -0.04566082  1.33489973 -0.43446506
   0.37468388 -0.33398102]]
V
 [[-2.33028699e+00  4.18506262e-01  1.57360937e+00  1.37308738e-01
   8.52094837e-01  6.58167848e-01 -4.12112209e-01 -2.52690046e-03]
 [-2.40477900e+00 -8.15979439e-01 -8.21925214e-01 