### -1. Back prop for matmul

Given

$$
Y = WX, \quad y_{ij} = \sum_a w_{ia}x_{aj}
$$

Note that

$$
\begin{align*}
\dfrac{\partial y_{ij}}{\partial x_{mn}}&=\dfrac{\partial}{\partial x_{mn}}\left(\sum_aw_{ia}x_{aj}\right)\\
&=\sum_a\dfrac{\partial}{\partial x_{mn}}w_{ia}x_{aj}\\
&=\sum_a w_{ia}\delta_{m,a}\delta_{n,j}\\
&=w_{im}\delta_{n,j}
\end{align*}
$$

thus

$$
\begin{align*}
\dfrac{\partial L}{\partial x_{mn}}&=\sum_{ij}\dfrac{\partial L}{\partial y_{ij}}\dfrac{\partial y_{ij}}{\partial x_{mn}}\\
&=\sum_{ij}\dfrac{\partial L}{\partial y_{ij}}w_{im}\delta_{n,j}\\
&=\sum_{i}\dfrac{\partial L}{\partial y_{in}}w_{im}
\end{align*}
$$

therefore

$$
\left(\dfrac{\partial L}{\partial X}\right)=W^T\left(\dfrac{\partial L}{\partial Y}\right)
$$

and similarly,

$$
\dfrac{\partial L}{\partial W}=\dfrac{\partial L}{\partial Y} X^T
$$

#### 0. Input and embedding

Let:

* $I \in \mathbb{Z}^{T}$: token indices of the input sequence
* Embedding matrix: $E_{\text{lookup}} \in \mathbb{R}^{d_{\text{token}} \times d_{\text{model}}}$

Then:

$$
X = E_{\text{lookup}}[I] \in \mathbb{R}^{T \times d_{\text{model}}}
$$

This means that each token index in $I$ is used to select a row from $E_{\text{lookup}}$.

Note that this is not a matmul (it can be written as a matmul if $I$ was one-hot encoded), it is just a lookup. However, autograd engines can track which row it used and back prop gradients to that row.

#### 1. Linear Projections

We define three learnable projection matrices:

* $W_Q \in \mathbb{R}^{d_{\text{model}} \times d_k}$
* $W_K \in \mathbb{R}^{d_{\text{model}} \times d_k}$
* $W_V \in \mathbb{R}^{d_{\text{model}} \times d_v}$

Then the projections are:

$$
Q = X W_Q \in \mathbb{R}^{T \times d_k}
$$

$$
K = X W_K \in \mathbb{R}^{T \times d_k}
$$

$$
V = X W_V \in \mathbb{R}^{T \times d_v}
$$

The dimension $k$ is usually taken to be even for RoPE and acceleration. Each token corresponds to a row.

#### 2. RoPE

Positional Embeddings are necessary to make sure the model is position-ware to inputs. For each token position $t\in\{1,\cdots,T\}$, and each dimension pair $i\in\{1,\cdots, d_k/2\}$ (arrays start from 1), define

$$
q_{t,i} = \left[\begin{matrix}
Q_{t, 2i-1}\\ Q_{t,2i}
\end{matrix}\right],\quad
k_{t,i} = \left[\begin{matrix}
K_{t, 2i-1}\\ K_{t,2i}
\end{matrix}\right]
$$

and define

$$
R(i, t) = \left[\begin{matrix}
\cos(\omega_i t) & -\sin(\omega_i t)\\
\sin(\omega_i t) & \cos(\omega_i t)
\end{matrix}\right]
$$

so that

$$
q'_{t,i}=R(i, t)q_{t,i},\quad k'_{t,i}=R(i, t)k_{t,i}
$$

and we put it back into $Q', K'$.

We use RoPE mainly because of the property where the dot product only relies on the original dot product and the relative position.

#### 3. Naive Attention

We compute attention weights:

$$
A = \frac{Q' K'^\top}{\sqrt{d_k}} \in \mathbb{R}^{T \times T}
$$

and $A_{ij}$ is basically the dot product of the query of token $i$ and the key of token $j$. So now the $i$-th row of $A$ is the 'score' of keys of other tokens against token $i$. We apply row-wise softmax to convert them to weights that adds to 1:

$$
\alpha_{ij} = \frac{\exp\left( A_{ij} \right)}{\sum_{k=1}^{T} \exp\left( A_{ik} \right)}
$$

then we use this to average against $V$:

$$
O = \alpha V \in \mathbb{R}^{T\times d_v}
$$

#### 4. Masks

Before applying softmax, we may need to mask out certain entries in $A$ to force better behaviour.

(a) Causal Mask: For auto-regressive generation, the output of token $i$ should only use the values of tokens $j\le i$, since the future is not avaliable at inference-time. Define 

$$
(M_c)_{ij}=\begin{cases}
0, & j\le i\\
-\inf, &j > i
\end{cases}
$$

and let $A'=A+M_c$, this effective turns the weights / scores of masked entries to zero.

(b) Padding Mask: If padding tokens were to be used, they would be meaningless, therefore, we zero out every entry whose indices contains a padding token.

For the autograd engine, the gradient will only flow to non-masked entries:

$$
\dfrac{\partial L}{\partial \alpha} = \dfrac{\partial L}{\partial O} V^T
$$

$$
\dfrac{\partial L}{\partial A'_{ij}}=\sum_{ab}\dfrac{\partial L}{\partial \alpha_{ab}}\dfrac{\partial \alpha_{ab}}{\partial A'_{ij}}
$$

Since the softmax was applied row-wise, within the same row we have

$$
\dfrac{\partial \alpha_{ij}}{\partial A'_{ik}}= \begin{cases}
\alpha_{ij}(1-\alpha_{ij}),& j = k\\
-\alpha_{ij}\alpha_{ik},& j\neq k\\
\end{cases}
$$

therefore

$$
\dfrac{\partial L}{\partial A'_{ij}}=\sum_b\dfrac{\partial L}{\partial \alpha_{ib}}\dfrac{\partial \alpha_{ib}}{\partial A'_{ij}}
$$

note that we can factorize $\alpha_{ij}$ in the above expression, therefore it is zero.

For $A$,

$$
\dfrac{\partial L}{\partial A_{ij}}=\sum_{ab}\dfrac{\partial L}{\partial A'_{ab}}\dfrac{\partial A'_{ab}}{\partial A_{ij}}= \dfrac{\partial L}{\partial A'_{ij}}=\dfrac{\partial L}{\partial M_{ij}}
$$

note that $M$ is not learnable so we turn off the gradients for $M$.


#### 5. Multi-head Attention, Multi-query Attention, and Grouped-query Attention

Split $d_k$, $d_v$ into $h$ heads:

$$
d_k=h\times d^h_k,\quad d_v=h\times d^h_v
$$

define per-head projection matrices for queries, keys, values for $m\in\{1,\cdots,h\}$:

$$
W_Q^m\in\mathbb{R}^{d_{model}\times d_k^h},\quad W_K^m\in\mathbb{R}^{d_{model}\times d_k^h},\quad W_V^m\in\mathbb{R}^{d_{model}\times d_v^h}
$$

Then for each input $X\in\mathbb{R}^{T\times d_{model}}$, we have

$$
Q^m=XW_Q^m,\quad K^m=XW^m_K,\quad V^m=XW^m_V
$$

and we apply RoPE for each head ($Q^m$, $K^m$), then apply attention and softmax for each head, getting $O^m\in\mathbb{R}^{T\times d^h_v}$. Concatenating all heads gives $O\in \mathbb{R}^{T\times d_v}$.

For multi-query attention, each head only have their own $Q^m\in \mathbb{R}^{T\times d^h_k}$, and share $K\in  \mathbb{R}^{T\times d^h_k}$ and $V\in \mathbb{R}^{T\times d^h_v}$. 

For grouped-query attention, we divide the set of heads $\{1,\cdots,h\}$ into $g$ disjoint groups, with each group sharing a set of keys and values. Queries remain per-head as in MHA and MQA.

#### 6. Back Prop (The intended way)

Given $O=\sigma(A-M)V=DV$, where $A = QK^T/\sqrt{d_k}$, we deduce:

$$
\dfrac{\partial L}{\partial V} = D^T\dfrac{\partial L}{\partial O}
$$

As for $Q, K$, first let $A=QK^T/\sqrt{d_k}$, then by the back prop of matmul, we have

$$
\dfrac{\partial L}{\partial Q}=\dfrac{1}{\sqrt{d_k}}\left(\dfrac{\partial L}{\partial A}\right)K
$$

$$
\dfrac{\partial L}{\partial K}=\left(\dfrac{\partial L}{\partial K^T}\right)^T=\dfrac{1}{\sqrt{d_k}}\left(\dfrac{\partial L}{\partial A}\right)^TQ
$$

$$
\dfrac{\partial L}{\partial D} = \dfrac{\partial L}{\partial O} V^T
$$

for softmax, we also have

$$
\dfrac{\partial D_{ab}}{\partial A_{cd}}=\delta_{a,c}D_{ab}(\delta_{b,d}-D_{ad})
$$

thus

$$
\begin{align*}
\dfrac{\partial L}{\partial A_{cd}}&=\sum_{ab}\dfrac{\partial L}{\partial D_{ab}}\dfrac{\partial D_{ab}}{\partial A_{cd}}\\
&=\sum_b\dfrac{\partial L}{\partial D_{cb}}D_{cb}(\delta_{b,d}-D_{cd})\\
&=D_{cd}\left(\dfrac{\partial L}{\partial D_{cd}}-\sum_bD_{cb}\dfrac{\partial L}{\partial D_{cb}}\right)\\
&=D_{cd}\left(\sum_{m}\dfrac{\partial L}{\partial O_{cm}}V_{dm}-\sum_bD_{cb}\sum_n\dfrac{\partial L}{\partial O_{cm}}V_{bm}\right)\\
&=D_{cd}\left(\sum_{m}\dfrac{\partial L}{\partial O_{cm}}V_{dm}-\sum_n\dfrac{\partial L}{\partial O_{cn}}O_{cn}\right)\\
&=D_{cd}\left(\left(\dfrac{\partial L}{\partial D}\right)_{cd}-\sum_n\dfrac{\partial L}{\partial O_{cn}}O_{cn}\right)
\end{align*}
$$

This allows us to calculate all the gradients by propagating through $A$.

##### 6.0. Masking

Causal masking ensures that the $m$-output token $O_m$ only depends on $K$ and $V$ positions $n\le m$. That is, each query attends only to the past or current tokens. Therefore the gradient $\frac{\partial L}{\partial V_n}$ and $\frac{\partial L}{\partial K_n}$ only receive contributions from $O_m$ with $m\ge n$.

The gradient $\frac{\partial L}{\partial Q_m}$ only depends on $O_m$, since attention depends on the current query only.


##### 6.1. The un-intended way

$$
\begin{align*}
\dfrac{\partial L}{\partial Q_{ij}}&=\sum_{abcd}\dfrac{\partial L}{\partial D_{ab}}\dfrac{\partial D_{ab}}{\partial F_{cd}}\dfrac{\partial F_{cd}}{\partial Q_{ij}}\\
&=\sum_{abcde}\dfrac{\partial L}{\partial O_{ae}}V_{be}\delta_{a,c}D_{ab}(\delta_{b,d}-D_{ad})\dfrac{1}{\sqrt{d_k}}\delta_{c,i}K_{dj}\\
&=\sum_{bde}\dfrac{\partial L}{\partial O_{ie}}V_{be}D_{ib}(\delta_{b,d}-D_{id})\dfrac{1}{\sqrt{d_k}}K_{dj}\\
&=\dfrac{1}{\sqrt{d_k}}\sum_{be}\dfrac{\partial L}{\partial O_{ie}}V_{be}D_{ib}\left(K_{bj}-\sum_dD_{id}K_{dj}\right)\\
&=\dfrac{1}{\sqrt{d_k}}\left(\sum_{be}\dfrac{\partial L}{\partial O_{ie}}V_{be}D_{ib}K_{bj}\right)-\left(\sum_d D_{id}K_{dj}\right)\Delta_i
\end{align*}
$$

$$
\begin{align*}
\dfrac{\partial L}{\partial K_{ij}} 
&= \sum_{abcd} \dfrac{\partial L}{\partial D_{ab}} \dfrac{\partial D_{ab}}{\partial F_{cd}} \dfrac{\partial F_{cd}}{\partial K_{ij}} \\
&= \sum_{abcde} \dfrac{\partial L}{\partial O_{ae}} V_{be} \delta_{a,c} D_{ab} (\delta_{b,d} - D_{ad}) \dfrac{1}{\sqrt{d_k}} \delta_{d,i} Q_{cj} \\
&= \sum_{abe} \dfrac{\partial L}{\partial O_{ae}} V_{be} D_{ab} (\delta_{b,i} - D_{ai}) \dfrac{1}{\sqrt{d_k}} Q_{aj} \\
&= \dfrac{1}{\sqrt{d_k}} \sum_{ae} \dfrac{\partial L}{\partial O_{ae}} D_{ai} Q_{aj} \left( V_{ie} - \sum_b D_{ab} V_{be} \right)\\
&= \dfrac{1}{\sqrt{d_k}} \sum_{a} D_{ai} Q_{aj}\sum_{e} \dfrac{\partial L}{\partial O_{ae}} \left( V_{ie} - O_{ae} \right)\\
&= \dfrac{1}{\sqrt{d_k}} \sum_{a} D_{ai} Q_{aj}\left(\left( \sum_{e} \dfrac{\partial L}{\partial O_{ae}}V_{ie}\right) - \Delta_a\right)
\end{align*}
$$

where

$$
\Delta_k=\sum_l\dfrac{\partial L}{\partial O_{kl}}O_{kl}
$$



#### 7. GPU-Optimization: FlashAttention

On GPU, computing $\sigma(QK^T/\sqrt{d_k})V$ is slow because those are big matrices (cannot fit in SRAM) and the amount of computation and memory is quadratic over $d^h_k$.

The computation we need to do is for a given query $Q[i]$, compute dot products with all keys $K[j]$, apply softmax, then take a weighted sum over $V[j]$.

Our approach is to:

(1) Increase memory throughput: let each block compute BLOCK_M rows of queries, and within each block, we iterate (stream) over all of $K$ and corresponding $V$ (with row-tiling size BLOCK_N). Since every block is loading the same K, V-tiles, they might stay hot in the L2 cache. Also, most of the masked regions are never calculated.

(2) Fuse Operations. Results are not written back to DRAM until the computation is complete.