Implementing Linear-Complexity Attention Mechanism in PyTorch
===

In [11]:
import math
import matplotlib.pyplot as plt
import seaborn as sns

import torch
%matplotlib inline

## Compute Query, Key and Value Matrices


$\text{Input: }\ X\in \mathbb{R}^{N\times d}$

$\text{Projection Matrices:}$  
 * $W_q\in \mathbb{R}^{d\times d_q}$  
 * $W_k\in \mathbb{R}^{d\times d_k}$  
 * $W_v\in \mathbb{R}^{d\times d_v}$

$\text{Q, K, V Projection:}$

$$Q = XW_q$$
$$K = XW_k$$
$$V = XW_v$$


In [1]:
# define parameters
N = 1000
d = 256  # aka d_model
dk = dq = dv = d  # for simplicity

## Original Scaled Dot-product Attention $\mathcal{O}(N^2)$



$$\text{Attention}(Q,K,V)=\text{Softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$



## Method 1: Efficient-Attention $\mathcal{O}(N)$

Ref.: ["Efficient Attention: Attention with Linear Complexities"](https://arxiv.org/abs/1812.01243)

$$\hat{A}(Q, K, V)=\sigma_\text{row}(Q) \left(\sigma_\text{col} (K)^\top V\right)$$



## Method 2: Linear Attention using Kernels

Ref.: ["Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"](https://arxiv.org/abs/2006.16236)

$$\hat{A}(Q,K,V)=\phi(Q)\left(\phi(K)^\top V\right)$$

$\text{where feature function }\ \phi(x) = ELU(x) + 1$


<img src="elu.png" alt="ELU(x)" style="width:250px;"/>


## Method 3: Linear Attention using Taylor Series Approximation $\mathcal{O}(N)$

Ref.: ["Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation"](https://arxiv.org/abs/2007.14902)

$$\hat{A}(Q,K,V)=\frac{V^\top.\mathbb{1}_N + Q^\prime \left({K^\prime}^\top V\right)}{N+Q^\prime \left({K^\prime}^\top .\mathcal{1}_N\right)}$$

$\text{where }\ $  
* $Q^\prime = \frac{Q}{\|Q\|_2}$  
* $K^\prime = \frac{Q}{\|K\|_2}$

## Method 4: Linformer's Attention $\mathcal{O}(N)$

Ref.: ["Linformer: Self-Attention with Linear Complexity"](https://arxiv.org/abs/2006.04768)

$\text{Low-rank approximation: projection matrices }E,F\in \mathbb{R}^{r\times N}$

$\longrightarrow EK \in \mathbb{R}^{r\times d_k}$  
$\longrightarrow FV \in \mathbb{R}^{r\times d_v}$

$\text{Context-mapping matrix P: }\ \ P=\text{Softmax}\left(\frac{Q(EK)^\top}{\sqrt{d_k}}\right)$


$$\hat{A}(Q,EK,EV)=\text{Softmax}\left(\frac{Q(EK)^\top}{\sqrt{d_k}}\right)EV$$