In [1]:
import numpy as np

In [2]:
dh = np.array([[[ 0.0047309 , -0.02851535, -0.13561962,  0.07165096,
          0.01057472, -0.03511244],
        [ 0.04032968, -0.01704817,  0.07002992, -0.04101618,
         -0.05707668, -0.03169758],
        [-0.02885697,  0.04073668, -0.04297836, -0.02013535,
          0.04352404,  0.03589717],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.05022935, -0.02123297,  0.08722008, -0.05108438,
         -0.07108724, -0.03947835],
        [-0.05051402,  0.07130938, -0.07523346, -0.03524686,
          0.07618864,  0.06283784],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])

In [3]:
rng = np.random.default_rng(1000)

In [23]:
eps = 1e-5

## Dive in math
---
from norm formula

$$ LN(x) = \frac{x - \mu}{\sigma + \epsilon} \gamma + \beta$$

from the cross-entropy and softmax gradient we got

$$dZ = \frac {\partial \mathcal {L}} {\partial Z} = \frac {\partial \mathcal {L}}{\partial P} \frac {\partial P} {\partial Z} = \frac{P - Y} {N}$$

where 
* $Z$ is logits
* $P$ is $softmax\left(Z\right)$
* $N$ is token count in loss after mask
* $\mathcal {L}$ is loss function cross entropy

at the LM Head we got

$$Z = h\ @\ W_{vocab} + b_{vocab}$$

where 
* $h \in \mathbb{R} ^ {B \times T \times d_{model}}$
* $W \in \mathbb{R} ^ {d_model \times V}$
* $b \in \mathbb{R} ^ {V}$

$$\frac {\partial \mathcal {L}}{\partial W} = dW = \sum_{b, t} h_{b,t}^\intercal dZ_{b, t}$$

$$\frac {\partial \mathcal {L}}{\partial b} = db = \sum_{b, t} dZ_{b, t}$$

and $dh$

$$ \frac {\partial \mathcal {L}} {\partial h} = \frac {\partial \mathcal {L}} {\partial Z} \cdot \frac {\partial Z} {\partial h} = dh$$

$$dh_{b, t} = dZ_{b, t} W^\intercal$$

We will start from here ->

$$h = LN(y)$$
when $y$ is an output matrix from transformer stack



$$LN(\mathbf{X}) = (\hat{x} \odot \gamma) \oplus \beta$$

where 

$$\hat{x} = \frac {\mathbf{X} - \mu} {\sqrt{\sigma^2 + \epsilon}}$$

What is $\odot$ operator do?

give 
$$\mathbf{A} = \begin{bmatrix} a_{1, 1} & a_{1, 2} \\ a_{2, 1} & a_{2, 2} \end{bmatrix}$$
and
$$\mathbf{B} = \begin{bmatrix} b_{1, 1} \\ b_{2, 1} \end{bmatrix}$$
then 
$$ \mathbf{A} \odot \mathbf{B} = \begin{bmatrix} 
    a_{1, 1} \times b_{1, 1} & a_{1, 2} \times b_{1, 1}\\
    a_{2, 1} \times b_{2, 1} & a_{2, 2} \times b_{2, 1}
    \end{bmatrix}$$

In [4]:
B = 2
V = 20
d_model = 6
T = 4

In [5]:
X = rng.random((B, T, d_model), np.float64)
print(X)

[[[0.52138574 0.60384185 0.4709418  0.20324794 0.52875903 0.19103628]
  [0.2815456  0.75368155 0.55167178 0.86372208 0.80537222 0.24837266]
  [0.18985741 0.98399558 0.66999717 0.28038283 0.20391323 0.62506469]
  [0.65260432 0.89880753 0.97476378 0.15393237 0.69908928 0.44724145]]

 [[0.01751321 0.29102491 0.38123661 0.32102791 0.94254467 0.70266697]
  [0.13645032 0.34320907 0.8119946  0.148494   0.05932569 0.31441663]
  [0.42015645 0.80801771 0.00950759 0.45408379 0.55868699 0.00288863]
  [0.29775757 0.05379911 0.56766875 0.94055815 0.72427372 0.85637809]]]


In [6]:
gamma = rng.random((d_model), np.float64);
beta = rng.random((d_model), np.float64);
print("gamma =>", gamma)
print("beta =>", beta)

gamma => [0.54431565 0.37965069 0.60442473 0.73460399 0.98840812 0.89224091]
beta => [0.51964178 0.20820684 0.29894051 0.83552541 0.18450649 0.18375199]


บางกรณีจะมีการแปลงเวคเตอร์ $\gamma$ เป็น diagonal matrix ตามด้านล่างเพื่อให้คูณได้ตามวิธีมาตรฐานของ math

In [7]:
gamma_diag = np.diag(gamma)
print(gamma_diag)

[[0.54431565 0.         0.         0.         0.         0.        ]
 [0.         0.37965069 0.         0.         0.         0.        ]
 [0.         0.         0.60442473 0.         0.         0.        ]
 [0.         0.         0.         0.73460399 0.         0.        ]
 [0.         0.         0.         0.         0.98840812 0.        ]
 [0.         0.         0.         0.         0.         0.89224091]]


In [8]:
print(X @ gamma_diag)

[[[0.28379842 0.22924897 0.28464887 0.14930675 0.52262972 0.17045038]
  [0.15324968 0.28613572 0.33344406 0.63449369 0.79603645 0.22160825]
  [0.10334236 0.3735746  0.40496285 0.20597035 0.2015495  0.55770828]
  [0.35522274 0.3412329  0.58917133 0.11307933 0.69098552 0.39904712]]

 [[0.00953271 0.11048781 0.23042883 0.23582839 0.93161881 0.62694822]
  [0.07427204 0.13029956 0.49078962 0.10908429 0.058638   0.28053538]
  [0.22869773 0.30676448 0.00574662 0.33357176 0.55221075 0.00257736]
  [0.1620741  0.02042487 0.34311303 0.69093777 0.71587803 0.76409557]]]


แต่ใน Python เราสามารถใช้เครื่องหมาย * ในการคูณตำแหน่งต่อตำแหน่งได้เลย (ฺBroadcasting)

$\mathbf{X} \odot \gamma = $

In [9]:
print(X * gamma) 

[[[0.28379842 0.22924897 0.28464887 0.14930675 0.52262972 0.17045038]
  [0.15324968 0.28613572 0.33344406 0.63449369 0.79603645 0.22160825]
  [0.10334236 0.3735746  0.40496285 0.20597035 0.2015495  0.55770828]
  [0.35522274 0.3412329  0.58917133 0.11307933 0.69098552 0.39904712]]

 [[0.00953271 0.11048781 0.23042883 0.23582839 0.93161881 0.62694822]
  [0.07427204 0.13029956 0.49078962 0.10908429 0.058638   0.28053538]
  [0.22869773 0.30676448 0.00574662 0.33357176 0.55221075 0.00257736]
  [0.1620741  0.02042487 0.34311303 0.69093777 0.71587803 0.76409557]]]


In [10]:
xhat = X * gamma + beta
print(xhat)

[[[0.8034402  0.43745581 0.58358938 0.98483216 0.70713621 0.35420238]
  [0.67289146 0.49434256 0.63238458 1.47001909 0.98054294 0.40536024]
  [0.62298414 0.58178144 0.70390337 1.04149576 0.38605599 0.74146028]
  [0.87486452 0.54943974 0.88811185 0.94860474 0.87549201 0.58279911]]

 [[0.52917449 0.31869464 0.52936935 1.0713538  1.1161253  0.81070021]
  [0.59391382 0.3385064  0.78973013 0.94460969 0.24314449 0.46428738]
  [0.74833951 0.51497132 0.30468713 1.16909717 0.73671725 0.18632935]
  [0.68171589 0.22863171 0.64205354 1.52646318 0.90038452 0.94784756]]]


from previous episode (LM-Head gradient)

$$ h = LM(\mathbf{X}) = (\hat{x}\odot\gamma) \oplus \beta $$

We got dh from previous episode &rarr;

In [11]:
dh

array([[[ 0.0047309 , -0.02851535, -0.13561962,  0.07165096,
          0.01057472, -0.03511244],
        [ 0.04032968, -0.01704817,  0.07002992, -0.04101618,
         -0.05707668, -0.03169758],
        [-0.02885697,  0.04073668, -0.04297836, -0.02013535,
          0.04352404,  0.03589717],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.05022935, -0.02123297,  0.08722008, -0.05108438,
         -0.07108724, -0.03947835],
        [-0.05051402,  0.07130938, -0.07523346, -0.03524686,
          0.07618864,  0.06283784],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])

In [12]:
dh.shape

(2, 4, 6)

In [13]:
dh.reshape(B * T , d_model)

array([[ 0.0047309 , -0.02851535, -0.13561962,  0.07165096,  0.01057472,
        -0.03511244],
       [ 0.04032968, -0.01704817,  0.07002992, -0.04101618, -0.05707668,
        -0.03169758],
       [-0.02885697,  0.04073668, -0.04297836, -0.02013535,  0.04352404,
         0.03589717],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ],
       [ 0.05022935, -0.02123297,  0.08722008, -0.05108438, -0.07108724,
        -0.03947835],
       [-0.05051402,  0.07130938, -0.07523346, -0.03524686,  0.07618864,
         0.06283784],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ]])

$$d\beta = \frac{\partial \mathcal {L}} {\partial \beta} = \sum_{batch}\sum_{sequence} \frac{\partial \mathcal {L}}{\partial h} =\sum_{batch}\sum_{sequence} dh$$

In [14]:
dBeta = np.sum(dh.reshape(B*T, d_model), axis=0, keepdims=True)
dBeta

array([[ 0.01591894,  0.04524957, -0.09658144, -0.07583181,  0.00212348,
        -0.00755336]])

$$ d\gamma = \frac {\partial \mathcal {L}}{\partial \gamma} = \sum_{batch} \sum_{sequence} \left(\frac{\partial \mathcal {L}} {\partial h} \odot \hat{x}\right)$$

In [15]:
dh.shape, xhat.shape

((2, 4, 6), (2, 4, 6))

In [16]:
dgamma = np.sum((dh * xhat).reshape(B * T, d_model), axis=0, keepdims = True)
print(dgamma)
print("dgamma shape =", dgamma.shape)

[[ 0.00954017  0.02016985 -0.07835543 -0.09872525 -0.09250307 -0.00149981]]
dgamma shape = (1, 6)


$$d \hat{x} = \frac{\partial \mathcal{L}}{\partial \hat{x}} = \left( \frac{\partial \mathcal{L}} {\partial h} \odot \frac{\partial h}{\partial \hat{x}} \right) = \frac{\partial \mathcal{L}} {\partial h} \odot \gamma$$

In [17]:
dxhat = dh * gamma
print(dxhat)

[[[ 0.0025751  -0.01082587 -0.08197185  0.05263508  0.01045214
   -0.03132876]
  [ 0.02195208 -0.00647235  0.04232782 -0.03013065 -0.05641505
   -0.02828188]
  [-0.0157073   0.01546571 -0.02597718 -0.01479151  0.04301951
    0.03202892]
  [ 0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.02734062 -0.00806111  0.05271797 -0.03752679 -0.07026321
   -0.0352242 ]
  [-0.02749557  0.02707266 -0.04547296 -0.02589248  0.07530547
    0.05606649]
  [ 0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.        ]]]


## Derived $\frac {\partial \hat{x}}{\partial \mathbf {X}}$

next is the hardest part of the transformer model &rarr;
$$\hat{x} = \frac {\mathbf{X} - \mu} {\sqrt{\sigma^2 + \epsilon}}$$

$$\frac{\partial \hat{x}}{\partial X} = ?$$

given :
$$u = \mathbf{X} - \mu$$
$$v = \sqrt{\sigma^2 + \epsilon}$$

$$\frac {\partial \hat{x}} {\partial \mathbf{X}} = \frac{\frac{\partial \left(\mathbf{X} - \mu\right) } {\partial \mathbf{X}} \cdot \sqrt{\sigma^2 + \epsilon} -  \left(\mathbf{X} - \mu\right)\cdot \frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial \mathbf{X}}}{\sigma^2 + \epsilon}$$

$$\frac{\partial \hat{x}_i}{\partial x_j} = \frac{1}{\sigma^2 + \epsilon}  \left(  \frac{\partial \left(x_i - \mu\right)}{\partial x_j} \cdot \sqrt {\sigma ^ 2 + \epsilon} - \left(x_i - \mu\right) \cdot \frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial x_j} \right)$$

* $\hat{x}_i$ หมายถึง ขาออกจากสมการ Norm แต่ละ feature / hidden dimension (มิติของ d_model)
* $x_j$ หมายถึง ขาเข้าสมการ Norm แต่ละเส้น

จากสมการนอร์ม $$\hat{x}_i = \frac {x_j - \mu} {\sqrt{\sigma^2 + \epsilon}}$$

เมื่อลองพิจารณาดูดี ๆ แล้วเราจะพบว่า $\hat{x}_i$ แต่ละตำแหน่งไม่ได้ขึ้นอยู่กับแค่ $x_j; \text{when } j = i$ เพียงแค่ค่าเดียวแต่ขึ้นกับตัวอื่นด้วย ผ่านทาง $\mu$ และ $\sigma$

ทำให้เราต้องหาการเปลี่ยนแปลงของ $\hat{x}_i$ เทียบกับ $x_j$ ทุกตัว หากว่า $\hat{x}_i$ เป็นเวคเตอร์ $m$ มิติ และ $x_j$ เป็นเวคเตอร์ $n$ มิติแล้วเราจะได้ว่า $$\frac{\partial \hat{x}_i}{\partial x_j} \in \mathbb{R}^{m \times n}$$



#### first part
---
$$\frac{\partial \left(\mathbf{X} - \mu\right) } {\partial \mathbf{X}}$$

$$ \frac{\partial \hat{
    \mathbf{x}}}{\partial \mathbf{x}} = \begin{bmatrix} 
    \frac{\partial \hat{x}_1}{\partial x_1} & \frac{\partial \hat{x}_1}{\partial x_2} & \cdots & \frac{\partial \hat{x}_1}{\partial x_n} \\
    \frac{\partial \hat{x}_2}{\partial x_1} & \frac{\partial \hat{x}_2}{\partial x_2} & \cdots & \frac{\partial \hat{x}_2}{\partial x_n} \\
    \vdots & \vdots & \ddots & \vdots \\
    \frac{\partial \hat{x}_m}{\partial x_1} & \frac{\partial \hat{x}_m}{\partial x_2} & \cdots & \frac{\partial \hat{x}_m}{\partial x_n} \end{bmatrix}$$

where $m = n = d_{model}$

$$\frac {\partial x_i}{\partial x_j} = \delta_{i, j}$$

and $$\frac{\partial \mu} {\partial \mathbf{x_j}} = \frac{\partial}{\partial x_j} \frac {x_1 + x_2 + \cdots + x_n}{d} = \frac{1}{d}$$

or $$\frac{\partial \mu} {\partial \mathbf{x}_j} = \left[\frac{1}{d}\right]^{1 \times d_{model}}$$

then $$\frac{\partial \left(x_i - \mu\right)}{\partial x_j} = \delta_{i, j} - \frac{1}{d}$$

#### Second part
---

$$\frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial x_j}$$

use chain rule:
$$\frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial x_j} = \frac{1}{2 \cdot \sqrt{\sigma^2 + \epsilon}} \cdot \frac{\partial \sigma^2}{\partial x_j}$$


where $$\sigma^2 = \frac{1}{d}\sum_k \left( x_k - \mu \right)^2$$


$$\frac{\partial \sigma^2}{\partial x_j} = \frac{2}{d} \sum_k \left(x_k - \mu\right) \frac{\partial \left(x_k - \mu\right)}{\partial x_j}


$$\frac{\partial \sigma^2}{\partial x_j} = \frac{2}{d}\sum_k \left(x_k - \mu\right) \left(\delta_{jk} - \frac{1}{d}\right)$$

จัดรูป

$$ \frac{\partial \sigma^2} {\partial x_j} = \frac{2} {d} \left[(x_j - \mu) \left(1 - \frac{1}{d}\right) + \sum_{j \neq k} \left(x_k - \mu \right) \cdot - \frac {1}{d} \right]$$


กระจาย $\left(x_j - \mu \right)$ เข้าไปคูณ $\left(1 - \frac{1}{d}\right)$ ได้ว่า

$$ \frac{\partial \sigma^2} {\partial x_j} = \frac{2} {d} \left[(x_j - \mu) - \left( \frac{x_j - \mu}{d}\right) + \sum_{j \neq k} - \frac {\left(x_k - \mu \right)}{d} \right]$$


สังเกตุที่พจน์ $$- \left( \frac{x_j - \mu}{d}\right) + \sum_{j \neq k} - \frac {\left(x_k - \mu \right)}{d} = \sum_{k} - \frac {\left(x_k - \mu \right)}{d}$$

จากความเป็นจริงของสถิติ :
$$\sum_{k}\left(x_k - \mu \right) = 0$$

ทำให้ :

$$ \frac{\partial \sigma^2} {\partial x_j} = \frac{2(x_j - \mu)} {d} $$

then finally we got:

$$\frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial x_j} = \frac{1}{2 \cdot \sqrt{\sigma^2 + \epsilon}} \cdot \frac{\partial \sigma^2}{\partial x_j}$$

$$\frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial x_j} = \frac{x_j - \mu}{d \cdot \sqrt{\sigma^2 + \epsilon}} $$


### Combined it together

$$\frac {\partial \hat{x}} {\partial \mathbf{X}} = \frac{\frac{\partial \left(\mathbf{X} - \mu\right) } {\partial \mathbf{X}} \cdot \sqrt{\sigma^2 + \epsilon} -  \left(\mathbf{X} - \mu\right)\cdot \frac{\partial \sqrt{\sigma^2 + \epsilon}}{\partial \mathbf{X}}}{\sigma^2 + \epsilon}$$

$$ \frac {\partial \hat{x}_i}{\partial x_j} = \frac{1}{\sigma^2 +\epsilon} \left[\left( \delta_{i, j} - \frac{1}{d}\right) \cdot \sqrt{\sigma^2 + \epsilon} - \left(x_i - \mu\right) \cdot \frac{x_j - \mu}{d \cdot \sqrt{\sigma^2 + \epsilon}}\right]$$

$$ \frac {\partial \hat{x}_i}{\partial x_j} = \frac{1}{\sigma^2 +\epsilon} \left[\left( \delta_{i, j} - \frac{1}{d}\right) \cdot \sqrt{\sigma^2 + \epsilon} - \frac{\sqrt{\sigma^2 + \epsilon}}{d} \cdot  \frac{\left(x_i - \mu\right)}{ \sqrt{\sigma^2 + \epsilon}} \frac{\left(x_j - \mu\right)}{ \sqrt{\sigma^2 + \epsilon}}\right]$$

$$ \frac {\partial \hat{x}_i}{\partial x_j} = \frac{1}{\sqrt{\sigma^2 +\epsilon}} \left[\left( \delta_{i, j} - \frac{1}{d}\right)  - \frac{1}{d} \cdot  \frac{\left(x_i - \mu\right)}{ \sqrt{\sigma^2 + \epsilon}} \frac{\left(x_j - \mu\right)}{ \sqrt{\sigma^2 + \epsilon}}\right]$$






When : 
$$\hat{x}_i = \frac{\left(x_i - \mu\right)}{ \sqrt{\sigma^2 + \epsilon}}$$
$$\hat{x}_j = \frac{\left(x_j - \mu\right)}{ \sqrt{\sigma^2 + \epsilon}}$$

$$ \frac {\partial \hat{x}_i}{\partial x_j} = \frac{1}{\sqrt{\sigma^2 +\epsilon}} \left(\delta_{i, j} - \frac{1}{d} - \frac{\hat{x}_i \hat{x}_j}{d}\right)$$

---

## finalized จริง ๆ ละ

$$ \frac{\partial \mathcal{L}}{\partial x_j} = \sum_i \frac{\partial \mathcal {L}} {\partial \hat{x}_i} \cdot \frac{\partial \hat{x}_i}{\partial x_j}$$

สมการด้านบนดูง่ายแต่มีความยากในการคำนวณ (ใช้ Resource เยอะ เพราะ $\frac{\partial \hat{x}_i}{\partial x_j}$ มีมิติ $d_{model} \times d_{model}$ หรือ $O(d_{model}^2)$) จึงต้องจัดรูปใหม่เพื่อให้คอมพิวเตอร์คำนวณง่ายขึ้น

$$\frac{\partial \mathcal{L}} {\partial x_j} = \frac{1}{d \cdot \sqrt{\sigma^2 + \epsilon}} \sum_i \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \left( d \cdot \delta_{i, j} - 1 - \hat{x}_i \hat{x}_j \right)

$$\frac{\partial \mathcal{L}} {\partial x_j} = \frac{1}{d \cdot \sqrt{\sigma^2 + \epsilon}} \sum_i  \left( d \cdot \delta_{i, j} \frac{\partial \mathcal{L}}{\partial \hat{x}_i} - \frac{\partial \mathcal{L}}{\partial \hat{x}_i} - \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \hat{x}_i \hat{x}_j \right)

$$\sum_i d \cdot \delta_{i, j} \frac{\partial \mathcal {L}}{\partial \hat{x}_i} = d \cdot \frac{\partial \mathcal {L}}{\partial \hat{x}_j}$$

$$ \sum_i \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \hat{x}_i \hat{x}_j = \hat{x}_j \sum_i \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \hat{x}_i$$

สมการที่เราจะนำไปใช้

$$\frac{\partial \mathcal{L}} {\partial x_j} = \frac{1}{d_{model} \cdot \sqrt{\sigma^2 + \epsilon}} \left( d_{model} \cdot \frac{\partial \mathcal {L}}{\partial \hat{x}_j} - \sum_{i = 1}^{d_{model}}  \frac{\partial \mathcal{L}}{\partial \hat{x}_i}  - \hat{x}_j \odot \sum_{i = 1}^{d_{model}} \left( \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \odot \hat{x}_i \right) \right)

**การหาค่าเฉลี่ยและส่วนเบี่ยงเบนมาตรฐาน ต้องทำในแนว hidden dimension หรือ feature และต้อง keepdims = True เท่านั้น**

## Demo code

* $d$ is d_model
* $\frac{1}{\sqrt{\sigma^2 + \epsilon}}$ is inv_std
* upstream gradient ที่ไหลมาจากด้านบนเรียก $g_i = d\hat{x}_i$

$$dx_j = \sum_i g_i \frac{\partial \hat{x}_i}{\partial x_j} = \sum_i \frac {\partial \mathcal {L}} {\partial {\hat{x}_i}} \frac{\partial \hat{x}_i} {\partial x_j}$$

In [20]:
mu = np.mean(X, axis=-1, keepdims = True)
print(mu)

[[[0.41986877]
  [0.58406098]
  [0.49220182]
  [0.63773979]]

 [[0.44266905]
  [0.30231505]
  [0.37555686]
  [0.5734059 ]]]


In [22]:
var = np.mean((X - mu)**2, axis=-1, keepdims=True)
print(var)

[[[0.0263177 ]
  [0.0602019 ]
  [0.08508576]
  [0.07601425]]

 [[0.08996642]
  [0.06196798]
  [0.08359805]
  [0.09727354]]]


In [24]:
inv_std = 1 / np.sqrt(var + eps)
print(inv_std)

[[[6.16301906]
  [4.07529299]
  [3.42804124]
  [3.62680273]]

 [[3.33377015]
  [4.01681008]
  [3.45840579]
  [3.20612394]]]


In [27]:
dGamma = np.sum(dh * xhat , axis=(0, 1))
print(dGamma)

[ 0.00954017  0.02016985 -0.07835543 -0.09872525 -0.09250307 -0.00149981]


In [26]:
dBeta = np.sum(dh, axis=(0, 1))
print(dBeta)

[ 0.01591894  0.04524957 -0.09658144 -0.07583181  0.00212348 -0.00755336]


In [28]:
dxhat = dh * gamma
print(dxhat)

[[[ 0.0025751  -0.01082587 -0.08197185  0.05263508  0.01045214
   -0.03132876]
  [ 0.02195208 -0.00647235  0.04232782 -0.03013065 -0.05641505
   -0.02828188]
  [-0.0157073   0.01546571 -0.02597718 -0.01479151  0.04301951
    0.03202892]
  [ 0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.02734062 -0.00806111  0.05271797 -0.03752679 -0.07026321
   -0.0352242 ]
  [-0.02749557  0.02707266 -0.04547296 -0.02589248  0.07530547
    0.05606649]
  [ 0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.        ]]]


In [32]:
sum_dxhat = np.sum(dxhat, axis= -1, keepdims = True)
print(sum_dxhat)

[[[-0.05846416]
  [-0.05702004]
  [ 0.03403815]
  [ 0.        ]]

 [[-0.07101671]
  [ 0.0595836 ]
  [ 0.        ]
  [ 0.        ]]]


In [33]:
sum_dxhat_xhat = np.sum(dxhat * xhat, axis=-1, keepdims=True)
print(sum_dxhat_xhat)

[[[-0.00237371]
  [-0.0727351 ]
  [ 0.00587766]
  [ 0.        ]]

 [[-0.10737707]
  [-0.02319432]
  [ 0.        ]
  [ 0.        ]]]


In [30]:
dx = 1 / d_model * inv_std * (d_model * dxhat - np.sum(dxhat, axis = -1, keepdims=True) - xhat * np.sum(dxhat * xhat, axis=-1, keepdims = True) )

In [31]:
dx

array([[[ 0.07788198, -0.00560083, -0.44371856,  0.38684485,
          0.1261935 , -0.13216348],
        [ 0.16143276,  0.03677408,  0.24246872, -0.01143926,
         -0.14273741, -0.05650211],
        [-0.07538471,  0.03161602, -0.11086203, -0.07365076,
          0.12672887,  0.08785918],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.16217772,  0.03159888,  0.2467916 , -0.02172796,
         -0.1281925 , -0.02960269],
        [-0.14111161,  0.07411266, -0.21028277, -0.12922675,
          0.26637395,  0.19252851],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])