In [1]:
import numpy as np
from transformer import FFN, FFN_init
from llm_operator import Relu, Relu_backward, Relu_cache

In [2]:
rng = np.random.default_rng(1000)

In [3]:
B, T, C = 2, 4, 6
scale = 4

```python
@dataclass
class FFN_init:

    gamma: npt.NDArray[np.float64]    # (1, 1, d_model) or (d_model, )
    beta: npt.NDArray[np.float64]     # (1, 1, d_model) or (d_model, )

    W_expand: npt.NDArray[np.float64] # (1, d_model, d_model * expand_scale) or (d_model, d_model * expand_scale)
    B_expand: npt.NDArray[np.float64] # (1, 1, d_model * expand_scale) or (d_model * expand_scale, )

    W_shrink: npt.NDArray[np.float64] # (1, d_model * expand_scale, d_model) or (d_model * expand_scale, d_model)
    B_shrink: npt.NDArray[np.float64] # (1, 1 * d_model) or (d_model, )

    activate_function: ActForward
    activate_cache: Any
    activate_function_backward: ActBackward
```

In [4]:
ffn_init = FFN_init(
    gamma = rng.random((1, 1, C), dtype=np.float64),
    beta = rng.random((1, 1, C), dtype=np.float64),
    W_expand = rng.random((1, C, C * scale), dtype=np.float64),
    B_expand = rng.random((1, 1, C * scale), dtype=np.float64),
    W_shrink = rng.random((1, C * scale, C), dtype=np.float64),
    B_shrink = rng.random((1, 1, C), dtype=np.float64),
    activate_function = Relu,
    activate_function_backward=Relu_backward
)

In [5]:
ffn = FFN(ffn_init)

In [6]:
X = rng.random((B, T, C), dtype=np.float64)

In [7]:
ffn.forward_train(X)

array([[[23.92037609, 27.12168054, 29.07900373, 20.28730955,
         21.14195596, 23.28213877],
        [30.44040209, 33.68720757, 35.86784994, 25.98414175,
         25.94948   , 29.40350561],
        [26.39209186, 29.30837687, 30.77354879, 22.26838924,
         22.32515406, 24.59190305],
        [28.58661469, 31.78213916, 32.53897538, 24.18977211,
         24.38957199, 26.76424884]],

       [[24.91530708, 29.25993282, 30.81397355, 21.70128923,
         23.19372598, 25.14891973],
        [24.36061606, 28.74848629, 29.44555091, 21.17362974,
         22.34397382, 23.92737842],
        [30.67973127, 33.68741797, 35.78231698, 26.05975649,
         25.85361462, 29.29386107],
        [23.22530179, 26.14592901, 26.96809057, 19.4274736 ,
         20.3395868 , 21.5256862 ]]])

In [8]:
G = rng.random((B, T, C), dtype=np.float64)

In [9]:
ffn.backward(G)

array([[[  2.96768605,   7.2880039 ,  -0.81799509,  -3.74083424,
          -0.95927093,  -4.73758968],
        [ -2.67043368,   2.70710815,   0.74997587,  -1.66765337,
           6.66028163,  -5.7792786 ],
        [ 12.11453659,  23.41336281,   3.99240203, -21.83638369,
           6.64313643, -24.32705418],
        [  3.1708491 ,  14.36190434,   0.23878837,  -9.25356312,
          -1.34424801,  -7.17373067]],

       [[  6.98660297,  15.27436438,   2.90372001, -14.8476854 ,
           3.57535636, -13.89235832],
        [  2.47405259,  10.35561839,   0.58782065,  -8.36800163,
           4.40039945,  -9.44988945],
        [ -5.87698986,   6.55471776,   0.08986921,  -4.19506386,
          10.23022261,  -6.80275587],
        [  2.33757549,   4.01353889,   3.22509174,  -1.96933503,
           6.46399053, -14.07086161]]])

In [10]:
print(G.shape)
print(ffn.W_shrink.transpose(0, 2, 1).shape)

(2, 4, 6)
(1, 6, 24)


In [11]:
def grad_check_transformer(forward, backward, eps=1e-6):
    rng = np.random.default_rng(1000)

    B, T, C = 2, 4, 6

    X = -rng.random((B, T, C), dtype=np.float64)
    G = -rng.random((B, T, C), dtype=np.float64) # upstream gradient

    # forward
    Y = forward(X)

    # analytical gradient
    dX = backward(G)

    # numerical gradient
    dX_num = np.zeros_like(X)
    it = np.nditer(X, flags=['multi_index'], op_flags=['readwrite'])

    while not it.finished:
        idx = it.multi_index
        old = X[idx]

        X[idx] = old + eps
        Yp = forward(X)
        Lp = np.sum(Yp * G)

        X[idx] = old - eps
        Ym = forward(X)
        Lm = np.sum(Ym * G)

        X[idx] = old
        dX_num[idx] = (Lp - Lm) / (2 * eps)

        it.iternext()

    # compare

    max_abs = np.max(np.abs(dX - dX_num))
    rel = max_abs / (np.max(np.abs(dX) + np.abs(dX_num)) + 1e-22)


    print("max_abs_diff:", max_abs)
    print("relative_diff:", rel)

In [12]:
grad_check_transformer(ffn.forward_train, ffn.backward, eps=1e-6)

max_abs_diff: 8.66631473250834e-08
relative_diff: 2.551800855851061e-09
