In [None]:
import numpy as np

Transformer ที่ใช้กันในปัจจุบันมี 2 รูปแบบ 1 คือ แบบ post-LN ตามที่ปรากฏใน Attention is all you need 
และอีกแบบที่นิยมนำมาใช้ใน การสร้างโมเดลขนาดใหญ่ในปัจจุบันคือ Pre-LN

## 1. ความแตกต่างกันระหว่าง Pre, Post LN 

ให้ $x$ คือ hidden state ($x \in \mathbb{R}^{T, d_{model}}$ )

#### Post-LN แบบดั้งเดิม
---
* Attention sublayer:
$$y = LN(x + MHA(x))$$
* FFN sublayer:
$$z = LN(y + FFN(y))$$
> จะเห็นได้ว่า หลังบวก residual (Add แล้วค่อย Norm)

Where 
> * LN หมายถึง linear normalized
> * MHA หมายถึง multihead attention
> * FFN หมายถึง feed-forward neural network

### Pre-LN
---
* Attention sublayer:
$$y = x + MHA(LN(x))$$
* FFN sublayer: 
$$z = y + FFN(LN(y))$$

> LN อยู่ก่อน sublayer (Norm ก่อนแล้วคำนวน Attention / FFN แล้วค่อย Add)

> **ใน Pre-LN จะมี Final LN อีกที หลังจากผ่าน Transformer stack ก่อนเข้า LM Head เพื่อให้ Scale นิ่ง**


### Linear normalized function
---
$$LN = \frac {X - \mu} {\sigma + \epsilon} \cdot \gamma + \beta$$

Where 
* $LN$ stand for Linear - Normalized
* $X$ stand for $X \in \mathbb {R} ^ {T \times d_{model}}$
* $\mu$ stand for mean of $X$ in $d_{model}$ direction
* $\sigma$ stand for standard derivation of $X$ in $d_{model}$ direction
* $\epsilon$ the minimum value that make this equation not to be divide by zero
* $\gamma$ stand for gaining parameter (trainable)
* $\beta$ stand for shift parameter (trainable)


### Softmax function

$$ softmax(x) = \frac {e^{x - x_{max}}} {\sum_{i} e^{x_i - x_{max}}}$$

สูตรด้านบนมีข้อควรระวังในเคสถ้าค่าใน X มีความห่างกันมาก ๆ การใช้ $x_{max}$ ของทั้ง matrix จะเกิดปรากฎการณ์ลู่เข้า 0 เหมือนกันหมด อาทิเช่น

In [42]:
x_max = 1000
x = np.array([980, 990, 1000])
print("if x = ", x)
print("e^x = ", np.exp(x - x_max))

x = np.array([1, 2, 3])
print("if x = ", x)
print("e^x = ", np.exp(x - x_max))

if x =  [ 980  990 1000]
e^x =  [2.06115362e-09 4.53999298e-05 1.00000000e+00]
if x =  [1 2 3]
e^x =  [0. 0. 0.]


ดังนั้นจึงแนะนำให้เลือกแนวแกนแล้วทำ softmax ตามแนวแกนนั้น

### ใน attention ต้องทำ softmax ในแนวแกนไหน ?
---
* ใน scaled dot-product attention:
$$scores = \frac {Q \cdot K^\intercal}{\sqrt{d_k}}$$

* ทำให้ $scores \in \mathbb{R}^{t_q \times t_k}$ หรือใน multi-head คือ $scores \in \mathbb{R} ^ {B \times H \times t_q \times t_k}$

* เราต้องการให้ "ต่อ 1 query" (แต่ละแถวใน $T_k$) เป็น distribution บน keys 
> **ดังนั้น softmax ต้องทำตามแนวแกน $T_k$ (แกนสุดท้าย)**

## 2. Demo Code

In [2]:
d_model = 6; # hidden states เป็นความยาว array ของแต่ละ token หลังการ embedding

H = 2; # จำนวน Attention head
d_head = 3 # d_model / H
V = 8 # จำนวน Vocab
T = 4 # จำนวน Token

In [3]:
rng = np.random.default_rng(1000)

In [4]:
X = rng.random((T, d_model), np.float64)
X

array([[0.52138574, 0.60384185, 0.4709418 , 0.20324794, 0.52875903,
        0.19103628],
       [0.2815456 , 0.75368155, 0.55167178, 0.86372208, 0.80537222,
        0.24837266],
       [0.18985741, 0.98399558, 0.66999717, 0.28038283, 0.20391323,
        0.62506469],
       [0.65260432, 0.89880753, 0.97476378, 0.15393237, 0.69908928,
        0.44724145]])

In [5]:
gamma1 = rng.random((d_model), np.float64)
beta1 = rng.random((d_model), np.float64)

print("gamma1 =", gamma1)
print(" beta1 =", beta1)

gamma1 = [0.01751321 0.29102491 0.38123661 0.32102791 0.94254467 0.70266697]
 beta1 = [0.13645032 0.34320907 0.8119946  0.148494   0.05932569 0.31441663]


In [6]:
mu = np.mean(X, axis = -1, keepdims=True)
std = np.std(X, axis = -1, keepdims=True)

print("mu =>\n", mu)
print("std =>\n", std)

mu =>
 [[0.41986877]
 [0.58406098]
 [0.49220182]
 [0.63773979]]
std =>
 [[0.16222732]
 [0.24536075]
 [0.29169464]
 [0.27570681]]


In [7]:
gamma1 * ((X - mu) / (std + 1e-9)) + beta1

array([[ 0.14740956,  0.6732444 ,  0.93201697, -0.28017197,  0.6919807 ,
        -0.67674215],
       [ 0.11485756,  0.54439777,  0.76166891,  0.51440018,  0.909485  ,
        -0.64693146],
       [ 0.1182977 ,  0.83387368,  1.04436808, -0.08462584, -0.87221309,
         0.63447171],
       [ 0.13739453,  0.61878157,  1.27801821, -0.4148424 ,  0.26905803,
        -0.17108784]])

In [8]:
def linearNorm(X, gamma, beta, eps = 1e-9, axis = -1) :
    mu = np.mean(X, axis = axis, keepdims=True)
    std = np.std(X, axis = axis, keepdims=True)
    return gamma * ((X - mu) / (std + eps)) + beta

In [9]:
datapipe = linearNorm(X, gamma1, beta1)
x1 = datapipe.copy()
datapipe

array([[ 0.14740956,  0.6732444 ,  0.93201697, -0.28017197,  0.6919807 ,
        -0.67674215],
       [ 0.11485756,  0.54439777,  0.76166891,  0.51440018,  0.909485  ,
        -0.64693146],
       [ 0.1182977 ,  0.83387368,  1.04436808, -0.08462584, -0.87221309,
         0.63447171],
       [ 0.13739453,  0.61878157,  1.27801821, -0.4148424 ,  0.26905803,
        -0.17108784]])

In [10]:
x1

array([[ 0.14740956,  0.6732444 ,  0.93201697, -0.28017197,  0.6919807 ,
        -0.67674215],
       [ 0.11485756,  0.54439777,  0.76166891,  0.51440018,  0.909485  ,
        -0.64693146],
       [ 0.1182977 ,  0.83387368,  1.04436808, -0.08462584, -0.87221309,
         0.63447171],
       [ 0.13739453,  0.61878157,  1.27801821, -0.4148424 ,  0.26905803,
        -0.17108784]])

In [11]:
W_k1 = rng.random((d_model, d_head), np.float64)
W_q1 = rng.random((d_model, d_head), np.float64)
W_v1 = rng.random((d_model, d_head), np.float64)

W_k2 = rng.random((d_model, d_head), np.float64)
W_q2 = rng.random((d_model, d_head), np.float64)
W_v2 = rng.random((d_model, d_head), np.float64)

In [12]:
K1 = X @ W_k1
Q1 = X @ W_q1
V1 = X @ W_v1

K2 = X @ W_k2
Q2 = X @ W_q2
V2 = X @ W_v2

In [13]:
print("K1 =>\n", K1)
print("Q1 =>\n", Q1)
print("V1 =>\n", V1)

K1 =>
 [[1.2527992  1.32075642 0.93814284]
 [2.05800531 1.85507023 1.76608856]
 [1.55996752 1.63750765 1.30605614]
 [1.82642607 1.90086531 1.51556223]]
Q1 =>
 [[1.32297136 0.92811187 0.93688803]
 [2.01381787 1.24562666 1.68262418]
 [1.59113101 1.19988588 1.18424662]
 [1.86500052 1.57009664 1.39953962]]
V1 =>
 [[1.474291   1.37938888 1.13889442]
 [2.47774422 2.3514763  1.60251449]
 [1.76324287 1.86219185 1.41314099]
 [2.17745369 2.00985722 1.87073196]]


In [15]:
scores1 = Q1 @ K1.T / np.sqrt(d_head)
scores2 = Q2 @ K2.T / np.sqrt(d_head)

In [18]:
print("scores1 => \n", scores1)

scores1 => 
 [[2.17208522 3.521272   2.77544456 3.23341389]
 [3.31781298 5.44258637 4.260161   4.96289114]
 [2.70726303 4.38319305 3.4604225  4.03088871]
 [3.30426476 5.324631   4.21942751 4.91435768]]


### softmax

In [30]:
scores1_max = np.max(scores1, axis = -1, keepdims=True) # softmax ตามแนว keys หรือแนวนอน
scores1_max

array([[3.521272  ],
       [5.44258637],
       [4.38319305],
       [5.324631  ]])

In [31]:
dscore1 = scores1 - scores1_max
print(dscore1)

[[-1.34918678  0.         -0.74582743 -0.28785811]
 [-2.12477339  0.         -1.18242537 -0.47969523]
 [-1.67593001  0.         -0.92277055 -0.35230434]
 [-2.02036624  0.         -1.10520349 -0.41027332]]


In [34]:
softmaxResult = np.exp(dscore1)/np.sum(np.exp(dscore1), axis= -1, keepdims = True)
softmaxResult

array([[0.1044632 , 0.40263147, 0.19098488, 0.30192045],
       [0.05841662, 0.48900559, 0.14989702, 0.30268077],
       [0.08180307, 0.43713618, 0.17372511, 0.30733564],
       [0.06233814, 0.47009728, 0.15566966, 0.31189492]])

In [44]:
def softmax(X, axis = -1) :
    X = X - np.max(X, axis = axis, keepdims=True);
    eX = np.exp(X)
    return eX / np.sum(eX, axis = axis, keepdims=True)

In [45]:
attn1 = softmax(scores1)
print("attention 1 :");
print(attn1)

attention 1 :
[[0.1044632  0.40263147 0.19098488 0.30192045]
 [0.05841662 0.48900559 0.14989702 0.30268077]
 [0.08180307 0.43713618 0.17372511 0.30733564]
 [0.06233814 0.47009728 0.15566966 0.31189492]]


In [46]:
attn2 = softmax(scores2)
print("attention 2 :");
print(attn2)

attention 2 :
[[0.13172653 0.27212327 0.18301422 0.41313598]
 [0.09215779 0.25641916 0.1529684  0.49845464]
 [0.12569073 0.27381841 0.17754409 0.42294677]
 [0.09610674 0.26572722 0.15079531 0.48737074]]


In [47]:
output1 = attn1 @ V1
print(output1)

[[2.14579748 2.05334122 1.59889611]
 [2.22113208 2.11794643 1.62822912]
 [2.17924186 2.08196385 1.6141228 ]
 [2.21030554 2.10816218 1.62778924]]


In [48]:
output2 = attn2 @ V2
print(output2)

[[1.61703732 1.08593014 1.76418502]
 [1.65971009 1.10975219 1.7909971 ]
 [1.62332767 1.0892788  1.76939115]
 [1.6566034  1.10739055 1.79149989]]


In [52]:
datapipe = np.concatenate([output1, output2], axis = -1)
print(datapipe)

[[2.14579748 2.05334122 1.59889611 1.61703732 1.08593014 1.76418502]
 [2.22113208 2.11794643 1.62822912 1.65971009 1.10975219 1.7909971 ]
 [2.17924186 2.08196385 1.6141228  1.62332767 1.0892788  1.76939115]
 [2.21030554 2.10816218 1.62778924 1.6566034  1.10739055 1.79149989]]
