## Self-Attention Does Not Reuire O(N^2) Memory

Implementation of: https://arxiv.org/pdf/2112.05682.pdf

### Example of actual pytorch code

In [1]:
import torch

In [2]:
import linear_mem_attention_pytorch

In [3]:
linear_mem_attention_pytorch.linear_mem_attn_torch

<module 'linear_mem_attention_pytorch.linear_mem_attn_torch' from '/Users/ericalcaidealdeano/miniconda3/envs/charm/lib/python3.7/site-packages/linear_mem_attention_pytorch-0.0.1-py3.7.egg/linear_mem_attention_pytorch/linear_mem_attn_torch.py'>

In [1]:


from linear_mem_attention_torch import *
from linear_mem_attention_torch.utils import qkv2res

ModuleNotFoundError: No module named 'linear_mem_attn_torch'

In [2]:
@torch.jit.script
def qkv2res(q, k, v):  
    # return (q @ torch.transpose(k, -1, -2)).softmax(dim=-1) @ v
    qk = torch.einsum('b h i d, b h j d -> b h i j', q, k).softmax(dim=-1)
    return torch.einsum('b h i j, b h j d -> b h i d', qk, v)

### Paper base jax code

In [3]:
import functools, jax, math
from jax import numpy as jnp

def _query_chunk_attention(query, key, value, precision, key_chunk_size=4096):
    """Multi-head dot product attention with a limited number of queries."""
    num_kv, num_heads, k_features = key.shape
    v_features = value.shape[-1]
    key_chunk_size = min(key_chunk_size, num_kv)
    query = query / jnp.sqrt(k_features)

    @functools.partial(jax.checkpoint, prevent_cse=False)
    def summarize_chunk(query, key, value):
        attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision)
        max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
        max_score = jax.lax.stop_gradient(max_score)
        exp_weights = jnp.exp(attn_weights - max_score)
        exp_values = jnp.einsum('vhf,qhv->qhf', value, exp_weights, precision=precision)
        return (
            exp_values, exp_weights.sum(axis=-1),
            max_score.reshape((query.shape[0], num_heads))
        )

    def chunk_scanner(chunk_idx):
        key_chunk = jax.lax.dynamic_slice(
            key, (chunk_idx, 0, 0),
            slice_sizes=(key_chunk_size, num_heads, k_features)
        )
        value_chunk = jax.lax.dynamic_slice(
             value, (chunk_idx, 0, 0),
             slice_sizes=(key_chunk_size, num_heads, v_features)
        )
        return summarize_chunk(query, key_chunk, value_chunk)

    chunk_values, chunk_weights, chunk_max = jax.lax.map(
        chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)
    )

    global_max = jnp.max(chunk_max, axis=0, keepdims=True)
    max_diffs = jnp.exp(chunk_max - global_max)
    chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
    chunk_weights *= max_diffs

    all_values = chunk_values.sum(axis=0)
    all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
    return all_values / all_weights

def attention(
    query, key, value, precision=jax.lax.Precision.HIGHEST,
    query_chunk_size=1024
):
    """Memory-efficient multi-head dot product attention."""
    num_q, num_heads, q_features = query.shape
    def chunk_scanner(chunk_idx, _):
        query_chunk = jax.lax.dynamic_slice(
             query, (chunk_idx, 0, 0),
             slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features)
        )
        return (
            chunk_idx + query_chunk_size,
            _query_chunk_attention(query_chunk, key, value, precision=precision)
        )

    _, res = jax.lax.scan(
    chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size))
    return res.reshape(num_q, num_heads, value.shape[-1])


### PyTorch Implementation

In [4]:
import torch
import numpy as np
from typing import Optional, Tuple, Any, List
from types import FunctionType

@torch.jit.script
def dynamic_slice(
    x: torch.Tensor, 
    slices: Tuple[int, int, int], 
    slice_sizes: Tuple[int, int, int],
) -> torch.Tensor:
    """ approx like jax.lax.dynamic_slice.
        * NOTE: assumes we dont work on first dim
        Ex: 
        dynamic_slice(
            x, 
            slices=(0, 0, 0),
            slice_sizes=(16, 64, 64)
        )
    """
    return x[
        :,
        slices[0]: slices[0] + slice_sizes[0],
        slices[1]: slices[1] + slice_sizes[1],
        slices[2]: slices[2] + slice_sizes[2],
    ]

def torch_map(fn, xs) -> Tuple[torch.Tensor, torch.Tensor,torch.Tensor]:
    """ approx like jax.lax.map """
    return

def torch_scan(
        f: FunctionType,
        init: int = 0,
        xs: Optional[List] = None,
        length: int = 0
) -> Tuple[Any, torch.Tensor]:
        if xs is None:
            xs = [None] * length
        carry = init
        ys = []
        for x in xs:
            carry, y = f(carry, x)
            ys.append(y)
        return carry, torch.stack(ys, dim=0)

###################
## ADAPTED FROM: https://arxiv.org/pdf/2112.05682.pdf
###################

def torch_query_chunk_attention(query, key, value, key_chunk_size=4096):
    """Multi-head dot product attention with a limited number of queries."""
    batch, num_kv, num_heads, k_features = key.shape
    v_features = value.shape[-1]
    query_chunk = query.shape[1] # b n h d
    key_chunk_size = min(key_chunk_size, num_kv)
    query = query / k_features**0.5

    # @functools.partial(jax.checkpoint, prevent_cse=False)
    def summarize_chunk(
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        
        attn_weights = torch.einsum('bqhd,bkhd->bqhk', query, key)
        max_score = torch.amax(attn_weights, dim=-1, keepdim=True).detach()
        exp_weights = torch.exp(attn_weights - max_score)
        exp_values = torch.einsum('bvhf,bqhv->bqhf', value, exp_weights)
        # (b q h f), (b q h), (b q h 1)
        return exp_values, exp_weights.sum(dim=-1), max_score  

    def chunk_scanner(
        chunk_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        key_chunk = dynamic_slice(key, (chunk_idx, 0, 0),
            slice_sizes=(key_chunk_size, num_heads, k_features)
        )
        value_chunk = dynamic_slice(
            value, (chunk_idx, 0, 0),
            slice_sizes=(key_chunk_size, num_heads, v_features)
        )
        return summarize_chunk(query, key_chunk, value_chunk)

    chunk_iter = np.arange(0, num_kv, key_chunk_size)
    chunk_values = torch.zeros(len(chunk_iter), batch, query_chunk, num_heads, v_features).to(query)
    chunk_weights = torch.zeros(len(chunk_iter), batch, query_chunk, num_heads).to(query)
    chunk_max = torch.zeros(len(chunk_iter), batch, query_chunk, num_heads, 1).to(query)
    for i, xi in enumerate(chunk_iter):
        chunk_values[i], chunk_weights[i], chunk_max[i] = chunk_scanner(xi)

    global_max = torch.amax(chunk_max, dim=0, keepdim=True)
    max_diffs = torch.exp(chunk_max - global_max)

    chunk_values *= max_diffs
    chunk_weights *= max_diffs[..., 0]

    all_values = chunk_values.sum(dim=0)
    all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
    return all_values / all_weights

def torch_attention(
    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 
    query_chunk_size=1024, key_chunk_size=4096,
) -> torch.Tensor:
    """ Memory-efficient multi-head dot product attention. 
        qkv should be provided in ()
    """
    batch, num_q, num_heads, q_features = query.shape

    def chunk_scanner(chunk_idx: int, _):
        query_chunk = dynamic_slice(
            query, (chunk_idx, 0, 0),
            slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features)
        )
        return (
            chunk_idx + query_chunk_size,
            torch_query_chunk_attention(query_chunk, key, value, key_chunk_size=key_chunk_size)
        )

    _, res = torch_scan(chunk_scanner, init=0, xs=None, length=np.math.ceil(num_q / query_chunk_size))
    return res.reshape(batch, num_q, num_heads, value.shape[-1])

#### Tests for Pytorch Attention

In [20]:
B, L, D = 1, 2**14, 64
a = torch.randn(B, L, D) # .cuda()
b = a[:, None, :, :]                                           # (b h n d) batch and heads
a_ = jax.numpy.asarray(torch.transpose(a, 0, 1).cpu().numpy()) # (n h d) heads but not batch
b_ = torch.from_numpy( np.asarray(a_) )[None, ...]
c_ = torch.cat([b_, b_], dim=0)
a.shape, a_.shape, b.shape, b_.shape, c_.shape
# attn = Attention(D)
# %timeit attn(a, a)

(torch.Size([1, 16384, 64]),
 (16384, 1, 64),
 torch.Size([1, 1, 16384, 64]),
 torch.Size([1, 16384, 1, 64]),
 torch.Size([2, 16384, 1, 64]))

In [6]:
# test batching works
assert torch.allclose(
    torch_attention(b_, b_, b_)[0], # .shape b n h d
    torch_attention(c_, c_, c_)[0], # .shape b n h d
), "Batching does not work"

# test query chunking works
assert torch.allclose(
    torch_attention(b_, b_, b_, query_chunk_size=32)[0], # .shape b n h d
    torch_attention(b_, b_, b_)[0], # .shape b n h d
    atol = 1e-6
), "Query chunking does not work"

# test key chunking works
assert torch.allclose(
    torch_attention(b_, b_, b_, key_chunk_size=128)[0], # .shape b n h d
    torch_attention(b_, b_, b_)[0], # .shape b n h d
    atol = 1e-6
), "Key chunking does not work"

# test correctness chunking works
assert torch.allclose(
    torch_attention(b_, b_, b_)[0], # .shape b n h d
    torch.transpose( qkv2res(*[torch.transpose(b_, 1, 2)]*3), 1, 2 )[0], # .shape b n h d
    atol = 1e+1 # slight difference, but paper code shows it as well
), "Key chunking does not work"

#### Performance Comparison

In [7]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [8]:
jax_attn = jax.jit(attention)

In [None]:
for exp2 in range(6, 16+1): 
    B, L, D = 1, 2**exp2, 64
    a = torch.randn(B, L, D) # .cuda()
    b = a[:, None, :, :]                                           # (b h n d) batch and heads
    a_ = jax.numpy.asarray(torch.transpose(a, 0, 1).cpu().numpy()) # (n h d) heads but not batch
    b_ = torch.from_numpy( np.asarray(a_) )[None, ...]
    a.shape, a_.shape, b.shape, b_.shape
    # attn = Attention(D)
    # %timeit attn(a, a)
    print()
    print(f"Attn w/ heads=1, batch=1, D=64, For L={2**exp2}")
    print("-> jax compiled linear")
    %timeit jax_attn(a_, a_, a_).block_until_ready()
    print("-> torch linear")
    %timeit torch_attention(b_, b_, b_)
    print("-> torch standard (einsum is used, but similar to matmul)")
    %timeit qkv2res(b, b, b)


Attn w/ heads=1, batch=1, D=64, For L=64
-> jax compiled linear
133 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-> torch linear
439 µs ± 378 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
-> torch standard (einsum is used, but similar to matmul)
182 µs ± 387 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Attn w/ heads=1, batch=1, D=64, For L=128
-> jax compiled linear
129 µs ± 8.91 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-> torch linear
534 µs ± 940 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
-> torch standard (einsum is used, but similar to matmul)
581 µs ± 405 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Attn w/ heads=1, batch=1, D=64, For L=256
-> jax compiled linear
151 µs ± 8.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-> torch linear
893 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
-> torch standard (einsum is used, but similar to matmul)
1.53 m


You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


386 µs ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch linear
25.6 ms ± 72.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
-> torch standard (einsum is used, but similar to matmul)
128 ms ± 196 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Attn w/ heads=1, batch=1, D=64, For L=4096
-> jax compiled linear



You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


1.16 ms ± 14 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch linear
114 ms ± 8.81 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
-> torch standard (einsum is used, but similar to matmul)
576 ms ± 151 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Attn w/ heads=1, batch=1, D=64, For L=8192
-> jax compiled linear



You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


5.11 ms ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch linear
465 ms ± 21.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch standard (einsum is used, but similar to matmul)
2.36 s ± 640 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Attn w/ heads=1, batch=1, D=64, For L=16384
-> jax compiled linear



You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


19.6 ms ± 382 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch linear
1.74 s ± 21.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch standard (einsum is used, but similar to matmul)
9.47 s ± 4.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Attn w/ heads=1, batch=1, D=64, For L=32768
-> jax compiled linear



You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


73.8 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch linear
7.92 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch standard (einsum is used, but similar to matmul)
37.3 s ± 21.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Attn w/ heads=1, batch=1, D=64, For L=65536
-> jax compiled linear



You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.


284 ms ± 670 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch linear
37.5 s ± 467 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
-> torch standard (einsum is used, but similar to matmul)


### Compare to community implementation (L = 2**14)

In [9]:
from memory_efficient_attention import efficient_dot_product_attention_pt

In [21]:
%timeit efficient_dot_product_attention_pt(b_, b_, b_, key_chunk_size=1024, query_chunk_size=4096)

0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
2.33 s ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%timeit torch_attention(b_, b_, b_)

1.74 s ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
