In [1]:
import numpy as np

In [2]:
rng = np.random.default_rng(4000)

In [3]:
def split_heads(x, H):
    # x: (B, T, C) => (B, H, T,d_head)
    B, T, C = x.shape
    d_head = C // H
    x = x.reshape(B, T, H, d_head)
    return x.transpose(0, 2, 1, 3)

In [4]:
def merge_heads(x):
    # x: (B, H, T, d_head) => (B, T, C)
    B, H, T, d_head = x.shape
    return x.transpose(0, 2, 1, 3).reshape(B, T, H*d_head)

In [5]:
def layer_norm(x, gamma, beta, eps=1e-5):
    mu = x.mean(axis=-1, keepdims=True)
    var = ((x - mu) ** 2 ).mean(axis = -1, keepdims = True)
    xhat = (x - mu ) / np.sqrt(var + eps)
    return xhat * gamma + beta

In [6]:
def softmax(x, axis = -1) :
    x = x - np.max(x, axis = axis, keepdims = True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis = axis, keepdims = True)

In [7]:
def causal_mask(T):
    # (1, 1, T, T) ใช้ broadcast กับ (B, H, T, T)
    m = np.triu(np.ones((T, T), dtype=bool), k = 1)
    return m[None, None, :, :]

In [8]:
dout = np.array([[[ 0.07788198, -0.00560083, -0.44371856,  0.38684485,
          0.1261935 , -0.13216348],
        [ 0.16143276,  0.03677408,  0.24246872, -0.01143926,
         -0.14273741, -0.05650211],
        [-0.07538471,  0.03161602, -0.11086203, -0.07365076,
          0.12672887,  0.08785918],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]],

       [[ 0.16217772,  0.03159888,  0.2467916 , -0.02172796,
         -0.1281925 , -0.02960269],
        [-0.14111161,  0.07411266, -0.21028277, -0.12922675,
          0.26637395,  0.19252851],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ]]])

In [9]:
dout.shape

(2, 4, 6)

In [10]:
B = 2 # Batch size
T = 4 # Token
V = 7 # Vocab size
d_model = 6 # hidden
H = 2 # attention head count

## Pre-LN type transformer architecture
---

$$out = FFN(LN(y)) + y$$
$$y = ATTN(LN(x)) + x$$

where x is input of attention

FFN definition from Attention is All you need: 
    $$FFN(x) = max(0, max(0, x \cdot W_1 + b_1)W_2 + b_2)$$

$$ATTN(x) = concat(attn_i(x), axis = -1)$$
$$attn = softmax(\frac{Q \cdot K^\intercal}{\sqrt{d_k} }+ causal\_mask) \cdot V$$

## Forwarding

In [11]:
X = rng.random((B, T, d_model), np.float64)
print(X)

[[[0.45461138 0.9079045  0.22341803 0.75301023 0.21061655 0.28646412]
  [0.9110329  0.06668155 0.80997244 0.44144683 0.01151965 0.43382742]
  [0.31255883 0.80263996 0.22221152 0.46048076 0.05819513 0.49784596]
  [0.46341809 0.94756117 0.88801601 0.33144693 0.52629353 0.91218153]]

 [[0.76991742 0.84385534 0.95556605 0.44375354 0.79734196 0.11813618]
  [0.8017991  0.88149331 0.6328389  0.11407255 0.44154406 0.5999176 ]
  [0.43327192 0.17617901 0.97017048 0.99389211 0.24304429 0.63569   ]
  [0.99131382 0.56461089 0.35934087 0.96484629 0.96325469 0.93638915]]]


In [12]:
Xrs = X.reshape(B * T, d_model)
print(Xrs)

[[0.45461138 0.9079045  0.22341803 0.75301023 0.21061655 0.28646412]
 [0.9110329  0.06668155 0.80997244 0.44144683 0.01151965 0.43382742]
 [0.31255883 0.80263996 0.22221152 0.46048076 0.05819513 0.49784596]
 [0.46341809 0.94756117 0.88801601 0.33144693 0.52629353 0.91218153]
 [0.76991742 0.84385534 0.95556605 0.44375354 0.79734196 0.11813618]
 [0.8017991  0.88149331 0.6328389  0.11407255 0.44154406 0.5999176 ]
 [0.43327192 0.17617901 0.97017048 0.99389211 0.24304429 0.63569   ]
 [0.99131382 0.56461089 0.35934087 0.96484629 0.96325469 0.93638915]]


In [13]:
dk = d_model // H
dq = d_model // H
dv = d_model // H

In [14]:
Wk = rng.random((d_model, d_model), np.float64)[None, :, :]
Wq = rng.random((d_model, d_model), np.float64)[None, :, :]
Wv = rng.random((d_model, d_model), np.float64)[None, :, :]

In [15]:
K = X @ Wk
print(K)

[[[1.69827184 0.48160908 1.48858585 1.38148467 1.43737408 1.74600031]
  [1.60772674 0.32969475 1.77063983 1.67920765 1.8350396  1.99610657]
  [1.4499264  0.38631034 1.1907154  1.21686007 1.26898723 1.32295591]
  [2.30724769 0.65258577 2.1348966  2.0128202  2.17980615 2.27854297]]

 [[1.99027233 0.62134838 2.16243293 1.54949139 1.82271898 2.50898325]
  [1.68489373 0.5519499  1.81588932 1.48378157 1.66687634 1.97018777]
  [2.42077915 0.50127999 2.18992614 2.34063848 2.41344037 2.47249864]
  [2.40920503 0.90931833 2.71140137 2.66256324 2.49278744 2.95386996]]]


In [16]:
Q = X @ Wq
V = X @ Wv

In [17]:
Ksplit = split_heads(K, H)
Qsplit = split_heads(Q, H)
Vsplit = split_heads(V, H)

In [18]:
Ksplit.shape

(2, 2, 4, 3)

In [19]:
scores = (Qsplit @ Ksplit.transpose(0, 1, 3, 2))/np.sqrt(dk)

In [20]:
scores.shape

(2, 2, 4, 4)

In [21]:
m = causal_mask(T)
scores = np.where(m, -1e9, scores)
scores.shape

(2, 2, 4, 4)

In [25]:
attn = softmax(scores, axis = -1)
print(attn)

[[[[1.         0.         0.         0.        ]
   [0.497259   0.502741   0.         0.        ]
   [0.38630999 0.38814851 0.2255415  0.        ]
   [0.09426586 0.09705267 0.03723875 0.77144272]]

  [[1.         0.         0.         0.        ]
   [0.34693574 0.65306426 0.         0.        ]
   [0.29142418 0.54193152 0.1666443  0.        ]
   [0.0899694  0.23432084 0.03652855 0.63918121]]]


 [[[1.         0.         0.         0.        ]
   [0.71826246 0.28173754 0.         0.        ]
   [0.31298757 0.11543245 0.57157997 0.        ]
   [0.07595578 0.01950134 0.17093249 0.7336104 ]]

  [[1.         0.         0.         0.        ]
   [0.69964269 0.30035731 0.         0.        ]
   [0.18943338 0.09595853 0.71460809 0.        ]
   [0.04338342 0.01713654 0.22646623 0.71301381]]]]


In [27]:
attn.shape

(2, 2, 4, 4)

In [31]:
context = attn @ Vsplit
print(context.shape)
print(context)

(2, 2, 4, 3)
[[[[1.05080832 1.27588609 1.22999668]
   [1.23416586 1.02423468 1.14895823]
   [1.11030788 1.08004929 1.1478513 ]
   [1.25899058 1.86508876 1.6742963 ]]

  [[1.26776215 1.55038072 0.97869839]
   [1.17164288 1.67564328 0.99319525]
   [1.13542784 1.59168337 0.94676382]
   [1.52693039 1.76236945 1.28605242]]]


 [[[1.74338133 1.51769822 1.55104743]
   [1.59495215 1.56628126 1.562557  ]
   [1.61276753 1.33791164 1.38436888]
   [1.89721471 1.67806444 1.54865867]]

  [[1.92648624 2.01224077 1.79940096]
   [1.80386147 1.92798053 1.67884222]
   [1.59286308 1.95044528 1.27493744]
   [2.09305153 2.53408733 1.60889467]]]]


In [34]:
output = merge_heads(context)
print(output.shape)
print(output)

(2, 4, 6)
[[[1.05080832 1.27588609 1.22999668 1.26776215 1.55038072 0.97869839]
  [1.23416586 1.02423468 1.14895823 1.17164288 1.67564328 0.99319525]
  [1.11030788 1.08004929 1.1478513  1.13542784 1.59168337 0.94676382]
  [1.25899058 1.86508876 1.6742963  1.52693039 1.76236945 1.28605242]]

 [[1.74338133 1.51769822 1.55104743 1.92648624 2.01224077 1.79940096]
  [1.59495215 1.56628126 1.562557   1.80386147 1.92798053 1.67884222]
  [1.61276753 1.33791164 1.38436888 1.59286308 1.95044528 1.27493744]
  [1.89721471 1.67806444 1.54865867 2.09305153 2.53408733 1.60889467]]]


start from $out$

$$dy = \frac{\partial \mathcal{L}}{\partial y} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out}{\partial y} = \frac{\partial \mathcal{L}}{\partial out}$$

$$db_2 = \frac {\partial \mathcal{L}}{\partial b_2} = \frac{\partial \mathcal{L}}{\partial out} \frac{\partial out} {\partial \mathcal{R}}\frac{\partial \mathcal{R}}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial b_2}$$


where

* $out = FFN(LN(y)) + y$

* $\mathcal{R} = Relu(\hat{y})$ then $\frac{\partial \mathcal{R}}{\partial \hat{y}} = max(0, \hat{y})$
* $\hat{y} = \hat{x} \cdot W_2 + b_2$ 

then
$$\frac{\partial \hat{y}}{\partial b_2} = 1$$

$$\frac{\partial \mathcal{L}} {\partial b_2} = \frac{\partial \mathcal{L}} {\partial out} \odot max(0, \hat{y}) 


sr