# Multi-Head Attention


## Definition: Attention($\operatorname{Attention}$)

For a set of key-value pairs $\{(k_i,v_i)\}_{i=1}^N \subset \mathbb{R}^{d_k\times d_v}$ and another set of queries $\{q_j\}_{j=1}^M \subset \mathbb{R}^{d_k}$, atention returns the "expected" value $o_j \in \mathbb{R}^{d_v}$ for each querry $q_j, \ j=1,2,\dots,M$.

$$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d_k}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times d_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times d_v}$

Output: $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M \times d_v}$
$$ \text{Attention}(Q,K,V) = O  = \alpha V, \quad \text{ where } \quad \alpha = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{M\times N}$$
$$ o_{i}  =  \sum_{j=1}^N\alpha_{i,j}v_j\, \quad \text{ where } \quad  \alpha_{i,j}=\frac{e^{\frac{q_i^{T}k_j}{\sqrt{d_k}}}}{\sum_{l=1}^Ne^{\frac{q_i^{T}k_l}{\sqrt{d_k}}}}$$


where $\text{softmax}$ is applied rowwise, so each row of $\alpha$ sums to one. $\sum_{j=1}^N\alpha_{i,j}=1, \ i =1,2,\dots,M$.

**Note:** The output $o_i$ does not depends on $q_j$ for any $j\neq i$.

**Note:** If we denote by $v(q_i)$ the random variable "value of the querry $q_i$", for $i=1,2,\dots,M$, with the induced probability
$$p(v(q_i) = v_j)=\alpha_{i,j}, \quad j=1,2,\dots,N,$$
Then
$$\mathbb{E}[v(q_i)]=\sum_{i=1}^{N}\alpha_{i,j}v_{j}=o_i, \quad j=1,2,\dots,N,$$


**Note:** The value $p(v(q_i) = v_j)=\alpha_{i,j}$ is ussually interpreted as how much attention (the value $o_i$ of) the querry $q_i$ pays to value $v_j$. So the "attention" of $q_i$ is partitioned along the values $v_j$.

### Property: $\text{Attention}$ Querry-Deletion Equivariance

If $A^{(i)}$ denotes the matrix obtained by deleting the $i$-th row of the matrix $A$, then  
$$\text{Attention}(Q,K,V)^{(i)}=\text{Attention}(Q^{(i)},K,V). $$

**Proof:**
Notice that $A^{(i)}=I^{(i)}A$, where $I$ is the identity matrix. Using this, we have
$$
\begin{align*}
\text{Attention}(Q^{(i)},K,V) & = \text{Attention}(I^{(i)}Q,K,V),\\
& = \text{softmax}\left(\frac{(I^{(i)}Q)K^T}{\sqrt{d_k}}\right)V,\\
& = \text{softmax}\left(I^{(i)}\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = I^{(i)}\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = \text{Attention}(Q,K,V)^{(i)} ,
\end{align*}
$$
where in the fourth equality we have used that the softmax funtion is applied rowwise.

### Property: $\text{Attention}$ Key-Value Permutation Invariance

If $\pi_r(A)$ denotes an arbitrary permutation over the rows of a matrix $A$, then
$$\text{Attention}(Q,\pi_r(K),\pi_r(V))=\text{Attention}(Q,K,V). $$

**Proof:**

Notice that $\pi_r(A)$ can be written as $\pi_r(A) = R_{\pi}A$ for some permutation matrix $R_{\pi}$. A permutation matrix is a square binary matrix that has exactly one entry of value 1 in each row and each column with all other entries 0

$$
\begin{align*}
\text{Attention}(Q,\pi_r(K),\pi_r(V)) & = \text{Attention}(Q,R_{\pi}K,R_{\pi}V),\\
& = \text{softmax}\left(\frac{Q(R_{\pi}K)^T}{\sqrt{d_k}}\right)R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^TR_{\pi}^{T}}{\sqrt{d_k}}\right)R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}R_{\pi}^{T}\right)R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)R_{\pi}^{T}R_{\pi}V,\\
& = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = \text{Attention}(Q,K,V),
\end{align*}
$$
where in the fifth equality we used that postmultiply by a permutation matrix is a column permutation and softmax is permutation equivariant. Notice that for any permutation $\pi$ we have 
$$
\begin{align*}
\text{softmax}\left(\pi(x)\right)& = \text{softmax}\left(x_{\pi(1)},x_{\pi(2)},\dots,x_{\pi(n)}\right),\\
&=\frac{\left(e^{x_{\pi(1)}},e^{x_{\pi(2)}},\dots,e^{x_{\pi(n)}}\right)}{\sum_{i=1}^n e^{x_{\pi(i)}}},\\
&=\frac{\left(e^{x_{\pi(1)}},e^{x_{\pi(2)}},\dots,e^{x_{\pi(n)}}\right)}{\sum_{i=1}^n e^{x_{i}}},\\
&=\frac{\pi\left(e^{x_{1}},e^{x_{2}},\dots,e^{x_{n}}\right)}{\sum_{i=1}^n e^{x_{i}}},\\
&=\pi\left(\frac{\left(e^{x_{1}},e^{x_{2}},\dots,e^{x_{n}}\right)}{\sum_{i=1}^n e^{x_{i}}}\right),\\
&=\pi\left(\text{softmax}(x)\right),\\
\end{align*}
$$
and $\text{softmax}$ is applied to each row. So if $A = [a_1|a_2|\dots|a_n]^T = \frac{QK^T}{\sqrt{d_k}}$, $f=\text{softmax}$, and a column permutation $\pi_c(A) = [\pi(a_1)|\pi(a_2)|\dots|\pi(a_n)]^T = AC_{\pi}$, where  $C_{\pi}$ is a permutation matrix, we have
$$
\begin{align*}
f(AC_{\pi}) & = f([\pi(a_1)|\pi(a_2)|\dots|\pi(a_n)]^T),\\
 & = ([f(\pi(a_1))|f(\pi(a_2))|\dots|f(\pi(a_n))]^T),\\
& = ([\pi(f(a_1))|\pi(f(a_2))|\dots|\pi(f(a_n))]^T),\\
& = \pi_c\left([f(a_1)|f(a_2)|\dots|f(a_n)]^T\right),\\
& = \pi_c\left(f(A)\right),\\
& =f(A)C_{\pi}.
\end{align*}
$$

### Property: $\operatorname{Attention}$ Permutation Equivariance
If $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\text{Attention}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\text{Attention}(Q,K,V)\right). $$


**Proof:**
Consider  $\pi_r(M) = R_{\pi}M$ for some permutation matrix $R_{\pi}$. From the previous result we have
$$
\begin{align*}
\text{Attention}(\pi_r(Q),\sigma_r(K),\sigma_r(V)) & = \text{Attention}(\pi_r(Q),K,V),\\
&= \text{Attention}(R_{\pi}Q,K,V),\\
& = \text{softmax}\left(\frac{(R_{\pi}Q)K^T}{\sqrt{d_k}}\right)V,\\
& = \text{softmax}\left(R_{\pi}\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = R_{\pi}\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\\
& = \pi_r\left(\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\right),\\
& = \pi_r\left(\text{Attention}(Q,K,V) \right),
\end{align*}
$$
where in the fifth equality we have used that the softmax funtion is applied rowwise.

## Definition: Masked Attention ($\text{MAttn}_{B}$)

Masked attention intentionally breaks permutation equivariance so that order matters. In many tasks, like language modeling or temporal prediction, we don’t want tokens to freely attend to all others. Masking restricts attention to valid positions (e.g., past tokens), enforcing causal or directional structure instead of treating the sequence as an unordered set.
$$ \text{MAttn}_{B}(Q,K,V) = \text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right) V$$

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d_k}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times d_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times d_v}$

Output: $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M \times d_v}$
$$ \text{MAttn}_{B}(Q,K,V) = O  = \alpha V, \quad \text{ where } \quad \alpha = \text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right) \in \mathbb{R}^{M\times N}$$
where $B\in\mathbb{R}^{M\times N}$, and $\text{softmax}$ is applied to each row, so each row of $\alpha$ sums to one. 

**Note:** In general, the mask breaks the permutation equivariance properties of attention. Notice also that the $\operatorname{Attention}$ function definition now depends on the sequence lengths $M$ and $N$.

### Property: $\text{MAtten}_B$ Querry-Deletion
If $A^{(i)}$ denotes the matrix obtained by deleting the $i$-th row of the matrix $A$, then 
$$
\begin{align*}
\text{MAttn}_{B}(Q,K,V)^{(i)}=\text{MAttn}_{B^{(i)}}(Q^{(i)},K,V)
\end{align*}
$$

**Proof:**
Notice that $A^{(i)}=I^{(i)}A$, where $I$ is the identity matrix. Using this, we have
$$
\begin{align*}
\text{MAttn}_{B^{(i)}}(Q^{(i)},K,V) & = \text{MAttn}_{I^{(i)}B}(I^{(i)}Q,K,V),\\
& = \text{softmax}\left(\frac{(I^{(i)}Q)K^T+I^{(i)}B}{\sqrt{d_k}}\right)V,\\
& = \text{softmax}\left(I^{(i)}\frac{QK^T+B}{\sqrt{d_k}}\right)V,\\
& = I^{(i)}\text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right)V,\\
& = \text{MAttn}_{B}(Q,K,V)^{(i)} ,
\end{align*}
$$
where in the fourth equality we have used that the softmax funtion is applied rowwise.

### Property: The entries of $B$ in $\text{MAttn}_{B}$ can selectively suppress attention.

If $b_{i,s} \to -\infty$, then the $i$-th output $o_i$ does not depend on $k_s$ or $v_s$. In particular, we have the formula:
$$ 
\begin{align*}
o_i \to \frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}} v_j.
\end{align*}
$$

**Note:** Remeber that $o_i$ does not depend on any $q_j$ with $j \neq i$.

**Proof:** If $B_{i,s}\to -\infty$ for some $(i,s)\in M\times N$ then
$$
\begin{align*}
\alpha_{i,j} &= \text{softmax}\left(\frac{QK^T+B}{\sqrt{d_k}}\right)_{i,j} =\frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}},\\
&\to\begin{cases}
\frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}},&\quad j\neq s\\
0 ,&\quad j = s.
\end{cases}
\end{align*}
$$
Then, we get
$$ 
\begin{align*}
o_i &= \sum_{j=1}^N\alpha_{i,j} v_j\to\sum_{j=1,j\neq s}^N\alpha_{i,j} v_j,\\
&=\sum_{j=1,j\neq s}^N\frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}} v_j.
\end{align*}
$$
Therefore, we conclude that $o_i$ does not depend on $k_s$ or $v_s$

### Property: $\text{MAtten}_B$ Key-Value deletion

If $A^{(i)}$ and $A^{[i]}$ denote the matrix obtained by deleting the $i$-th row and column resectively, and $B_{i,s}\to-\infty$ for $i=1,2,\dots,M$ then 
$$
\begin{align*}
\text{MAttn}_{B}(Q,K,V) \to \text{MAttn}_{B^{[s]}}(Q,K^{(s)},V^{(s)}), \quad Q \in \mathbb{R}^{M\times d_k}
\end{align*}
$$




**Proof:** From the previous property, for $i=1,2,\dots,M$ we have
$$ 
\begin{align*}
o_i & \to\sum_{j=1,j\neq s}^N \frac{e^{\frac{q_i^{T}k_j+b_{i,j}}{\sqrt{d_k}}}}{\sum_{l=1,l\neq s}^Ne^{\frac{q_i^{T}k_l+b_{i,l}}{\sqrt{d_k}}}} v_j,\\
&=\sum_{j=1,j\neq s}^N \text{Softmax}\left(\frac{\left(QK^T+B\right)^{[s]}}{\sqrt{d_k}}\right)_{i,j}v_j,\\
\end{align*}
$$
Appliying this result to every row, we have
$$ 
\begin{align*}
O & \to \text{Softmax}\left(\frac{\left(QK^T+B\right)^{[s]}}{\sqrt{d_k}}\right)V^{(s)},\\
&=\text{Softmax}\left(\frac{Q(K^T)^{[s]}+B^{[s]}}{\sqrt{d_k}}\right)V^{(s)},\\
&=\text{Softmax}\left(\frac{Q(K^{(s)})^T+B^{[s]}}{\sqrt{d_k}}\right)V^{(s)},\\
&=\text{MAttn}_{B^{[s]}}(Q,K^{(s)},V^{(s)}).
\end{align*}
$$


## Definition: Attention with weights($\text{Attn}_{\mathcal{W}}$)

$$\text{Attn}_{\mathcal{W}}(Q,K,V)=\text{Attention}(QW_q+1_Mb^T_q,KW_k+1_Mb^T_k,VW_v+1_Mb^T_v)$$


Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}, \quad M\in\mathbb{N}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}, \quad N\in\mathbb{N}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}, \quad N\in\mathbb{N}$.

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c_q\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c_k \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c_v \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{M\times d_v}$
\begin{align*}
O =\text{Attn}_{\mathcal{W}}(Q,K,V)=\text{Attention}(QW_q+1_Mb^T_q,KW_k+1_Mb^T_k,VW_v+1_Mb^T_v),
\end{align*}
where

\begin{align*}
\mathcal{W} &= \left(W_q,b_q,W_k,b_k,W_v,b_v\right),\\
1_M &= (1,1,\dots,1)\in \mathbb{R}^M.
\end{align*}

**Note:** The same $\operatorname{Attention}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{Attention}$ is **position-agnostic**.

**Note:** From the $\text{Attention}$ definition we have 

\begin{align*}
o_i = \frac{1}{\sum_{l=1}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}

where
$$ 
\begin{align*}
\hat{q}_i&=W_q^Tq_i+b_q,\\
\hat{k}_i&=W_k^Tk_i+b_k,\\
\hat{v}_i&=W_v^Tv_i+b_v,\\
\end{align*}
$$ 
So, the output $o_i$ does not depends on $q_j$ for any $j\neq i$.

### Property: $\text{Attn}_{\mathcal{W}}$ Querry-Deletion Equivariance

If $A^{(i)}$ denotes the matrix obtained by deleting the $i$-th row of the matrix $A$, then  
$$\text{Attn}_{\mathcal{W}}(Q,K,V)^{(i)}=\text{Attn}_{\mathcal{W}}(Q^{(i)},K,V). $$

**Proof:**
$$
\begin{align*}
\operatorname{Attn}_{\mathcal{W}}(Q^{(i)},K,V)
&=\text{Attention}(Q^{(i)}W_q+1_{M-1}b_q^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&=\text{Attention}((QW_q)^{(i)}+(1_{M}b_q^T)^{(i)},KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&=\text{Attention}((QW_q+1_{M}b_q^T)^{(i)},KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&=\text{Attention}(QW_q+1_{M}b_q^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T)^{(i)},\\
&=\operatorname{Attn}_{\mathcal{W}}(Q,K,V)^{(i)}.
\end{align*}
$$

### Property: $\text{Attn}_{\mathcal{W}}$ Permutation Equivariance

If Attention with weights has no bias (projections are linear instead of affine), the Attention is permutation equivariant. This means, if $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\operatorname{Attn}_{\mathcal{W}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\operatorname{Attn}_{\mathcal{W}}(Q,K,V)\right) $$
where 
$$
\begin{align*}
\mathcal{W} &= \left(W_q,0_{\mathbb{R}^{d_k}},W_k,0_{\mathbb{R}^{d_k}},W_v,0_{\mathbb{R}^{d_v}}\right)
\end{align*}
$$

Proof: Consider  $\pi_r(M) = R_{\pi}M$ and  $\sigma_r(M) = R_{\sigma}M$ for some permutation matrix $R_{\pi}$ and $R_{\sigma}$ respectively. From the previous result we have
$$
\begin{align*}
\operatorname{Attn}_{\mathcal{W}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))
&=\text{Attention}(\pi_r(Q)W_q,\sigma_r(K)W_k,\sigma_r(V)W_v),\\
&=\text{Attention}((R_{\pi}Q)W_q,(R_{\sigma}K)W_k,(R_{\sigma}V)W_v),\\
&=\text{Attention}(R_{\pi}(QW_q),R_{\sigma}(KW_k),R_{\sigma}(VW_v)),\\
&=\text{Attention}(\pi_r(QW_q),\sigma_r(KW_k),\sigma_r(VW_v)),\\
&=\pi_r\left(\text{Attention}(QW_q,KW_k,VW_v)\right),\\
&=\pi_r\left(\operatorname{Attn}_{\mathcal{W}}(Q,K,V)\right).
\end{align*}
$$

## Definition: Masked Attention with weights ($\text{MAttn}_{\mathcal{W}}$)
$$\text{MAttn}_{\mathcal{W}}(Q,K,V)=\text{MAttn}_B(QW_q+1_Mb_k^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T)$$

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$.

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c_q\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c_k \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c_v \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$

* a mask matrix $B\in \mathbb{R}^{M\times N}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{M\times d_v}$
\begin{align*}
O =\text{MAttn}_{\mathcal{W}}(Q,K,V)=\text{MAttn}_B(QW_q+1_Mb_k^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T)
\end{align*}
where
\begin{align*}
1_M &= (1,1,\dots,1) \in \mathbb{R}^{M},\\ 
\mathcal{W} &= \left(W_q,b_q,W_k,b_k,W_v,b_v,B\right).
\end{align*}

**Note:** The mask breaks the permutation-equivariance property of attention. In general, attention is not even position-agnostic, since the entries of $B$ depend on the input sequence lengths $M$ and $N$. Although this limitation can be mitigated by imposing specific structures on $B$, the masked version remains inherently order-dependent.

### Property: $\text{MAttn}_{\mathcal{W}}$ Querry-Deletion
If $A^{(i)}$ denotes the matrix obtained by deleting the $i$-th row of the matrix $A$, then 
$$
\begin{align*}
\text{MAttn}_{\mathcal{W}}(Q,K,V)^{(i)}=\text{MAttn}_{{\mathcal{W}}^{(i)}}(Q^{(i)},K,V)
\end{align*}
$$
where
\begin{align*}
\mathcal{W}^{(i)} &= \left(W_q,b_q,W_k,b_k,W_v,b_v,B^{(i)}\right).
\end{align*}

**Proof:**
$$
\begin{align*}
\operatorname{MAttn}_{\mathcal{W}}(Q^{(i)},K,V)
&=\text{MAttn}_{B}(Q^{(i)}W_q+1_{M-1}b_q^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&=\text{MAttn}_{B}((QW_q)^{(i)}+(1_{M}b_q^T)^{(i)},KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&=\text{MAttn}_{B}((QW_q+1_{M}b_q^T)^{(i)},KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&=\text{MAttn}_{B^{(i)}}(QW_q+1_{M}b_q^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T)^{(i)},\\
&=\operatorname{MAttn}_{\mathcal{W}^{(i)}}(Q,K,V)^{(i)}.
\end{align*}
$$

### Property: The entries of $B$ in $\text{MAttn}_{\mathcal{W}}$ can selectively suppress attention in the Attention with weights.

If $b_{i,s} \to -\infty$, then the $i$-th output $o_i$ does not depend on $k_s$ or $v_s$. In particular, we have
\begin{align*}
o_i \to \frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}
where

$$ 
\begin{align*}
\hat{q}_i&=W_q^Tq_i+b_q,\\
\hat{k}_i&=W_k^Tk_i+b_k,\\
\hat{v}_i&=W_v^Tv_i+b_v,\\
\end{align*}
$$

**Proof:** If $b_{i,s}\to -\infty$ for some $(i,s)\in M\times N$, we can apply the result for the Attention function to obtain 
$$ 
\begin{align*}
o_i \to \frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}
$$
where the vectors $\hat{q}_i$, $\hat{k}_i$ and $\hat{v}_i$ are the $i$-th rows of $QW_q+1_Mb^T_q$, $KW_k+1_Mb^T_k$ and $VW_v+1_Mb^T_v$ repectively. So we have

$$ 
\begin{align*}
\hat{q}^T_i&=q_i^TW_q+b_q^T,\\
\hat{k}^T_i&=k_i^TW_k+b_k^T,\\
\hat{v}^T_i&=v_i^TW_v+b_v^T.
\end{align*}
$$

Taking matrix tranpose in the previous equations we get the result.

### Property: $\text{MAtten}_{\mathcal{W}}$ Key-Value deletion

If $A^{(i)}$ and $A^{[i]}$ denote the matrix obtained by deleting the $i$-th row and column resectively, and $B_{i,s}\to-\infty$ for $i=1,2,\dots,M$ then 
$$
\begin{align*}
\text{MAtten}_{\mathcal{W}}(Q,K,V)\to \text{MAtten}_{\mathcal{W}^{[s]}}(Q,K^{(s)},V^{(s)}), \quad Q \in \mathbb{R}^{M\times d_k}
\end{align*}
$$

**Proof:** From the $\text{MAtten}_{B}$ Key-Value deletion property we get
$$
\begin{align*}
\operatorname{MAttn}_{\mathcal{W}}(Q,K,V)
&=\text{MAttn}_{B}(QW_q+1_{M}b_q^T,KW_k+1_Mb_k^T,VW_v+1_Mb_v^T),\\
&\to\text{MAttn}_{B^{[s]}}(QW_q+1_{M}b_q^T,(KW_k+1_Mb_k^T)^{(s)},(VW_v+1_Mb_v^T)^{(s)}),\\
&=\text{MAttn}_{B^{[s]}}(QW_q+1_{M}b_q^T,(KW_k)^{(s)}+(1_Mb_k^T)^{(s)},(VW_v)^{(s)}+(1_Mb_v^T)^{(s)}),\\
&=\text{MAttn}_{B^{[s]}}(QW_q+1_{M}b_q^T,K^{(s)}W_k+1_{M-1}b_k^T,V^{(s)}W_v+1_{M-1}b_v^T),\\
&=\operatorname{MAttn}_{\mathcal{W}^{[s]}}(Q,K^{(s)},V^{(s)}).
\end{align*}
$$

## Definition: Self-Attention ( $\text{SAttn}_{\mathcal{W}}$)
$$\text{SAttn}_{\mathcal{W}}(X)=\operatorname{Attn}_{\mathcal{W}}(X,X,X)$$

Inputs
* $X \in \mathbb{R}^{N \times c}, \quad N\in \mathbb{N}$ 

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{N\times d_v}$
\begin{align*}
O =\text{SAttn}_{\mathcal{W}}(X)=\operatorname{Attn}_{\mathcal{W}}(X,X,X),
\end{align*}
where
\begin{align*}
\mathcal{W} &= \left(W_q,b_q,W_k,b_k,W_v,b_v\right).
\end{align*}

### Property: $\text{SAttn}_{\mathcal{W}}$ Permutation Equivariance

If Self-Attention has no bias (projections are linear instead of affine), the Self-Attention is permutation equivariant. If $\pi_r$ denotes an arbitrary permutation over the rows of a matrix, then
$$\text{SAttn}_{\mathcal{W}}(\pi_r(X))=\pi_r\left(\text{SAttn}_{\mathcal{W}}(X)\right), $$
where 
$$
\begin{align*}
\mathcal{W} &= \left(W_q,0_{\mathbb{R}^{d_k}},W_k,0_{\mathbb{R}^{d_k}},W_v,0_{\mathbb{R}^{d_v}}\right).
\end{align*}
$$

Proof: From the previous result we have
$$
\begin{align*}
\text{SAttn}_{\mathcal{W}}(\pi_r(X))
&=\operatorname{Attn}_{\mathcal{W}}(\pi_r(X),\pi_r(X),\pi_r(X)),\\
&=\pi_r\left(\operatorname{Attn}_{\mathcal{W}}(X,X,X)\right),\\
&=\pi_r\left(\text{SAttn}_{\mathcal{W}}(X)\right).
\end{align*}
$$

## Definition: Masked Self-Attention ($\text{MSAttn}_{\mathcal{W}}$)
$$\text{MSAttn}_{\mathcal{W}}(X)=\text{MAttn}_{\mathcal{W}}(X,X,X)$$

Inputs
* $X = [x_1|x_2|\dots|x_N]^T \in \mathbb{R}^{N \times c}$ 

Weights
* a set of querry weights $W_q \in \mathbb{R}^{c\times d_k}$ and bias $b_q\in \mathbb{R}^{d_k}$
* a set of key weights $W_k \in \mathbb{R}^{c \times d_k}$ and bias $b_k\in \mathbb{R}^{d_k}$
* a set of value weights $W_v \in \mathbb{R}^{c \times d_v}$ and bias $b_q\in \mathbb{R}^{d_v}$
* a mask matrix $B\in \mathbb{R}^{N\times N}.$ 

Output: 
* Output $O= [o_1|o_2|\dots|o_N]^T \in\mathbb{R}^{N\times d_v}$
\begin{align*}
O =\text{MSAttn}_{\mathcal{W}}(X)=\text{MAttn}_{\mathcal{W}}(X,X,X),
\end{align*}
where
\begin{align*}
\mathcal{W} &= \left(W_q,b_q,W_k,b_k,W_v,b_v,B\right).
\end{align*}

**Note:** In general, the mask breaks the permutation equivariance properties of Self-Attention and the Self-Attention is not even position agnostic anymore, notice that the weights in $B$ depend on input sequence legth $N$. 

### Property: The entries of $B$ in $\text{MSAttn}_{\mathcal{W}}$ can selectively suppress attention.

If $b_{i,s} \to -\infty$ for $s \neq i$, then the $i$-th output $o_i$ does not depend on $x_s$. In particular, we have
\begin{align*}
o_i =\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_l+b_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{\hat{q}_i^{T}\hat{k}_j+b_{i,j}}{\sqrt{d_k}}} \hat{v}_j.
\end{align*}
where

$$ 
\begin{align*}
\hat{q}_l&=W_q^Tx_l+b_q,\\
\hat{k}_l&=W_k^Tx_l+b_k,\\
\hat{v}_l&=W_v^Tx_l+b_v,\\
\end{align*}
$$

**Proof:** It is straigthforwrd from the same result for Attention with weights.

### Property: $\text{MSAttn}_{\mathcal{W}}$ Input-deletion
If $A^{<i>}$ denote the matrix obtained by deleting the $i$-th row and column of the $A$, and $B_{i,s}\to-\infty$ for $i=1,2,\dots,s-1,s+1,\dots,N$ then 
$$
\begin{align*}
\text{MSAttn}_{\mathcal{W}}(X)^{(s)} \to \text{MSAttn}_{\mathcal{W}^{<s>}}(X^{(s)}).
\end{align*}
$$
where
\begin{align*}
\mathcal{W^{<s>}}_0 &=\mathcal{W}_0 = \left(W_{o},b_{o}\right),\\
\mathcal{W}^{<s>}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B^{<s>}_i\right), \quad i = 1,2,\dots,h.
\end{align*}

**Proof:** Combining $\text{MAttn}_{\mathcal{W}}$ Querry and Key-Value deletion properties we obtain
$$
\begin{align*}
\text{MSAttn}_{\mathcal{W}}(X)^{(s)}&=\text{MAttn}_{\mathcal{W}}(X,X,X)^{(s)},\\
&=\text{MAttn}_{\mathcal{W}^{(s)}}(X^{(s)},X,X),\\
&\to\text{MAttn}_{\mathcal{W}^{<s>}}(X^{(s)},X^{(s)},X^{(s)}),\\
&=\text{MSAttn}_{\mathcal{W}^{<s>}}(X^{(s)})
\end{align*}
$$

## Definition: Multi-Head Attention ($\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$)
$$
\begin{align*}
\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}&(Q,K,V)=\\
&=\begin{pmatrix}
\operatorname{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
\operatorname{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\operatorname{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
$$

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}, \quad M\in\mathbb{N}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}, \quad N\in\mathbb{N}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}, \quad N\in\mathbb{N}$

Weights
* $W_{q,i} \in \mathbb{R}^{c_q \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{hd_v \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O  &= \text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\operatorname{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
\operatorname{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\operatorname{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + 1_{M}b_o
\end{align*}
where


\begin{align*}
\mathcal{W}_0 &= \left(W_{o},b_{o}\right),\\
\mathcal{W}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\right), \quad i = 1,2,\dots,h.
\end{align*}

Usually we have:
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = c_k = c_v = d_h\cdot h = d_o$ (model dimension)  
**Note:** In this case, internally we have:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{c_q\times hd_k}=\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{c_k\times hd_k}=\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{c_v\times hd_v}=\mathbb{R}^{d\times d}$  

**Note:** The same $\operatorname{MHAttn}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{MHAttn}$ is **position-agnostic**.

### Property: $\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$ Querry-Deletion
If $A^{(i)}$ denotes the matrix obtained by deleting the $i$-th row of the matrix $A$, then 
$$
\begin{align*}
\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V)^{(s)}=\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q^{(s)},K,V)
\end{align*}
$$
where $s=1,2,\dots,M$.

**Proof:**
From $\text{Attn}$ Querry-Deletion we get
$$
\begin{align*}
\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}&(Q,K,V)^{(i)}\\
&=\left(\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o+1_Mb_o\right)^{(i)},\\
&=\left(\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o\right)^{(i)}+\left(1_Mb_o\right)^{(i)},\\
&=\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}^{(i)}W_o+1_{M-1}b_o,\\
&=\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)^{(i)}|
\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)^{(i)}|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)^{(i)}
\end{pmatrix}W_o+1_{M-1}b_o,\\
&=\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(Q^{(i)},K,V)|
\text{Attn}_{\mathcal{W}_{2}}(Q^{(i)},K,V)|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(Q^{(i)},K,V)
\end{pmatrix}W_o+1_{M-1}b_o,\\
&=\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q^{(i)},K,V)
\end{align*}
$$

### Property: Multi-Head Attention Permutation Equivariance

If Multi-Head Attention has no bias (projections are linear  instead of affine), the Multi-Head Attention is permutation equivariant. If $\pi_r$ and $\sigma_r$ denote arbitrary permutations over the rows of a matrix, then
$$\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(\pi_r(Q),\sigma_r(K),\sigma_r(V))=\pi_r\left(\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V)\right) $$

**Proof:**  Consider  $\pi_r(M) = R_{\pi}M$ and $\sigma_r(M) = R_{\sigma}M$ for permutation matrices $R_{\pi}$ and $R_{\sigma}$ respectively. From the Attention equivariance we have
$$
\begin{align*}
&\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(\pi_r(Q),\sigma_r(K),\sigma_r(V))\\
&=\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))|
\text{Attn}_{\mathcal{W}_{2}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(\pi_r(Q),\sigma_r(K),\sigma_r(V))
\end{pmatrix}W_o,\\
&=\begin{pmatrix}
\pi_r\left(\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)\right)|
\pi_r\left(\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)\right)|
\dots|
\pi_r\left(\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)\right)
\end{pmatrix}W_o,\\
&=\begin{pmatrix}
R_{\pi}\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
R_{\pi}\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
R_{\pi}\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o,\\
&=R_{\pi}\begin{pmatrix}
\text{Attn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{Attn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{Attn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o,\\
&=R_{\pi}\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\pi_r\left(\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V)\right).
\end{align*}
$$

## Definition: Multi-Head Attention (torch implementation)

Torch implementation corresponds to the particular (usual) case
 * $d_h:=d_k=d_v$ (head dimensions)
 * $d := c_q = d_h\cdot h = d_o$ (model dimension)  


Denote $d_h = d|h$ ($h$ must divide $d$). Then we have the simplified defintion

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times d}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{d \times d_h}$ and $b_{q,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_h}$ and $b_{k,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_h}$ and $b_{v,i}\in \mathbb{R}^{d_h}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d \times d }$ and $b_{o}\in \mathbb{R}^{d}$.


Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d}$
\begin{align*}
O  &= \text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\operatorname{Attn}_{\mathcal{W}_1}(Q,K,V)|
\operatorname{Attn}_{\mathcal{W}_2}(Q,K,V)|
\dots|
\operatorname{Attn}_{\mathcal{W}_h}(Q,K,V)
\end{pmatrix}W_o + B_o
\end{align*}
where:
    * $W_q=[W_{q,1}|W_{q,2}|\dots|W_{q,h}]\in\mathbb{R}^{d\times d}$  
    * $W_k=[W_{k,1}|W_{k,2}|\dots|W_{k,h}]\in\mathbb{R}^{d\times d}$  
    * $W_v=[W_{v,1}|W_{v,2}|\dots|W_{v,h}]\in\mathbb{R}^{d\times d}$  
    * $b_q=[b^T_{q,1}|b^T_{q,2}|\dots|b^T_{q,h}]^T\in\mathbb{R}^{d}$  
    * $b_k=[b^T_{k,1}|b^T_{k,2}|\dots|b^T_{k,h}]^T\in\mathbb{R}^{d}$ 
    * $b_v=[b^T_{v,1}|b^T_{v,2}|\dots|b^T_{v,h}]^T\in\mathbb{R}^{d}$  
    * $B_o= [b_o|b_o|\dots|b_o]^T \in \mathbb{R}^{M \times d}$
    * $\mathcal{W}_0 = \left(W_{o},b_{o}\right)$
    * $\mathcal{W}_i =\left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\right), \quad i = 1,2,\dots,h$.

## Definition: Masked Multi-Head Attention ($ \text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$)
$$
\begin{align*}
\text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}&(Q,K,V),\\
&=\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + 1_Mb_o
\end{align*}
$$

Inputs
* a set of queries $Q = [q_1|q_2|\dots|q_M]^T \in \mathbb{R}^{M \times c_q}$
* a set of keys $K = [k_1|k_2|\dots|k_N]^T \in \mathbb{R}^{N \times c_k}$
* a set of values $V = [v_1|v_2|\dots|v_N]^T \in \mathbb{R}^{N \times c_v}$

Weights
* $W_{q,i} \in \mathbb{R}^{c_q \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c_k \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c_v \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $B_{i} \in \mathbb{R}^{M \times N}, \quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
$$
\begin{align*}
O  &= \text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V),\\
&=\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + 1_Mb_o
\end{align*}
$$
where

$$
\begin{align*}
\mathcal{W}_0 &= \left(W_{o},b_{o}\right),\\
\mathcal{W}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B_i\right), \quad i = 1,2,\dots,h.
\end{align*}
$$

### Property: $\text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$ Querry-Deletion
If $A^{(i)}$ denotes the matrix obtained by deleting the $i$-th row of the matrix $A$, then 
$$
\begin{align*}
\text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(Q,K,V)^{(s)}=\text{MMHAttn}_{\{\mathcal{W}^{(s)}_{i}\}_{i=0}^h}(Q^{(s)},K,V)
\end{align*}
$$
where $s=1,2,\dots,M$ and 
\begin{align*}
\mathcal{W}^{(s)}_0 &=\mathcal{W}_0= \left(W_{o},b_{o}\right),\\
\mathcal{W}^{(s)}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B^{(s)}_i\right), \quad i = 1,2,\dots,h.
\end{align*}

**Proof:**
From $\text{MAttn}$ Querry-Deletion we get
$$
\begin{align*}
\text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}&(Q,K,V)^{(s)}\\
&=\left(\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o+1_Mb_o\right)^{(s)},\\
&=\left(\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o\right)^{(s)}+\left(1_Mb_o\right)^{(s)},\\
&=\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}^{(s)}W_o+1_{M-1}b_o,\\
&=\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)^{(s)}|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)^{(s)}|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)^{(i)}
\end{pmatrix}W_o+1_{M-1}b_o,\\
&=\begin{pmatrix}
\text{MAttn}_{\mathcal{W}^{(s)}_{1}}(Q^{(s)},K,V)|
\text{MAttn}_{\mathcal{W}^{(s)}_{2}}(Q^{(s)},K,V)|
\dots|
\text{MAttn}_{\mathcal{W}^{(s)}_{h}}(Q^{(s)},K,V)
\end{pmatrix}W_o+1_{M-1}b_o,\\
&=\text{MHAttn}_{\{\mathcal{W}^{(s)}_{i}\}_{i=0}^h}(Q^{(s)},K,V)
\end{align*}
$$

### Property: The entries of the masks $B_i$ (when they agree) can selectively suppress attention in $\text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$.

Denote the $(i,s) \in M\times N$ entry of the $l$-th mask matrix by $b^{(l)}_{i,s}$. If $b_{l,i,s} \to -\infty$ for all values $l=1,2,\dots,h$, then the $i$-th output $o_i$ of the masked Multi-Head Attention does not depend on $k_s$ or $v_s$. Particulary, we have:
$$
\begin{align*}
o_i &=W_o^T\begin{pmatrix}
o_{i}^{(1)}\\
o_{i}^{(2)}\\
\vdots\\
o_{i}^{(h)}
\end{pmatrix}
+ B^T_o\\
o_i^{(\kappa)} &=\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_l^{(\kappa)}+b^{(\kappa)}_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_j^{(\kappa)}+b^{(\kappa)}_{i,j}}{\sqrt{d_k}}} \hat{v}_j^{(\kappa)}, \quad \kappa = 1,2,\dots,h.
\end{align*}
$$
where

$$ 
\begin{align*}
\hat{q}^{(\kappa)}_l&=W_{q,\kappa}^Tq_l+b_q, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{k}^{(\kappa)}_l&=W_{k,\kappa}^Tk_l+b_k, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{v}^{(\kappa)}_l&=W_{v,\kappa}^Tv_l+b_v, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\end{align*}
$$

**Proof:** It is strightforward from hte same property for $\text{MAttn}_{\mathcal{W}}$.

### Property: $\text{MMHAtten}_{\{\mathcal{W}\}_{i=0}^{h}}$ Key-Value deletion

If $A^{(i)}$ and $A^{[i]}$ denote the matrix obtained by deleting the $i$-th row and column resectively, and Denote the $(i,s) \in M\times N$ entry of the $l$-th mask matrix by $B^{(l)}_{i,s}$.
If $B^{(l)}_{i,s}\to-\infty$ for $i=1,2,\dots,M$ and $l=1,2,\dots,h$ then 
$$
\begin{align*}
\text{MMHAtten}_{{\{\mathcal{W}\}_{i=0}^{h}}}(Q,K,V)\to \text{MMHAtten}_{\{\mathcal{W}^{[s]}\}_{i=0}^{h}}(Q,K^{(s)},V^{(s)}), \quad Q \in \mathbb{R}^{M\times d_k}
\end{align*}
$$

where $s=1,2,\dots,M$ and 
\begin{align*}
\mathcal{W}^{[s]}_0 &=\mathcal{W}_0= \left(W_{o},b_{o}\right),\\
\mathcal{W}^{[s]}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B^{[s]}_i\right), \quad i = 1,2,\dots,h.
\end{align*}

**Proof:** From the $\text{MAtten}_{B}$ Key-Value deletion property we get
$$
\begin{align*}
\operatorname{MMHAttn}_{{{\{\mathcal{W}\}_{i=0}^{h}}}}(Q,K,V)
&=\begin{pmatrix}
\text{MAttn}_{\mathcal{W}_{1}}(Q,K,V)|
\text{MAttn}_{\mathcal{W}_{2}}(Q,K,V)|
\dots|
\text{MAttn}_{\mathcal{W}_{h}}(Q,K,V)
\end{pmatrix}W_o + 1_Mb_o,\\
&\to \begin{pmatrix}
\text{MAttn}_{\mathcal{W}^{[s]}
_{1}}(Q,K^{(s)},V^{(s)})|
\dots|
\text{MAttn}_{\mathcal{W}_{h}^{[s]}}(Q,K^{(s)},V^{(s)})
\end{pmatrix}W_o + 1_Mb_o,\\
&=\text{MMHAtten}_{\{\mathcal{W}^{[i]}\}_{i=0}^{h}}(Q,K^{(s)},V^{(s)}).
\end{align*}
$$

## Definition: Multi-Head Self-Attention ($\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$)
$$\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X)$$

This is usually denoted as $\text{MHAttn}(X)$ with only one argument.

Inputs
* $X \in \mathbb{R}^{N \times c}, \quad N \in \mathbb{N}.$ 

Weights
* $W_{q,i} \in \mathbb{R}^{c \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $W_{o} \in \mathbb{R}^{d_vh \times d_o }$ and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O = \text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X).
\end{align*}
where
\begin{align*}
\mathcal{W}_0 &= \left(W_{o},b_{o}\right),\\
\mathcal{W}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\right), \quad i = 1,2,\dots,h.
\end{align*}

**Note:** The same $\operatorname{MHSAttn}$ function can be applied to inputs of different sequence lengths. The model parameters are not tied to specific positions. In this sense, $\operatorname{MultiHeadHeadAttention}$ is **position-agnostic**.

### Property: $\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$ Permutation Equivariance

If Multi-Head-Self-Attention has no bias (linear projections instead of affine projections), the Multi-Head-Self-Attention is permutation equivariant. If $\pi_r(M)$ denotes an arbitrary permutation over the rows of a matrix $M$, then
$$\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(\pi_r(X))=\pi_r\left(\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)\right),$$


Proof: From the previous result we have
$$
\begin{align*}
\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(\pi_r(X))
&=\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(\pi_r(X),\pi_r(X),\pi_r(X)),\\
&=\pi_r\left(\text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X)\right),\\
&=\pi_r\left(\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)\right).
\end{align*}
$$

 ## Definition: Masked Multi-Head Self-Attention ($\text{MMHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$)
 $$\text{MMHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MMHA}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X)$$

This is usually denoted as $\text{MHAttn}(X)$ with only one argument.

Inputs
* $X = [x_1|x_2|\dots|x_N]^T\in  \mathbb{R}^{N \times c}.$ 

Weights
* $W_{q,i} \in \mathbb{R}^{c \times d_k}$ and $b_{q,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{k,i} \in \mathbb{R}^{c \times d_k}$ and $b_{k,i}\in \mathbb{R}^{d_k}$, $\quad i=1,2,\dots,h.$
* $W_{v,i} \in \mathbb{R}^{c \times d_v}$ and $b_{v,i}\in \mathbb{R}^{d_v}$, $\quad i=1,2,\dots,h.$
* $B_{i} \in \mathbb{R}^{N \times N}, \quad i=1,2,\dots,h.$
* $W_{o} \inç and $b_{o}\in \mathbb{R}^{d_0}$.

Output: 
* Output $O= [o_1|o_2|\dots|o_M]^T \in\mathbb{R}^{M\times d_o}$
\begin{align*}
O = \text{MMHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)= \text{MMHA}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X).
\end{align*}
where
\begin{align*}
\mathcal{W}_0 &= \left(W_{o},b_{o}\right),\\
\mathcal{W}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B_i\right), \quad i = 1,2,\dots,h.
\end{align*}

### Property: The entries of the masks $B_i$ (when they agree) can selectively suppress attention in the Multi-Head Self-Attention.

Denote the $(i,s) \in M\times N$ entry of the $l$-th mask matrix by $b^{(l)}_{i,s}$. If $b_{l,i,s} \to -\infty$ for fixed values $s \neq i$ and for all values $l=1,2,\dots,h$, then the $i$-th output $o_i$ of the masked Multi-Head Self-Attention does not depend on $x_s$. Particulary, we have:
$$
\begin{align*}
o_i &=W_o^T\begin{pmatrix}
o_{i}^{(1)}\\
o_{i}^{(2)}\\
\vdots\\
o_{i}^{(h)}
\end{pmatrix}
+ B^T_o\\
o_i^{(\kappa)} &=\frac{1}{\sum_{l=1,l\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_l^{(\kappa)}+b^{(\kappa)}_{i,l}}{\sqrt{d_k}}}}\sum_{j=1,j\neq s}^Ne^{\frac{(\hat{q}_i^{(\kappa)})^{T}\hat{k}_j^{(\kappa)}+b^{(\kappa)}_{i,j}}{\sqrt{d_k}}} \hat{v}_j^{(\kappa)}, \quad \kappa = 1,2,\dots,h.
\end{align*}
$$
where

$$ 
\begin{align*}
\hat{q}^{(\kappa)}_l&=W_{q,\kappa}^Tx_l+b_q, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{k}^{(\kappa)}_l&=W_{k,\kappa}^Tx_l+b_k, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\hat{v}^{(\kappa)}_l&=W_{v,\kappa}^Tx_l+b_v, \quad \kappa = 1,2,\dots,h, \ l = 1,2,\dots, N\\
\end{align*}
$$

### Property: $\text{MMHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}$ Input-deletion
If $A^{<i>}$ denote the matrix obtained by deleting the $i$-th row and column of the $A$, and $B_{i,s}\to-\infty$ for $i=1,2,\dots,s-1,s+1,\dots,N$ then 
$$
\begin{align*}
\text{MMHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X)^{(s)} \to \text{MMHSAttn}_{\{\mathcal{W}_i^{<s>}\}_{i=0}^h}(X^{(s)}).
\end{align*}
$$
where
\begin{align*}
\mathcal{W^{<s>}}_0 &=\mathcal{W}_0 = \left(W_{o},b_{o}\right),\\
\mathcal{W}^{<s>}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},B^{<s>}_i\right), \quad i = 1,2,\dots,h.
\end{align*}

**Proof:** Combining $\text{MMHAtten}_{\{\mathcal{W}_i\}_{i=0}^{h}}$ Querry and Key-Value deletion properties we obtain
$$
\begin{align*}
\text{MMHSAttn}_{\{\mathcal{W}\}_{i=0}^{h}}(X)^{(s)}&=\text{MMHAttn}_{\{\mathcal{W}_i\}_{i=0}^{h}}(X,X,X)^{(s)},\\
&=\text{MMHAttn}_{\{\mathcal{W}^{(s)}_{i}\}_{i=0}^{h}}(X^{(s)},X,X),\\
&\to\text{MMHAttn}_{\{\mathcal{W}^{<s>}_{i}\}_{i=0}^{h}}(X^{(s)},X^{(s)},X^{(s)}),\\
&=\text{MMHSAttn}_{\{\mathcal{W}_i^{<s>}\}_{i=0}^h}(X^{(s)}).
\end{align*}
$$

## Code: Masked Multi-Head Attention

The code for the remaining models are easily deduced from $\text{MMHAttn}$:
* $\text{MHAttn}_{\{\mathcal{W}_i\}_{i=0}^h}=\text{MMHAttn}_{\{\hat{\mathcal{W}}_{i}\}_{i=0}^h}(Q,K,V)$
where
$$
\begin{align*}
\mathcal{W}_0 &= \hat{\mathcal{W}}_0 =\left(W_{o},b_{o}\right),\\
\hat{\mathcal{W}}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i},0_{M\times N}\right), \quad i = 1,2,\dots,h,\\
\mathcal{W}_i &= \left(W_{q,i},b_{q,i},W_{k,i},b_{k,i},W_{v,i},b_{v,i}\right), \quad i = 1,2,\dots,h,\\
\end{align*}
$$
* $\text{MAttn}_{\mathcal{W}}=  \text{MMHAttn}_{\{\mathcal{W}_i\}_{i=0}^{1}}(Q,K,V)$
where $h=1$ , $d_o=d_v$
$$
\begin{align*}
\mathcal{W}_0 &= \left(I_{d_v\times d_v},0_{\mathbb{R}^{d_v}}\right),\\
\mathcal{W} &= \mathcal{W}_1 = \left(W_{q,1},b_{q,1},W_{k,1},b_{k,1},W_{v,1},b_{v,1},B_1\right).
\end{align*}
$$
* $\operatorname{Attn}_{\mathcal{W}}(Q,K,V) = \text{MAttn}_{\hat{\mathcal{W}}}(Q,K,V)$ 
where
$$
\begin{align*}
\hat{\mathcal{W}} &= \left(W_{q},b_{q},W_{k},b_{k},W_{v},b_{v},0_{M\times N}\right),\\
\mathcal{W} &= \left(W_{q},b_{q},W_{k},b_{k},W_{v},b_{v}\right).
\end{align*}
$$
* $\text{SAttn}_{\mathcal{W}}(X) = \operatorname{Attn}_{\mathcal{W}}(X,X,X)$
* $\text{MSAttn}_{\mathcal{W}}(X) = \operatorname{MAttn}_{\mathcal{W}}(X,X,X)$
* $\text{MHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X) = \text{MHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X)$
* $\text{MMHSAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X) = \text{MMHAttn}_{\{\mathcal{W}_{i}\}_{i=0}^h}(X,X,X)$

In [None]:
import einops
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
class MHAttn(torch.nn.Module):
    """
    Multi-Head Attention module with (general) parameters
    cq, ck, cv: input dimensions for Q, K, V
    dk, dv: dimensions for each head's Q, K and V
    do: output dimension
    h: number of heads.
    
    The initialization of the weights differs from PyTorch’s `nn.MHAttn`.
    Here we use standard `nn.Linear` initialization (Xavier uniform for weights and
    zeros for biases) for clarity and simplicity. 
    """

    def __init__(
        self, cq, ck, cv, dk, dv, do, h, bias=True, add_bias_kv=False, device=None, dtype=None
    ):
        super().__init__()
        assert dk % h == 0, "dk must be divisible by h"
        self.cq = cq
        self.ck = ck
        self.cv = cv
        self.dk = dk
        self.dv = dv
        self.do = do
        self.h = h
        self.add_bias_kv = add_bias_kv
        self.device = device
        self.dtype = dtype
        # Q -> QW_q+1_Mb^T_q
        self.q_proj = torch.nn.Linear(cq, dk * h, bias, self.device, self.dtype)
        # K -> KW_k+1_Mb^T_k
        self.k_proj = torch.nn.Linear(ck, dk * h, bias, self.device, self.dtype)
        # V -> VW_v+1_Mb^T_v
        self.v_proj = torch.nn.Linear(cv, dv * h, bias, self.device, self.dtype)

        self.out_proj = torch.nn.Linear(dv * h, do, bias, self.device, self.dtype)
        if self.add_bias_kv:
            self.bias_k = torch.nn.Parameter(
                torch.zeros(1, 1, dk * h, device=self.device, dtype=self.dtype)
            )
            self.bias_v = torch.nn.Parameter(
                torch.zeros(1, 1, dv * h, device=self.device, dtype=self.dtype)
            )

    def forward(self, Q, K, V,attn_mask=None):
        """Forward pass of the MHA module."""
        # Linear projections
        proj_q = self.q_proj(Q)  # Q=QW_q+1_Mb^T_q
        proj_k = self.k_proj(K)  # K=KW_k+1_Mb^T_k
        proj_v = self.v_proj(V)  # V=VW_v+1_Mb^T_v
        if self.add_bias_kv:
            # append bias to the key and value sequences
            batch_size = proj_k.shape[0]
            proj_k = torch.cat([proj_k, self.bias_k.repeat(batch_size, 1, 1)], dim=1)
            proj_v = torch.cat([proj_v, self.bias_v.repeat(batch_size, 1, 1)], dim=1)

        # Reshape for multi-head attention
        r_q = einops.rearrange(proj_q, "b m (h dk) -> b h m dk", h=self.h)
        r_k = einops.rearrange(proj_k, "b n (h dk) -> b h n dk", h=self.h)
        r_v = einops.rearrange(proj_v, "b n (h dv) -> b h n dv", h=self.h)

        # QK^T
        scores = torch.einsum("bhmd, bhnd -> bhmn", r_q, r_k)
        if attn_mask is not None:
            match attn_mask.dim():
                case 2:
                    attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) 
                case 3:
                    attn_mask = einops.rearrange(attn_mask, "(b h) m n -> b h m n", h = self.h)
                case 4:
                    pass 
                case _:
                    raise ValueError("attn_mask has incorrect dimensions")
            scores += attn_mask
        
        # softmax(QK^T/sqrt(dk))
        attn = torch.nn.functional.softmax(scores / (self.dk**0.5), dim=-1)

        # softmax(QK^T/sqrt(dk))V
        o = torch.einsum("bhmn, bhnv -> bhmv", attn, r_v)

        # Reshape back
        r_o = einops.rearrange(o, "b h m dv -> b m (h dv)")

        # Final linear projection
        proj_o = self.out_proj(r_o)
        return proj_o


In [None]:
batch_dim = 3
M = 5  # sequence length of q
N = 3  # sequence length of k,v
d = 4 # embedding/model dimension
ck = 8  # key dimension 
cv = 16  # value dimension
h = 2  # number of heads

mask = torch.full((M, N), float('-inf'),device=device)
mask = torch.triu(mask, diagonal=1)

#mask = mask.repeat(batch_dim * h, 1, 1)
#mask[0]= torch.full((M, N), 0,device=device)

  # Lower triangular mask

bias = True
add_bias_kv = False

In [43]:
mask.shape

torch.Size([5, 3])

#### Weights

In [None]:
nn_attn = torch.nn.MHAttn(embed_dim=d,kdim=ck,vdim=cv,
                                        num_heads=h,batch_first=True,
                                        bias=bias,add_bias_kv=add_bias_kv).to(device)

assert d % h == 0, "d must be divisible by h"
dh = d // h
attn = MHAttn(cq = d, ck = ck, cv=cv, dk = dh ,dv=dh,do=d,h=h,bias=bias,add_bias_kv=add_bias_kv).to(device)

In [45]:
print("Torch MHA Weights:")
for name, w in nn_attn.named_parameters():
    print(f"{name} - {w.shape}")

print("\nOur MHA Weights:")
for name, w in attn.named_parameters():
    print(f"{name} - {w.shape}")

Torch MHA Weights:
q_proj_weight - torch.Size([4, 4])
k_proj_weight - torch.Size([4, 8])
v_proj_weight - torch.Size([4, 16])
in_proj_bias - torch.Size([12])
out_proj.weight - torch.Size([4, 4])
out_proj.bias - torch.Size([4])

Our MHA Weights:
q_proj.weight - torch.Size([4, 4])
q_proj.bias - torch.Size([4])
k_proj.weight - torch.Size([4, 8])
k_proj.bias - torch.Size([4])
v_proj.weight - torch.Size([4, 16])
v_proj.bias - torch.Size([4])
out_proj.weight - torch.Size([4, 4])
out_proj.bias - torch.Size([4])


#### Output

In [46]:
q = torch.rand(batch_dim,M,d,device=device)
k = torch.rand(batch_dim,N,ck,device=device) 
v = torch.rand(batch_dim,N,cv,device=device)

#### Load weights

Weights are created differently, so lets load the nn wieghts on our implementation

In [47]:
with torch.no_grad():
    # 1) copy weights (shapes already match)
    attn.q_proj.weight.copy_(nn_attn.q_proj_weight)   # (d, d)
    attn.k_proj.weight.copy_(nn_attn.k_proj_weight)   # (d, ck)
    attn.v_proj.weight.copy_(nn_attn.v_proj_weight)   # (d, cv)

    # 2) split the packed bias: (3d,) -> (d,) + (d,) + (d,)
    b = nn_attn.in_proj_bias      # shape (48,)
    if bias:
        attn.q_proj.bias.copy_(b[0:d])        # 0:d
        attn.k_proj.bias.copy_(b[d:2*d])      # d:2d
        attn.v_proj.bias.copy_(b[2*d:3*d])    # 2d:3d
    
    if add_bias_kv:
        attn.bias_k.copy_(nn_attn.bias_k.squeeze(0))
        attn.bias_v.copy_(nn_attn.bias_v.squeeze(0))

    # 3) output projection
    attn.out_proj.weight.copy_(nn_attn.out_proj.weight)
    if bias:
        attn.out_proj.bias.copy_(nn_attn.out_proj.bias)

In [48]:
out = attn(q, k, v,attn_mask=mask)
nn_out,_ = nn_attn(q,k,v, attn_mask=mask)

In [49]:
out

tensor([[[ 0.2999,  0.0182,  0.2453,  0.1384],
         [ 0.2678,  0.1382,  0.0801,  0.1355],
         [ 0.2231,  0.0361,  0.2631,  0.0378],
         [ 0.2550,  0.0491,  0.2635,  0.0249],
         [ 0.2416,  0.0456,  0.2600,  0.0315]],

        [[ 0.1625,  0.0785,  0.1103,  0.0671],
         [ 0.2661,  0.2824,  0.1659, -0.1440],
         [ 0.2406,  0.3124,  0.1191, -0.1383],
         [ 0.2286,  0.3090,  0.1082, -0.1263],
         [ 0.2359,  0.3027,  0.1226, -0.1367]],

        [[ 0.2562,  0.2110, -0.0463, -0.0349],
         [ 0.3384,  0.2915,  0.0669, -0.0994],
         [ 0.3229,  0.3656,  0.0345, -0.1476],
         [ 0.3223,  0.3671,  0.0368, -0.1438],
         [ 0.3368,  0.3819,  0.0356, -0.1506]]], device='mps:0',
       grad_fn=<LinearBackward0>)

In [50]:
nn_out

tensor([[[ 0.2999,  0.0182,  0.2453,  0.1384],
         [ 0.2678,  0.1382,  0.0801,  0.1355],
         [ 0.2231,  0.0361,  0.2631,  0.0378],
         [ 0.2550,  0.0491,  0.2635,  0.0249],
         [ 0.2416,  0.0456,  0.2600,  0.0315]],

        [[ 0.1625,  0.0785,  0.1103,  0.0671],
         [ 0.2661,  0.2824,  0.1659, -0.1440],
         [ 0.2406,  0.3124,  0.1191, -0.1383],
         [ 0.2286,  0.3090,  0.1082, -0.1263],
         [ 0.2359,  0.3027,  0.1226, -0.1367]],

        [[ 0.2562,  0.2110, -0.0463, -0.0349],
         [ 0.3384,  0.2915,  0.0669, -0.0994],
         [ 0.3229,  0.3656,  0.0345, -0.1476],
         [ 0.3223,  0.3671,  0.0368, -0.1438],
         [ 0.3368,  0.3819,  0.0356, -0.1506]]], device='mps:0',
       grad_fn=<TransposeBackward0>)