# Self Attention for Transformers

In [5]:
import numpy as np
import math

L,d_k,d_v=4,8,8

Q,K,V= np.random.randn(L,d_k),np.random.randn(L,d_k),np.random.randn(L,d_v)

In [8]:
print("Q:",Q)
print("-"*100)
print("K:",K)
print("-"*100)
print("V:",V)

Q: [[-0.18893947 -1.35479006 -0.55866686  2.35425287 -0.36205667 -1.2607808
   0.506202    1.37869849]
 [-0.52296881 -0.09230961  2.36580814 -1.59494757  0.87851451 -0.79857993
  -0.64321244  0.20643123]
 [-1.2941705  -1.81415514 -0.21097989 -0.18675128 -0.54250129 -0.16289557
  -0.62879241 -0.54581436]
 [ 0.29802882 -1.50651087 -0.63559757  1.48080693  2.57330375  0.59962074
   1.02848052 -0.04844782]]
----------------------------------------------------------------------------------------------------
K: [[ 0.6025335  -0.92451763  1.05174224  0.97579833 -0.31447442 -1.02482474
   1.61228602 -0.19923093]
 [-0.7293687   0.25881528  1.25144275  0.60814213  1.38506764  0.25502577
   0.57491904 -1.67082083]
 [-0.27967393 -0.98981757  0.5300907  -1.97008895  0.60805254 -0.69601799
   1.37241535 -0.56077205]
 [-1.04147619 -0.52877561 -1.09745195 -0.26005618 -0.48479355  0.45135969
   0.22950975  0.5364034 ]]
------------------------------------------------------------------------------------

## Self Attention

$$
\text{self attention = softmax}(\frac{Q.K^T}{\sqrt{d_k}}+M)\\
\text{new V= self attention } . V
$$


## Query (Q) : What am i looking for? (Pasta) 
[sequence_length x $d_q$] 

## Key   (K) : What i can offer (Food)
[sequence_length x $d_k$] 

## Value (V) : What i actually offer (Rice)
[sequence_length x $d_v$] 

In [9]:
np.matmul(Q,K.T)

array([[ 4.79578735, -2.31579065, -2.96143478,  1.37620188],
       [ 0.16607304,  2.64670345,  4.7254054 , -2.41135221],
       [-0.07419865, -0.14569446,  1.64032857,  2.33962084],
       [ 2.59296501,  3.88720534,  0.73960197,  0.03184337]])

In [10]:
# Why do we need to divide by sqrt(d_k)?

Q.var(),K.var(),np.matmul(Q,K.T).var()

(1.2780324528413598, 0.8245605456576397, 5.415188798952857)

In [14]:
# But if we scale the dot product by 1/sqrt(d_k), we get a variance closer to main values.
scaled= np.matmul(Q,K.T)/math.sqrt(d_k)
scaled.var()

0.676898599869107

## Masking

In [16]:
mask=np.tril(np.ones((L,L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [17]:
mask[mask==0]=np.NINF
mask[mask==1]=0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [18]:
scaled+mask

array([[ 1.69556688,        -inf,        -inf,        -inf],
       [ 0.05871569,  0.93575098,        -inf,        -inf],
       [-0.02623318, -0.05151077,  0.57994373,        -inf],
       [ 0.91675157,  1.37433463,  0.26148879,  0.01125833]])

In [27]:
## Softmax

def softmax(x):
    return np.exp(x)/np.sum(np.exp(x),axis=1)


In [29]:
attention=softmax(np.array(scaled+mask))
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.19459181, 0.70620748, 0.        , 0.        ],
       [0.17874411, 0.26312977, 0.48140574, 0.        ],
       [0.4589494 , 1.09498148, 0.35011284, 0.11539788]])

In [30]:
new_v=np.matmul(attention,V)
new_v

array([[-0.61910506, -0.26590315,  0.26611101,  0.16678663,  0.28683136,
         0.37957868,  1.24123852, -0.38156662],
       [ 0.14985784, -0.82263091,  0.56954994, -0.39643973,  0.25291282,
        -1.53348781, -0.56940752,  0.90191236],
       [-0.80034977,  0.01275365,  0.37515154, -0.34526117,  0.31923347,
        -0.33158054, -0.63786402,  0.9902757 ],
       [-0.51332275, -0.94273973,  0.99188874, -0.78760627,  0.40728835,
        -2.02041027, -1.07214861,  1.69869887]])

## Function

In [31]:
def softmax(x):
    return np.exp(x)/np.sum(np.exp(x),axis=1)

def self_attention(Q,K,V, mask=False):
    scaled= np.matmul(Q,K.T)/np.sqrt(K.shape[-1])
    if mask:
        mask=np.tril(np.ones(scaled.shape))
        mask[mask==0]=np.NINF
        mask[mask==1]=0
        scaled+=mask
    attention=softmax(scaled)
    newV=np.matmul(attention,V)
    return newV,attention


In [36]:
values, attention=self_attention(Q,K,V, mask=True)
print("Q\n", Q)
print("-"*100)
print("K\n", K)
print("-"*100)
print("V\n", K)

Q
 [[-0.18893947 -1.35479006 -0.55866686  2.35425287 -0.36205667 -1.2607808
   0.506202    1.37869849]
 [-0.52296881 -0.09230961  2.36580814 -1.59494757  0.87851451 -0.79857993
  -0.64321244  0.20643123]
 [-1.2941705  -1.81415514 -0.21097989 -0.18675128 -0.54250129 -0.16289557
  -0.62879241 -0.54581436]
 [ 0.29802882 -1.50651087 -0.63559757  1.48080693  2.57330375  0.59962074
   1.02848052 -0.04844782]]
----------------------------------------------------------------------------------------------------
K
 [[ 0.6025335  -0.92451763  1.05174224  0.97579833 -0.31447442 -1.02482474
   1.61228602 -0.19923093]
 [-0.7293687   0.25881528  1.25144275  0.60814213  1.38506764  0.25502577
   0.57491904 -1.67082083]
 [-0.27967393 -0.98981757  0.5300907  -1.97008895  0.60805254 -0.69601799
   1.37241535 -0.56077205]
 [-1.04147619 -0.52877561 -1.09745195 -0.26005618 -0.48479355  0.45135969
   0.22950975  0.5364034 ]]
------------------------------------------------------------------------------------

In [37]:
print("New V\n", values)
print("Attention\n", attention)

New V
 [[-0.61910506 -0.26590315  0.26611101  0.16678663  0.28683136  0.37957868
   1.24123852 -0.38156662]
 [ 0.14985784 -0.82263091  0.56954994 -0.39643973  0.25291282 -1.53348781
  -0.56940752  0.90191236]
 [-0.80034977  0.01275365  0.37515154 -0.34526117  0.31923347 -0.33158054
  -0.63786402  0.9902757 ]
 [-0.51332275 -0.94273973  0.99188874 -0.78760627  0.40728835 -2.02041027
  -1.07214861  1.69869887]]
Attention
 [[1.         0.         0.         0.        ]
 [0.19459181 0.70620748 0.         0.        ]
 [0.17874411 0.26312977 0.48140574 0.        ]
 [0.4589494  1.09498148 0.35011284 0.11539788]]
