## Generate Data

In [1]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

In [2]:
print("Q\n", q)
print("K\n", k)
print("V\n", v)

Q
 [[-0.67239746  0.65643026 -0.38607383  1.04259749  1.18984462 -0.71930886
   1.52166802 -1.26141514]
 [ 0.06669797  0.0403164  -0.45386231 -1.14674233 -0.7245614  -0.58857103
  -1.06170008  0.64068057]
 [-2.55357512  0.11218279  0.00908873  3.01800334  0.51084095  0.10120081
   0.72977999 -1.41511397]
 [ 0.1308654   0.24656579 -1.70066495 -1.15488767  1.37924421  1.15747264
  -0.68524438  2.3559399 ]]
K
 [[-1.03335093 -0.70666215  0.60421157  1.57414063  1.13874658 -0.26711968
  -0.40700281 -0.50892546]
 [-0.03834118 -1.24276769  0.21820102  0.80058268 -1.60568339 -0.30186671
   0.5182941  -0.5313478 ]
 [-0.98923012 -2.36330413  1.13727085 -0.68193675  1.54144524 -0.15662151
  -1.52785808  0.54584296]
 [-0.63433113 -0.00321842 -0.53127923 -2.12506945 -0.94181991 -0.53916661
  -1.13275906  1.1341112 ]]
V
 [[ 0.13644799 -1.10300126  0.23957621  0.43252026 -0.74035559  0.38309679
   0.35594858  0.52655438]
 [-0.38208085  0.7062285  -0.72487066 -0.71447198  1.05291698 -0.41719128
   0.4

## Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [3]:
np.matmul(q, k.T)

array([[ 3.20858905, -0.27402265, -3.10293238, -5.47312939],
       [-2.73859167, -0.61936736,  1.051731  ,  5.56459396],
       [ 8.29356775,  2.65599471, -0.90264679, -7.76608999],
       [-2.81366113, -5.778125  ,  2.41895141,  4.79898326]])

In [4]:
# Why we need sqrt(d_k) in denominator
q.var(), k.var(), np.matmul(q, k.T).var()

(1.3720152411174946, 0.9811603765925998, 18.557568998436647)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(1.3720152411174946, 0.9811603765925998, 2.3196961248045804)

In [6]:
scaled

array([[ 1.13440754, -0.09688164, -1.09705226, -1.93504345],
       [-0.96823837, -0.21897943,  0.37184306,  1.96738106],
       [ 2.932219  ,  0.93903594, -0.31913383, -2.74572745],
       [-0.99477943, -2.04287568,  0.85522847,  1.6966968 ]])

## Softmax

$$
\text{softmax} = \frac{e^{x_i}}{\sum_j e^x_j}
$$

In [7]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [8]:
attention = softmax(scaled)

In [9]:
attention

array([[0.69169006, 0.2019154 , 0.07426785, 0.03212668],
       [0.0388079 , 0.08209546, 0.14822111, 0.73087553],
       [0.84860585, 0.11563189, 0.03285954, 0.00290272],
       [0.04451584, 0.01560746, 0.28311495, 0.65676175]])

In [10]:
new_v = np.matmul(attention, v)
new_v

array([[-0.05397451, -0.74417808,  0.05484616,  0.2404507 , -0.34460823,
         0.22373461,  0.39204122,  0.40238019],
       [-0.34607565,  0.03494939,  0.25618464, -0.38572136,  0.10882165,
        -0.47066127, -0.44908555,  0.43078706],
       [ 0.04312287, -0.91367257,  0.13119659,  0.33100178, -0.5288734 ,
         0.3049888 ,  0.38962498,  0.45509644],
       [-0.41388369, -0.29626975,  0.32342182, -0.07814276, -0.07538953,
        -0.25595608, -0.27668134,  0.50723937]])

# Function

In [11]:
def softmax(x):
  return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k, v):
  d_k = q.shape[-1]
  scaled = np.matmul(q, k.T) / math.sqrt(d_k)
  attention = softmax(scaled)
  out = np.matmul(attention, v)
  return out, attention

In [12]:
values, attention = scaled_dot_product_attention(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[-0.67239746  0.65643026 -0.38607383  1.04259749  1.18984462 -0.71930886
   1.52166802 -1.26141514]
 [ 0.06669797  0.0403164  -0.45386231 -1.14674233 -0.7245614  -0.58857103
  -1.06170008  0.64068057]
 [-2.55357512  0.11218279  0.00908873  3.01800334  0.51084095  0.10120081
   0.72977999 -1.41511397]
 [ 0.1308654   0.24656579 -1.70066495 -1.15488767  1.37924421  1.15747264
  -0.68524438  2.3559399 ]]
K
 [[-1.03335093 -0.70666215  0.60421157  1.57414063  1.13874658 -0.26711968
  -0.40700281 -0.50892546]
 [-0.03834118 -1.24276769  0.21820102  0.80058268 -1.60568339 -0.30186671
   0.5182941  -0.5313478 ]
 [-0.98923012 -2.36330413  1.13727085 -0.68193675  1.54144524 -0.15662151
  -1.52785808  0.54584296]
 [-0.63433113 -0.00321842 -0.53127923 -2.12506945 -0.94181991 -0.53916661
  -1.13275906  1.1341112 ]]
V
 [[ 0.13644799 -1.10300126  0.23957621  0.43252026 -0.74035559  0.38309679
   0.35594858  0.52655438]
 [-0.38208085  0.7062285  -0.72487066 -0.71447198  1.05291698 -0.41719128
   0.4