## Self-Attention Does Not Reuire O(N^2) Memory

Implementation of: https://arxiv.org/pdf/2112.05682.pdf

In [1]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

### Importing

In [2]:
import numpy as np
import torch
import jax

from linear_mem_attention_pytorch import *
from linear_mem_attention_pytorch.utils import qkv2res
from linear_mem_attention_pytorch.fast_attn import Attention
from linear_mem_attention_pytorch import linear_mem_attn_jax

#### Tests for Pytorch Attention

In [3]:
B, L, D = 1, 2**14, 64
a = torch.randn(B, L, D) # .cuda()
b = a[:, None, :, :]                                           # (b h n d) batch and heads
a_ = jax.numpy.asarray(torch.transpose(a, 0, 1).cpu().numpy()) # (n h d) heads but not batch
b_ = torch.from_numpy( np.asarray(a_) )[None, ...]
c_ = torch.cat([b_, b_], dim=0)
a.shape, a_.shape, b.shape, b_.shape, c_.shape
# attn = Attention(D)
# %timeit attn(a, a)

(torch.Size([1, 16384, 64]),
 (16384, 1, 64),
 torch.Size([1, 1, 16384, 64]),
 torch.Size([1, 16384, 1, 64]),
 torch.Size([2, 16384, 1, 64]))

#### Performance Comparison

In [10]:
@torch.jit.script
def qkv2res2(q, k, v): 
    """ Inputs must be in (b n h d) format. """  
    # return (q @ torch.transpose(k, -1, -2)).softmax(dim=-1) @ v
    qk = torch.einsum('b h i d, b h j d -> b h i j', q, k).softmax(dim=-1)
    return torch.einsum('b h i j, b h j d -> b h i d', qk, v)

In [4]:
jax_attn = jax.jit(linear_mem_attn_jax.attention)

In [13]:
# CPU
for exp2 in range(6, 14+1): 
    B, L, D = 1, 2**exp2, 64
    a = torch.randn(B, L, D) # .cuda()
    b = a[:, None, :, :]                                           # (b h n d) batch and heads
    a_ = jax.numpy.asarray(torch.transpose(a, 0, 1).cpu().numpy()) # (n h d) heads but not batch
    b_ = torch.from_numpy( np.asarray(a_) )[None, ...]
    a.shape, a_.shape, b.shape, b_.shape
    # attn = Attention(D)
    # %timeit attn(a, a)
    print()
    print(f"Attn w/ heads=1, batch=1, D=64, For L={2**exp2}")
    print("-> jax compiled linear")
    %timeit jax_attn(a_, a_, a_).block_until_ready()
    print("-> torch linear")
    %timeit attention(b_, b_, b_, query_chunk_size=1024, key_chunk_size=4096)
    print("-> torch standard (einsum is used, but similar to matmul)")
    %timeit qkv2res(b_, b_, b_)


Attn w/ heads=1, batch=1, D=64, For L=64
-> jax compiled linear
135 µs ± 5.39 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-> torch linear
535 µs ± 327 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
-> torch standard (einsum is used, but similar to matmul)
194 µs ± 58.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Attn w/ heads=1, batch=1, D=64, For L=128
-> jax compiled linear
149 µs ± 5.69 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-> torch linear
638 µs ± 630 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
-> torch standard (einsum is used, but similar to matmul)
653 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Attn w/ heads=1, batch=1, D=64, For L=256
-> jax compiled linear
153 µs ± 6.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
-> torch linear
997 µs ± 219 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
-> torch standard (einsum is used, but similar to matmul)
2.03 

### Compare to community implementation (L = 2**14)

In [14]:
from memory_efficient_attention import efficient_dot_product_attention_pt

In [15]:
print("Community implementation")
%timeit efficient_dot_product_attention_pt(b_, b_, b_, key_chunk_size=1024, query_chunk_size=4096)

Community implementation
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
0
4096
8192
12288
2.55 s ± 63.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
print("Our implementation")
%timeit attention(b_, b_, b_, key_chunk_size=1024, query_chunk_size=4096)

Our implementation
1.88 s ± 92.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
