In [2]:
import numpy as np
import math

In [3]:
L, d_k, d_v = 4,8,8
q = np.random.randn(L,d_k)
k = np.random.randn(L,d_k)
v = np.random.randn(L,d_v)

In [4]:
q

array([[-1.43314657, -0.68958429, -0.86108124,  0.09508489,  1.4236755 ,
        -0.18503826, -2.06543922, -0.72425694],
       [ 0.01381426, -0.05797834,  0.17358011,  0.18137392,  1.40001165,
        -1.11316294,  1.30004025, -1.4833429 ],
       [ 1.2015744 ,  0.88562765, -1.26675785,  1.90983526,  1.51828946,
        -0.6215344 , -0.84381093,  0.27898317],
       [-0.50455575,  2.19433688, -1.04592851, -0.0990632 , -2.91957275,
        -1.24198651, -0.4260315 ,  0.54103664]])

In [5]:
k

array([[-0.89573882,  0.59575872,  0.62932483,  1.3706639 ,  1.99202557,
        -1.81906467,  0.59556153, -1.43379392],
       [ 0.70779064, -0.98869881,  0.46194471,  0.3435982 ,  1.04767303,
        -0.50423831, -0.84448408, -0.33698132],
       [-1.21415259, -0.18845878, -0.56447991, -1.59124802,  0.28647738,
         2.07454014, -0.57179019,  1.48130302],
       [ 0.67573788, -0.20859519, -0.94975811,  0.71722781, -0.29003071,
         1.12675774, -0.71884537, -0.08917012]])

In [6]:
v

array([[-1.05800722, -0.96276342, -0.95349055,  0.0885535 ,  1.96479734,
        -0.71011295,  0.28629564,  1.98846157],
       [ 0.43508605,  0.50894494,  0.26250843,  0.83530551, -0.35779131,
        -0.87981717, -0.59658486,  1.56959687],
       [-0.41695546, -0.52716846, -0.07328311,  0.58373222, -0.25523755,
         2.34150665, -0.25622958,  1.02477756],
       [-0.55918959, -0.07295214,  0.97466163,  0.21109998, -0.05923678,
        -0.62107704, -1.1790067 , -0.12794776]])

# Self Attention

In [8]:
q.var(), k.var(), (q @ k.T).var()

(1.3856409547521307, 1.0045444160234271, 13.552189614150734)

as you can see the variance of the matrix-matrix multiplication is much higher than the varaince of each matrix on its own 
so that's why we need to scale the multiplication

In [9]:
scaled = (q @ k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.3856409547521307, 1.0045444160234271, 1.6940237017688413)

## Masking

In [10]:
mask = np.tril(np.ones((L,L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [11]:
mask[mask==0] = -np.infty
mask[mask==1] = 0

In [12]:
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [13]:
scaled+mask

array([[ 1.21702355,        -inf,        -inf,        -inf],
       [ 2.83753546,  0.57970533,        -inf,        -inf],
       [ 1.59961718,  0.9081136 , -1.38185007,        -inf],
       [-1.28015972, -1.87344409, -0.50232767, -0.06045517]])

## Softmax

In [31]:
np.sum(scaled,axis=-1,keepdims=True)

array([[ 3.40966096],
       [ 0.68879402],
       [ 2.05966313],
       [-3.71638665]])

In [28]:
def softmax(x):
    return np.exp(x)/np.sum(np.exp(x),axis=-1,keepdims=True)

In [35]:
attention_encoder = softmax(scaled)
attention_encoder

array([[0.34304923, 0.28075445, 0.23207798, 0.14411834],
       [0.87860741, 0.09188227, 0.00806653, 0.02144379],
       [0.48417117, 0.2424838 , 0.02455637, 0.24878865],
       [0.14053935, 0.07764935, 0.30591876, 0.47589254]])

In [36]:
attention_decoder = softmax(scaled+mask)
attention_decoder

array([[1.        , 0.        , 0.        , 0.        ],
       [0.90532381, 0.09467619, 0.        , 0.        ],
       [0.64452058, 0.32279039, 0.03268903, 0.        ],
       [0.14053935, 0.07764935, 0.30591876, 0.47589254]])

In [37]:
self_attention = attention_decoder @ v
self_attention

array([[-1.05800722, -0.96276342, -0.95349055,  0.0885535 ,  1.96479734,
        -0.71011295,  0.28629564,  1.98846157],
       [-0.91664684, -0.82342767, -0.8383644 ,  0.15925314,  1.74490349,
        -0.7261799 ,  0.20270788,  1.94880505],
       [-0.5550957 , -0.47347093, -0.53220464,  0.34578479,  1.14251726,
        -0.66513736, -0.01642433,  1.82175417],
       [-0.50857615, -0.291775  ,  0.32779619,  0.35634172,  0.14207678,
         0.25262883, -0.64555455,  0.65394457]])

## Let's pipeline it into a function

In [46]:
def softmax(x):
    return np.exp(x)/np.sum(np.exp(x),axis=-1,keepdims=True)

def scaled_dot_product_attention(q,k,v,mask=False):
    L,d_k = k.shape
    scaled = (q @ k.T) / math.sqrt(d_k)
    if mask:
        mask_mat = np.tril(np.ones((L,L)))
        mask_mat[mask_mat==0] = -np.infty
        mask_mat[mask_mat==1] = 0
        scaled += mask_mat
    return softmax(scaled) @ v

In [47]:
self_attention_func_encoder = scaled_dot_product_attention(q,k,v)
self_attention_func_encoder

array([[-0.41815188, -0.32024463, -0.12993457,  0.43078872,  0.50579859,
        -0.03671273, -0.29866209,  1.3422002 ],
       [-0.90495082, -0.80494485, -0.7933147 ,  0.163789  ,  1.69008169,
        -0.69918051,  0.16937664,  1.8968179 ],
       [-0.55611421, -0.3738264 , -0.15731342,  0.31227673,  0.84353449,
        -0.65417565, -0.30566163,  1.33669046],
       [-0.50857615, -0.291775  ,  0.32779619,  0.35634172,  0.14207678,
         0.25262883, -0.64555455,  0.65394457]])

In [48]:
self_attention_func_encoder == attention_encoder @ v

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])

In [49]:
self_attention_func_decoder = scaled_dot_product_attention(q,k,v,mask=True)
self_attention_func_decoder

array([[-1.05800722, -0.96276342, -0.95349055,  0.0885535 ,  1.96479734,
        -0.71011295,  0.28629564,  1.98846157],
       [-0.91664684, -0.82342767, -0.8383644 ,  0.15925314,  1.74490349,
        -0.7261799 ,  0.20270788,  1.94880505],
       [-0.5550957 , -0.47347093, -0.53220464,  0.34578479,  1.14251726,
        -0.66513736, -0.01642433,  1.82175417],
       [-0.50857615, -0.291775  ,  0.32779619,  0.35634172,  0.14207678,
         0.25262883, -0.64555455,  0.65394457]])

In [50]:
self_attention_func_decoder == attention_decoder @ v

array([[ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True]])