In [22]:
import numpy as np
import torch

def np_conv(name, mtx):
    mtx = str(mtx)[1:-1]
    out = f"${name}"" = \\begin{bmatrix}\n"
    for i, char in enumerate(mtx):
        if char == '[':
            out += "\t"
        elif char == ' ' and mtx[i-1].isdigit():
            out += ' & '
        elif char == ']':
            out += "\\\\"
        else:
            out += char
    out += "\n\end{bmatrix}$\\\\\\\\"
    return out

In [23]:

# Embedding matrix
E = np.array([
    [0, 1, 2],
    [6, 7, 1],
    [3, 4, 5],
    [0, 2, 1],
    [1, 3, 0],
    [3, 8, 6],
    [2, 7, 5],
    [6, 2, 1],
    [9, 1, 3],
    [0, 1, 1]])

# Query matrix
Q = np.array([
    [1, 1, 7],
    [2, 5, 1],
    [2, 6, 9]])

# Key matrix
K = np.array([
    [0, 4, 8],
    [1, 6, 9],
    [4, 2, 2]])

# Value Matrix
V = np.array([
    [2, 1, 0],
    [4, 3, 1],
    [6, 5, 1]])

# FFN matrix
W_f = np.array([
    [1, 0, 1],
    [0, 1, 1],
    [1, 1, 1]])

# FFN bias
b_f = np.array([2, 1, 1]).T

# Output projection matrix
M_out = np.array([
    [0, 2, 1, 1, 3, 1, 0, 0, 4, 1],
    [1, 1, 3, 1, 0, 0, 4, 1, 0, 2],
    [1, 4, 0, 0, 1, 3, 1, 1, 2, 0]])

# one hot vecs for input sentence [BOS] the fox jumped [EOS]
I = np.array([
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]])


In [24]:
X = I@E
print(np_conv('X', X))

# Calculate Q, K and V for each input vec
Q_x = Q@X.T
print(np_conv('Q_x', Q_x))

K_x = K@X.T
print(np_conv('K_x', K_x))

V_x = V@X.T
print(np_conv('V_x', V_x))

# compute attention scores
S = (Q@X.T).T@(K@X.T)
print(np_conv('S', S))

$X = \begin{bmatrix}
	0 & 1 & 2\\
 	0 & 2 & 1\\
 	2 & 7 & 5\\
 	6 & 2 & 1\\
 	3 & 4 & 5\\
\end{bmatrix}$\\\\
$Q_x = \begin{bmatrix}
	15 &  9 & 44 & 15 & 42\\
 	 7 & 11 & 44 & 23 & 31\\
 	24 & 21 & 91 & 33 & 75\\
\end{bmatrix}$\\\\
$K_x = \begin{bmatrix}
	20 & 16 & 68 & 16 & 56\\
 	24 & 21 & 89 & 27 & 72\\
 	 6 &  6 & 32 & 30 & 30\\
\end{bmatrix}$\\\\
$V_x = \begin{bmatrix}
	 1 &  2 & 11 & 14 & 10\\
 	 5 &  7 & 34 & 31 & 29\\
 	 7 & 11 & 52 & 47 & 43\\
\end{bmatrix}$\\\\
$S = \begin{bmatrix}
	 612 &  531 & 2411 & 1149 & 2064\\
 	 570 &  501 & 2263 & 1071 & 1926\\
 	2482 & 2174 & 9820 & 4622 & 8362\\
 	1050 &  921 & 4123 & 1851 & 3486\\
 	2034 & 1773 & 8015 & 3759 & 6834\\
\end{bmatrix}$\\\\


In [25]:
A = np.array([
    [0, -np.inf, -np.inf, -np.inf, -np.inf],
    [0, 0, -np.inf, -np.inf, -np.inf],
    [0, 0, 0, -np.inf, -np.inf],
    [0, 0, 0, 0, -np.inf],
    [0, 0, 0, 0, 0]])

S = S + A
print(np_conv('S', S))

$S = \begin{bmatrix}
	 612.  -inf  -inf  -inf  -inf\\
 	 570.  501.  -inf  -inf  -inf\\
 	2482. 2174. 9820.  -inf  -inf\\
 	1050.  921. 4123. 1851.  -inf\\
 	2034. 1773. 8015. 3759. 6834.\\
\end{bmatrix}$\\\\


In [26]:
# scale S by d_h
S = np.divide(S, np.sqrt(Q.shape[0]))
print(np_conv('S', S))

# normalise with softmax
S = np.round(np.array(torch.softmax(torch.tensor(S), dim=1)), decimals=5)
print(np_conv('S', S))

print(np.exp(520/1.73)/((np.exp(-np.inf/1.73)*3) + np.exp(570/1.73) + np.exp(520/1.73)))

# multiply score matrix with value matrix (i.e. z_0 = a^0_0 x v_0 + a^0_1 x v_1 + ... + a^0_4 x v_4 for all z)
R = S@V_x.T
print(np_conv('R', R))

# residual connection
R = R + X
print(np_conv('R', R))


$S = \begin{bmatrix}
	 353.33836474 &          -inf          -inf          -inf          -inf\\
 	 329.08965344 &  289.25248486 &          -inf          -inf          -inf\\
 	1432.98336813 & 1255.15948522 & 5669.57964344 &          -inf          -inf\\
 	 606.21778265 &  531.73959792 & 2380.41515987 & 1068.67534827 &          -inf\\
 	1174.33044753 & 1023.64202727 & 4627.46240755 & 2170.25966188 & 3945.61173964\\
\end{bmatrix}$\\\\
$S = \begin{bmatrix}
	1. 0. 0. 0. 0.\\
 	1. 0. 0. 0. 0.\\
 	0. 0. 1. 0. 0.\\
 	0. 0. 1. 0. 0.\\
 	0. 0. 1. 0. 0.\\
\end{bmatrix}$\\\\
2.8063146353260925e-13
$R = \begin{bmatrix}
	 1.  5.  7.\\
 	 1.  5.  7.\\
 	11. 34. 52.\\
 	11. 34. 52.\\
 	11. 34. 52.\\
\end{bmatrix}$\\\\
$R = \begin{bmatrix}
	 1.  6.  9.\\
 	 1.  7.  8.\\
 	13. 41. 57.\\
 	17. 36. 53.\\
 	14. 38. 57.\\
\end{bmatrix}$\\\\


In [27]:
# layer normalisation
epsilon = 0.00001
gamma = 1
beta = 0
Y = np.divide((np.subtract(R.T, np.mean(R, axis=1))), np.sqrt((np.var(R, axis=1) + epsilon))) * gamma + beta
print(np_conv('Y', np.round(Y, decimals = 3)))

# FFN forward pass
O = (W_f@Y).T + b_f
print(np_conv('O', np.round(O, decimals=3)))

# calculate logits over vocab
L = M_out.T@O.T
print(np_conv('L', np.round(L.T, decimals=3)))

# apply softmax to predict next token
probs = np.round(np.array(torch.softmax(torch.tensor(L), dim=0)), decimals=3)
print(np_conv('probs_T', probs.T))

$Y = \begin{bmatrix}
	-1.313 & -1.402 & -1.32 &  -1.247 & -1.269\\
 	 0.202 &  0.539 &  0.22 &   0.045 &  0.095\\
 	 1.111 &  0.863 &  1.1 &    1.201 &  1.175\\
\end{bmatrix}$\\\\
$O = \begin{bmatrix}
	1.798 & 2.313 & 1.   \\
 	1.461 & 2.402 & 1.   \\
 	1.78 &  2.32 &  1.   \\
 	1.955 & 2.247 & 1.   \\
 	1.905 & 2.269 & 1.   \\
\end{bmatrix}$\\\\
$L = \begin{bmatrix}
	 3.313 &  9.909 &  8.738 &  4.111 &  6.394 &  4.798 & 10.253 &  3.313 &  9.192 &  6.424\\
 	 3.402 &  9.323 &  8.666 &  3.863 &  5.383 &  4.461 & 10.607 &  3.402 &  7.843 &  6.264\\
 	 3.32 &   9.88 &   8.74 &   4.1 &    6.34 &   4.78 &  10.279 &  3.32 &   9.12 &   6.42 & \\
 	 3.247 & 10.156 &  8.695 &  4.201 &  6.864 &  4.955 &  9.987 &  3.247 &  9.819 &  6.448\\
 	 3.269 & 10.08 &   8.713 &  4.175 &  6.716 &  4.905 & 10.077 &  3.269 &  9.621 &  6.444\\
\end{bmatrix}$\\\\
$probs_T = \begin{bmatrix}
	0.    0.305 & 0.094 & 0.001 & 0.009 & 0.002 & 0.43 &  0.    0.149 & 0.009\\
 	0.    0.184 & 0.095 & 0.001 & 0.004 & 0.001 