In [2]:
import numpy as np
import math

L, d_k, d_v = 4, 8, 8
q = np.random.randn(L, d_k)
k = np.random.randn(L, d_k)
v = np.random.randn(L, d_v)

## Self Attention

In [3]:
np. matmul(q, k.T)

array([[ 6.27669192, -1.47350354, -3.19416925, -2.83136567],
       [ 4.67557146,  1.42079285,  2.12530465, -1.643423  ],
       [-3.63945806,  0.38888208,  2.12912139,  1.19110733],
       [-1.41252549,  1.09485173, -0.14143964, -1.12710263]])

In [4]:
q.var(), k.var(), np.matmul(q, k.T).var()

(0.9388843388110786, 0.6822663524883068, 7.112287132764719)

In [5]:
scaled = np.matmul(q, k.T) / math.sqrt(d_k)
q.var(), k.var(), scaled.var()

(0.9388843388110786, 0.6822663524883068, 0.8890358915955898)

In [16]:
scaled

array([[ 2.21914571, -0.52096217, -1.12930937, -1.00103893],
       [ 1.65306414,  0.50232613,  0.75140867, -0.58103777],
       [-1.28674274,  0.13749058,  0.75275809,  0.42112003],
       [-0.49940318,  0.38708854, -0.05000646, -0.39849096]])

## Masking

In [12]:
mask = np.tril(np.ones( (L, L)))
mask

array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]])

In [13]:
mask[mask == 0] = -np.infty
mask[mask == 1] = 0
mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [14]:
scaled + mask

array([[ 2.21914571,        -inf,        -inf,        -inf],
       [ 1.65306414,  0.50232613,        -inf,        -inf],
       [-1.28674274,  0.13749058,  0.75275809,        -inf],
       [-0.49940318,  0.38708854, -0.05000646, -0.39849096]])

## Softmax

In [18]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

In [35]:
attention = softmax(scaled + mask)
attention

array([[1.        , 0.        , 0.        , 0.        ],
       [0.75964569, 0.24035431, 0.        , 0.        ],
       [0.07787287, 0.32353618, 0.59859094, 0.        ],
       [0.16393047, 0.39779391, 0.25693909, 0.18133653]])

In [36]:
new_v = np.matmul(attention, v)
new_v

array([[ 0.27921481, -1.61197661, -1.46371184, -0.4341913 ,  0.5109939 ,
        -0.11487282, -0.09923813, -1.16228709],
       [ 0.41773985, -1.18583877, -1.07289059, -0.03056215,  0.13426518,
         0.00586366, -0.29655696, -1.41607961],
       [-0.19557084,  0.42859128,  0.01657062, -0.04308915,  0.14843371,
        -0.21634235,  0.79030764, -1.06602354],
       [-0.04173758,  0.40657756, -0.45038202,  0.70766714, -0.11536355,
         0.03321447,  0.0429623 , -1.06449782]])

In [37]:
v

array([[ 0.27921481, -1.61197661, -1.46371184, -0.4341913 ,  0.5109939 ,
        -0.11487282, -0.09923813, -1.16228709],
       [ 0.85555166,  0.16098034,  0.16230957,  1.24511768, -1.05639519,
         0.38745427, -0.92018796, -2.2181971 ],
       [-0.82546524,  0.83869889,  0.13037458, -0.68848027,  0.75247249,
        -0.55589272,  1.8305484 , -0.43075473],
       [-1.18976771,  2.15785595, -1.70125209,  2.53915519,  0.15306876,
         0.22471616, -0.24850933,  0.65678704]])

In [38]:
def softmax(x):
    return (np.exp(x).T / np.sum(np.exp(x), axis=-1)).T

def scaled_dot_product_attention(q, k ,v, mask=None):
    d_k = q.shape[-1]
    scaled = np.matmul(q, k.T) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = softmax(scaled)
    out = np.matmul(attention, v)
    return out, attention

In [40]:
values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("New V\n", values)
print("Attention\n", attention)

Q
 [[ 0.21827976  0.29169933  2.62722975  0.25567183 -1.62182615  0.40129385
   1.7342214  -0.67822269]
 [ 1.18390483 -0.92181107  1.64459561  0.09108538 -0.91471029  1.17674544
   2.00155926  0.75720602]
 [ 1.76490509 -0.51926953  0.03176173  0.01765212  1.5024297   0.12568892
   0.18694038 -0.05733183]
 [-1.02020627 -0.59077518 -0.19402845  0.88391445  0.43880911  0.15582571
  -0.52500026  0.12719653]]
K
 [[-1.10833774  0.06313029  0.60030634 -1.06862347 -1.33234107  0.25729035
   1.75435921  0.16254706]
 [-0.45045007 -0.43847875 -0.02290908  0.09556501  0.62228887  0.23556751
   0.39095934  1.43731377]
 [ 0.73645148 -1.0672955  -0.89985083 -0.75474135  0.14289899  1.42211726
  -0.16408379  0.79750106]
 [ 0.59247324 -0.42708696 -0.34280843 -0.88952517  0.1643907  -1.84336711
  -0.14160905  0.67252263]]
V
 [[ 0.27921481 -1.61197661 -1.46371184 -0.4341913   0.5109939  -0.11487282
  -0.09923813 -1.16228709]
 [ 0.85555166  0.16098034  0.16230957  1.24511768 -1.05639519  0.38745427
  -0.9

In [11]:
import math 

def pointwise_mutual_information(x, y, pxy):
    print(math.log( pxy / (x*y) , 2))

pxy = [0.1, 0.7, 0.15, 0.05]
x = [0.8, 0.8, 0.2, 0.2]
y = [0.25, 0.75, 0.25, 0.75]

for i in range(4):
    pointwise_mutual_information(x[i], y[i], pxy[i])


-1.0
0.22239242133644774
1.584962500721156
-1.5849625007211563


In [10]:
import math

def pointwise_mutual_information(x, y, pxy):
    p_x = x
    p_y = y
    pmi = math.log(pxy / (p_x * p_y), 2)
    print("PMI:", pmi)

pxy = [0.1, 0.7, 0.15, 0.05]
x = [0.8, 0.8, 0.2, 0.2]
y = [0.25, 0.75, 0.25, 0.75]

for i in range(4):
    pointwise_mutual_information(x[i], y[i], pxy[i])


PMI: -1.0
PMI: 0.22239242133644774
PMI: 1.584962500721156
PMI: -1.5849625007211563
