# Efficient Attention Mechanism

- Input
    - query, key, value : (batch_size, num_heads, seq_len, embed_dim)

- Output
    - attention value : (batch_size, seq_len, embed_dim)

## Input

In [32]:
import torch

SEQ = 512
query = torch.rand(32, 8, SEQ, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, SEQ, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, SEQ, 64, dtype=torch.float16, device="cuda")

print(f"{query.shape, key.shape, value.shape}")

(torch.Size([32, 8, 512, 64]), torch.Size([32, 8, 512, 64]), torch.Size([32, 8, 512, 64]))


## Standard Attention

In [33]:
import torch.nn as nn
import math

class StandardAttention(nn.Module):
    def __init__(self, dropout, device):
        super().__init__()
        self.device = device
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, query, key, value, mask=None):
        trg_len, src_len = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(key.size(-1))
        attn_bias = torch.zeros(trg_len, src_len, dtype=query.dtype, device=self.device)

        # masking
        if mask is not None:
            mask = mask.to(self.device)
            if mask.dtype == torch.bool:
                attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
            else:
                attn_bias += mask

        # operation
        attn_score = query @ key.transpose(-2, -1) * scale_factor # QK^T / scale
        attn_score += attn_bias # masking value
        attn_prob = self.softmax(attn_score)
        attn_prob = self.dropout(attn_prob)

        return attn_prob @ value

In [34]:
attention = StandardAttention(dropout=0.0, device="cuda")
attn = attention(query, key, value)

print(f"{attn.shape}")

torch.Size([32, 8, 512, 64])


## Compare various Attention mechanisms

In [46]:
import torch.utils.benchmark as benchmark
import torch.nn.functional as F

def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    
    return t0.blocked_autorange().mean * 1e6

def timer(query, key, value, is_causal=False):
    return benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value, is_causal=is_causal)

In [51]:
import torch
import torch.nn.functional as F

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
    attn = F.scaled_dot_product_attention(query, key, value)
    print(f"Standard")
    print(f"    time: {timer(query, key, value): .3f}ms")
    print(f"    shape: {attn.shape}")
    
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
    attn = F.scaled_dot_product_attention(query, key, value, is_causal=True)
    print(f"Causal")
    print(f"    time: {timer(query, key, value, is_causal=True): .3f}ms")
    print(f"    shape: {attn.shape}")

# 지원 X
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False):
    attn = F.scaled_dot_product_attention(query, key, value, is_causal=True)
    print(f"Flash")
    print(f"    time: {timer(query, key, value): .3f}ms")
    print(f"    shape: {attn.shape}")

# 지원 X
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True):
    attn = F.scaled_dot_product_attention(query, key, value)
    print(f"Memory Efficient")
    print(f"    time: {timer(query, key, value): .3f}ms")
    print(f"    shape: {attn.shape}")

Standard
    time:  294525.718ms
    shape: torch.Size([32, 32, 1024, 32])
Causal
    time:  431429.127ms
    shape: torch.Size([32, 32, 1024, 32])
Flash
    True
    time:  290587.018ms
    shape: torch.Size([32, 32, 1024, 32])
Memory Efficient
    True
    time:  2225368.515ms
    shape: torch.Size([32, 32, 1024, 32])
