In [1]:
import numpy as np

In [2]:
rng = np.random.default_rng(4000)

In [3]:
def split_heads(x, H):
    # x: (B, T, C) => (B, H, T,d_head)
    B, T, C = x.shape
    d_head = C // H
    x = x.reshape(B, T, H, d_head)
    return x.transpose(0, 2, 1, 3)

In [4]:
def merge_heads(x):
    # x: (B, H, T, d_head) => (B, T, C)
    B, H, T, d_head = x.shape
    return x.transpose(0, 2, 1, 3).reshape(B, T, H*d_head)

In [5]:
def layer_norm(x, gamma, beta, eps=1e-5):
    mu = x.mean(axis=-1, keepdims=True)
    var = ((x - mu) ** 2 ).mean(axis = -1, keepdims = True)
    xhat = (x - mu ) / np.sqrt(var + eps)
    return xhat * gamma + beta

In [6]:
def softmax(x, axis = -1) :
    x = x - np.max(x, axis = axis, keepdims = True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis = axis, keepdims = True)

In [7]:
def causal_mask(T):
    # (1, 1, T, T) ใช้ broadcast กับ (B, H, T, T)
    m = np.triu(np.ones((T, T), dtype=bool), k = 1)
    return m[None, None, :, :]

In [8]:
dout = np.array([[[ 0.07788198, -0.00560083, -0.44371856,  0.38684485,
          0.1261935 , -0.13216348],
        [ 0.16143276,  0.03677408,  0.24246872, -0.01143926,
         -0.14273741, -0.05650211],
        [-0.07538471,  0.03161602, -0.11086203, -0.07365076,
          0.12672887,  0.08785918],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.16217772,  0.03159888,  0.2467916 , -0.02172796,
         -0.1281925 , -0.02960269],
        [-0.14111161,  0.07411266, -0.21028277, -0.12922675,
          0.26637395,  0.19252851],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])

In [9]:
dout.shape

(2, 4, 6)

In [10]:
B = 2 # Batch size
T = 4 # Token
V = 7 # Vocab size
d_model = 6 # hidden
H = 2 # attention head count

## Pre-LN type transformer architecture
---

$$out = FFN(LN(y)) + y$$
$$y = ATTN(LN(x)) + x$$

where x is input of attention

FFN definition from Attention is All you need: 
    $$FFN(x) =  max(0, x \cdot W_1 + b_1)W_2 + b_2$$

$$ATTN(x) = concat(attn_i(x), axis = -1)$$
$$attn = softmax(\frac{Q \cdot K^\intercal}{\sqrt{d_k} }+ causal\_mask) \cdot V$$

## Forwarding

In [11]:
X = rng.random((B, T, d_model), np.float64)
print(X)

[[[0.45461138 0.9079045  0.22341803 0.75301023 0.21061655 0.28646412]
  [0.9110329  0.06668155 0.80997244 0.44144683 0.01151965 0.43382742]
  [0.31255883 0.80263996 0.22221152 0.46048076 0.05819513 0.49784596]
  [0.46341809 0.94756117 0.88801601 0.33144693 0.52629353 0.91218153]]

 [[0.76991742 0.84385534 0.95556605 0.44375354 0.79734196 0.11813618]
  [0.8017991  0.88149331 0.6328389  0.11407255 0.44154406 0.5999176 ]
  [0.43327192 0.17617901 0.97017048 0.99389211 0.24304429 0.63569   ]
  [0.99131382 0.56461089 0.35934087 0.96484629 0.96325469 0.93638915]]]


In [12]:
gammaAttnNorm = rng.random((1, 1, 6), np.float64)
betaAttnNorm = rng.random((1, 1, 6), np.float64)

In [13]:
Xnorm = layer_norm(X, gamma=gammaAttnNorm, beta=betaAttnNorm)
Xnorm

array([[[ 0.52011077,  0.38861154, -0.4005128 ,  0.46378381,
         -0.27132719, -0.24127436],
        [ 0.76100334,  0.00913523,  0.99734685,  0.01587232,
         -0.41553674,  0.29061284],
        [ 0.47478749,  0.40632862, -0.25930322,  0.1444342 ,
         -0.47708986,  0.68283324],
        [ 0.38557457,  0.31668842,  0.84078051, -0.57880595,
         -0.10807404,  1.09172606]],

       [[ 0.59817875,  0.25602827,  0.97605325, -0.29136199,
          0.40521678, -1.1970939 ],
        [ 0.67894445,  0.33129603,  0.39501747, -0.76263651,
         -0.07392747,  0.38780843],
        [ 0.45800281, -0.00642622,  1.0968379 ,  0.57181923,
         -0.29679539,  0.47054655],
        [ 0.66380303,  0.03343025, -1.00056959,  0.31335035,
          0.48970649,  0.78199846]]])

In [14]:
Xrs = X.reshape(B * T, d_model)
print(Xrs)

[[0.45461138 0.9079045  0.22341803 0.75301023 0.21061655 0.28646412]
 [0.9110329  0.06668155 0.80997244 0.44144683 0.01151965 0.43382742]
 [0.31255883 0.80263996 0.22221152 0.46048076 0.05819513 0.49784596]
 [0.46341809 0.94756117 0.88801601 0.33144693 0.52629353 0.91218153]
 [0.76991742 0.84385534 0.95556605 0.44375354 0.79734196 0.11813618]
 [0.8017991  0.88149331 0.6328389  0.11407255 0.44154406 0.5999176 ]
 [0.43327192 0.17617901 0.97017048 0.99389211 0.24304429 0.63569   ]
 [0.99131382 0.56461089 0.35934087 0.96484629 0.96325469 0.93638915]]


In [15]:
dk = d_model // H
dq = d_model // H
dv = d_model // H

In [16]:
Wk = rng.random((d_model, d_model), np.float64)[None, :, :]
Wq = rng.random((d_model, d_model), np.float64)[None, :, :]
Wv = rng.random((d_model, d_model), np.float64)[None, :, :]

In [17]:
K = Xnorm @ Wk
print(K)

[[[ 0.75364815 -0.33172428  0.37443807  0.84211588  0.92705112
    0.35322833]
  [ 0.48641203  0.15671827  0.76140575  0.36125282  0.90758711
    1.11170233]
  [ 0.70076475  0.05389832  0.44205931  0.4708287   1.30921318
    1.04988157]
  [ 0.66741741  0.80488449  0.92366316  0.09698369  0.95324974
    1.796865  ]]

 [[ 0.59226206 -0.10060461  0.77761451  0.64814133 -0.30366746
    0.19367871]
  [ 0.54656635  0.22337732  0.59020276 -0.02152924  0.52641593
    1.1219031 ]
  [ 0.7490698   0.50819744  1.04714757  0.84381965  1.23699247
    1.33688153]
  [ 1.48802283  0.68961316  0.92778691  0.88861011  1.36888263
    1.30644642]]]


In [18]:
Q = Xnorm @ Wq
V = Xnorm @ Wv

In [19]:
Ksplit = split_heads(K, H)
Qsplit = split_heads(Q, H)
Vsplit = split_heads(V, H)

In [20]:
Ksplit.shape

(2, 2, 4, 3)

In [21]:
scores = (Qsplit @ Ksplit.transpose(0, 1, 3, 2))/np.sqrt(dk)

In [22]:
scores.shape

(2, 2, 4, 4)

In [23]:
m = causal_mask(T)
scores = np.where(m, -1e9, scores)
scores.shape

(2, 2, 4, 4)

In [24]:
attn = softmax(scores, axis = -1)
print(attn)

[[[[1.         0.         0.         0.        ]
   [0.44244559 0.55755441 0.         0.        ]
   [0.25474096 0.40506062 0.34019842 0.        ]
   [0.1236585  0.22991513 0.1756287  0.47079767]]

  [[1.         0.         0.         0.        ]
   [0.52907567 0.47092433 0.         0.        ]
   [0.32852691 0.33223791 0.33923518 0.        ]
   [0.28939822 0.22416113 0.27927587 0.20716478]]]


 [[[1.         0.         0.         0.        ]
   [0.47925302 0.52074698 0.         0.        ]
   [0.26236041 0.24947453 0.48816506 0.        ]
   [0.16227187 0.18249404 0.27826576 0.37696832]]

  [[1.         0.         0.         0.        ]
   [0.5476553  0.4523447  0.         0.        ]
   [0.23192524 0.24668675 0.52138801 0.        ]
   [0.10319673 0.19003574 0.34841212 0.3583554 ]]]]


In [25]:
attn.shape

(2, 2, 4, 4)

In [26]:
context = attn @ Vsplit
print(context.shape)
print(context)

(2, 2, 4, 3)
[[[[ 0.11889747  0.35708279  0.43460387]
   [ 0.47490764  0.4969105   0.37326554]
   [ 0.48343317  0.44976986  0.33305024]
   [ 0.81914372  0.33959522  0.19313138]]

  [[-0.12802176  0.17105823 -0.21362499]
   [ 0.32269054  0.26074511  0.33869517]
   [ 0.26963483  0.25978137  0.333508  ]
   [ 0.4238376   0.36279201  0.4971091 ]]]


 [[[ 0.66586537  0.05968577  0.28005171]
   [ 0.77855504 -0.04485205  0.15768309]
   [ 0.76909218  0.48378599  0.32249732]
   [ 0.68667808  0.4223773   0.42577563]]

  [[ 0.87218112  0.61094387  0.52454246]
   [ 0.76590158  0.52734491  0.64630513]
   [ 0.93968807  0.60578931  0.86949989]
   [ 0.73404693  0.66683416  0.70461989]]]]


In [27]:
mergeAttnBias = rng.random((1, 1, 6), np.float64) * 2 - 1
mergeAttnBias

array([[[ 0.04349354, -0.9691173 , -0.00458403,  0.78509866,
         -0.51202507, -0.53125553]]])

In [28]:
output = merge_heads(context) + mergeAttnBias
print(output.shape)
print(output)

(2, 4, 6)
[[[ 0.16239102 -0.61203451  0.43001984  0.6570769  -0.34096684
   -0.74488052]
  [ 0.51840119 -0.4722068   0.3686815   1.1077892  -0.25127996
   -0.19256035]
  [ 0.52692671 -0.51934744  0.32846621  1.05473349 -0.25224369
   -0.19774753]
  [ 0.86263726 -0.62952208  0.18854734  1.20893626 -0.14923306
   -0.03414643]]

 [[ 0.70935892 -0.90943153  0.27546768  1.65727978  0.0989188
   -0.00671306]
  [ 0.82204859 -1.01396935  0.15309906  1.55100024  0.01531985
    0.11504961]
  [ 0.81258572 -0.48533131  0.31791329  1.72478673  0.09376424
    0.33824437]
  [ 0.73017162 -0.54674     0.4211916   1.51914559  0.15480909
    0.17336437]]]


In [29]:
g = Xnorm + output
print(g.shape)
print(g)

(2, 4, 6)
[[[ 0.68250178 -0.22342298  0.02950703  1.12086071 -0.61229402
   -0.98615488]
  [ 1.27940452 -0.46307157  1.36602835  1.12366153 -0.6668167
    0.09805249]
  [ 1.0017142  -0.11301882  0.06916299  1.19916769 -0.72933356
    0.48508571]
  [ 1.24821183 -0.31283366  1.02932785  0.63013031 -0.25730709
    1.05757963]]

 [[ 1.30753766 -0.65340326  1.25152093  1.36591779  0.50413558
   -1.20380696]
  [ 1.50099304 -0.68267332  0.54811653  0.78836373 -0.05860763
    0.50285804]
  [ 1.27058853 -0.49175753  1.41475119  2.29660596 -0.20303115
    0.80879091]
  [ 1.39397465 -0.51330975 -0.57937798  1.83249594  0.64451558
    0.95536283]]]


In [30]:
expand_scale = 4

In [31]:

W_expand = rng.random((d_model, d_model * expand_scale), np.float64)[None, :, :]
b_expand = rng.random((d_model * expand_scale), np.float64)[None, None, :]

print(W_expand.shape)
print(b_expand.shape)

(1, 6, 24)
(1, 1, 24)


In [32]:
gammaNormFFN = rng.random((d_model), np.float64)
betaNormFFN = rng.random((d_model), np.float64)

gNorm = layer_norm(g, gamma = gammaNormFFN, beta = betaNormFFN)
print(gNorm)

[[[ 0.47084769  0.74583645  0.37016924  1.07378707  0.151618
    0.31902038]
  [ 0.47373066  0.7052754   1.22617038  0.58366479 -0.05580794
    0.63969551]
  [ 0.4770831   0.72826242  0.03101142  0.93144688 -0.15787975
    0.87159563]
  [ 0.48054791  0.69006328  0.93879409  0.12974595 -0.04043644
    1.05202934]]

 [[ 0.46552071  0.70700265  0.99987158  0.66517736  0.5412354
    0.23729744]
  [ 0.51416836  0.67743029  0.47666211  0.40357621  0.20547289
    0.8207324 ]
  [ 0.43605184  0.68983229  0.81995798  1.04773141  0.04743694
    0.77137448]
  [ 0.46460614  0.69728521 -0.74357156  0.93491436  0.52026355
    0.91131355]]]


In [33]:
#pre activation
pre = gNorm @ W_expand + b_expand

In [34]:
# calculate then pass through Relu function
act = np.maximum(pre, 0)
print(act.shape)
print(act)

(2, 4, 24)
[[[2.5605826  2.82399276 2.36503944 2.24696054 1.94733265 2.58281995
   1.57701695 2.02213092 1.76399306 2.59368172 3.00130511 2.71078044
   1.48863431 2.91034982 2.85769399 1.13161516 2.46659003 2.22992261
   2.04577871 2.10614509 2.1175794  2.55766515 2.80782586 2.36646168]
  [2.43414425 3.20365116 2.310536   2.09059281 1.96245397 3.0556605
   1.68248098 2.67639502 2.03355415 2.50796865 3.47718069 2.84642221
   1.66192048 3.5847145  3.13045176 1.67098604 2.99193515 2.08160409
   2.07315552 2.97728323 2.22496864 2.57811662 3.47834439 2.31747641]
  [2.47736241 2.47376523 2.01206032 2.00259335 2.26368469 2.45420598
   1.18712148 2.1518142  1.73947086 2.38947223 3.30147227 2.34258166
   1.35692225 2.69452886 2.38776459 0.74931382 2.65299942 2.17256958
   1.967079   2.13618449 1.73438032 2.13267493 2.688782   1.79136933]
  [2.32512394 2.89244483 1.72524168 1.70380205 2.10797357 2.86236961
   1.15821619 2.68282664 1.90589122 2.17520994 3.56461382 2.490676
   1.69527705 3.5067264

In [35]:
W_shrink = rng.random((d_model * expand_scale, d_model), np.float64)[None, :, :]
b_shrink = rng.random((d_model), np.float64)[None, None, :]

In [36]:
y_hat = act @ W_shrink + b_shrink
print(y_hat)

[[[28.14140088 27.20974863 27.51909607 27.07242418 28.99656047
   28.88155279]
  [31.00618175 29.51301419 29.83683917 29.1025531  32.3310899
   32.95429599]
  [26.33500001 25.1549939  25.21528128 24.78193407 27.22526089
   27.06673537]
  [28.05944342 26.30083469 26.37082051 25.95547653 29.6052129
   30.20566315]]

 [[30.5634335  29.37198446 29.88456614 29.57601361 31.88782155
   32.27087502]
  [26.94593634 25.52056279 25.66368488 25.49584779 28.29772388
   28.5190477 ]
  [32.40186095 31.0491131  31.53607936 30.83102152 33.57023041
   33.99744602]
  [24.15721279 23.1547287  23.29108026 23.59418796 25.13025549
   24.56538243]]]


In [37]:
out = y_hat + g # X_FFN = attn + X_input_transformer
print(out)

[[[28.82390267 26.98632565 27.5486031  28.19328489 28.38426644
   27.89539791]
  [32.28558627 29.04994262 31.20286752 30.22621463 31.66427321
   33.05234848]
  [27.33671421 25.04197507 25.28444427 25.98110176 26.49592733
   27.55182108]
  [29.30765524 25.98800103 27.40014836 26.58560684 29.34790581
   31.26324279]]

 [[31.87097116 28.7185812  31.13608707 30.94193139 32.39195714
   31.06706805]
  [28.44692938 24.83788947 26.21180141 26.28421152 28.23911625
   29.02190573]
  [33.67244949 30.55735557 32.95083055 33.12762748 33.36719926
   34.80623694]
  [25.55118744 22.64141895 22.71170227 25.42668391 25.77477107
   25.52074526]]]


Backwarding
---

start from $out$

$$\text{output from transformer} = out = \text{output from FFN} + \text{output from attention}$$
$$\text{output from FFN} = y$$

$$\frac{\partial out} {\partial y} = 1 $$
$$\frac{\partial out} {\partial \text{output from attention}} = 1$$

$$dy = \frac{\partial \mathcal{L}}{\partial y} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out}{\partial y} = \frac{\partial \mathcal{L}}{\partial out}$$

#### find $db_{shrink}$ and $d\hat{y}$

$$db_{shrink} = \frac {\partial \mathcal{L}}{\partial b_{shrink}} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out} {\partial \hat{y}} \frac{\partial \hat{y}}{\partial b_{shrink}}$$


$$d\hat{y} = \frac{\partial \mathcal{L}}{\partial \hat{y}} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out}{\partial \hat{y}} $$

$$out = \hat{y} + g$$
$$\hat{y} = FFN(g) = act\ @\ W_{shrink} \oplus B_{shrink}$$
$$act = Relu(pre)$$
$$pre = \hat{x} \ @\ W_{expand} \oplus B_{expand}$$
$$\hat{x} = LN(g)

where

* $out = FFN(LN(y)) + y$

* $\mathcal{R}(\hat{y}) = Relu(\hat{y})$ then $\frac{\partial \mathcal{R}(\hat{y})}{\partial \hat{y}} = \begin{bmatrix} 0;\ \text{if}\ \hat{y} \leq 0\\ 1;\ \text{if}\ \hat{y} > 0 \end{bmatrix}^{B \times T \times d_{model}} $
* $\hat{y} = \hat{x} \cdot W_{shrink} + b_{shrink}$ 

then
$$\frac{\partial \hat{y}}{\partial b_{shrink}} = 1$$

$$\frac{\partial \mathcal{L}} {\partial \hat{y}} = \frac{\partial \mathcal{L}} {\partial out} \odot \begin{bmatrix} 0;\ \text{if}\ \hat{y} \leq 0\\ 1;\ \text{if}\ \hat{y} > 0 \end{bmatrix}^{B \times T \times d_{model}} $$

after checking dimension

$$\frac{\partial \mathcal{L}} {\partial b_{shrink}} = \sum_{b, t} \left(\frac{\partial \mathcal{L}} {\partial \hat{y}}  \right)$$

We must sum over batch and token for get $b_{shrink}$


In [61]:
print(y_hat.shape)
print(dout.shape)

(2, 4, 6)
(2, 4, 6)


In [41]:
dyhat = dout
print(dyhat.shape)

(2, 4, 6)


In [None]:
# print(np.sum(dyhat, axis=(0, 1), keepdims=True))
# print(dyhat)

In [43]:
dbshrink = np.sum(dout, axis=(0, 1), keepdims=True)
dbshrink

array([[[ 0.18499614,  0.16850081, -0.27560304,  0.15080012,
          0.24836641,  0.06211941]]])

#### find $dW_{shrink}$

from basic chain rule calculus :
$$dW_{shrink} = \frac{\partial \mathcal {L}}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial W_{shrink}}$$

$$\hat{y} = act\ @\ W_{shrink} \oplus B_{shrink}$$

$$\frac{\partial \hat{y}}{\partial W_{shrink}} = act^\intercal$$

after shape check We got:

$$dW_{shrink} = \sum_b \left[\frac{\partial \hat{y}}{\partial W_{shrink}} @ \frac{\partial \mathcal{L}}{\partial \hat{y}}\right]$$

We must sum over batch for get $dW_{shrink}$

In [63]:
print((dout * (y_hat > 0)).shape)
print(act.transpose(0, 2, 1).shape)
print(W_shrink.shape)

(2, 4, 6)
(2, 24, 4)
(1, 24, 6)


In [47]:
dW_shrink_ns = act.transpose(0, 2, 1) @ dyhat # K = X_expand
dW_shrink = np.sum(dW_shrink_ns, axis = 0)
print(dW_shrink.shape)
# print(dW_shrink)

(24, 6)


#### find $dP$

In [48]:
act.shape

(2, 4, 24)

In [49]:
W_shrink.shape

(1, 24, 6)

In [50]:
dyhat.shape

(2, 4, 6)

$$\hat{y} = Relu\left(P\right)\ @\ W_{shrink} \oplus B_{shrink}$$

$$\frac{\partial \mathcal{R}(P)}{\partial K} = \begin{bmatrix} 0;\ \text{if}\ P \leq 0\\ 1;\ \text{if}\ P > 0 \end{bmatrix}^{B \times T \times d_{model\ expand}} $$

$$\frac{\partial \hat{y}}{\partial \mathcal{R}(P)} = W_{shrink}^\intercal$$

$$\frac{\partial \mathcal{L}}{\partial P} = \left(\frac{\partial \mathcal{L}} {\partial \hat{y}} \odot \frac{\partial \mathcal{R}(P)}{\partial P}\right)\ @\ \frac{\partial \hat{y}}{\partial \mathcal{R}(P)}$$

$$\frac{\partial \mathcal{L}}{\partial P} = \left(\frac{\partial \mathcal{L}} {\partial \hat{y}}  \odot \begin{bmatrix} 0;\ \text{if}\ P \leq 0\\ 1;\ \text{if}\ P > 0 \end{bmatrix}^{B \times T \times d_{model\ expand}} \right)\ @\ W_{shrink}^\intercal$$

In [51]:
print(dyhat.shape)
print(W_shrink.shape)

(2, 4, 6)
(1, 24, 6)


In [52]:
(dyhat @ W_shrink.transpose(0, 2, 1)).shape

(2, 4, 24)

In [53]:
dP = (dyhat * (K > 0)) @ W_shrink.transpose(0, 2, 1)
dP.shape

(2, 4, 24)

In [54]:
dP[1, :, :]

array([[ 0.09254627,  0.11374827,  0.14890023,  0.361972  ,  0.27509936,
         0.30831747,  0.18469182,  0.1264584 ,  0.07065802,  0.25579232,
         0.16540285,  0.22431094,  0.03883917, -0.00982095,  0.24843259,
         0.2386919 ,  0.13247461,  0.20188788,  0.23107782,  0.13955727,
         0.14000969,  0.09093389,  0.26470852,  0.24108853],
       [ 0.09383805,  0.16192291,  0.12122446, -0.06523674, -0.0601916 ,
        -0.04517194,  0.01352094,  0.28599826,  0.27297169,  0.06547371,
         0.1295933 , -0.06751397,  0.37309481,  0.17567669,  0.01578171,
         0.0032237 ,  0.07233978,  0.04661951, -0.06110677,  0.24435766,
         0.2178707 ,  0.18593711,  0.22071681, -0.09817053],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
  

#### find $db_{expand}$

$$P = Relu\left(\hat{g}\right) \ @\ W_{expand} \oplus B_{expand}$$

In [55]:
b_expand.shape

(1, 1, 24)

In [56]:
dP.shape

(2, 4, 24)

$$\frac{\partial \mathcal{L}}{\partial B_{expand}} = \sum_{b, t} \left[ \frac{\partial \mathcal{L}}{\partial P} \frac{\partial P}{\partial b_{expand}} \right]$$

$$\frac{\partial P}{\partial b_{expand}} = 1$$

In [57]:
dBexpand = np.sum(dP, axis = (0, 1), keepdims=True)
print(dBexpand.shape)
print(dBexpand)

(1, 1, 24)
[[[0.15591544 0.42743279 0.4269225  0.36648836 0.36186341 0.29687424
   0.02160679 0.34945075 0.39040559 0.4631125  0.42132805 0.37029689
   0.55943213 0.22046049 0.31935185 0.03408293 0.38088539 0.28494046
   0.26295628 0.6247452  0.68063961 0.47846425 0.62438797 0.41609452]]]


#### find $dW_{expand}$

In [58]:
W_expand.shape

(1, 6, 24)

In [60]:
print(dP.shape)
print(X_FFN_N.shape) # \hat{g} = X_FFN_N

(2, 4, 24)


NameError: name 'X_FFN_N' is not defined

$$P = \hat{g} \ @\ W_{expand} \oplus B_{expand}$$
$$\hat{g} = LN(g)

$$\frac{\partial P}{\partial W_{expand}} = \hat{g}^\intercal$$

$$\frac{\partial \mathcal{L}}{\partial W_{expand}} = \sum_b \left[\frac{\partial P}{\partial W_{expand}}\ @\ \frac{\partial \mathcal{L}}{\partial P} \right]$$

We need to sum over batch

In [None]:
dW_expand = np.sum(X_FFN_N.transpose(0, 2, 1) @ dP, axis = (0), keepdims = True)
print(dW_expand.shape)
print(W_expand.shape)

(1, 6, 24)
(1, 6, 24)


In [None]:
#print(dW_expand)

#### find $d\hat{g}$

In [None]:
print(X_FFN_N.shape)
print(dP.shape)
print(W_expand.transpose(0, 2, 1).shape)

(2, 4, 6)
(2, 4, 24)
(1, 24, 6)


from basic calculus

$$\frac{\partial \mathcal{L}}{\partial \hat{g}} = \frac {\partial \mathcal{L}}{\partial P} \frac{\partial P}{\partial \hat{g}}$$

$$ \frac{\partial P} {\partial \hat{g}} = W_{expand}^\intercal$$

after shape check : 
$$\frac{\partial \mathcal{L}}{\partial \hat{g}} = \frac {\partial \mathcal{L}}{\partial P}\ @\ \frac{\partial P}{\partial \hat{g}}$$

In [None]:
dghat = dP @ W_expand.transpose(0, 2, 1)
print(dghat.shape)

(2, 4, 6)


#### find $dg$ the most hardest one

$$ d\hat{g} = LN(g)$$

from old proof result :
$$ \frac {\partial \hat{g}_i}{\partial g_j} = \frac{1}{\sqrt{\sigma^2 +\epsilon}} \left(\delta_{i, j} - \frac{1}{d} - \frac{\hat{g}_i \hat{g}_j}{d}\right)$$

สมการที่เราจะนำไปใช้

$$\frac{\partial \mathcal{L}} {\partial g_j} = \frac{1}{d_{model} \cdot \sqrt{\sigma^2 + \epsilon}} \left( d_{model} \cdot \frac{\partial \mathcal {L}}{\partial \hat{g}_j} - \sum_{i = 1}^{d_{model}}  \frac{\partial \mathcal{L}}{\partial \hat{g}_i}  - \hat{g}_j \odot \sum_{i = 1}^{d_{model}} \left( \frac{\partial \mathcal{L}}{\partial \hat{g}_i} \odot \hat{g}_i \right) \right)

let focus on $\sum_{i = 1}^{d_{model}}  \frac{\partial \mathcal{L}}{\partial \hat{g}_i}$ and $\sum_{i = 1}^{d_{model}} \left( \frac{\partial \mathcal{L}}{\partial \hat{g}_i} \odot \hat{g}_i \right)$

In [None]:
sum_dghat = np.sum(dghat, axis=-1, keepdims=True)
print(sum_dghat.shape)

(2, 4, 1)


In [None]:
dghat_ghat = dghat * X_FFN_N
sum_dghat_ghat = np.sum(sum_dghat, axis=-1, keepdims=True)
print(dghat_ghat.shape)
print(sum_dghat_ghat)

(2, 4, 6)
[[[ 0.10582762]
  [ 8.05604222]
  [-0.08013945]
  [ 0.        ]]

 [[13.14283015]
  [ 7.469995  ]
  [ 0.        ]
  [ 0.        ]]]


In [None]:
X_FFN_N * sum_dghat_ghat

array([[[ 4.98286931e-02,  7.89300993e-02,  3.91741309e-02,
          1.13636335e-01,  1.60453731e-02,  3.37611687e-02],
        [ 3.81639419e+00,  5.68172836e+00,  9.87808031e+00,
          4.70202821e+00, -4.49591145e-01,  5.15341405e+00],
        [-3.82331753e-02, -5.83625470e-02, -2.48523796e-03,
         -7.46456370e-02,  1.26523953e-02, -6.98491910e-02],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00, -0.00000000e+00,  0.00000000e+00]],

       [[ 6.11825964e+00,  9.29201578e+00,  1.31411423e+01,
          8.74231312e+00,  7.11336493e+00,  3.11876000e+00],
        [ 3.84083506e+00,  5.06040090e+00,  3.56066361e+00,
          3.01471230e+00,  1.53488149e+00,  6.13086693e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]])

In [None]:
sigma = np.std(X_FFN_N, axis=-1, keepdims=True)
epsilon = 1e-5
inv_deriv = 1 / (d_model * (sigma ** 2 + epsilon))
print(inv_deriv)

[[[1.78991956]
  [1.17511731]
  [0.97339223]
  [1.04984681]]

 [[3.04226182]
  [4.35544061]
  [1.63657814]
  [0.51606866]]]


In [None]:
dg = inv_deriv * (d_model * dghat - sum_dghat - X_FFN_N * sum_dghat_ghat )
print(dg.shape)
print(dg)

(2, 4, 6)
[[[  0.97997072  -0.16652285  -2.10783724  -1.55173116   2.30296505
    -0.04998054]
  [ -2.8818212   -7.92454198 -10.87078277  -1.79104115  -1.72678326
    -8.62731939]
  [  0.19124787   0.09310451   0.45285393  -1.2424103   -0.27184143
     1.00182446]
  [  0.           0.           0.           0.           0.
     0.        ]]

 [[-10.57370359 -34.00716466 -35.99693221 -16.50121414 -30.73078506
   -16.77629679]
  [ -7.54138014 -29.46001436  -8.83065497 -17.81658203 -13.05791069
   -24.08863357]
  [  0.           0.           0.           0.           0.
     0.        ]
  [  0.           0.           0.           0.           0.
     0.        ]]]


---

## Backward enter to the attention

from the Pre-LN structure we currently at
$$g = X + attentionResult$$

In [None]:
dg.shape

(2, 4, 6)

In [None]:
out.shape # attn result + X + result from FFN

(2, 4, 6)

from the output of attention in short term:

$$output = X + FFN(LN_2(X))$$
$$X = Input + ATTN(LN_1(Input))$$

$$\frac{\partial output}{\partial X} = I + \frac{\partial FFN}{\partial LN_2(X)}\frac{\partial LN_2(X)}{\partial X}$$

when $I$ is identity matrix

$$\frac{\partial \mathcal{L}}{\partial output} \frac{\partial output}{\partial X} = \frac{\partial \mathcal{L}}{\partial output} + \frac{\partial \mathcal{L}}{\partial output}\frac{\partial FFN}{\partial LN_2(X)} \frac{\partial LN_2(X)}{\partial X}$$

In computational way to find $\frac{\partial \mathcal{L}}{\partial X}$

ให้เขียนแบบแยก Residual เพื่อให้ไม่เผลอทำให้สมการวนเองและทำให้ implement ง่าย

ให้ $g_{out} = \frac{\partial \mathcal{L}}{\partial output}$

* ฝั่ง residual (จาก $X$ ใน $X + ...$) : 

$$g_X^{(res)} = g_{out}$$

* ฝั่ง FFN :

$$g_{ffn\_out} = g_{out}$$
$$g_{LN_2} = g_{ffn\_out} \cdot \frac{\partial FFN}{\partial LN_2(X)}$$

$$g_{X}^{(ffn)} = g_z \cdot \frac{\partial LN_2(X)}{\partial X}$$

* รวม :
$$ g_X = g_{out} + g_X^{(ffn)}$$



In [None]:
dFFN_input = dout + dg
print(dFFN_input.shape)
print(dFFN_input)

(2, 4, 6)
[[[  1.0578527   -0.17212368  -2.5515558   -1.16488631   2.42915855
    -0.18214402]
  [ -2.72038844  -7.8877679  -10.62831405  -1.80248041  -1.86952067
    -8.6838215 ]
  [  0.11586316   0.12472053   0.3419919   -1.31606106  -0.14511256
     1.08968364]
  [  0.           0.           0.           0.           0.
     0.        ]]

 [[-10.41152587 -33.97556578 -35.75014061 -16.5229421  -30.85897756
   -16.80589948]
  [ -7.68249175 -29.3859017   -9.04093774 -17.94580878 -12.79153674
   -23.89610506]
  [  0.           0.           0.           0.           0.
     0.        ]
  [  0.           0.           0.           0.           0.
     0.        ]]]


## pass to attention now 

$$ATTN(LN(X)) = concat(context)$$

$$context = attn\ @\ V$$

$$attn = softmax(scores)$$

split before operate @ at this step
$$scores = \frac {Q\ @\ K^\intercal}{\sqrt{d_K}} + \text{causal mask}$$

In [None]:
dATTN = dFFN_input

#### split $dATTN$ into $dContext$

In [None]:
print(dATTN.shape)

(2, 4, 6)


In [None]:
dContext = split_heads(dATTN, H=H)
print(dContext.shape)

(2, 2, 4, 3)


#### focus on first head for intuition

$$\frac{\partial \mathcal{L}}{\partial ATTN_{splited}} = dATTN\_splited$$

$$\frac{\partial Context}{\partial V} = softmax(scores)^\intercal$$

In [None]:
print(dContext.shape)
print(attn.shape)
print(Vsplit.shape)

(2, 2, 4, 3)
(2, 2, 4, 4)
(2, 2, 4, 3)


$$\frac{\partial \mathcal{L}}{\partial V} =softmax(scores)^\intercal\ @\ \frac{\partial \mathcal{L}}{\partial Context}$$

In [None]:
dV = attn.transpose(0, 1, 3, 2) @ dContext
print(dV.shape)
print(dV)

(2, 2, 4, 3)
[[[[-1.16256063e-01 -3.63026034e+00 -7.16688709e+00]
   [-1.46983298e+00 -4.34734044e+00 -5.78733596e+00]
   [ 3.94164637e-02  4.24297264e-02  1.16345104e-01]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[-2.55089632e+00  1.39236727e+00 -4.41855230e+00]
   [-1.28607725e+00 -9.28614660e-01 -3.72738861e+00]
   [-4.46454207e-01 -4.92272854e-02  3.69659022e-01]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]]


 [[[-1.40933832e+01 -4.80588479e+01 -4.00830373e+01]
   [-4.00063438e+00 -1.53026196e+01 -4.70804102e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[-2.63510594e+01 -3.78643304e+01 -2.98927281e+01]
   [-8.11769150e+00 -5.78618386e+00 -1.08092765e+01]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]]]


#### find $dW_V$

merge dV first :

In [None]:
dV_merge = merge_heads(dV)
print(dV_merge.shape)

(2, 4, 6)


In [None]:
Wv.shape

(1, 6, 6)

In [None]:
X.shape

(2, 4, 6)

from 
$$V = X\ @\ W_V$$

$$\frac{\partial \mathcal {L}}{\partial W_V} =\frac{\partial V}{\partial W_V}\ @\ \frac{\partial \mathcal{L}}{\partial V} = X^\intercal\ @\ \frac{\partial \mathcal{L}}{\partial V} $$

In [None]:
dWv = np.sum(X.transpose(0, 2, 1) @ dV_merge, axis=0, keepdims=True)
print(dWv.shape)

(1, 6, 6)


In [None]:
print(dWv)

[[[-15.43804388 -54.86863683 -43.12976893 -29.26776579 -34.02016355
   -36.97075784]
  [-15.59123215 -57.59563381 -44.77376634 -32.15224294 -35.88979674
   -38.71691176]
  [-17.20665478 -59.93035976 -47.54436249 -32.02817941 -40.29560219
   -39.32913064]
  [ -7.42859586 -27.7050975  -26.22201367 -15.31355099 -16.84661084
   -19.30048375]
  [-13.04282567 -45.88832009 -35.60805735 -25.17318128 -32.46598093
   -29.55954416]
  [ -4.71632316 -17.76261088 -12.06553274  -9.49390078  -7.97288376
   -12.71485475]]]


#### find $dX_V$

$$\frac{\partial \mathcal {L}}{\partial X_V} =\frac{\partial \mathcal{L}}{\partial V}\ @\ \frac{\partial V}{\partial X_V} = \frac{\partial \mathcal{L}}{\partial V}\ @\ W_V^\intercal$$

In [None]:
dXv = dV_merge @ Wv.transpose(0, 2, 1)
print(dXv.shape)

(2, 4, 6)


In [None]:
print(dXv)

[[[-7.43729644e+00 -2.50752730e+00 -5.83894625e+00 -6.55832461e+00
   -5.80127809e+00 -4.37458854e+00]
  [-7.74376043e+00 -4.32025735e+00 -6.06056059e+00 -7.04344685e+00
   -7.02860481e+00 -4.96742217e+00]
  [ 7.57156717e-02 -1.51675251e-01 -8.06214168e-02 -3.92865369e-02
   -1.88398155e-01  2.26156821e-02]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]

 [[-8.17923757e+01 -6.86978422e+01 -7.73256392e+01 -8.04234764e+01
   -1.01655834e+02 -6.22134715e+01]
  [-2.01065269e+01 -1.52312097e+01 -2.23442024e+01 -1.92139394e+01
   -2.40871258e+01 -1.73065470e+01]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]]]


#### find $dA$ or dAttention

In [None]:
dContext.shape

(2, 2, 4, 3)

In [None]:
Vsplit.shape

(2, 2, 4, 3)

$$\frac{\partial \mathcal{L}}{\partial A} =\frac{\partial \mathcal{L}}{\partial Context} \ @\ \frac{\partial Context}{\partial attn} = \frac{\partial \mathcal{L}}{\partial Context}\ @\ V^\intercal$$

In [None]:
dA =  dContext @ Vsplit.transpose(0, 1, 3, 2) 
print(dA.shape)

(2, 2, 4, 4)


#### find $dScores$

from gradient softmax equation : 

$$\frac{\partial \mathcal{L}} {\partial x_j} = p_j \left( \frac{\partial \mathcal{L}}{\partial S(x_j)} - \sum_i \left( p_i \odot \frac{\partial \mathcal{L}}{\partial S(x_i)} \right)  \right)$$

where :
* $p_j$ is attention
* $x_j$ is scores

In [None]:
sum_A_dA = np.sum(attn * dA, axis=-1, keepdims=True)
print(sum_A_dA.shape)

(2, 2, 4, 1)


In [None]:
dScores = attn * (dA - sum_A_dA)
print(dScores.shape)

(2, 2, 4, 4)


#print(dScores)

#### find $dQ$

$$\frac{\partial Scores}{\partial Q} = \frac{K}{\sqrt{d_K}}$$

In [None]:
Ksplit.shape

(2, 2, 4, 3)

In [None]:
dScores.shape

(2, 2, 4, 4)

$$\frac{\partial \mathcal{L}}{\partial Q} = \frac{\partial \mathcal{L}}{\partial Scores}\ @\ \frac{\partial Scores}{\partial Q} 
    = \frac{\partial \mathcal{L}}{\partial Scores}\ @\ \frac{K}{\sqrt{d_K}}$$

In [None]:
dQ = dScores @ Ksplit / np.sqrt(dk)
print(dQ.shape)

(2, 2, 4, 3)


#### find $dW_Q$

from 
$$ Q = X @ W_Q$$

then

$$\frac{\partial \mathcal {L}}{\partial Q} = \frac{\partial \mathcal{L}}{\partial Q} @ \frac{\partial Q}{\partial W_Q} = \frac{\partial \mathcal{L}}{\partial Q}\ @\ X^\intercal$$

In [None]:
merge_heads(dQ).shape

(2, 4, 6)

In [None]:
dQ.shape

(2, 2, 4, 3)

In [None]:
X.shape

(2, 4, 6)

In [None]:
dWq = np.sum(X.transpose(0, 2, 1) @ merge_heads(dQ), axis=(0), keepdims=True)
print(dWq.shape)

(1, 6, 6)


#### find $dX_Q$

from 
$$ Q = X @ W_Q$$

then

$$\frac{\partial \mathcal {L}}{\partial X_Q} = \frac{\partial \mathcal{L}}{\partial Q} @ \frac{\partial Q}{\partial X_Q} = \frac{\partial \mathcal{L}}{\partial Q}\ @\ W_Q^\intercal$$

In [None]:
Wq.shape

(1, 6, 6)

In [None]:
merge_heads(dQ).shape

(2, 4, 6)

In [None]:
dXq = merge_heads(dQ) @ Wq.transpose(0, 2, 1)

#### find $K$

from
$$scores = \frac {Q\ @\ K^\intercal}{\sqrt{d_K}} + \text{causal mask}$$

In [None]:
dScores.shape

(2, 2, 4, 4)

In [None]:
Ksplit.shape

(2, 2, 4, 3)

$$\frac{\partial scores} {\partial K^\intercal} = \frac{Q^\intercal}{\sqrt{d_K}}$$

In [None]:
dK = ((Qsplit.transpose(0, 1, 3, 2) / np.sqrt(dk) ) @ dScores).transpose(0, 1, 3, 2)

In [None]:
print(dK.shape)

(2, 2, 4, 3)


#### find $dW_K$

$$ K = X @ W_K$$

$$\frac{\partial \mathcal {L}}{\partial W_K} = X^\intercal\ @\ \frac{\partial \mathcal{L}}{\partial K}$$

In [None]:
Wk.shape

(1, 6, 6)

In [None]:
X.shape

(2, 4, 6)

In [None]:
dWk = np.sum(X.transpose(0, 2, 1) @ merge_heads(dK), axis=0, keepdims=True)
print(dWk.shape)

(1, 6, 6)


In [None]:
print(dWk)

[[[-0.17935903 -0.07379323 -0.14400758 -0.91333167  0.03073453
   -0.36477625]
  [ 0.36318901  0.21756336  0.3267497   1.6798761  -0.05577867
    0.6686711 ]
  [-0.35030301 -0.39009919 -0.40854123 -1.17882025  0.03662745
   -0.46731342]
  [ 0.02246844 -0.19399796 -0.08732353  0.61095559 -0.02322782
    0.24470773]
  [-0.03334504 -0.24139602 -0.14472397  0.38781607 -0.01597566
    0.1565735 ]
  [ 0.09607527  0.3538624   0.23997033 -0.28085891  0.01347865
   -0.1146539 ]]]


#### find $dX_K$

In [None]:
Wk.shape

(1, 6, 6)

In [None]:
merge_heads(dK).shape

(2, 4, 6)

$$\frac{\partial \mathcal{L}}{\partial X_K} = \frac{\partial \mathcal{L}}{\partial K}\ @\ W_K^\intercal$$

In [None]:

dXk = merge_heads(dK) @ Wk.transpose(0, 2, 1)

### merge dX together

In [None]:
dXk

array([[[ 2.28190223e+00,  3.32201153e+00,  1.00404008e+00,
          2.68353296e+00,  2.74213251e+00,  1.30060156e+00],
        [-2.29092147e+00, -3.33652361e+00, -1.00510459e+00,
         -2.69443062e+00, -2.74173995e+00, -1.30112496e+00],
        [ 9.01923789e-03,  1.45120776e-02,  1.06450916e-03,
          1.08976515e-02, -3.92558889e-04,  5.23399972e-04],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[-7.51312887e-01, -9.47059576e-01, -5.63267846e-01,
         -7.21042089e-01, -1.54825926e+00, -8.84453573e-01],
        [ 7.51312887e-01,  9.47059576e-01,  5.63267846e-01,
          7.21042089e-01,  1.54825926e+00,  8.84453573e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]])

In [None]:
dXq

array([[[ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [-0.23289814,  0.42042467,  0.40815389, -0.78767123,
         -0.28326874, -0.44140456],
        [ 0.00287003,  0.00981757,  0.00204284,  0.00510431,
          0.01183887,  0.00884284],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.02299618,  0.13356724, -0.12146742, -0.06821528,
         -0.05300299,  0.14694192],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])

In [None]:
dXv

array([[[-7.43729644e+00, -2.50752730e+00, -5.83894625e+00,
         -6.55832461e+00, -5.80127809e+00, -4.37458854e+00],
        [-7.74376043e+00, -4.32025735e+00, -6.06056059e+00,
         -7.04344685e+00, -7.02860481e+00, -4.96742217e+00],
        [ 7.57156717e-02, -1.51675251e-01, -8.06214168e-02,
         -3.92865369e-02, -1.88398155e-01,  2.26156821e-02],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[-8.17923757e+01, -6.86978422e+01, -7.73256392e+01,
         -8.04234764e+01, -1.01655834e+02, -6.22134715e+01],
        [-2.01065269e+01, -1.52312097e+01, -2.23442024e+01,
         -1.92139394e+01, -2.40871258e+01, -1.73065470e+01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]])

In [None]:
dX = dXk + dXv + dXq

In [None]:
dX

array([[[-5.15539420e+00,  8.14484236e-01, -4.83490617e+00,
         -3.87479165e+00, -3.05914558e+00, -3.07398698e+00],
        [-1.02675800e+01, -7.23635629e+00, -6.65751129e+00,
         -1.05255487e+01, -1.00536135e+01, -6.70995170e+00],
        [ 8.76049346e-02, -1.27345601e-01, -7.75140654e-02,
         -2.32845723e-02, -1.76951840e-01,  3.19819256e-02],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[-8.25436886e+01, -6.96449017e+01, -7.78889070e+01,
         -8.11445185e+01, -1.03204094e+02, -6.30979250e+01],
        [-1.93322178e+01, -1.41505828e+01, -2.19024019e+01,
         -1.85611126e+01, -2.25918695e+01, -1.62751515e+01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]]])