## Causal mask
---

causal mask มีไว้เพื่อบังคับ auto regression ที่ position t ให้มองเห็นได้แค่ $\leq t$ ไม่ให้เห็น $t + 1, t + 2, ...$

สมมุติ ตัวอย่าง

> * T = 4 จำนวน token มากสุด 4 ตัว
> * row = query position
> * col = key position

กฏคือ $ \text{col} \leq \text{row}$

> * q0: $\begin{bmatrix}1& 0& 0& 0\end{bmatrix}$  
> * q1: $\begin{bmatrix}1& 1& 0& 0\end{bmatrix}$  
> * q2: $\begin{bmatrix}1& 1& 1& 0\end{bmatrix}$  
> * q3: $\begin{bmatrix}1& 1& 1& 1\end{bmatrix}$  

ตำแหน่งที่เป็น 0 จะถูกบวกด้วย $-\infty$ ก่อนการทำ softmax

### ตำแหน่งที่ใช้ casual mask

* ใน transformer block ที่ self-attention เดิมเรามี 

$$ scores = \frac {Q \cdot K^T} {\sqrt{d_h}}$$

จากนั้นเรานำ socore ที่ได้

$$ scores = scores + causal\_mask$$

เมื่อ
* $Q = W_Q \cdot X$ &rarr; Q คือ query 
* $K = W_K \cdot X$ &rarr; K คือ key 
* $d_h = \frac {d_{model}} {H}$

* $d_{model}$ คือ hidden parameter ที่มาจากขนาดของเวคเตอร์ที่ได้จากการ embed token
* $H$ คือ จำนวน attention head

### code example 

In [6]:
import numpy as np

In [None]:
def make_causal_mask(T, neg_large = -1e9):
    # upper triangle (k > t ) = block
    m = np.triu(np.ones((T, T), dtype=np.float64), k = 1)
    mask = m * neg_large
    return mask # (T, T)



In [8]:
def softmax(x, axis = -1) :
    x = x - np.max(x, axis=axis, keepdims=True);
    ex = np.exp(x)
    return ex / np.sum(ex, axis=axis, keepdims=True)

In [9]:
scores = np.random.rand(4, 4)
print(scores)

[[0.50390039 0.5365974  0.41871129 0.81252469]
 [0.84036985 0.86761153 0.80269944 0.87209218]
 [0.69733857 0.93032391 0.81018176 0.74386275]
 [0.41280469 0.59346427 0.12186543 0.97038267]]


In [15]:
T = scores.shape[0]
causal_mask = make_causal_mask(T)
masked_scores = scores + causal_mask
Attn = softmax(masked_scores)

*Attn* stand for attention 

In [16]:
print("causal_mask=\n", causal_mask)
print("Attn=\n", Attn)

causal_mask=
 [[-0.e+00 -1.e+09 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -0.e+00]]
Attn=
 [[1.         0.         0.         0.        ]
 [0.49319    0.50681    0.         0.        ]
 [0.29569882 0.37327924 0.33102193 0.        ]
 [0.21312847 0.25532945 0.15932655 0.37221554]]


In [17]:
V = np.random.rand(4, 4) #stand for value
print("V =\n", V)

V =
 [[0.74636963 0.87301979 0.14951819 0.45018703]
 [0.64471524 0.95888822 0.22731667 0.93179853]
 [0.54371212 0.97139524 0.2648877  0.74728867]
 [0.76782001 0.01404621 0.1735202  0.56182687]]


In [19]:
OutputFromAttention = np.dot(Attn, V);

print("Output from attention head => \n")
print(OutputFromAttention)

Output from attention head => 

[[0.74636963 0.87301979 0.14951819 0.45018703]
 [0.69485017 0.91653877 0.18894724 0.69427255]
 [0.64134007 0.93763712 0.21674859 0.72830976]
 [0.69610971 0.59089504 0.19669778 0.66204689]]


### code example in case many batch

$$Q \in \mathbb{R}^{B \times H \times T \times d_h}$$

$$K \in \mathbb{R}^{B \times H \times T \times d_h}$$

เมื่อ $d_h$ คือจำนวนพารามิเตอร์ / จำนวน attention head

In [22]:
B = 2 # 2 batch
H = 3 # 3 attention head
T = 4 # 4 token ต่อ 1 batch
d = 6 # 6 embedded size 

dh = 2 # d / H

In [38]:
Q = np.random.rand(B, H, T, dh) 
K = np.random.rand(B, H, T, dh)
V = np.random.rand(B, H, T, dh)

In [28]:
K.transpose(0, 1, 3, 2).shape

(2, 3, 2, 4)

In [35]:
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(dh)
print(scores.shape)

(2, 3, 4, 4)


In [36]:
causal_mask = make_causal_mask(T)
print(causal_mask)


[[-0.e+00 -1.e+09 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -1.e+09 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -1.e+09]
 [-0.e+00 -0.e+00 -0.e+00 -0.e+00]]


In [37]:
scores = scores + causal_mask[None, None, :, :]
print(scores)

[[[[ 3.12863633e-01 -1.00000000e+09 -1.00000000e+09 -9.99999999e+08]
   [ 2.25268871e-01  1.42432977e-01 -1.00000000e+09 -1.00000000e+09]
   [ 1.88965341e-01  4.78668279e-01  4.26067823e-01 -9.99999999e+08]
   [ 2.93842260e-01  6.33129693e-01  5.96516571e-01  8.96083615e-01]]

  [[ 5.98580836e-01 -1.00000000e+09 -1.00000000e+09 -1.00000000e+09]
   [ 5.47363570e-01  4.64608704e-01 -1.00000000e+09 -1.00000000e+09]
   [ 6.98155768e-01  5.65976030e-01  1.62867396e-01 -1.00000000e+09]
   [ 7.41296641e-01  6.34992139e-01  2.32832600e-01  1.47403765e-01]]

  [[ 1.24176962e-01 -1.00000000e+09 -1.00000000e+09 -1.00000000e+09]
   [ 2.22770654e-01  2.04558495e-01 -1.00000000e+09 -1.00000000e+09]
   [ 2.66581849e-01  2.51726505e-01  8.50174499e-02 -1.00000000e+09]
   [ 4.70164120e-01  4.08047212e-01  3.07464795e-01  5.25382744e-01]]]


 [[[ 2.98567694e-01 -1.00000000e+09 -1.00000000e+09 -1.00000000e+09]
   [ 4.58717771e-01  3.24219353e-01 -9.99999999e+08 -1.00000000e+09]
   [ 6.55235559e-01  3.987

In [39]:
Attn = softmax(scores)
print(Attn.shape)

(2, 3, 4, 4)


In [40]:
Attn

array([[[[1.        , 0.        , 0.        , 0.        ],
         [0.52069714, 0.47930286, 0.        , 0.        ],
         [0.27750016, 0.37074869, 0.35175116, 0.        ],
         [0.17909503, 0.2514399 , 0.24240039, 0.32706467]],

        [[1.        , 0.        , 0.        , 0.        ],
         [0.52067692, 0.47932308, 0.        , 0.        ],
         [0.406226  , 0.35592851, 0.23784549, 0.        ],
         [0.32757425, 0.29453865, 0.19700926, 0.18087784]],

        [[1.        , 0.        , 0.        , 0.        ],
         [0.50455291, 0.49544709, 0.        , 0.        ],
         [0.35470817, 0.3494778 , 0.29581403, 0.        ],
         [0.25998395, 0.2443259 , 0.22094649, 0.27474366]]],


       [[[1.        , 0.        , 0.        , 0.        ],
         [0.53357401, 0.46642599, 0.        , 0.        ],
         [0.34642937, 0.26805973, 0.3855109 , 0.        ],
         [0.24451884, 0.23936896, 0.29251992, 0.22359228]],

        [[1.        , 0.        , 0.        , 

In [42]:
output = Attn @ V
print(output.shape)

(2, 3, 4, 2)


In [43]:
output

array([[[[0.78955968, 0.37541394],
         [0.58973034, 0.35510959],
         [0.4729767 , 0.47090235],
         [0.45612288, 0.35448738]],

        [[0.13537609, 0.62826851],
         [0.34295712, 0.436247  ],
         [0.39849529, 0.51446262],
         [0.35481428, 0.44777531]],

        [[0.20350649, 0.62553219],
         [0.46531218, 0.77509053],
         [0.43259926, 0.71092944],
         [0.31716871, 0.77889239]]],


       [[[0.14920089, 0.58289401],
         [0.31916874, 0.70206434],
         [0.22539753, 0.59122384],
         [0.27702893, 0.66209751]],

        [[0.89988387, 0.14176359],
         [0.71741481, 0.37487981],
         [0.56326507, 0.50600797],
         [0.57614152, 0.5594295 ]],

        [[0.25647566, 0.71580986],
         [0.45405469, 0.4273219 ],
         [0.56641927, 0.54277648],
         [0.53447418, 0.45788281]]]])