In [1]:
import numpy as np

Transformer ที่ใช้กันในปัจจุบันมี 2 รูปแบบ 1 คือ แบบ post-LN ตามที่ปรากฏใน Attention is all you need 
และอีกแบบที่นิยมนำมาใช้ใน การสร้างโมเดลขนาดใหญ่ในปัจจุบันคือ Pre-LN

## 1. ความแตกต่างกันระหว่าง Pre, Post LN 

ให้ $x$ คือ hidden state ($x \in \mathbb{R}^{T, d_{model}}$ )

#### Post-LN แบบดั้งเดิม
---
* Attention sublayer:
$$y = LN(x + MHA(x))$$
* FFN sublayer:
$$z = LN(y + FFN(y))$$
> จะเห็นได้ว่า หลังบวก residual (Add แล้วค่อย Norm)

Where 
> * LN หมายถึง linear normalized
> * MHA หมายถึง multihead attention
> * FFN หมายถึง feed-forward neural network

### Pre-LN
---
* Attention sublayer:
$$y = x + MHA(LN(x))$$
* FFN sublayer: 
$$z = y + FFN(LN(y))$$

> LN อยู่ก่อน sublayer (Norm ก่อนแล้วคำนวน Attention / FFN แล้วค่อย Add)

> **ใน Pre-LN จะมี Final LN อีกที หลังจากผ่าน Transformer stack ก่อนเข้า LM Head เพื่อให้ Scale นิ่ง**


### Linear normalized function
---
$$LN = \frac {X - \mu} {\sigma + \epsilon} \cdot \gamma + \beta$$

Where 
* $LN$ stand for Linear - Normalized
* $X$ stand for $X \in \mathbb {R} ^ {T \times d_{model}}$
* $\mu$ stand for mean of $X$ in $d_{model}$ direction
* $\sigma$ stand for standard derivation of $X$ in $d_{model}$ direction
* $\epsilon$ the minimum value that make this equation not to be divide by zero
* $\gamma$ stand for gaining parameter (trainable)
* $\beta$ stand for shift parameter (trainable)


### Softmax function

$$ softmax(x) = \frac {e^{x - x_{max}}} {\sum_{i} e^{x_i - x_{max}}}$$

สูตรด้านบนมีข้อควรระวังในเคสถ้าค่าใน X มีความห่างกันมาก ๆ การใช้ $x_{max}$ ของทั้ง matrix จะเกิดปรากฎการณ์ลู่เข้า 0 เหมือนกันหมด อาทิเช่น

In [2]:
x_max = 1000
x = np.array([980, 990, 1000])
print("if x = ", x)
print("e^x = ", np.exp(x - x_max))

x = np.array([1, 2, 3])
print("if x = ", x)
print("e^x = ", np.exp(x - x_max))

if x =  [ 980  990 1000]
e^x =  [2.06115362e-09 4.53999298e-05 1.00000000e+00]
if x =  [1 2 3]
e^x =  [0. 0. 0.]


ดังนั้นจึงแนะนำให้เลือกแนวแกนแล้วทำ softmax ตามแนวแกนนั้น

### ใน attention ต้องทำ softmax ในแนวแกนไหน ?
---
* ใน scaled dot-product attention:
$$scores = \frac {Q \cdot K^\intercal}{\sqrt{d_k}}$$

* ทำให้ $scores \in \mathbb{R}^{t_q \times t_k}$ หรือใน multi-head คือ $scores \in \mathbb{R} ^ {B \times H \times t_q \times t_k}$

* เราต้องการให้ "ต่อ 1 query" (แต่ละแถวใน $T_k$) เป็น distribution บน keys 
> **ดังนั้น softmax ต้องทำตามแนวแกน $T_k$ (แกนสุดท้าย)**

## 2. Demo Code

In [3]:
d_model = 6; # hidden states เป็นความยาว array ของแต่ละ token หลังการ embedding

H = 2; # จำนวน Attention head
d_head = 3 # d_model / H
V = 8 # จำนวน Vocab
T = 4 # จำนวน Token

In [4]:
rng = np.random.default_rng(1000)

#### initial X that the result of text after embedding

In [5]:
X = rng.random((T, d_model), np.float64)
X

array([[0.52138574, 0.60384185, 0.4709418 , 0.20324794, 0.52875903,
        0.19103628],
       [0.2815456 , 0.75368155, 0.55167178, 0.86372208, 0.80537222,
        0.24837266],
       [0.18985741, 0.98399558, 0.66999717, 0.28038283, 0.20391323,
        0.62506469],
       [0.65260432, 0.89880753, 0.97476378, 0.15393237, 0.69908928,
        0.44724145]])

In [6]:
gamma1 = rng.random((d_model), np.float64)
beta1 = rng.random((d_model), np.float64)

print("gamma1 =", gamma1)
print(" beta1 =", beta1)

gamma1 = [0.01751321 0.29102491 0.38123661 0.32102791 0.94254467 0.70266697]
 beta1 = [0.13645032 0.34320907 0.8119946  0.148494   0.05932569 0.31441663]


In [7]:
mu = np.mean(X, axis = -1, keepdims=True)
std = np.std(X, axis = -1, keepdims=True)

print("mu =>\n", mu)
print("std =>\n", std)

mu =>
 [[0.41986877]
 [0.58406098]
 [0.49220182]
 [0.63773979]]
std =>
 [[0.16222732]
 [0.24536075]
 [0.29169464]
 [0.27570681]]


In [8]:
gamma1 * ((X - mu) / (std + 1e-9)) + beta1

array([[ 0.14740956,  0.6732444 ,  0.93201697, -0.28017197,  0.6919807 ,
        -0.67674215],
       [ 0.11485756,  0.54439777,  0.76166891,  0.51440018,  0.909485  ,
        -0.64693146],
       [ 0.1182977 ,  0.83387368,  1.04436808, -0.08462584, -0.87221309,
         0.63447171],
       [ 0.13739453,  0.61878157,  1.27801821, -0.4148424 ,  0.26905803,
        -0.17108784]])

In [9]:
def linearNorm(X, gamma, beta, eps = 1e-9, axis = -1) :
    mu = np.mean(X, axis = axis, keepdims=True)
    std = np.std(X, axis = axis, keepdims=True)
    return gamma * ((X - mu) / (std + eps)) + beta

#### Norm first time before pass to transformers

In [10]:
datapipe = linearNorm(X, gamma1, beta1)
x1 = datapipe.copy()
datapipe

array([[ 0.14740956,  0.6732444 ,  0.93201697, -0.28017197,  0.6919807 ,
        -0.67674215],
       [ 0.11485756,  0.54439777,  0.76166891,  0.51440018,  0.909485  ,
        -0.64693146],
       [ 0.1182977 ,  0.83387368,  1.04436808, -0.08462584, -0.87221309,
         0.63447171],
       [ 0.13739453,  0.61878157,  1.27801821, -0.4148424 ,  0.26905803,
        -0.17108784]])

In [11]:
x1

array([[ 0.14740956,  0.6732444 ,  0.93201697, -0.28017197,  0.6919807 ,
        -0.67674215],
       [ 0.11485756,  0.54439777,  0.76166891,  0.51440018,  0.909485  ,
        -0.64693146],
       [ 0.1182977 ,  0.83387368,  1.04436808, -0.08462584, -0.87221309,
         0.63447171],
       [ 0.13739453,  0.61878157,  1.27801821, -0.4148424 ,  0.26905803,
        -0.17108784]])

#### Initial Attention

In [12]:
W_k1 = rng.random((d_model, d_head), np.float64)
W_q1 = rng.random((d_model, d_head), np.float64)
W_v1 = rng.random((d_model, d_head), np.float64)

W_k2 = rng.random((d_model, d_head), np.float64)
W_q2 = rng.random((d_model, d_head), np.float64)
W_v2 = rng.random((d_model, d_head), np.float64)

In [13]:
K1 = datapipe @ W_k1
Q1 = datapipe @ W_q1
V1 = datapipe @ W_v1

K2 = datapipe @ W_k2
Q2 = datapipe @ W_q2
V2 = datapipe @ W_v2

In [14]:
print("K1 =>\n", K1)
print("Q1 =>\n", Q1)
print("V1 =>\n", V1)

K1 =>
 [[ 0.26115996 -0.06372358  0.10692326]
 [ 1.0258848   0.51635283  0.84805778]
 [ 0.65105125  0.8523355   0.56283134]
 [ 0.34983276  0.15806268  0.38329655]]
Q1 =>
 [[0.68075297 0.60439707 0.2725525 ]
 [1.26582777 0.72975408 1.00270057]
 [0.79607297 0.85239715 0.67972781]
 [0.59747968 0.91993633 0.40841944]]
V1 =>
 [[1.07842594 0.63704085 1.03204335]
 [1.82433963 1.38412767 1.18936384]
 [0.48853425 0.74732176 0.74693018]
 [0.9077596  0.57232672 1.17976634]]


In [15]:
scores1 = Q1 @ K1.T / np.sqrt(d_head)
scores2 = Q2 @ K2.T / np.sqrt(d_head)

In [16]:
print("scores1 => \n", scores1)

scores1 => 
 [[0.09723345 0.71683609 0.64187219 0.25296645]
 [0.22591312 1.45824363 1.16074272 0.54415642]
 [0.13063318 1.05843621 0.93957103 0.38899662]
 [0.08145585 0.82810523 0.80999676 0.29500943]]


### softmax

In [17]:
scores1_max = np.max(scores1, axis = -1, keepdims=True) # softmax ตามแนว keys หรือแนวนอน
scores1_max

array([[0.71683609],
       [1.45824363],
       [1.05843621],
       [0.82810523]])

In [18]:
dscore1 = scores1 - scores1_max
print(dscore1)

[[-0.61960264  0.         -0.0749639  -0.46386964]
 [-1.23233051  0.         -0.29750092 -0.91408721]
 [-0.92780303  0.         -0.11886518 -0.6694396 ]
 [-0.74664938  0.         -0.01810847 -0.5330958 ]]


In [19]:
softmaxResult = np.exp(dscore1)/np.sum(np.exp(dscore1), axis= -1, keepdims = True)
softmaxResult

array([[0.1738922 , 0.32312467, 0.29978763, 0.2031955 ],
       [0.11975041, 0.41064955, 0.30497788, 0.16462217],
       [0.14145716, 0.35773767, 0.31764512, 0.18316005],
       [0.15576218, 0.32864553, 0.32274782, 0.19284446]])

In [20]:
def softmax(X, axis = -1) :
    X = X - np.max(X, axis = axis, keepdims=True);
    eX = np.exp(X)
    return eX / np.sum(eX, axis = axis, keepdims=True)

In [21]:
attn1 = softmax(scores1)
print("attention 1 :");
print(attn1)

attention 1 :
[[0.1738922  0.32312467 0.29978763 0.2031955 ]
 [0.11975041 0.41064955 0.30497788 0.16462217]
 [0.14145716 0.35773767 0.31764512 0.18316005]
 [0.15576218 0.32864553 0.32274782 0.19284446]]


In [22]:
attn2 = softmax(scores2)
print("attention 2 :");
print(attn2)

attention 2 :
[[0.25462965 0.27944164 0.21468638 0.25124232]
 [0.22684834 0.30534416 0.22195323 0.24585427]
 [0.2363493  0.31179841 0.21194349 0.23990879]
 [0.25365091 0.28687266 0.21026727 0.24920917]]


In [23]:
output1 = attn1 @ V1
print(output1)

[[1.10792819 0.89835427 1.02742073]
 [1.17673568 0.96681157 1.03401221]
 [1.1266319  0.9274791  1.02481496]
 [1.10026882 0.90568076 1.02021393]]


In [24]:
output2 = attn2 @ V2
print(output2)

[[0.80294033 0.30417803 0.60954595]
 [0.81450153 0.30946238 0.62696523]
 [0.81276153 0.30840972 0.63228343]
 [0.80426773 0.3046392  0.61486037]]


In [25]:
datapipe = np.concatenate([output1, output2], axis = -1)
print(datapipe)

[[1.10792819 0.89835427 1.02742073 0.80294033 0.30417803 0.60954595]
 [1.17673568 0.96681157 1.03401221 0.81450153 0.30946238 0.62696523]
 [1.1266319  0.9274791  1.02481496 0.81276153 0.30840972 0.63228343]
 [1.10026882 0.90568076 1.02021393 0.80426773 0.3046392  0.61486037]]


#### Residual add

In [26]:
datapipe = x1 + datapipe
y = datapipe.copy()
print(datapipe)

[[ 1.25533775  1.57159866  1.9594377   0.52276836  0.99615873 -0.0671962 ]
 [ 1.29159324  1.51120934  1.79568112  1.32890172  1.21894738 -0.01996624]
 [ 1.2449296   1.76135278  2.06918304  0.72813569 -0.56380337  1.26675514]
 [ 1.23766335  1.52446233  2.29823215  0.38942533  0.57369722  0.44377254]]


### FFN sublayer

linear normalized first

In [27]:
gamma2 = rng.random(6, np.float64)
beta2 = rng.random(6, np.float64)

In [28]:
datapipe = linearNorm(datapipe, gamma2, beta2)


In [29]:
W_FFN_Expand_1 = rng.random((d_model, 4 * d_model), np.float64) # expand to 4 times of d_model
W_FFN_Shrink_1 = rng.random((d_model * 4, d_model), np.float64) # shrink it back to d_model

คิดว่า b ของฝั่ง expand เมื่อคูณกับ W_shrink แล้วก็จะกลายเป็นค่าคงที่อีกตัวที่เทรนได้ (เป็น bias เหมือนกัน)

ตรงนี้เดี๋ยวค่อยมาดูรายละเอียด

In [30]:
b_FFN_Shrink_1 = rng.random((1, d_model), np.float64)

In [31]:
print(datapipe.shape)
print(W_FFN_Expand_1.shape)

(4, 6)
(6, 24)


In [32]:
datapipe = np.maximum(0, datapipe @ W_FFN_Expand_1) @ W_FFN_Shrink_1 + b_FFN_Shrink_1


### conclusion function FFN

In [33]:
# X must be LN before pass into this function
def FFN(X, W1, W2, b, ):
    h = X @ W1
    h = np.maximum(h)
    h = h @ W2 + b
    return h

In [34]:
datapipe = y + datapipe

In [35]:
print(datapipe)

[[18.76599546 20.99229084 21.68027351 19.90435733 19.829247   19.2513378 ]
 [16.38637038 18.01560618 18.266348   17.71446771 16.93678407 16.2778588 ]
 [16.34029452 19.31995309 21.99998681 19.19274757 19.20078361 19.76311457]
 [18.76473422 21.27579944 23.23766897 20.73510626 20.78578715 20.55541809]]


## Out of transformer stack

last norm before pass to LM Head

In [36]:
gammaLast = rng.random((1, d_model), np.float64)
betaLast = rng.random((1, d_model), np.float64)

In [37]:
datapipe = linearNorm(datapipe, gammaLast, betaLast)
print(datapipe)

[[ 0.11817093  0.8662992   1.59151171  0.34445855  0.58332962  0.3393915 ]
 [ 0.21323897  0.8936586   1.45134137  0.52746859  0.41651269  0.2035713 ]
 [-0.13218447  0.14647865  1.59725987  0.36924313  0.7509019   0.67616418]
 [-0.0408012   0.3672279   1.6601334   0.35618554  0.73302145  0.51279948]]


## multi dimensional Pre-LN
---

In [38]:
def layer_norm(x, gamma, beta, eps=1e-5):
    mu = x.mean(axis=-1, keepdims=True)
    var = ((x - mu) ** 2 ).mean(axis = -1, keepdims = True)
    xhat = (x - mu ) / np.sqrt(var + eps)
    return xhat * gamma + beta


In [39]:
def softmax(x, axis = -1) :
    x = x - np.max(x, axis = axis, keepdims = True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis = axis, keepdims = True)

In [40]:
def split_heads(x, H):
    # x: (B, T, C) => (B, H, T,d_head)
    B, T, C = x.shape
    d_head = C // H
    x = x.reshape(B, T, H, d_head)
    return x.transpose(0, 2, 1, 3)

In [41]:
def merge_heads(x):
    # x: (B, H, T, d_head) => (B, T, C)
    B, H, T, d_head = x.shape
    return x.transpose(0, 2, 1, 3).reshape(B, T, H*d_head)

In [42]:
def causal_mask(T):
    # (1, 1, T, T) ใช้ broadcast กับ (B, H, T, T)
    m = np.triu(np.ones((T, T), dtype=bool), k = 1)
    return m[None, None, :, :]

In [43]:
def mha_preln(Xn, Wq, Wk, Wv, Wo, H, use_causal_mask = True):
    # Xn : (B, T, C)
    B, T, C = Xn.shape
    d_head = C // H

    # 1) Q, K, V (B, T, C)
    Q = Xn @ Wq
    K = Xn @ Wk
    V = Xn @ Wv

    # 2) split -> (B, H, T, d_head)
    Qh = split_heads(Q, H)
    Kh = split_heads(K, H)
    Vh = split_heads(V, H)

    # 3) scores -> (B, H, T, T)
    scores = (Qh @ Kh.transpose(0, 1, 3, 2)) / np.sqrt(d_head)

    # 4) mask (optional)
    if use_causal_mask:
        m = causal_mask(T) # (1, 1, T, T)
        scores = np.where(m, -1e9, scores) # ก้นอนาคต

    # 5) softmax over key axis = -1
    attn = softmax(scores, axis=-1) # (B, H, T, T)

    # 6) context -> (B, H, T, d_head)
    ctx = attn @ Vh

    # 7) merge -> (B, T, C) + output projection
    out = merge_heads(ctx) @ Wo

    return out, attn # attn เอาไว้ debug weight

### focus on step by step
---

In [44]:
B = 2 # batch size = 2

In [45]:
X = rng.random((B, T, d_model), np.float64)
print(X)


[[[0.80943842 0.96547771 0.3540245  0.70510091 0.40169697 0.87693014]
  [0.93680304 0.64488261 0.96911327 0.49181922 0.03944605 0.52157371]
  [0.29086218 0.22098476 0.00704177 0.48097445 0.94359348 0.40594856]
  [0.49455348 0.16145771 0.3107687  0.79218656 0.30968507 0.99984928]]

 [[0.65191454 0.77502337 0.90404854 0.59314112 0.51631802 0.96213001]
  [0.80328414 0.27060947 0.05885957 0.9393993  0.96047678 0.39814721]
  [0.06534784 0.91797449 0.82697909 0.72563988 0.2847025  0.50900874]
  [0.98730819 0.5310692  0.69979885 0.71863463 0.90734609 0.69646824]]]


#### Layer norm

In [46]:
mu = np.mean(X, axis=-1, keepdims=True)
print(mu)

[[[0.68544477]
  [0.60060632]
  [0.39156753]
  [0.5114168 ]]

 [[0.7337626 ]
  [0.57179608]
  [0.55494209]
  [0.75677087]]]


In [47]:
std = np.std(X, axis=-1, keepdims=True)
print(std)

[[[0.23137203]
  [0.31192083]
  [0.28868471]
  [0.2947058 ]]

 [[0.1615586 ]
  [0.34731415]
  [0.30270344]
  [0.15005454]]]


In [48]:
np.sqrt(((X - mu) ** 2 ).mean(axis = -1, keepdims = True))

array([[[0.23137203],
        [0.31192083],
        [0.28868471],
        [0.2947058 ]],

       [[0.1615586 ],
        [0.34731415],
        [0.30270344],
        [0.15005454]]])

In [49]:
xhat = (X - mu) / (std + 1e-10)
print(xhat.shape)

(2, 4, 6)


In [50]:
gamma_demo = rng.random((1, d_model), np.float64)[None, :, :];
print(gamma_demo.shape)

(1, 1, 6)


In [51]:
beta_demo = rng.random((1, d_model), np.float64)[None, :, :]
print(beta_demo)

[[[0.21907373 0.5114122  0.1777603  0.08596633 0.162551   0.20204752]]]


In [52]:
LNresult = xhat * gamma_demo + beta_demo
print(LNresult)

[[[ 0.25076559  0.7485146  -0.07226558  0.12298451 -0.07610142
    0.40860827]
  [ 0.28281316  0.53921987  0.38397425 -0.06600467 -0.18754466
    0.13880852]
  [ 0.19844428  0.39565472 -0.05473726  0.220917    0.53466834
    0.2144809 ]
  [ 0.21568987  0.27878189  0.05892017  0.50110137  0.02934318
    0.61570314]]

 [[ 0.18911404  0.56144387  0.36173811 -0.2933044  -0.09936541
    0.55484634]
  [ 0.25848907  0.34152883 -0.08002484  0.5471613   0.38032935
    0.07725959]
  [ 0.12342527  0.74635694  0.33462577  0.33168507 -0.0111794
    0.16417416]
  [ 0.30992922  0.21675076  0.11148843 -0.0247767   0.35782687
    0.10174539]]]


In [53]:
print(xhat[0, 0, 0], gamma_demo[0, 0, 0])
print(xhat[0, 0, 0] * gamma_demo[0, 0, 0])

0.5359059587325756 0.059136969551111696
0.03169185436382765


#### Softmax of attention dimension test


example scores 

In [54]:
scores = rng.random((B, H, d_head, d_head)) # assume that d_head equal to d_k, d_q, d_v
print(scores.shape)

(2, 2, 3, 3)


In [55]:
print(scores)

[[[[0.14189603 0.6300982  0.55783344]
   [0.1727177  0.98415625 0.855209  ]
   [0.80430245 0.79958981 0.43096569]]

  [[0.32856365 0.24904132 0.14973707]
   [0.27256478 0.16073375 0.91163138]
   [0.85511375 0.9526142  0.2748951 ]]]


 [[[0.39020331 0.05061463 0.63003636]
   [0.76325312 0.20437612 0.93282426]
   [0.20400894 0.62847871 0.27769553]]

  [[0.54150655 0.90851982 0.48290291]
   [0.05293984 0.76902455 0.14327839]
   [0.60339743 0.25421058 0.98040098]]]]


In [56]:
max_score_in_keys_dim = np.max(scores, axis=3, keepdims=True)
print(max_score_in_keys_dim)

[[[[0.6300982 ]
   [0.98415625]
   [0.80430245]]

  [[0.32856365]
   [0.91163138]
   [0.9526142 ]]]


 [[[0.63003636]
   [0.93282426]
   [0.62847871]]

  [[0.90851982]
   [0.76902455]
   [0.98040098]]]]


In [57]:
scores - max_score_in_keys_dim

array([[[[-0.48820217,  0.        , -0.07226476],
         [-0.81143855,  0.        , -0.12894725],
         [ 0.        , -0.00471263, -0.37333676]],

        [[ 0.        , -0.07952233, -0.17882658],
         [-0.6390666 , -0.75089763,  0.        ],
         [-0.09750045,  0.        , -0.6777191 ]]],


       [[[-0.23983305, -0.57942172,  0.        ],
         [-0.16957115, -0.72844815,  0.        ],
         [-0.42446977,  0.        , -0.35078318]],

        [[-0.36701327,  0.        , -0.42561691],
         [-0.71608471,  0.        , -0.62574616],
         [-0.37700356, -0.72619041,  0.        ]]]])

In [58]:
edx = np.exp(scores- max_score_in_keys_dim)
print(edx)

[[[[0.61372879 1.         0.93028456]
   [0.44421858 1.         0.87902034]
   [1.         0.99529845 0.68843336]]

  [[1.         0.9235574  0.83625091]
   [0.52778483 0.47194273 1.        ]
   [0.90710194 1.         0.50777385]]]


 [[[0.7867592  0.56022224 1.        ]
   [0.8440267  0.48265742 1.        ]
   [0.65411653 1.         0.70413641]]

  [[0.69280045 1.         0.65336659]
   [0.48866177 1.         0.53486219]
   [0.68591364 0.48374837 1.        ]]]]


In [59]:
ex = np.sum(edx, axis = 3, keepdims=True)
print(ex)

[[[[2.54401335]
   [2.32323891]
   [2.68373181]]

  [[2.75980831]
   [1.99972756]
   [2.41487579]]]


 [[[2.34698144]
   [2.32668412]
   [2.35825293]]

  [[2.34616704]
   [2.02352395]
   [2.169662  ]]]]


In [60]:
softmax_result = edx / ex
print(softmax_result)

[[[[0.24124433 0.3930797  0.36567598]
   [0.19120658 0.43043356 0.37835985]
   [0.37261547 0.3708636  0.25652092]]

  [[0.36234401 0.33464549 0.3030105 ]
   [0.26392837 0.23600351 0.50006812]
   [0.37563089 0.41409997 0.21026914]]]


 [[[0.33522174 0.23869905 0.42607921]
   [0.36275947 0.20744433 0.4297962 ]
   [0.27737335 0.42404273 0.29858392]]

  [[0.29529033 0.42622711 0.27848255]
   [0.24149048 0.49418738 0.26432214]
   [0.31613847 0.22296024 0.46090128]]]]


#### multi head attentions

In [61]:
x = rng.random((B, T, d_model), np.float64)
xn = layer_norm(x, gamma_demo, beta_demo)
print(xn.shape)
print(xn)
 

(2, 4, 6)
[[[ 0.31728564  0.3545899   0.09886953 -0.13001186 -0.03055553
    0.47153063]
  [ 0.24539964  0.13505295  0.35534153 -0.09773216  0.34624992
    0.18612963]
  [ 0.2008037   0.31883809  0.38580775 -0.4353892   0.44220698
    0.16696776]
  [ 0.29272858  0.27572898  0.12828828  0.17486323 -0.06730063
    0.50607886]]

 [[ 0.2884909   0.22158631  0.16106252 -0.21227009  0.42899288
    0.13129842]
  [ 0.15770598  0.70660317 -0.01737839  0.72312385  0.22994283
    0.04000532]
  [ 0.29196831  0.27418478 -0.02603753  0.62909217  0.11109389
    0.24294572]
  [ 0.23743158  0.69087057  0.4288766  -0.28500018  0.09478671
   -0.16374402]]]


In [62]:
W_k = rng.random((d_model, d_model), np.float64)[None, :, :]
W_q = rng.random((d_model, d_model), np.float64)[None, :, :]
W_v = rng.random((d_model, d_model), np.float64)[None, :, :]

In [63]:
W_k

array([[[9.90403237e-01, 8.73894466e-01, 5.81459184e-01, 3.93215206e-01,
         7.37391609e-01, 3.53972504e-04],
        [6.00975765e-01, 1.91477061e-01, 9.44808161e-02, 1.08077936e-01,
         6.02625508e-01, 3.29648765e-01],
        [3.69544209e-02, 3.64927497e-01, 7.66911703e-01, 1.36921905e-02,
         8.20887307e-01, 8.29809885e-01],
        [8.01417433e-01, 5.66768652e-01, 2.17383067e-01, 9.02577129e-01,
         9.11702867e-01, 6.40389023e-01],
        [4.04755371e-01, 3.61381984e-01, 6.63625799e-01, 6.52928710e-03,
         5.67953523e-01, 7.71656411e-01],
        [4.46303419e-01, 3.68393378e-01, 3.80812254e-01, 2.00306791e-01,
         8.06386426e-01, 3.52723105e-01]]])

In [64]:
total = 0
pos1 = 0
pos2 = 1 # set this
pos3 = 1 # set this
for i in range(6):
    total += xn[pos1, pos2, i] * W_k[0, i, pos3]

print(total)

0.5082927065444741


In [65]:
K = xn @ W_k
print(K.shape)
print(K)

(2, 4, 6)
[[[ 0.62487877  0.4702301   0.42483962  0.14134417  0.77315903
    0.25852856]
  [ 0.48223214  0.50829271  0.70738087  0.0672895   0.81168092
    0.6097244 ]
  [ 0.30932335  0.35187379  0.70516081 -0.23793927  0.64576358
    0.54662952]
  [ 0.79912946  0.61664744  0.4807175   0.40542126  1.01662193
    0.43600545]]

 [[ 0.48696146  0.43640718  0.60074894 -0.02289624  0.63447654
    0.4482107 ]
  [ 1.27065085  0.77445365  0.47015773  0.8003324   1.34997313
    0.87319416]
  [ 1.11054135  0.78434442  0.47869984  0.76127644  1.19170128
    0.64316443]
  [ 0.40308016  0.3086877   0.47083423 -0.11551296  0.60543364
    0.4165902 ]]]


In [66]:
Q = xn @ W_q
V = xn @ W_v

reshape it to 2 attention

In [67]:
K = K.reshape(B, T, H, d_head)
K = K.transpose(0, 2, 1, 3)

In [68]:
K

array([[[[ 0.62487877,  0.4702301 ,  0.42483962],
         [ 0.48223214,  0.50829271,  0.70738087],
         [ 0.30932335,  0.35187379,  0.70516081],
         [ 0.79912946,  0.61664744,  0.4807175 ]],

        [[ 0.14134417,  0.77315903,  0.25852856],
         [ 0.0672895 ,  0.81168092,  0.6097244 ],
         [-0.23793927,  0.64576358,  0.54662952],
         [ 0.40542126,  1.01662193,  0.43600545]]],


       [[[ 0.48696146,  0.43640718,  0.60074894],
         [ 1.27065085,  0.77445365,  0.47015773],
         [ 1.11054135,  0.78434442,  0.47869984],
         [ 0.40308016,  0.3086877 ,  0.47083423]],

        [[-0.02289624,  0.63447654,  0.4482107 ],
         [ 0.8003324 ,  1.34997313,  0.87319416],
         [ 0.76127644,  1.19170128,  0.64316443],
         [-0.11551296,  0.60543364,  0.4165902 ]]]])

In [69]:
Q = split_heads(Q, H)
V = split_heads(V, H)

In [70]:
print(Q.shape, K.shape)
print(K.transpose(0, 1, 3, 2).shape)

(2, 2, 4, 3) (2, 2, 4, 3)
(2, 2, 3, 4)


In [71]:
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_head)
print(scores.shape)
print(scores)

(2, 2, 4, 4)
[[[[0.57354173 0.64574433 0.52585143 0.71274585]
   [0.61318297 0.70863705 0.58724996 0.75830437]
   [0.54409675 0.6717224  0.57503785 0.66691308]
   [0.67772374 0.73351139 0.58504636 0.84586933]]

  [[0.21238883 0.28433982 0.1771957  0.34945785]
   [0.36508144 0.46775911 0.28520918 0.59372134]
   [0.42973069 0.52737622 0.32989967 0.67747075]
   [0.22481164 0.30710716 0.17255412 0.390054  ]]]


 [[[0.52383715 0.75189643 0.71390936 0.40912218]
   [0.66626785 1.00380341 0.92963333 0.53023202]
   [0.64516461 1.09101803 1.00586808 0.51165897]
   [0.52236126 0.65378471 0.63388721 0.40567922]]

  [[0.35374219 1.00805345 0.86553112 0.30271616]
   [0.45861249 1.66762341 1.46894385 0.35349977]
   [0.31301436 1.08182528 0.93627141 0.24538253]
   [0.25986584 0.94839707 0.84997287 0.20226235]]]]


In [72]:
attention = softmax(scores, axis=-1)
print(attention)

[[[[0.2393684  0.25729066 0.22822078 0.27512015]
   [0.23636535 0.2600393  0.23031448 0.27328087]
   [0.23265396 0.26432458 0.23996505 0.26305641]
   [0.24084955 0.25466785 0.21953136 0.28495124]]

  [[0.23883854 0.25665654 0.23057925 0.27392567]
   [0.23319494 0.25841128 0.21529356 0.29310023]
   [0.23317148 0.25708832 0.21101794 0.29872225]
   [0.23727382 0.25762636 0.22519288 0.27990694]]]


 [[[0.22950956 0.28830068 0.27755438 0.20463537]
   [0.21855043 0.30629612 0.28440016 0.1907533 ]
   [0.20528894 0.32062488 0.29445374 0.17963244]
   [0.24106252 0.2749199  0.26950374 0.21451384]]

  [[0.18042908 0.34711254 0.30100495 0.17145344]
   [0.1250484  0.41893334 0.34344681 0.11257145]
   [0.16787546 0.36214089 0.31308649 0.15689716]
   [0.17424653 0.34688815 0.31437238 0.16449294]]]]


In [73]:
output = attention @ V

#### Merging head test

In [74]:
output

array([[[[0.53080813, 0.61757762, 0.40079894],
         [0.53076114, 0.61763796, 0.40059782],
         [0.52954473, 0.61748829, 0.3978924 ],
         [0.53210138, 0.61779622, 0.40358274]],

        [[0.3142161 , 0.68820484, 0.44146245],
         [0.31144675, 0.69005191, 0.44428768],
         [0.31019895, 0.69103171, 0.44538531],
         [0.31329397, 0.68878863, 0.44237228]]],


       [[[0.80709804, 0.79681708, 0.72722673],
         [0.81770814, 0.80270622, 0.74702577],
         [0.82786545, 0.80819952, 0.7654466 ],
         [0.79802175, 0.79184352, 0.71096255]],

        [[0.4283847 , 0.61980924, 0.40860138],
         [0.44722003, 0.62206956, 0.4263595 ],
         [0.43269902, 0.62017097, 0.4127364 ],
         [0.42941869, 0.61961798, 0.41036065]]]])

In [75]:
output.transpose(0, 2, 1, 3)

array([[[[0.53080813, 0.61757762, 0.40079894],
         [0.3142161 , 0.68820484, 0.44146245]],

        [[0.53076114, 0.61763796, 0.40059782],
         [0.31144675, 0.69005191, 0.44428768]],

        [[0.52954473, 0.61748829, 0.3978924 ],
         [0.31019895, 0.69103171, 0.44538531]],

        [[0.53210138, 0.61779622, 0.40358274],
         [0.31329397, 0.68878863, 0.44237228]]],


       [[[0.80709804, 0.79681708, 0.72722673],
         [0.4283847 , 0.61980924, 0.40860138]],

        [[0.81770814, 0.80270622, 0.74702577],
         [0.44722003, 0.62206956, 0.4263595 ]],

        [[0.82786545, 0.80819952, 0.7654466 ],
         [0.43269902, 0.62017097, 0.4127364 ]],

        [[0.79802175, 0.79184352, 0.71096255],
         [0.42941869, 0.61961798, 0.41036065]]]])

In [76]:
d_head * H

6

In [77]:
output = output.transpose(0, 2, 1, 3).reshape(B, T, d_head * H)
print(output)

[[[0.53080813 0.61757762 0.40079894 0.3142161  0.68820484 0.44146245]
  [0.53076114 0.61763796 0.40059782 0.31144675 0.69005191 0.44428768]
  [0.52954473 0.61748829 0.3978924  0.31019895 0.69103171 0.44538531]
  [0.53210138 0.61779622 0.40358274 0.31329397 0.68878863 0.44237228]]

 [[0.80709804 0.79681708 0.72722673 0.4283847  0.61980924 0.40860138]
  [0.81770814 0.80270622 0.74702577 0.44722003 0.62206956 0.4263595 ]
  [0.82786545 0.80819952 0.7654466  0.43269902 0.62017097 0.4127364 ]
  [0.79802175 0.79184352 0.71096255 0.42941869 0.61961798 0.41036065]]]


##### split head test

In [78]:
# def split_heads(x, H):
#     # x: (B, T, C) => (B, H, T,d_head)
#     B, T, C = x.shape
#     d_head = C // H
#     x = x.reshape(B, T, H, d_head)
#     return x.transpose(0, 2, 1, 3)

In [79]:
d_head = d_model // H
print(d_head)

3


In [80]:
x = x.reshape(B, T, H, d_head)
print(x.shape, x)

(2, 4, 2, 3) [[[[0.90509585 0.14851999 0.25566055]
   [0.24223111 0.08956021 0.72648858]]

  [[0.49791102 0.16358028 0.57875576]
   [0.37545161 0.56838617 0.42600403]]

  [[0.44269466 0.28213245 0.80019635]
   [0.23128505 0.85859306 0.48280562]]

  [[0.99429198 0.10303178 0.43777483]
   [0.61520002 0.11101223 0.98433274]]]


 [[[0.85182281 0.10232676 0.49321632]
   [0.32690038 0.90700098 0.44016645]]

  [[0.20780487 0.72052569 0.18757965]
   [0.83795396 0.55666821 0.30572708]]

  [[0.9821657  0.196397   0.21034739]
   [0.98660463 0.50076549 0.63848679]]

  [[0.52648938 0.70011372 0.84993214]
   [0.19342952 0.3376662  0.01733952]]]]


จากการ reshape เราจะพบว่าตำแหน่งมิติของ attention จะอยู่ในลำดับที่ 3 (index = 2) เนื่องจากเราจะแยกทำทีละ attention เราจึงต้องเรียงมิติใหม่โดยใช้ np.transpos

In [81]:
x = x.transpose(0, 2, 1, 3) 
print(x.shape)
print(x)

(2, 2, 4, 3)
[[[[0.90509585 0.14851999 0.25566055]
   [0.49791102 0.16358028 0.57875576]
   [0.44269466 0.28213245 0.80019635]
   [0.99429198 0.10303178 0.43777483]]

  [[0.24223111 0.08956021 0.72648858]
   [0.37545161 0.56838617 0.42600403]
   [0.23128505 0.85859306 0.48280562]
   [0.61520002 0.11101223 0.98433274]]]


 [[[0.85182281 0.10232676 0.49321632]
   [0.20780487 0.72052569 0.18757965]
   [0.9821657  0.196397   0.21034739]
   [0.52648938 0.70011372 0.84993214]]

  [[0.32690038 0.90700098 0.44016645]
   [0.83795396 0.55666821 0.30572708]
   [0.98660463 0.50076549 0.63848679]
   [0.19342952 0.3376662  0.01733952]]]]


#### end of attention
---

In [82]:
print(output)

[[[0.53080813 0.61757762 0.40079894 0.3142161  0.68820484 0.44146245]
  [0.53076114 0.61763796 0.40059782 0.31144675 0.69005191 0.44428768]
  [0.52954473 0.61748829 0.3978924  0.31019895 0.69103171 0.44538531]
  [0.53210138 0.61779622 0.40358274 0.31329397 0.68878863 0.44237228]]

 [[0.80709804 0.79681708 0.72722673 0.4283847  0.61980924 0.40860138]
  [0.81770814 0.80270622 0.74702577 0.44722003 0.62206956 0.4263595 ]
  [0.82786545 0.80819952 0.7654466  0.43269902 0.62017097 0.4127364 ]
  [0.79802175 0.79184352 0.71096255 0.42941869 0.61961798 0.41036065]]]


In [83]:
outputN = layer_norm(output, gamma_demo, beta_demo)
print(outputN)

[[[ 0.23386173  0.69338454  0.04387221 -0.54342843  0.4508408
    0.09000136]
  [ 0.23359943  0.69169471  0.04420301 -0.54910848  0.45106648
    0.09575234]
  [ 0.23320724  0.69124942  0.04205275 -0.54783866  0.45169223
    0.09951963]
  [ 0.23408066  0.69242701  0.04660168 -0.54916425  0.45041617
    0.09022444]]

 [[ 0.28303643  0.71090658  0.28076641 -0.45816233  0.14876411
   -0.140007  ]
  [ 0.28354214  0.7065472   0.29068536 -0.45126559  0.13597661
   -0.13832   ]
  [ 0.28269302  0.69955679  0.30161103 -0.45560393  0.13474884
   -0.13739319]
  [ 0.28324402  0.71632128  0.27091456 -0.45853116  0.15381663
   -0.13996342]]]
