DO NOT FORGET TO,
- install dependencies both on c++ and python
- `pip install -e .` in __./min_flash_attention__ directory

In [None]:
from min_flash_attention import *

In [None]:
# Use small model params, otherwise slower than manual attention. See caveats in README.
batch_size = 8
n_head = 12
seq_len = 1024
head_embd = 64

q = torch.randn(batch_size, n_head, seq_len, head_embd, requires_grad=True).cuda()
k = torch.randn(batch_size, n_head, seq_len, head_embd, requires_grad=True).cuda()
v = torch.randn(batch_size, n_head, seq_len, head_embd, requires_grad=True).cuda()

### Vanilla attention

In [None]:
# Our minimal flash attention aims to be faster than this by avoiding HBM read/writes of N^2 matrices.
def vanilla_attention(q, k, v):
    att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
    # add casual mask
    mask = torch.tril(torch.ones(att.size(-2), att.size(-1)), diagonal=0).cuda()
    att = att.masked_fill(mask == 0, float('-inf'))
    att = F.softmax(att, dim=-1)
    y = att @ v
    return y

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    manual_result = vanilla_attention(q, k, v)

prof.key_averages().table(sort_by='cuda_time_total', row_limit=10)

### Min Flash attention v1

In [None]:
with (
    torch.autograd.profiler.profile(use_cuda=True) as prof,
    torch.no_grad(),
):
    minimal_result, l, m = min_flash_attention.forward(q, k, v, 1)
prof.key_averages().table(sort_by='cuda_time_total', row_limit=10)

In [None]:
# attention values sanity check
torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-03)

### Min Flash attention v2

In [None]:
with (
    torch.autograd.profiler.profile(use_cuda=True) as prof,
    torch.no_grad(),
):
    minimal_result, l, _ = min_flash_attention.forward(q, k, v, 2)
prof.key_averages().table(sort_by='cuda_time_total', row_limit=10)

In [None]:
# attention values sanity check
torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-03)