Implementing Linear-Complexity Attention in PyTorch
====

In [None]:
import math
import torch

## Compute Query, Key and Value Matrices


$\text{Input: }\ X\in \mathbb{R}^{N\times d}$

$\text{Projection Matrices:}$  
 * $W_q\in \mathbb{R}^{d\times d_q}$  
 * $W_k\in \mathbb{R}^{d\times d_k}$  
 * $W_v\in \mathbb{R}^{d\times d_v}$
 
Fro simplicity, assume $d_q=d_k=d_v=d$

$\text{Q, K, V Projection:}$

$$Q = XW_q$$
$$K = XW_k$$
$$V = XW_v$$


In [None]:
# define parameters
N = 1000
d = 256  # aka d_model
dk = dq = dv = d  # for simplicity

In [None]:
# Make a random input X


# build projection matrices Wq, Wk, Qv


In [None]:
# Build Q, K, V


## Original Scaled Dot-product Attention $\mathcal{O}(N^2)$

Ref.: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)

$$\text{Attention}(Q,K,V)=\text{Softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$



In [None]:
# ~~~~~~~~~~~~~~~ Step 1 ~~~~~~~~~~~~~~~~~~~~
# Compute the attention scores




# ~~~~~~~~~~~~~~~ Step 2 ~~~~~~~~~~~~~~~~~~~~
# Normalization: applying softmax

# Sanity-check: each row sums to 1




# ~~~~~~~~~~~~~~~ Step 3 ~~~~~~~~~~~~~~~~~~~~
# Compute the attention output


In [None]:
# batched tensors

# ~~~~~~~~~~~~~~~ Step 1 ~~~~~~~~~~~~~~~~~~~~
# Compute the attention scores




# ~~~~~~~~~~~~~~~ Step 2 ~~~~~~~~~~~~~~~~~~~~
# Normalization: applying softmax

# Sanity-check: each row sums to 1




# ~~~~~~~~~~~~~~~ Step 3 ~~~~~~~~~~~~~~~~~~~~


## Method 1: Efficient-Attention $\mathcal{O}(N)$

Ref.: ["Efficient Attention: Attention with Linear Complexities"](https://arxiv.org/abs/1812.01243)

$$\hat{A}(Q, K, V)=\sigma_\text{row}(Q) \left(\sigma_\text{col} (K)^\top V\right)$$



In [None]:
# ~~~~~~~~~~~~~~~ Step 1 ~~~~~~~~~~~~~~~~~~~~
# Apply sigma to Q and K




# ~~~~~~~~~~~~~~~ Step 2 ~~~~~~~~~~~~~~~~~~~~
# Calculate (sigma_col K)^T x V




# ~~~~~~~~~~~~~~~ Step 3 ~~~~~~~~~~~~~~~~~~~~
# Calculate the final output


## Method 2: Linear Attention using Kernels

Ref.: ["Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"](https://arxiv.org/abs/2006.16236)

$$\hat{A}(Q,K,V)=\phi(Q)\left(\phi(K)^\top V\right)$$

$\text{where feature function }\ \phi(x) = ELU(x) + 1$

<img src="elu.png" alt="ELU(x)" style="width:250px;"/>


In [None]:
elu = torch.nn.ELU()
def phi(x):
    return elu(x) + 1


# ~~~~~~~~~~~~~~~ Step 1 ~~~~~~~~~~~~~~~~~~~~
# Apply phi to Q and K


# ~~~~~~~~~~~~~~~ Step 2 ~~~~~~~~~~~~~~~~~~~~
# Calculate (phi_K)^T x V


# ~~~~~~~~~~~~~~~ Step 3 ~~~~~~~~~~~~~~~~~~~~
#  final output


## Method 3: Linear Attention using Taylor Series Approximation $\mathcal{O}(N)$

Ref.: ["Linear Attention Mechanism: An Efficient Attention for Semantic Segmentation"](https://arxiv.org/abs/2007.14902)

$$\hat{A}(Q,K,V)=\frac{V^\top.\mathbb{1}_N + Q^\prime \left({K^\prime}^\top V\right)}{N+Q^\prime \left({K^\prime}^\top .\mathbb{1}_N\right)}$$

$\text{where }\ $  
* $Q^\prime = \frac{Q}{\|Q\|_2}$  
* $K^\prime = \frac{K}{\|K\|_2}$

In [None]:
# Create a vector of ones


In [None]:
# ~~~~~~~~~~~~~~~ Step 1 ~~~~~~~~~~~~~~~~~~~~
# Normalize Q and K -> Q_prime and K_prime



# ~~~~~~~~~~~~~~~ Step 2 ~~~~~~~~~~~~~~~~~~~~
# Calculate intermediate terms


# ~~~~~~~~~~~~~~~ Step 3 ~~~~~~~~~~~~~~~~~~~~
# Compute the final result


## Method 4: Linformer's Attention $\mathcal{O}(N)$

Ref.: ["Linformer: Self-Attention with Linear Complexity"](https://arxiv.org/abs/2006.04768)

$\text{Low-rank approximation: projection matrices }E,F\in \mathbb{R}^{r\times N}$

$\longrightarrow EK \in \mathbb{R}^{r\times d_k}$  
$\longrightarrow FV \in \mathbb{R}^{r\times d_v}$

$\text{Context-mapping matrix P: }\ \ P=\text{Softmax}\left(\frac{Q(EK)^\top}{\sqrt{d_k}}\right)$


$$\hat{A}(Q,EK,EV)=\text{Softmax}\left(\frac{Q(EK)^\top}{\sqrt{d_k}}\right)FV$$

In [None]:
# assume r is 100
r = 100

# Define random projection matrices E and F


In [None]:
# ~~~~~~~~~~~~~~~ Step 1 ~~~~~~~~~~~~~~~~~~~~
# Project K and V




# ~~~~~~~~~~~~~~~ Step 2 ~~~~~~~~~~~~~~~~~~~~
# Compute the context-mapping matrix P




# ~~~~~~~~~~~~~~~ Step 3 ~~~~~~~~~~~~~~~~~~~~
# Compute the final output
