In [1]:
import numpy as np
import math

### INITIALIZE QUERY KEY AND VALUE VECTORS

In [3]:
L,d_k, d_v=4,8,8
q=np.random.randn(L,d_k)
k=np.random.randn(L,d_k)
v=np.random.randn(L ,d_v)


## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$ 

In [12]:
non_scaled= np.dot(q,k.T)
non_scaled

array([[ 3.80455444, -2.7930426 , -0.88015216,  2.40713596],
       [ 7.527268  , -2.89781127,  1.91688114,  2.02065071],
       [-5.81449877,  5.50958336, -2.52298173, -0.29200241],
       [ 1.67064351,  4.71477264, -1.2562126 , -0.92202523]])

In [14]:
#Check Variances 
q.var(),k.var(),non_scaled.var()

(0.8963136066040833, 1.0250371090854091, 11.901918793723505)

In [16]:
#it can be seen that the variance of q.kT is very high, we need to reduce it by scaling
scaled= np.dot(q,k.T)/math.sqrt(d_k)
q.var(),k.var(),scaled.var()

(0.8963136066040833, 1.0250371090854091, 1.487739849215438)

In [18]:
scaled
#we can see that the values were reduced to a range and of low variance

array([[ 1.34511312, -0.98748968, -0.31118078,  0.85105108],
       [ 2.66129112, -1.024531  ,  0.67771983,  0.71440791],
       [-2.05573576,  1.94793188, -0.89200875, -0.10323844],
       [ 0.59066168,  1.66692385, -0.44413823, -0.32598515]])

## Masking

- This is to ensure words don't get context from words generated in the future. 
- Not required in the encoders, but required int he decoders

In [21]:
mask =np.tril(np.ones((L,L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [22]:
mask[mask==0]=-np.infty
mask[mask==0]=0

In [23]:
scaled+mask

array([[ 2.34511312,        -inf,        -inf,        -inf],
       [ 3.66129112, -0.024531  ,        -inf,        -inf],
       [-1.05573576,  2.94793188,  0.10799125,        -inf],
       [ 1.59066168,  2.66692385,  0.55586177,  0.67401485]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [24]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [25]:
attention= softmax(scaled+mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.9755369 , 0.0244631 , 0.        , 0.        ],
       [0.01694898, 0.92878303, 0.05426799, 0.        ],
       [0.21327198, 0.6256746 , 0.07577499, 0.08527842]])

In [27]:
new_v=np.dot(attention,v)
new_v,v

(array([[ 1.10126308, -0.04994959,  1.02084367,  0.17022373, -1.8905071 ,
         -0.70856559, -1.75184522,  0.0985254 ],
        [ 1.05659271, -0.05169121,  0.98991741,  0.19217061, -1.85490828,
         -0.65186021, -1.6880424 ,  0.13115717],
        [-0.79149402, -0.1951564 , -0.22648053,  0.9948521 , -0.39918437,
          1.43942004,  0.7767421 ,  1.19898806],
        [-0.46934977, -0.14927797,  0.10617265,  0.64403253, -0.58614919,
          0.9242513 ,  0.15725369,  0.80228357]]),
 array([[ 1.10126308, -0.04994959,  1.02084367,  0.17022373, -1.8905071 ,
         -0.70856559, -1.75184522,  0.0985254 ],
        [-0.72476787, -0.12114315, -0.24335655,  1.06736605, -0.43530231,
          1.60943067,  0.85627943,  1.43244349],
        [-2.52464119, -1.50722554, -0.32721612,  0.01134205,  0.68473506,
         -0.79941043,  0.20520903, -2.45284664],
        [-0.69706738,  0.60250978,  0.76821441, -0.71477277,  0.43990983,
          1.51226901, -0.23955433,  0.83129813]]))

In [28]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v, mask=None):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  if mask is not None:
    scaled = scaled + mask
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [29]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.53956003 -0.7278342   0.79487738 -0.46728178 -0.52146064  0.45303998
   1.91997477 -0.73071626]
 [ 0.3558573   0.7139521   0.95890709  1.01977779 -0.67939045  1.89435826
   0.96634299  0.52971517]
 [-1.03448014 -0.0193485   0.63200876 -0.24783897 -0.54612951 -1.7305552
  -1.33485169  0.34452031]
 [-1.69628504  0.16801981  1.16609652  1.26262055  0.13514221  0.00291674
  -0.4809035  -1.2920545 ]]
K
 [[-0.62326288  0.48682762 -0.53659623  0.65433147 -1.49399641  3.01245882
   0.96469085 -0.76483198]
 [-2.04244406  0.18643837 -0.69058378  0.70425149 -1.13689369 -0.65008176
  -1.78358513 -0.33494799]
 [ 0.82066678  1.21676776 -0.01005504  0.23899918  0.86422993  0.26624263
   0.53252604  0.17033795]
 [ 0.49164881 -0.76981682  0.10685198  0.61704574 -1.75762392 -0.23120006
   0.84706998  0.16782274]]
V
 [[ 1.10126308 -0.04994959  1.02084367  0.17022373 -1.8905071  -0.70856559
  -1.75184522  0.0985254 ]
 [-0.72476787 -0.12114315 -0.24335655  1.06736605 -0.43530231  1.60943067
   0.85