In [1]:
import numpy as np

In [2]:
rng = np.random.default_rng(4000)

In [3]:
def split_heads(x, H):
    # x: (B, T, C) => (B, H, T,d_head)
    B, T, C = x.shape
    d_head = C // H
    x = x.reshape(B, T, H, d_head)
    return x.transpose(0, 2, 1, 3)

In [4]:
def merge_heads(x):
    # x: (B, H, T, d_head) => (B, T, C)
    B, H, T, d_head = x.shape
    return x.transpose(0, 2, 1, 3).reshape(B, T, H*d_head)

In [5]:
def layer_norm(x, gamma, beta, eps=1e-5):
    mu = x.mean(axis=-1, keepdims=True)
    var = ((x - mu) ** 2 ).mean(axis = -1, keepdims = True)
    xhat = (x - mu ) / np.sqrt(var + eps)
    return xhat * gamma + beta

In [6]:
def softmax(x, axis = -1) :
    x = x - np.max(x, axis = axis, keepdims = True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis = axis, keepdims = True)

In [7]:
def causal_mask(T):
    # (1, 1, T, T) ใช้ broadcast กับ (B, H, T, T)
    m = np.triu(np.ones((T, T), dtype=bool), k = 1)
    return m[None, None, :, :]

In [8]:
dout = np.array([[[ 0.07788198, -0.00560083, -0.44371856,  0.38684485,
          0.1261935 , -0.13216348],
        [ 0.16143276,  0.03677408,  0.24246872, -0.01143926,
         -0.14273741, -0.05650211],
        [-0.07538471,  0.03161602, -0.11086203, -0.07365076,
          0.12672887,  0.08785918],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.16217772,  0.03159888,  0.2467916 , -0.02172796,
         -0.1281925 , -0.02960269],
        [-0.14111161,  0.07411266, -0.21028277, -0.12922675,
          0.26637395,  0.19252851],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])

In [9]:
dout.shape

(2, 4, 6)

In [10]:
B = 2 # Batch size
T = 4 # Token
V = 7 # Vocab size
d_model = 6 # hidden
H = 2 # attention head count

## Pre-LN type transformer architecture
---

$$out = FFN(LN(y)) + y$$
$$y = ATTN(LN(x)) + x$$

where x is input of attention

FFN definition from Attention is All you need: 
    $$FFN(x) = max(0, max(0, x \cdot W_1 + b_1)W_2 + b_2)$$

$$ATTN(x) = concat(attn_i(x), axis = -1)$$
$$attn = softmax(\frac{Q \cdot K^\intercal}{\sqrt{d_k} }+ causal\_mask) \cdot V$$

## Forwarding

In [11]:
X = rng.random((B, T, d_model), np.float64)
print(X)

[[[0.45461138 0.9079045  0.22341803 0.75301023 0.21061655 0.28646412]
  [0.9110329  0.06668155 0.80997244 0.44144683 0.01151965 0.43382742]
  [0.31255883 0.80263996 0.22221152 0.46048076 0.05819513 0.49784596]
  [0.46341809 0.94756117 0.88801601 0.33144693 0.52629353 0.91218153]]

 [[0.76991742 0.84385534 0.95556605 0.44375354 0.79734196 0.11813618]
  [0.8017991  0.88149331 0.6328389  0.11407255 0.44154406 0.5999176 ]
  [0.43327192 0.17617901 0.97017048 0.99389211 0.24304429 0.63569   ]
  [0.99131382 0.56461089 0.35934087 0.96484629 0.96325469 0.93638915]]]


In [12]:
Xrs = X.reshape(B * T, d_model)
print(Xrs)

[[0.45461138 0.9079045  0.22341803 0.75301023 0.21061655 0.28646412]
 [0.9110329  0.06668155 0.80997244 0.44144683 0.01151965 0.43382742]
 [0.31255883 0.80263996 0.22221152 0.46048076 0.05819513 0.49784596]
 [0.46341809 0.94756117 0.88801601 0.33144693 0.52629353 0.91218153]
 [0.76991742 0.84385534 0.95556605 0.44375354 0.79734196 0.11813618]
 [0.8017991  0.88149331 0.6328389  0.11407255 0.44154406 0.5999176 ]
 [0.43327192 0.17617901 0.97017048 0.99389211 0.24304429 0.63569   ]
 [0.99131382 0.56461089 0.35934087 0.96484629 0.96325469 0.93638915]]


In [13]:
dk = d_model // H
dq = d_model // H
dv = d_model // H

In [14]:
Wk = rng.random((d_model, d_model), np.float64)[None, :, :]
Wq = rng.random((d_model, d_model), np.float64)[None, :, :]
Wv = rng.random((d_model, d_model), np.float64)[None, :, :]

In [15]:
K = X @ Wk
print(K)

[[[1.69827184 0.48160908 1.48858585 1.38148467 1.43737408 1.74600031]
  [1.60772674 0.32969475 1.77063983 1.67920765 1.8350396  1.99610657]
  [1.4499264  0.38631034 1.1907154  1.21686007 1.26898723 1.32295591]
  [2.30724769 0.65258577 2.1348966  2.0128202  2.17980615 2.27854297]]

 [[1.99027233 0.62134838 2.16243293 1.54949139 1.82271898 2.50898325]
  [1.68489373 0.5519499  1.81588932 1.48378157 1.66687634 1.97018777]
  [2.42077915 0.50127999 2.18992614 2.34063848 2.41344037 2.47249864]
  [2.40920503 0.90931833 2.71140137 2.66256324 2.49278744 2.95386996]]]


In [16]:
Q = X @ Wq
V = X @ Wv

In [17]:
Ksplit = split_heads(K, H)
Qsplit = split_heads(Q, H)
Vsplit = split_heads(V, H)

In [18]:
Ksplit.shape

(2, 2, 4, 3)

In [19]:
scores = (Qsplit @ Ksplit.transpose(0, 1, 3, 2))/np.sqrt(dk)

In [20]:
scores.shape

(2, 2, 4, 4)

In [21]:
m = causal_mask(T)
scores = np.where(m, -1e9, scores)
scores.shape

(2, 2, 4, 4)

In [22]:
attn = softmax(scores, axis = -1)
print(attn)

[[[[1.         0.         0.         0.        ]
   [0.497259   0.502741   0.         0.        ]
   [0.38630999 0.38814851 0.2255415  0.        ]
   [0.09426586 0.09705267 0.03723875 0.77144272]]

  [[1.         0.         0.         0.        ]
   [0.34693574 0.65306426 0.         0.        ]
   [0.29142418 0.54193152 0.1666443  0.        ]
   [0.0899694  0.23432084 0.03652855 0.63918121]]]


 [[[1.         0.         0.         0.        ]
   [0.71826246 0.28173754 0.         0.        ]
   [0.31298757 0.11543245 0.57157997 0.        ]
   [0.07595578 0.01950134 0.17093249 0.7336104 ]]

  [[1.         0.         0.         0.        ]
   [0.69964269 0.30035731 0.         0.        ]
   [0.18943338 0.09595853 0.71460809 0.        ]
   [0.04338342 0.01713654 0.22646623 0.71301381]]]]


In [23]:
attn.shape

(2, 2, 4, 4)

In [24]:
context = attn @ Vsplit
print(context.shape)
print(context)

(2, 2, 4, 3)
[[[[1.05080832 1.27588609 1.22999668]
   [1.23416586 1.02423468 1.14895823]
   [1.11030788 1.08004929 1.1478513 ]
   [1.25899058 1.86508876 1.6742963 ]]

  [[1.26776215 1.55038072 0.97869839]
   [1.17164288 1.67564328 0.99319525]
   [1.13542784 1.59168337 0.94676382]
   [1.52693039 1.76236945 1.28605242]]]


 [[[1.74338133 1.51769822 1.55104743]
   [1.59495215 1.56628126 1.562557  ]
   [1.61276753 1.33791164 1.38436888]
   [1.89721471 1.67806444 1.54865867]]

  [[1.92648624 2.01224077 1.79940096]
   [1.80386147 1.92798053 1.67884222]
   [1.59286308 1.95044528 1.27493744]
   [2.09305153 2.53408733 1.60889467]]]]


In [25]:
output = merge_heads(context)
print(output.shape)
print(output)

(2, 4, 6)
[[[1.05080832 1.27588609 1.22999668 1.26776215 1.55038072 0.97869839]
  [1.23416586 1.02423468 1.14895823 1.17164288 1.67564328 0.99319525]
  [1.11030788 1.08004929 1.1478513  1.13542784 1.59168337 0.94676382]
  [1.25899058 1.86508876 1.6742963  1.52693039 1.76236945 1.28605242]]

 [[1.74338133 1.51769822 1.55104743 1.92648624 2.01224077 1.79940096]
  [1.59495215 1.56628126 1.562557   1.80386147 1.92798053 1.67884222]
  [1.61276753 1.33791164 1.38436888 1.59286308 1.95044528 1.27493744]
  [1.89721471 1.67806444 1.54865867 2.09305153 2.53408733 1.60889467]]]


In [26]:
X_FFN = X + output
print(X_FFN.shape)
print(X_FFN)

(2, 4, 6)
[[[1.50541969 2.18379059 1.45341472 2.02077238 1.76099727 1.26516251]
  [2.14519876 1.09091623 1.95893067 1.61308971 1.68716294 1.42702267]
  [1.42286671 1.88268925 1.37006282 1.5959086  1.6498785  1.44460978]
  [1.72240867 2.81264993 2.56231231 1.85837732 2.28866297 2.19823395]]

 [[2.51329875 2.36155356 2.50661348 2.37023978 2.80958273 1.91753714]
  [2.39675126 2.44777457 2.1953959  1.91793403 2.36952459 2.27875981]
  [2.04603945 1.51409064 2.35453936 2.58675518 2.19348957 1.91062744]
  [2.88852852 2.24267533 1.90799955 3.05789783 3.49734202 2.54528382]]]


In [27]:
expand_scale = 4

In [28]:

W_expand = rng.random((d_model, d_model * expand_scale), np.float64)[None, :, :]
b_expand = rng.random((d_model * expand_scale), np.float64)[None, None, :]

print(W_expand.shape)
print(b_expand.shape)

(1, 6, 24)
(1, 1, 24)


In [29]:
gamma = rng.random((d_model), np.float64)
beta = rng.random((d_model), np.float64)

X_FFN_N = layer_norm(X_FFN, gamma = gamma, beta = beta)
print(X_FFN_N)

[[[ 0.08728607  1.49611494  0.43019102  1.02796724  1.16567207
   -0.44679896]
  [ 0.38439053  0.27265565  1.21324143  0.19498251  1.0708205
   -0.08809221]
  [ 0.05831561  1.63169049  0.26823069  0.43332582  1.47760112
   -0.0921864 ]
  [-0.02705926  1.50409262  1.19680431 -0.47506848  1.10098782
    0.20264506]]

 [[ 0.22957322  0.83642906  0.95680305  0.16323514  2.43953724
   -0.72303876]
  [ 0.28153162  1.30890326  0.59533477 -1.19581726  1.54246478
    0.29525969]
  [ 0.15088208  0.23909581  1.14500097  1.35179134  1.24306595
   -0.03463275]
  [ 0.22983243  0.5805443   0.08279414  0.8062191   2.48653713
    0.1163327 ]]]


In [30]:
# calculate then pass through Relu function
X_expand = np.maximum(X_FFN_N @ W_expand + b_expand, 0)
print(X_expand.shape)
print(X_expand)

(2, 4, 24)
[[[2.32001661 2.99356751 2.69023381 2.75079464 3.68270609 2.82399537
   2.35321254 3.10095955 2.81838067 1.36765357 3.06242165 3.04675111
   3.18791831 2.05267604 2.82600241 2.75942372 3.79559057 2.50925136
   2.94899243 3.53517387 2.33331537 2.80977521 1.40714308 2.59583338]
  [1.42045799 2.27862488 1.44387469 2.29315205 3.23291292 2.61574291
   1.47340445 2.47593323 2.20055205 0.89934746 2.09051582 2.46358101
   2.08183745 1.77663255 2.11618775 1.78103236 2.26243329 1.74418902
   2.49773375 3.19639391 2.04034854 1.74670195 1.35827702 2.67253615]
  [2.34359209 2.69796691 2.74295421 3.0067361  3.47252862 2.93202166
   2.62364147 2.95663377 2.61782279 1.04487108 2.95358166 3.25982435
   3.28091132 1.49884488 2.86528932 2.90531869 3.500473   2.71784755
   3.01663688 3.55231908 1.70647245 2.51519746 1.53130385 2.56498644]
  [1.45770686 2.53727564 2.16748349 2.59846751 3.62346873 2.80384835
   2.51019791 2.82930452 2.51230971 0.26598257 2.41460314 3.3636241
   2.8914277  1.40679

In [31]:
W_shrink = rng.random((d_model * expand_scale, d_model), np.float64)[None, :, :]
b_shrink = rng.random((d_model), np.float64)[None, None, :]

In [32]:
y_hat = X_expand @ W_shrink + b_shrink
X_shrink = y_hat
print(X_shrink)

[[[27.75634902 29.35922815 33.49682252 32.55031983 33.17177079
   36.47838881]
  [21.83819648 22.49861954 25.40237089 25.29411448 25.26688914
   28.11831878]
  [27.26476764 28.80715593 32.64153983 31.8861005  32.88227853
   35.77511097]
  [24.31680297 24.54325274 28.53880367 28.34426832 28.99883868
   31.07973456]]

 [[26.86153647 28.49845828 32.39786433 32.59636387 32.55987649
   36.42970984]
  [19.54913161 20.49453226 22.95167865 23.17952201 24.5667358
   25.73038086]
  [29.34903397 30.20578542 34.26490308 33.36431761 32.73344277
   37.41510099]
  [30.24738163 32.33363197 35.61778376 34.92453829 35.25592677
   39.84388742]]]


In [33]:
out = X_shrink + X_FFN # X_FFN = attn + X_input_transformer
print(out)

[[[29.26176872 31.54301875 34.95023724 34.57109221 34.93276806
   37.74355132]
  [23.98339524 23.58953577 27.36130156 26.90720419 26.95405208
   29.54534145]
  [28.68763435 30.68984517 34.01160264 33.4820091  34.53215703
   37.21972075]
  [26.03921163 27.35590267 31.10111598 30.20264563 31.28750166
   33.27796851]]

 [[29.37483523 30.86001185 34.90447781 34.96660364 35.36945922
   38.34724698]
  [21.94588287 22.94230684 25.14707455 25.09745604 26.93626038
   28.00914067]
  [31.39507342 31.71987607 36.61944244 35.95107279 34.92693233
   39.32572843]
  [33.13591015 34.5763073  37.5257833  37.98243611 38.75326879
   42.38917124]]]


Backwarding
---

start from $out$

$$\text{output from transformer} = out = \text{output from FFN} + \text{output from attention}$$
$$\text{output from FFN} = y$$

$$\frac{\partial out} {\partial y} = 1 $$
$$\frac{\partial out} {\partial \text{output from attention}} = 1$$

$$dy = \frac{\partial \mathcal{L}}{\partial y} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out}{\partial y} = \frac{\partial \mathcal{L}}{\partial out}$$

#### find $db_{shrink}$ and $d\hat{y}$

$$db_{shrink} = \frac {\partial \mathcal{L}}{\partial b_{shrink}} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out} {\partial \hat{y}} \frac{\partial \hat{y}}{\partial b_{shrink}}$$


$$d\hat{y} = \frac{\partial \mathcal{L}}{\partial \hat{y}} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out}{\partial \hat{y}} $$

$$FFN(g) = P\ @\ W_{shrink} \oplus B_{shrink}$$
$$P = \hat{g} \ @\ W_{expand} \oplus B_{expand}$$
$$\hat{g} = LN(g)

where

* $out = FFN(LN(y)) + y$

* $\mathcal{R}(\hat{y}) = Relu(\hat{y})$ then $\frac{\partial \mathcal{R}(\hat{y})}{\partial \hat{y}} = \begin{bmatrix} 0;\ \text{if}\ \hat{y} \leq 0\\ 1;\ \text{if}\ \hat{y} > 0 \end{bmatrix}^{B \times T \times d_{model}} $
* $\hat{y} = \hat{x} \cdot W_{shrink} + b_{shrink}$ 

then
$$\frac{\partial \hat{y}}{\partial b_{shrink}} = 1$$

$$\frac{\partial \mathcal{L}} {\partial \hat{y}} = \frac{\partial \mathcal{L}} {\partial out} \odot \begin{bmatrix} 0;\ \text{if}\ \hat{y} \leq 0\\ 1;\ \text{if}\ \hat{y} > 0 \end{bmatrix}^{B \times T \times d_{model}} $$

after checking dimension

$$\frac{\partial \mathcal{L}} {\partial b_{shrink}} = \sum_{b, t} \left(\frac{\partial \mathcal{L}} {\partial \hat{y}}  \right)$$

We must sum over batch and token for get $b_{shrink}$


In [34]:
y_hat.shape

(2, 4, 6)

In [35]:
dout.shape

(2, 4, 6)

In [36]:
dReluYhat = y_hat > 0
dReluYhat.shape

(2, 4, 6)

In [37]:
dyhat = dout
print(dyhat.shape)

(2, 4, 6)


In [38]:
# print(np.sum(dyhat, axis=(0, 1), keepdims=True))
print(dyhat)

[[[ 0.07788198 -0.00560083 -0.44371856  0.38684485  0.1261935
   -0.13216348]
  [ 0.16143276  0.03677408  0.24246872 -0.01143926 -0.14273741
   -0.05650211]
  [-0.07538471  0.03161602 -0.11086203 -0.07365076  0.12672887
    0.08785918]
  [ 0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.16217772  0.03159888  0.2467916  -0.02172796 -0.1281925
   -0.02960269]
  [-0.14111161  0.07411266 -0.21028277 -0.12922675  0.26637395
    0.19252851]
  [ 0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.        ]]]


In [39]:
dbshrink = np.sum(dout, axis=(0, 1), keepdims=True)
dbshrink

array([[[ 0.18499614,  0.16850081, -0.27560304,  0.15080012,
          0.24836641,  0.06211941]]])

#### find $dW_{shrink}$

from basic chain rule calculus :
$$dW_{shrink} = \frac{\partial \mathcal {L}}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial W_{shrink}}$$

$$\hat{y} = K\ @\ W_{shrink} \oplus B_{shrink}$$

$$\frac{\partial \hat{y}}{\partial W_{shrink}} = K^\intercal$$

after shape check We got:

$$dW_{shrink} = \sum_b \left[\frac{\partial \hat{y}}{\partial W_{shrink}} @ \frac{\partial \mathcal{L}}{\partial \hat{y}}\right]$$

We must sum over batch for get $dW_{shrink}$

In [40]:
(dout * (y_hat > 0)).shape

(2, 4, 6)

In [41]:
X_expand.transpose(0, 2, 1).shape

(2, 24, 4)

In [42]:
W_shrink.shape

(1, 24, 6)

In [43]:
dW_shrink_ns = X_expand.transpose(0, 2, 1) @ dyhat # K = X_expand
dW_shrink = np.sum(dW_shrink_ns, axis = 0)
print(dW_shrink.shape)
# print(dW_shrink)

(24, 6)


#### find $dP$

In [44]:
X_expand.shape

(2, 4, 24)

In [45]:
W_shrink.shape

(1, 24, 6)

In [46]:
dyhat.shape

(2, 4, 6)

$$\hat{y} = Relu\left(P\right)\ @\ W_{shrink} \oplus B_{shrink}$$

$$\frac{\partial \mathcal{R}(P)}{\partial K} = \begin{bmatrix} 0;\ \text{if}\ P \leq 0\\ 1;\ \text{if}\ P > 0 \end{bmatrix}^{B \times T \times d_{model\ expand}} $$

$$\frac{\partial \hat{y}}{\partial \mathcal{R}(P)} = W_{shrink}^\intercal$$

$$\frac{\partial \mathcal{L}}{\partial P} = \left(\frac{\partial \mathcal{L}} {\partial \hat{y}} \odot \frac{\partial \mathcal{R}(P)}{\partial P}\right)\ @\ \frac{\partial \hat{y}}{\partial \mathcal{R}(P)}$$

$$\frac{\partial \mathcal{L}}{\partial P} = \left(\frac{\partial \mathcal{L}} {\partial \hat{y}}  \odot \begin{bmatrix} 0;\ \text{if}\ P \leq 0\\ 1;\ \text{if}\ P > 0 \end{bmatrix}^{B \times T \times d_{model\ expand}} \right)\ @\ W_{shrink}^\intercal$$

In [47]:
print(dyhat.shape)
print(W_shrink.shape)

(2, 4, 6)
(1, 24, 6)


In [48]:
(dyhat @ W_shrink.transpose(0, 2, 1)).shape

(2, 4, 24)

In [49]:
dP = (dyhat * (K > 0)) @ W_shrink.transpose(0, 2, 1)
dP.shape

(2, 4, 24)

In [50]:
dP[1, :, :]

array([[-0.02320399,  0.13566701,  0.0835031 ,  0.09804354,  0.02299889,
         0.09813413,  0.31841884,  0.22338141,  0.27997756,  0.188616  ,
         0.00716171, -0.00735264,  0.14223163,  0.10753838,  0.20723656,
        -0.08225344, -0.00870833,  0.19023417,  0.22159395,  0.11544541,
         0.17256298,  0.23179216,  0.04542552,  0.10394   ],
       [ 0.19902305, -0.08369688,  0.20690232,  0.04913803,  0.05919416,
         0.00447043, -0.10818855, -0.13065227, -0.14879831, -0.01188733,
         0.26396885,  0.25544608, -0.01371743,  0.11301917, -0.11771402,
         0.29375551,  0.13045445, -0.05537048, -0.01647172, -0.01646714,
         0.00807723, -0.16771011,  0.22665711,  0.12033958],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
  

#### find $db_{expand}$

$$P = Relu\left(\hat{g}\right) \ @\ W_{expand} \oplus B_{expand}$$

In [51]:
b_expand.shape

(1, 1, 24)

In [52]:
dP.shape

(2, 4, 24)

$$\frac{\partial \mathcal{L}}{\partial B_{expand}} = \sum_{b, t} \left[ \frac{\partial \mathcal{L}}{\partial P} \frac{\partial P}{\partial b_{expand}} \right]$$

$$\frac{\partial P}{\partial b_{expand}} = 1$$

In [53]:
dBexpand = np.sum(dP, axis = (0, 1), keepdims=True)
print(dBexpand.shape)
print(dBexpand)

(1, 1, 24)
[[[ 0.52807845  0.02083732  0.29501652  0.11307535  0.23229567
    0.25835789  0.27487867  0.23754832  0.16420999 -0.00402603
    0.20660631  0.29427045  0.26891028  0.34408056  0.29926609
    0.3587176   0.17531845  0.18706538 -0.00446159  0.27200958
    0.21405253  0.1533026   0.51044953  0.54205255]]]


#### find $dW_{expand}$

In [54]:
W_expand.shape

(1, 6, 24)

In [55]:
print(dP.shape)
print(X_FFN_N.shape) # \hat{g} = X_FFN_N

(2, 4, 24)
(2, 4, 6)


$$P = \hat{g} \ @\ W_{expand} \oplus B_{expand}$$
$$\hat{g} = LN(g)

$$\frac{\partial P}{\partial W_{expand}} = \hat{g}^\intercal$$

$$\frac{\partial \mathcal{L}}{\partial W_{expand}} = \sum_b \left[\frac{\partial P}{\partial W_{expand}}\ @\ \frac{\partial \mathcal{L}}{\partial P} \right]$$

We need to sum over batch

In [56]:
dW_expand = np.sum(X_FFN_N.transpose(0, 2, 1) @ dP, axis = (0), keepdims = True)
print(dW_expand.shape)
print(W_expand.shape)

(1, 6, 24)
(1, 6, 24)


In [57]:
#print(dW_expand)

#### find $d\hat{g}$

In [58]:
print(X_FFN_N.shape)
print(dP.shape)
print(W_expand.transpose(0, 2, 1).shape)

(2, 4, 6)
(2, 4, 24)
(1, 24, 6)


from basic calculus

$$\frac{\partial \mathcal{L}}{\partial \hat{g}} = \frac {\partial \mathcal{L}}{\partial P} \frac{\partial P}{\partial \hat{g}}$$

$$ \frac{\partial P} {\partial \hat{g}} = W_{expand}^\intercal$$

after shape check : 
$$\frac{\partial \mathcal{L}}{\partial \hat{g}} = \frac {\partial \mathcal{L}}{\partial P}\ @\ \frac{\partial P}{\partial \hat{g}}$$

In [59]:
dghat = dP @ W_expand.transpose(0, 2, 1)
print(dghat.shape)

(2, 4, 6)


#### find $dg$ the most hardest one

$$ d\hat{g} = LN(g)$$

from old proof result :
$$ \frac {\partial \hat{g}_i}{\partial g_j} = \frac{1}{\sqrt{\sigma^2 +\epsilon}} \left(\delta_{i, j} - \frac{1}{d} - \frac{\hat{g}_i \hat{g}_j}{d}\right)$$

สมการที่เราจะนำไปใช้

$$\frac{\partial \mathcal{L}} {\partial g_j} = \frac{1}{d_{model} \cdot \sqrt{\sigma^2 + \epsilon}} \left( d_{model} \cdot \frac{\partial \mathcal {L}}{\partial \hat{g}_j} - \sum_{i = 1}^{d_{model}}  \frac{\partial \mathcal{L}}{\partial \hat{g}_i}  - \hat{g}_j \odot \sum_{i = 1}^{d_{model}} \left( \frac{\partial \mathcal{L}}{\partial \hat{g}_i} \odot \hat{g}_i \right) \right)

let focus on $\sum_{i = 1}^{d_{model}}  \frac{\partial \mathcal{L}}{\partial \hat{g}_i}$ and $\sum_{i = 1}^{d_{model}} \left( \frac{\partial \mathcal{L}}{\partial \hat{g}_i} \odot \hat{g}_i \right)$

In [60]:
sum_dghat = np.sum(dghat, axis=-1, keepdims=True)
print(sum_dghat.shape)

(2, 4, 1)


In [61]:
dghat_ghat = dghat * X_FFN_N
sum_dghat_ghat = np.sum(sum_dghat, axis=-1, keepdims=True)
print(dghat_ghat.shape)
print(sum_dghat_ghat)

(2, 4, 6)
[[[-0.607476  ]
  [ 7.60746876]
  [ 0.03369551]
  [ 0.        ]]

 [[ 8.91184658]
  [ 3.06179531]
  [ 0.        ]
  [ 0.        ]]]


In [62]:
X_FFN_N * sum_dghat_ghat

array([[[-5.30241906e-02, -9.08853926e-01, -2.61330719e-01,
         -6.24465429e-01, -7.08117809e-01,  2.71419648e-01],
        [ 2.92423896e+00,  2.07421932e+00,  9.22969629e+00,
          1.48332334e+00,  8.14623349e+00, -6.70158705e-01],
        [ 1.96497415e-03,  5.49806374e-02,  9.03816888e-03,
          1.46011331e-02,  4.97885180e-02, -3.10626740e-03],
        [-0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         -0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[ 2.04592127e+00,  7.45412745e+00,  8.52688195e+00,
          1.45472654e+00,  2.17407816e+01, -6.44361053e+00],
        [ 8.61992180e-01,  4.00759386e+00,  1.82279320e+00,
         -3.66134768e+00,  4.72271143e+00,  9.04024727e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00, -0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]])

In [63]:
sigma = np.std(X_FFN_N, axis=-1, keepdims=True)
epsilon = 1e-5
inv_deriv = 1 / (d_model * (sigma ** 2 + epsilon))
print(inv_deriv)

[[[0.37182575]
  [0.74729227]
  [0.36498962]
  [0.31950039]]

 [[0.17760978]
  [0.21228267]
  [0.50770274]
  [0.24048564]]]


In [64]:
dg = inv_deriv * (d_model * dghat - sum_dghat - X_FFN_N * sum_dghat_ghat )
print(dg.shape)
print(dg)

(2, 4, 6)
[[[ 0.63197095  0.36020242  0.3306147  -0.48916015  0.26164413
   -0.24588356]
  [-4.18273278 -1.40989492 -7.14950445  0.93694935 -6.1487175
    0.62602152]
  [ 0.42804598  0.11062142 -0.28801772 -0.207494    0.07656392
   -0.16617078]
  [ 0.          0.          0.          0.          0.
    0.        ]]

 [[-0.8575688  -1.25878385 -1.60860361  0.2697869  -3.85834904
    1.13645828]
  [ 0.24002654 -0.62711396 -0.79213968  0.67046709 -0.85700331
   -0.47213077]
  [ 0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.        ]]]


---

## Backward enter to the attention

from the Pre-LN structure we currently at
$$g = X + attentionResult$$

In [65]:
dg.shape

(2, 4, 6)

In [66]:
out.shape # attn result + X + result from FFN

(2, 4, 6)

from the output of attention in short term:

$$output = X + FFN(LN_2(X))$$
$$X = Input + ATTN(LN_1(Input))$$

$$\frac{\partial output}{\partial X} = I + \frac{\partial FFN}{\partial LN_2(X)}\frac{\partial LN_2(X)}{\partial X}$$

when $I$ is identity matrix

$$\frac{\partial \mathcal{L}}{\partial output} \frac{\partial output}{\partial X} = \frac{\partial \mathcal{L}}{\partial output} + \frac{\partial \mathcal{L}}{\partial output}\frac{\partial FFN}{\partial LN_2(X)} \frac{\partial LN_2(X)}{\partial X}$$

In computational way to find $\frac{\partial \mathcal{L}}{\partial X}$

ให้เขียนแบบแยก Residual เพื่อให้ไม่เผลอทำให้สมการวนเองและทำให้ implement ง่าย

ให้ $g_{out} = \frac{\partial \mathcal{L}}{\partial output}$

* ฝั่ง residual (จาก $X$ ใน $X + ...$) : 

$$g_X^{(res)} = g_{out}$$

* ฝั่ง FFN :

$$g_{ffn\_out} = g_{out}$$
$$g_{LN_2} = g_{ffn\_out} \cdot \frac{\partial FFN}{\partial LN_2(X)}$$

$$g_{X}^{(ffn)} = g_z \cdot \frac{\partial LN_2(X)}{\partial X}$$

* รวม :
$$ g_X = g_{out} + g_X^{(ffn)}$$



In [67]:
dFFN_input = dout + dg
print(dFFN_input.shape)
print(dFFN_input)

(2, 4, 6)
[[[ 0.70985293  0.35460159 -0.11310386 -0.1023153   0.38783763
   -0.37804704]
  [-4.02130002 -1.37312084 -6.90703573  0.92551009 -6.29145491
    0.56951941]
  [ 0.35266127  0.14223744 -0.39887975 -0.28114476  0.20329279
   -0.0783116 ]
  [ 0.          0.          0.          0.          0.
    0.        ]]

 [[-0.69539108 -1.22718497 -1.36181201  0.24805894 -3.98654154
    1.10685559]
  [ 0.09891493 -0.5530013  -1.00242245  0.54124034 -0.59062936
   -0.27960226]
  [ 0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.        ]]]


## pass to attention now 

$$ATTN(LN(X)) = concat(context)$$

$$context = softmax(scores)\ @\ V$$

$$scores = \frac {Q\ @\ K^\intercal}{\sqrt{d_K}} + \text{causal mask}$$

In [68]:
dATTN = dFFN_input

#### split $dATTN$ into $dContext$

In [70]:
print(dATTN.shape)

(2, 4, 6)


In [72]:
dContext = split_heads(dATTN, H=H)
print(dContext.shape)

(2, 2, 4, 3)


#### focus on first head for intuition

$$\frac{\partial \mathcal{L}}{\partial ATTN_{splited}} = dATTN\_splited$$

$$\frac{\partial Context}{\partial V} = softmax(scores)^\intercal$$

In [82]:
print(dContext.shape)
print(attn.shape)
print(Vsplit.shape)

(2, 2, 4, 3)
(2, 2, 4, 4)
(2, 2, 4, 3)


$$\frac{\partial \mathcal{L}}{\partial V} =softmax(scores)^\intercal\ @\ \frac{\partial \mathcal{L}}{\partial Context}$$

In [83]:
dV = attn.transpose(0, 1, 3, 2) @ dContext
print(dV.shape)
print(dV)

(2, 2, 4, 3)
[[[[-1.15353813 -0.27324737 -3.7017808 ]
   [-1.88478744 -0.63511489 -3.62727462]
   [ 0.07953975  0.03208044 -0.08996394]
   [ 0.          0.          0.        ]]

  [[ 0.13684485 -1.73564853 -0.20328229]
   [ 0.45205635 -3.99855355  0.32949324]
   [-0.04685117  0.03387758 -0.01305018]
   [ 0.          0.          0.        ]]]


 [[[-0.62434419 -1.62438504 -2.08181442]
   [ 0.02786805 -0.15580123 -0.28242003]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.62673379 -4.39977106  0.91123391]
   [ 0.16256549 -0.17739984 -0.08398058]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]]]


#### find $dW_V$

merge dV first :

In [85]:
dV_merge = merge_heads(dV)
print(dV_merge.shape)

(2, 4, 6)


In [88]:
Wv.shape

(1, 6, 6)

In [90]:
X.shape

(2, 4, 6)

from 
$$V = X\ @\ W_V$$

$$\frac{\partial \mathcal {L}}{\partial W_V} =\frac{\partial V}{\partial W_V}\ @\ \frac{\partial \mathcal{L}}{\partial V} = X^\intercal\ @\ \frac{\partial \mathcal{L}}{\partial V} $$

In [98]:
dWv = np.sum(X.transpose(0, 2, 1) @ dV_merge, axis=0, keepdims=True)
print(dWv.shape)

(1, 6, 6)


In [96]:
print(dWv)

[[[-2.67500296 -2.06836852 -6.84482651  1.07228382 -7.95097007
    0.83792509]
  [-1.61143192 -1.77276766 -5.680646    0.78895431 -5.68438846
    0.5218569 ]
  [-2.34563859 -2.21914959 -5.95306567  1.08807923 -7.93550276
    1.03616162]
  [-1.93790889 -1.20995435 -5.38618324  0.57768995 -5.02916034
    0.38115413]
  [-0.74554834 -1.42698325 -2.61129557  0.60280379 -3.99609865
    0.64970548]
  [-1.16556032 -0.62320125 -3.09419312  0.38355675 -2.84121483
    0.1354814 ]]]


#### find $dX_V$

$$\frac{\partial \mathcal {L}}{\partial X_V} =\frac{\partial \mathcal{L}}{\partial V}\ @\ \frac{\partial V}{\partial X_V} = \frac{\partial \mathcal{L}}{\partial V}\ @\ W_V^\intercal$$

In [100]:
dXv = dV_merge @ Wv.transpose(0, 2, 1)
print(dXv.shape)

(2, 4, 6)


In [101]:
print(dXv)

[[[-4.01245214e+00 -3.45993223e+00 -3.20775586e+00 -2.76905334e+00
   -1.43666016e+00 -2.65500809e+00]
  [-6.36317301e+00 -3.95476546e+00 -4.11967909e+00 -4.83180708e+00
   -2.13471805e+00 -3.71560753e+00]
  [ 3.18825279e-02 -5.25741343e-02 -1.99883783e-03  4.05485561e-02
    9.95028513e-03 -1.65220092e-02]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]

 [[-4.93539700e+00 -3.50062176e+00 -2.79336286e+00 -3.94254719e+00
   -1.45712863e+00 -3.84861302e+00]
  [-2.40758860e-01 -3.64684725e-01 -2.14789874e-01 -1.15868277e-01
   -4.64848506e-02 -2.87756441e-01]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]]
