# Multi-Head Attention


## Definition: Attention

For a set of key-value pairs $\{(k_i,v_i)\}_{i=1}^N \in \mathbb{R}^{d_k\times d_v}$ and another set of queries $\{q_j\}_{j=1}^M \in \mathbb{R}^{d_k}$, atention returns the "expected" value $o_j \in \mathbb{R}^{d_v}$ for each querry $q_j, \ j=1,2,\dots,M$.

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d_k}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times d_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times d_v}$

Output: $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M \times d_v}$
$$ \text{Attention}(Q,K,V) = O  = \alpha V, \quad \text{ where } \quad \alpha = \text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right) \in \mathbb{R}^{M\times N}$$
$$ o_{i}  =  \sum_{j=1}^N\alpha_{i,j}v_j\, \quad \text{ where } \quad  \alpha_{i,j}=\frac{e^{\frac{q_i^{T}k_j}{\sqrt{d_k}}}}{\sum_{l=1}^Ne^{\frac{q_i^{T}k_l}{\sqrt{d_k}}}}$$


where $\text{softmax}$ is applied per row, so each row of $\alpha$ sums to one. 

If we denote by $v(q_i)$ the random variable "value of the querry $q_i$", for $i=1,2,\dots,M$, then the induced probability of $v(q_i)$ is 
$$p(v(q_i) = v_j)=\alpha_{i,j}, \quad j=1,2,\dots,N.$$

Notice that $\sum_{j=1}^N\alpha_{i,j}=1, \ i =1,2,\dots,M$.

**Note:** The value $p(v(q_i) = v_j)=\alpha_{i,j}$ is ussually interpreted as how much attention (the outpu $o_i$ of) the querry $q_i$ pays to value $v_j$. So the "attention" of $o_i$ is partitioned along the values $v_j$.

### Property: Key-Value Permutation Invariance

If $\pi_r(M)$ denotes an arbitrary permutation over the rows of a matrix $M$, then
$$\text{Attention}(Q,\pi_r(K),\pi_r(V))=\text{Attention}(Q,K,V) $$

Proof:

Notice that $\pi_r(M)$ can be written as $\pi_r(M) = R_{\pi}M$ for some permutation matrix $R_{\pi}$. A permutation matrix is a square binary matrix that has exactly one entry of 1 in each row and each column with all other entries 0

$$
\begin{align*}
\text{Attention}(Q,\pi_r(K),\pi_r(V)) & = \text{Attention}(Q,R_{\pi}K,R_{\pi}V),\\
& = \text{softmax}\left(\frac{Q(R_{\pi}K)^T}{\sqrt{d_k}}\right)R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^TR_{\pi}^{T}}{\sqrt{d_k}}\right)R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}R_{\pi}^{T}\right)R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)R_{\pi}^{T}R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = \text{Attention}(Q,K,V),
\end{align*}
$$
where in the fifth equality we used that postmultiply by a permutation matrix is a column permutation and softmax is permutation equivariant. Notice that for any permutation $\pi$ we have 
$$
\begin{align*}
\text{softmax}\left(\pi(x)\right)& = \text{softmax}\left(x_{\pi(1)},x_{\pi(2)},\dots,x_{\pi(n)}\right),\\
&=\frac{\left(e^{x_{\pi(1)}},e^{x_{\pi(2)}},\dots,e^{x_{\pi(n)}}\right)}{\sum_{i=1}^n e^{x_{\pi(i)}}},\\
&=\frac{\left(e^{x_{\pi(1)}},e^{x_{\pi(2)}},\dots,e^{x_{\pi(n)}}\right)}{\sum_{i=1}^n e^{x_{i}}},\\
&=\frac{\pi\left(e^{x_{1}},e^{x_{2}},\dots,e^{x_{n}}\right)}{\sum_{i=1}^n e^{x_{i}}},\\
&=\pi\left(\frac{\left(e^{x_{1}},e^{x_{2}},\dots,e^{x_{n}}\right)}{\sum_{i=1}^n e^{x_{i}}}\right),\\
&=\pi\left(\text{softmax}(x)\right),\\
\end{align*}
$$
and $\text{softmax}$ is applied to each row. So for $A = [a_1|a_2|\dots|a_n]^T = \frac{QK^T}{\sqrt{d_k}}$, $f=\text{softmax}$, and a column permutation $\pi_c(A) = [\pi(a_1)|\pi(a_2)|\dots|\pi(a_n)]^T = AC_{\pi}$, we have
$$
\begin{align*}
f(AC_{\pi}) & = f([\pi(a_1)|\pi(a_2)|\dots|\pi(a_n)]^T),\\
 & = ([f(\pi(a_1))|f(\pi(a_2))|\dots|f(\pi(a_n))]^T),\\
& = ([\pi(f(a_1))|\pi(f(a_2))|\dots|\pi(f(a_n))]^T),\\
& = \pi_c\left([f(a_1)|f(a_2)|\dots|f(a_n)]^T\right),\\
& = \pi_c\left(f(A)\right),\\
& =f(A)C_{\pi}.
\end{align*}
$$

### Property: Attention Permutation Equivariance

If $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\text{Attention}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\text{Attention}(Q,K,V)\right). $$

Proof:
Consider  $\pi_r(M) = R_{\pi}M$ for some permutation matrix $R_{\pi}$. From the previous result we have
$$
\begin{align*}
\text{Attention}(\pi_r(Q),\sigma_r(K),\sigma_r(V)) & = \text{Attention}(\pi_r(Q),K,V),\\
&= \text{Attention}(R_{\pi}Q,K,V),\\
& = \text{softmax}\left(\frac{(R_{\pi}Q)K^T}{\sqrt{d_k}}\right)V,\\
& = \text{softmax}\left(R_{\pi}\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = R_{\pi}\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = \pi_r\left(\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\right),\\
& = \pi_r\left(\text{Attention}(Q,K,V) \right),
\end{align*}
$$
where in the fifth equality we have used that the softmax funtion is applied individually for each row.

## Masked Attention

Masked attention intentionally breaks permutation equivariance so that order matters. In many tasks, like language modeling or temporal prediction, we don’t want tokens to freely attend to all others. Masking restricts attention to valid positions (e.g., past tokens), enforcing causal or directional structure instead of treating the sequence as an unordered set.

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d_k}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times d_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times d_v}$

Output: $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M \times d_v}$
$$ \text{MaskedAttention}_{B}(Q,K,V) = O  = \alpha V, \quad \text{ where } \quad \alpha = \text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right) \in \mathbb{R}^{M\times N}$$
where $B\in\mathbb{R}^{M\times N}$ and $\text{softmax}$ is applied to each row, so each row of $\alpha$ sums to one. 

**Note:** In general, the mask breaks the permutation equivariance properties of attention. 

### Property: The entries of $B$ cancel out the attention.

If $b_{i,s} \to -\infty$ for $s \neq i$, then the $i$-th output $o_i$ does not depend on $q_s$, $k_s$, or $v_s$. We have the formula:
$$ 
\begin{align*}
o_i =\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}} v_j.
\end{align*}
$$

**Proof:** If $b_{i,s}\to -\infty$ for some $(i,s)\in M\times N$ then
$$
\begin{align*}
\alpha_{i,j} &= \text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right)_{i,j} =\frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}},\\
&=\begin{cases}
\frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}},&\quad j\neq s\\
0 ,&\quad j = s.
\end{cases}
\end{align*}
$$
Then we get
$$ 
\begin{align*}
o_i &= \sum_{j=1}^N\alpha_{i,j} v_j=\sum_{j=1,j\neq s}^N\alpha_{i,j} v_j,\\
&=\sum_{j=1,j\neq s}^N\frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}} v_j
\end{align*}
$$
Since $i\neq s$ we conclude that $o_i$ does not depend on $q_s,k_s$ or $v_s$

## Definition: Attention (with weights)
Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}, \quad M\in\mathbb{N}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}, \quad N\in\mathbb{N}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}, \quad N\in\mathbb{N}$.

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c_q\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c_k \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c_v \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{M\times d_v}$
\begin{align*}
O =\text{Attention}_{\mathcal{W}}(Q,K,V)=\text{Attention}(QW_q+B_q,KW_k+B_k,VW_v+B_v)
\end{align*}
where
\begin{align*}
B_q &= [b_q|b_q|\dots|b_q]^T \in \mathbb{R}^{M \times d_k},\\ 
B_k &= [b_k|b_k|\dots|b_k]^T \in \mathbb{R}^{N \times d_k},\\
B_v &= [b_v|b_v|\dots|b_v]^T \in \mathbb{R}^{N \times d_v},\\
\mathcal{W} &= \{W_q,b_q,W_k,b_k,W_v,b_v\}.
\end{align*}

**Note:** The same $\operatorname{Attention}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{Attention}$ is **position-agnostic**.

### Property: Attention (with weights) Permutation Equivariance

If Attention with weights has no bias (linear projections instead of affine projections), the Attention is permutation equivariant. If $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\text{Attention}_{\mathcal{W}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\text{Attention}_{\mathcal{W}}(Q,K,V)\right) $$
where 
$$
\begin{align*}
\mathcal{W} &= \{W_q,b_q,W_k,b_k,W_v,b_v\},\\
b_q &= 0_{\mathbb{R}^{d_k}},\\
b_k &= 0_{\mathbb{R}^{d_k}},\\
b_v &= 0_{\mathbb{R}^{d_v}},\\
\end{align*}
$$

Proof: Consider  $\pi_r(M) = R_{\pi}M$ and  $\sigma_r(M) = R_{\sigma}M$ for some permutation matrix $R_{\pi}$ and $R_{\sigma}$ respectively. From the previous result we have
$$
\begin{align*}
\text{Attention}_{\mathcal{W}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))
&=\text{Attention}(\pi_r(Q)W_q,\sigma_r(K)W_k,\sigma_r(V)W_v),\\
&=\text{Attention}((R_{\pi}Q)W_q,(R_{\sigma}K)W_k,(R_{\sigma}V)W_v),\\
&=\text{Attention}(R_{\pi}(QW_q),R_{\sigma}(KW_k),R_{\sigma}(VW_v)),\\
&=R_{\pi}\text{Attention}(QW_q,KW_k,VW_v),\\
&=R_{\pi}\text{Attention}_{\mathcal{W}}(Q,K,V),\\
&=\pi_r\left(\text{Attention}_{\mathcal{W}}(Q,K,V)\right).
\end{align*}
$$

## Definition: Masked Attention with weights
Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$.

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c_q\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c_k \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c_v \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$

* a mask matrix $B\in \mathbb{R}^{M\times N}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{M\times d_v}$
\begin{align*}
O =\text{MaskedAttention}_{\mathcal{W}}(Q,K,V)=\text{MaskedAttention}_B(QW_q+B_q,KW_k+B_k,VW_v+B_v)
\end{align*}
where
\begin{align*}
B_q &= [b_q|b_q|\dots|b_q]^T \in \mathbb{R}^{M \times d_k},\\ 
B_k &= [b_k|b_k|\dots|b_k]^T \in \mathbb{R}^{N \times d_k},\\
B_v &= [b_v|b_v|\dots|b_v]^T \in \mathbb{R}^{N \times d_v},\\
\mathcal{W} &= \{W_q,b_q,W_k,b_k,W_v,b_v,B\}.
\end{align*}

**Note:** The mask breaks the permutation-equivariance property of attention. In general, attention is not even position-agnostic, since the entries of $B$ depend on the input sequence lengths $M$ and $N$. Although this limitation can be mitigated by imposing specific structures on $B$, the masked version remains inherently order-dependent.

### Property: The entries of $B$ cancel out the attention in the Attention (with weights).

If $b_{i,s} \to -\infty$ for $s \neq i$, then the $i$-th output $o_i$ does not depend on $q_s$, $k_s$, or $v_s$. We have
\begin{align*}
o_i =\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}
where

$$ 
\begin{align*}
\hat{q}_i&=W_q^Tq_i+b_q,\\
\hat{k}_i&=W_k^Tk_i+b_k,\\
\hat{v}_i&=W_v^Tv_i+b_v,\\
\end{align*}
$$

**Proof:** If $b_{i,s}\to -\infty$ for some $(i,s)\in M\times N$, we can apply yhe result for the Attention function to obtain 
$$ 
\begin{align*}
o_i =\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}
$$
where the vectors $\hat{q}_i$, $\hat{k}_i$ and $\hat{v}_i$ are the $i$-th rows of $QW_q+B_q$,$KW_k+B_k$ and $VW_v+B_v$ repectively. So we have

$$ 
\begin{align*}
\hat{q}^T_i&=q_i^TW_q+b_q^T,\\
\hat{k}^T_i&=k_i^TW_k+b_k^T,\\
\hat{v}^T_i&=v_i^TW_v+b_v^T,\\
\end{align*}
$$

Taking matrix tranpose in the previous equations we get the result. Since $i\neq s$ we conclude that $o_i$ does not depend on $q_s,k_s$ or $v_s$

## Definition: Self-Attention
Inputs
* $X \in \mathbb{R}^{N \times c}, \quad N\in \mathbb{N}$ 

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{N\times d_v}$
\begin{align*}
O =\text{SelfAttention}_{\mathcal{W}}(X)=\text{Attention}_{\mathcal{W}}(X,X,X),
\end{align*}
where
\begin{align*}
\mathcal{W} &= \{W_q,b_q,W_k,b_k,W_v,b_v\}.
\end{align*}

### Property: Self-Attention Permutation Equivariance

If Self-Attention has no bias (linear projections instead of affine projections), the Self-Attention is permutation equivariant. If $\pi_r$ denotes as arbitrary permutation over the rows of a matrix, then
$$\text{SelfAttention}_{\mathcal{W}}(\pi_r(X))=\pi_r\left(\text{SelfAttention}_{\mathcal{W}}(X)\right) $$
where 
$$
\begin{align*}
\mathcal{W} &= \{W_q,b_q,W_k,b_k,W_v,b_v\},\\
b_q &= 0_{\mathbb{R}^{d_k}},\\
b_k &= 0_{\mathbb{R}^{d_k}},\\
b_v &= 0_{\mathbb{R}^{d_v}},\\
\end{align*}
$$

Proof: From the previous result we have
$$
\begin{align*}
\text{SelfAttention}_{\mathcal{W}}(\pi_r(X))
&=\text{Attention}_{\mathcal{W}}(\pi_r(X),\pi_r(X),\pi_r(X)),\\
&=\pi_r\left(\text{Attention}_{\mathcal{W}}(X,X,X)\right),\\
&=\pi_r\left(\text{SelfAttention}_{\mathcal{W}}(X)\right).
\end{align*}
$$

## Definition: Masked Self-Attention
Inputs
* $X = [x_1|x_2|\dots|x_N]^T \in \mathbb{R}^{N \times c}$ 

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$
* a mask matrix $B\in \mathbb{R}^{N\times N}.$ 

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{N\times d_v}$
\begin{align*}
O =\text{MaskedSelfAttention}_{\mathcal{W}}(X)=\text{MaskedAttention}_{\mathcal{W}}(X,X,X),
\end{align*}
where
\begin{align*}
\mathcal{W} &= \{W_q,b_q,W_k,b_k,W_v,b_v,B\}.
\end{align*}

**Note:** In general, the mask breaks the permutation equivariance properties of Self-Attention and the Self-Attention is not even position agnostic anymore, notice that the weights in $B$ depend on input sequence legth $N$. 

### Property: The entries of $B$ cancel out the attention in the Self-Attention.

If $b_{i,s} \to -\infty$ for $s \neq i$, then the $i$-th output $o_i$ does not depend on $x_s$. We have
\begin{align*}
o_i =\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}
where

$$ 
\begin{align*}
\hat{q}_i&=W_q^Tx_i+b_q,\\
\hat{k}_i&=W_k^Tx_i+b_k,\\
\hat{v}_i&=W_v^Tx_i+b_v,\\
\end{align*}
$$

**Proof:** It is straigthforwrd from the same result for Attention with weights.

## Definition: Multi-Head Attention

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}, \quad M\in\mathbb{N}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}, \quad N\in\mathbb{N}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}, \quad N\in\mathbb{N}$

Weights
* $W_{q,i} \in \mathbb{R}^{c_q \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O  &= \text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)|
\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where


\begin{align*}
B_o  &= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d_o},\\ 
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h.
\end{align*}

Usually we have:
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = c_k = c_v = d_h\cdot h = d_o$ (model dimension)  
**Note:** In this case, internally we have:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{d\times d}$  

**Note:** The same $\operatorname{MultiHeadAttention}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{MultiHeadAttention}$ is **position-agnostic**.

### Property: Multi-Head Attention Permutation Equivariance

If Multi-Head Attention has no bias (linear projections instead of affine projections), the Multi-Head Attention is permutation equivariant. If $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V)\right) $$

Proof:  Consider  $\pi_r(M) = R_{\pi}M$ and $\sigma_r(M) = R_{\sigma}M$ for permutation matrices $R_{\pi}$ and $R_{\sigma}$ respectively. From the Attention equivariance we have
$$
\begin{align*}
&\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(Q),\sigma_r(K),\sigma_r(V))\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{B}_{1}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))|
\text{Attention}_{\mathcal{B}_{2}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))|
\dots|
\text{Attention}_{\mathcal{B}_{h}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))
\end{pmatrix}W_o,\\
&=\begin{pmatrix}
\pi_r\left(\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)\right)|
\pi_r\left(\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)\right)|
\dots|
\pi_r\left(\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)\right)
\end{pmatrix}W_o,\\
&=\begin{pmatrix}
R_{\pi}\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)|
R_{\pi}\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)|
\dots|
R_{\pi}\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)
\end{pmatrix}W_o,\\
&=R_{\pi}\begin{pmatrix}
\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)|
\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)
\end{pmatrix}W_o,\\
&=R_{\pi}\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\pi_r\left(\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V)\right).
\end{align*}
$$

## Definition: Multi-Head Attention (torch implementation)

Torch implementation corresponds to the particular (usual) case
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = d_h\cdot h = d_o$ (model dimension)  

Denote $d_h = d|h$ ($h$ must divide $d$). Then we have the simplified defintion

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{d \times d_h}$ and $b_{q,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_h}$ and $b_{k,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_h}$ and $b_{v,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d \times d }$ and $b_{o}\in \mathbb{R}^{d}$.


Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d}$
\begin{align*}
O  &= \text{MultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{W}_1}(Q,K,V)|
\text{Attention}_{\mathcal{W}_2}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{W}_h}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{d\times d}$  
    * $b_q=[b^T_{q,1}|b^T_{q,2}|\dots|b^T_{q,h}]^T\in\mathbb{R}^{d}$  
    * $b_k=[b^T_{k,1}|b^T_{k,2}|\dots|b^T_{k,h}]^T\in\mathbb{R}^{d}$ 
    * $b_v=[b^T_{v,1}|b^T_{v,2}|\dots|b^T_{v,h}]^T\in\mathbb{R}^{d}$  
    * $B_o= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d}$
    * $\mathcal{W}_0 = \{W_{o},b_{o}\}$
    * $\mathcal{W}_i = \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h$.

## Definition: Masked Multi-Head Attention

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{c_q \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $B_{i} \in \mathbb{R}^{M \times N}, \quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O  &= \text{MaskedMultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{MaskAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MaskAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MaskAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where $\operatorname{MaskAttn}=\operatorname{MaskedAttention}$ and


\begin{align*}
B_o  &= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d_o},\\ 
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B_i\}, \quad i = 1,2,\dots,h.
\end{align*}

### Property: If the same entry of all masks are $-\infty$, the corresponding attention in the Multi-Head Attention cancels out.

Denote the $(i,s) \in M\times N$ entry of the $l$-th mask matrix by $b^{(l)}_{i,s}$. If $b_{l,i,s} \to -\infty$ for fixed values $s \neq i$ and for all values $l=1,2,\dots,h$, then the $i$-th output $o_i$ of the masked Multi-Head Attention does not depend on $q_s$,$k_s$ or $v_s$. Particulary, we have:
$$
\begin{align*}
o_i &=W_o^T\begin{pmatrix}
o_{i}^{(1)}\\
o_{i}^{(2)}\\
\vdots\\
o_{i}^{(h)}
\end{pmatrix}
+ B^T_o\\
o_i^{(\kappa)} &=\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_l^{(\kappa)}+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_j^{(\kappa)}+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j^{(\kappa)}, \quad \kappa = 1,2,\dots,h.
\end{align*}
$$
where

$$ 
\begin{align*}
\hat{q}^{(\kappa)}_l&=W_{q,l}^Tq_l+b_q, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{k}^{(\kappa)}_l&=W_{k,l}^Tk_l+b_k, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{v}^{(\kappa)}_l&=W_{v,l}^Tv_l+b_v, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\end{align*}
$$

**Proof:** It is straigthforwrd from the same result for Masked Attention with weights.

## Definition: Multi-Head Self-Attention

This is usually denoted as $\text{MultiHeadAttention}(X)$ with only one argument.

Inputs
* $X \in \mathbb{R}^{N \times c}, \quad N \in \mathbb{N}.$ 

Weights
* $W_{q,i} \in \mathbb{R}^{c \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O = \text{MultiHeadSelfAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X).
\end{align*}
where
\begin{align*}
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h.
\end{align*}

**Note:** The same $\operatorname{MultiHeadSelfAttention}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{MultiHeadHeadAttention}$ is **position-agnostic**.

### Property: Multi-Head Self-Attention Permutation Equivariance

If Multi-Head-Self-Attention has no bias (linear projections instead of affine projections), the Multi-Head-Self-Attention is permutation equivariant. If $\pi_r(M)$ denotes an arbitrary permutation over the rows of a matrix $M$, then
$$\text{MultiHeadSelfAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(X))=\pi_r\left(\text{MultiHeadSelfAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(X)\right),$$


Proof: From the previous result we have
$$
\begin{align*}
\text{MultiHeadSelfAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(X))
&=\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(X),\pi_r(X),\pi_r(X)),\\
&=\pi_r\left(\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(X,X,X)\right),\\
&=\pi_r\left(\text{MultiHeadSelfAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(X)\right).
\end{align*}
$$

## Definition: Masked Multi-Head Self-Attention

This is usually denoted as $\text{MultiHeadAttention}(X)$ with only one argument.

Inputs
* $X \in \mathbb{R}^{N \times c}, \quad N \in \mathbb{N}.$ 

Weights
* $W_{q,i} \in \mathbb{R}^{c \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $B_{i} \in \mathbb{R}^{M \times N}, \quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O = \text{MaskedMultiHeadSelfAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MaskedMultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X).
\end{align*}
where
\begin{align*}
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B_i\}, \quad i = 1,2,\dots,h.
\end{align*}

### Property: If the same entry of all masks are $-\infty$, the corresponding attention in the Multi-Head Attention cancels out.

Denote the $(i,s) \in M\times N$ entry of the $l$-th mask matrix by $b^{(l)}_{i,s}$. If $b_{l,i,s} \to -\infty$ for fixed values $s \neq i$ and for all values $l=1,2,\dots,h$, then the $i$-th output $o_i$ of the masked Multi-Head Attention does not depend on $q_s$,$k_s$ or $v_s$. Particulary, we have:
$$
\begin{align*}
o_i &=W_o^T\begin{pmatrix}
o_{i}^{(1)}\\
o_{i}^{(2)}\\
\vdots\\
o_{i}^{(h)}
\end{pmatrix}
+ B^T_o\\
o_i^{(\kappa)} &=\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_l^{(\kappa)}+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_j^{(\kappa)}+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j^{(\kappa)}, \quad \kappa = 1,2,\dots,h.
\end{align*}
$$
where

$$ 
\begin{align*}
\hat{q}^{(\kappa)}_l&=W_{q,l}^Tx_l+b_q, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{k}^{(\kappa)}_l&=W_{k,l}^Tx_l+b_k, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{v}^{(\kappa)}_l&=W_{v,l}^Tx_l+b_v, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\end{align*}
$$

**Proof:** It is straigthforwrd from the same result for Multi-Head Attention with weights.

## Definition: Multi-Head Attention

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}, \quad M\in\mathbb{N}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}, \quad N\in\mathbb{N}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}, \quad N\in\mathbb{N}$

Weights
* $W_{q,i} \in \mathbb{R}^{c_q \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O  &= \text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)|
\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where


\begin{align*}
B_o  &= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d_o},\\ 
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h.
\end{align*}

Usually we have:
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = c_k = c_v = d_h\cdot h = d_o$ (model dimension)  
**Note:** In this case, internally we have:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{d\times d}$  

**Note:** The same $\operatorname{MultiHeadAttention}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{MultiHeadAttention}$ is **position-agnostic**.

### Property: Multi-Head Attention Permutation Equivariance

If Multi-Head Attention has no bias (linear projections instead of affine projections), the Multi-Head Attention is permutation equivariant. If $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V)\right) $$

Proof:  Consider  $\pi_r(M) = R_{\pi}M$ and $\sigma_r(M) = R_{\sigma}M$ for permutation matrices $R_{\pi}$ and $R_{\sigma}$ respectively. From the Attention equivariance we have
$$
\begin{align*}
&\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(\pi_r(Q),\sigma_r(K),\sigma_r(V))\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{B}_{1}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))|
\text{Attention}_{\mathcal{B}_{2}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))|
\dots|
\text{Attention}_{\mathcal{B}_{h}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))
\end{pmatrix}W_o,\\
&=\begin{pmatrix}
\pi_r\left(\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)\right)|
\pi_r\left(\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)\right)|
\dots|
\pi_r\left(\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)\right)
\end{pmatrix}W_o,\\
&=\begin{pmatrix}
R_{\pi}\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)|
R_{\pi}\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)|
\dots|
R_{\pi}\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)
\end{pmatrix}W_o,\\
&=R_{\pi}\begin{pmatrix}
\text{Attention}_{\mathcal{B}_{1}}(Q,K,V)|
\text{Attention}_{\mathcal{B}_{2}}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{B}_{h}}(Q,K,V)
\end{pmatrix}W_o,\\
&=R_{\pi}\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\pi_r\left(\text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=0}^h}(Q,K,V)\right).
\end{align*}
$$

## Definition: Multi-Head Attention (torch implementation)

Torch implementation corresponds to the particular (usual) case
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = d_h\cdot h = d_o$ (model dimension)  

Denote $d_h = d|h$ ($h$ must divide $d$). Then we have the simplified defintion

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{d \times d_h}$ and $b_{q,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_h}$ and $b_{k,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_h}$ and $b_{v,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d \times d }$ and $b_{o}\in \mathbb{R}^{d}$.


Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d}$
\begin{align*}
O  &= \text{MultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{Attention}_{\mathcal{W}_1}(Q,K,V)|
\text{Attention}_{\mathcal{W}_2}(Q,K,V)|
\dots|
\text{Attention}_{\mathcal{W}_h}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{d\times d}$  
    * $b_q=[b^T_{q,1}|b^T_{q,2}|\dots|b^T_{q,h}]^T\in\mathbb{R}^{d}$  
    * $b_k=[b^T_{k,1}|b^T_{k,2}|\dots|b^T_{k,h}]^T\in\mathbb{R}^{d}$ 
    * $b_v=[b^T_{v,1}|b^T_{v,2}|\dots|b^T_{v,h}]^T\in\mathbb{R}^{d}$  
    * $B_o= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d}$
    * $\mathcal{W}_0 = \{W_{o},b_{o}\}$
    * $\mathcal{W}_i = \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h$.

## Definition: Masked Multi-Head Attention

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{c_q \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $B_{i} \in \mathbb{R}^{M \times N}, \quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O  &= \text{MaskedMultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{MaskAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MaskAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MaskAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where $\operatorname{MaskAttn}=\operatorname{MaskedAttention}$ and


\begin{align*}
B_o  &= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d_o},\\ 
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B_i\}, \quad i = 1,2,\dots,h.
\end{align*}

### Property: If the same entry of all masks are $-\infty$, the corresponding attention in the Multi-Head Attention cancels out.

Denote the $(i,s) \in M\times N$ entry of the $l$-th mask matrix by $b^{(l)}_{i,s}$. If $b_{l,i,s} \to -\infty$ for fixed values $s \neq i$ and for all values $l=1,2,\dots,h$, then the $i$-th output $o_i$ of the masked Multi-Head Attention does not depend on $q_s$,$k_s$ or $v_s$. Particulary, we have:
$$
\begin{align*}
o_i &=W_o^T\begin{pmatrix}
o_{i}^{(1)}\\
o_{i}^{(2)}\\
\vdots\\
o_{i}^{(h)}
\end{pmatrix}
+ B^T_o\\
o_i^{(\kappa)} &=\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_l^{(\kappa)}+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_j^{(\kappa)}+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j^{(\kappa)}, \quad \kappa = 1,2,\dots,h.
\end{align*}
$$
where

$$ 
\begin{align*}
\hat{q}^{(\kappa)}_l&=W_{q,l}^Tq_l+b_q, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{k}^{(\kappa)}_l&=W_{k,l}^Tk_l+b_k, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{v}^{(\kappa)}_l&=W_{v,l}^Tv_l+b_v, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\end{align*}
$$

**Proof:** It is straigthforwrd from the same result for Masked Attention with weights.

## Definition: Multi-Head Self-Attention

This is usually denoted as $\text{MultiHeadAttention}(X)$ with only one argument.

Inputs
* $X \in \mathbb{R}^{N \times c}, \quad N \in \mathbb{N}.$ 

Weights
* $W_{q,i} \in \mathbb{R}^{c \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O = \text{MultiHeadSelfAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MultiHeadAttention}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X).
\end{align*}
where
\begin{align*}
\mathcal{W}_0 &= \{W_{o},b_{o}\},\\
\mathcal{W}_i &= \{W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\}, \quad i = 1,2,\dots,h.
\end{align*}

**Note:** The same $\operatorname{MultiHeadSelfAttention}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{MultiHeadHeadAttention}$ is **position-agnostic**.

## Code: Multi-Head Attention

The code for the remaining models are easily deduced from $\text{MultiHeadAttention}$:
* $\text{Attention}_{\mathcal{W}}(Q,K,V) = \text{MultiHeadAttention}_{\{\mathcal{W},\operatorname{Id},0\}}(Q,K,V)$ (one head)
* $\text{SelfAttention}_{\mathcal{W}}(X) = \text{Attention}_{\mathcal{W}}(X,X,X)$
* $\text{MultiHeadSelfAttention}_{\{\mathcal{B}_{i}\}_{i=1}^h\cup \{W_o,b_o\}}(X) = \text{MultiHeadAttention}_{\{\mathcal{B}_{i}\}_{i=1}^h\cup \{W_o,b_o\}}(X,X,X)$

In [79]:
import einops
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [80]:
class MultiHeadAttention(torch.nn.Module):
    """
    Multi-Head Attention module with (general) parameters
    cq, ck, cv: input dimensions for Q, K, V
    dk, dv: dimensions for each head's Q, K and V
    do: output dimension
    h: number of heads.
    
    The initialization of the weights differs from PyTorch’s `nn.MultiheadAttention`.
    Here we use standard `nn.Linear` initialization (Xavier uniform for weights and
    zeros for biases) for clarity and simplicity. 
    """

    def __init__(
        self, cq, ck, cv, dk, dv, do, h, bias=True, add_bias_kv=False, device=None, dtype=None
    ):
        super().__init__()
        assert dk % h == 0, "dk must be divisible by h"
        self.cq = cq
        self.ck = ck
        self.cv = cv
        self.dk = dk
        self.dv = dv
        self.do = do
        self.h = h
        self.add_bias_kv = add_bias_kv
        self.device = device
        self.dtype = dtype
        # Q -> QW_q+B_q
        self.q_proj = torch.nn.Linear(cq, dk * h, bias, self.device, self.dtype)
        # K -> KW_k+B_k
        self.k_proj = torch.nn.Linear(ck, dk * h, bias, self.device, self.dtype)
        # V -> VW_v+B_v
        self.v_proj = torch.nn.Linear(cv, dv * h, bias, self.device, self.dtype)

        self.out_proj = torch.nn.Linear(dv * h, do, bias, self.device, self.dtype)
        if self.add_bias_kv:
            self.bias_k = torch.nn.Parameter(
                torch.zeros(1, 1, dk * h, device=self.device, dtype=self.dtype)
            )
            self.bias_v = torch.nn.Parameter(
                torch.zeros(1, 1, dv * h, device=self.device, dtype=self.dtype)
            )

    def forward(self, Q, K, V,attn_mask=None):
        """Forward pass of the MHA module."""
        # Linear projections
        proj_q = self.q_proj(Q)  # Q=QW_q+B_q
        proj_k = self.k_proj(K)  # K=KW_k+B_k
        proj_v = self.v_proj(V)  # V=VW_v+B_v
        if self.add_bias_kv:
            # append bias to the key and value sequences
            batch_size = proj_k.shape[0]
            proj_k = torch.cat([proj_k, self.bias_k.repeat(batch_size, 1, 1)], dim=1)
            proj_v = torch.cat([proj_v, self.bias_v.repeat(batch_size, 1, 1)], dim=1)

        # Reshape for multi-head attention
        r_q = einops.rearrange(proj_q, "b m (h dk) -> b h m dk", h=self.h)
        r_k = einops.rearrange(proj_k, "b n (h dk) -> b h n dk", h=self.h)
        r_v = einops.rearrange(proj_v, "b n (h dv) -> b h n dv", h=self.h)

        # QK^T
        scores = torch.einsum("bhmd, bhnd -> bhmn", r_q, r_k)
        if attn_mask is not None:
            scores += attn_mask.unsqueeze(0).unsqueeze(0)  # Broadcasting the mask over heads
        
        # softmax(QK^T/sqrt(dk))
        attn = torch.nn.functional.softmax(scores / (self.dk**0.5), dim=-1)

        # softmax(QK^T/sqrt(dk))V
        o = torch.einsum("bhmn, bhnv -> bhmv", attn, r_v)

        # Reshape back
        r_o = einops.rearrange(o, "b h m dv -> b m (h dv)")

        # Final linear projection
        proj_o = self.out_proj(r_o)
        return proj_o


In [81]:
torch.arange(12).reshape(3,4).unsqueeze(0).shape    

torch.Size([1, 3, 4])

In [82]:
batch_dim = 3
M = 5  # sequence length of q
N = 3  # sequence length of k,v
d = 4 # embedding/model dimension
ck = 8  # key dimension 
cv = 16  # value dimension
h = 2  # number of heads

mask = torch.full((M, N), float('-inf'),device=device)
mask = torch.triu(mask, diagonal=1)
  # Lower triangular mask

bias = True
add_bias_kv = False

In [83]:
mask

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='mps:0')

#### Weights

In [84]:
nn_attn = torch.nn.MultiheadAttention(embed_dim=d,kdim=ck,vdim=cv,
                                        num_heads=h,batch_first=True,
                                        bias=bias,add_bias_kv=add_bias_kv).to(device)

assert d % h == 0, "d must be divisible by h"
dh = d // h
attn = MultiHeadAttention(cq = d, ck = ck, cv=cv, dk = dh ,dv=dh,do=d,h=h,bias=bias,add_bias_kv=add_bias_kv).to(device)

In [85]:
print("Torch MHA Weights:")
for name, w in nn_attn.named_parameters():
    print(f"{name} - {w.shape}")

print("\nOur MHA Weights:")
for name, w in attn.named_parameters():
    print(f"{name} - {w.shape}")

Torch MHA Weights:
q_proj_weight - torch.Size([4, 4])
k_proj_weight - torch.Size([4, 8])
v_proj_weight - torch.Size([4, 16])
in_proj_bias - torch.Size([12])
out_proj.weight - torch.Size([4, 4])
out_proj.bias - torch.Size([4])

Our MHA Weights:
q_proj.weight - torch.Size([4, 4])
q_proj.bias - torch.Size([4])
k_proj.weight - torch.Size([4, 8])
k_proj.bias - torch.Size([4])
v_proj.weight - torch.Size([4, 16])
v_proj.bias - torch.Size([4])
out_proj.weight - torch.Size([4, 4])
out_proj.bias - torch.Size([4])


#### Output

In [86]:
q = torch.rand(batch_dim,M,d,device=device)
k = torch.rand(batch_dim,N,ck,device=device) 
v = torch.rand(batch_dim,N,cv,device=device)

#### Load weights

Weights are created differently, so lets load the nn wieghts on our implementation

In [87]:
with torch.no_grad():
    # 1) copy weights (shapes already match)
    attn.q_proj.weight.copy_(nn_attn.q_proj_weight)   # (d, d)
    attn.k_proj.weight.copy_(nn_attn.k_proj_weight)   # (d, ck)
    attn.v_proj.weight.copy_(nn_attn.v_proj_weight)   # (d, cv)

    # 2) split the packed bias: (3d,) -> (d,) + (d,) + (d,)
    b = nn_attn.in_proj_bias      # shape (48,)
    if bias:
        attn.q_proj.bias.copy_(b[0:d])        # 0:d
        attn.k_proj.bias.copy_(b[d:2*d])      # d:2d
        attn.v_proj.bias.copy_(b[2*d:3*d])    # 2d:3d
    
    if add_bias_kv:
        attn.bias_k.copy_(nn_attn.bias_k.squeeze(0))
        attn.bias_v.copy_(nn_attn.bias_v.squeeze(0))

    # 3) output projection
    attn.out_proj.weight.copy_(nn_attn.out_proj.weight)
    if bias:
        attn.out_proj.bias.copy_(nn_attn.out_proj.bias)

In [88]:
out = attn(q, k, v,attn_mask=mask)
nn_out,_ = nn_attn(q,k,v, attn_mask=mask)

In [89]:
out

tensor([[[-0.2615, -0.0980, -0.0476,  0.0864],
         [-0.5672, -0.2472, -0.1618,  0.1811],
         [-0.4092, -0.2078, -0.0457,  0.1335],
         [-0.4180, -0.2088, -0.0589,  0.1355],
         [-0.4114, -0.2066, -0.0552,  0.1334]],

        [[-0.5213, -0.0770, -0.1666,  0.1922],
         [-0.5738, -0.1202, -0.2060,  0.2089],
         [-0.6630, -0.1685, -0.2321,  0.2355],
         [-0.6603, -0.1693, -0.2241,  0.2352],
         [-0.6556, -0.1652, -0.2271,  0.2332]],

        [[-0.9014, -0.4767, -0.2997,  0.2832],
         [-0.5683, -0.3029, -0.1355,  0.1863],
         [-0.6160, -0.2944, -0.1184,  0.2057],
         [-0.6061, -0.2948, -0.0987,  0.2034],
         [-0.5989, -0.2933, -0.0916,  0.2013]]], device='mps:0',
       grad_fn=<LinearBackward0>)

In [91]:
nn_out

tensor([[[-0.2615, -0.0980, -0.0476,  0.0864],
         [-0.5672, -0.2472, -0.1618,  0.1811],
         [-0.4092, -0.2078, -0.0457,  0.1335],
         [-0.4180, -0.2088, -0.0589,  0.1355],
         [-0.4114, -0.2066, -0.0552,  0.1334]],

        [[-0.5213, -0.0770, -0.1666,  0.1922],
         [-0.5738, -0.1202, -0.2060,  0.2089],
         [-0.6630, -0.1685, -0.2321,  0.2355],
         [-0.6603, -0.1693, -0.2241,  0.2352],
         [-0.6556, -0.1652, -0.2271,  0.2332]],

        [[-0.9014, -0.4767, -0.2997,  0.2832],
         [-0.5683, -0.3029, -0.1355,  0.1863],
         [-0.6160, -0.2944, -0.1184,  0.2057],
         [-0.6061, -0.2948, -0.0987,  0.2034],
         [-0.5989, -0.2933, -0.0916,  0.2013]]], device='mps:0',
       grad_fn=<TransposeBackward0>)