# Efficient Attention Mechanism

- Input
    - query, key, value : (batch_size, num_heads, seq_len, embed_dim)

- Output
    - attention value : (batch_size, seq_len, embed_dim)

## Input

In [5]:
import torch

query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")

## Standard Attention

In [10]:
import torch
import math

def standard_attention(query, key, value, mask=None, dropout=0.0):
    trg_len, src_len = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(key.size(-1)) 
    attn_bias = torch.zeros(trg_len, src_len, dtype=query.dtype)
    
    # masking
    if mask is not None:
        if mask.dtype == torch.bool:
            attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
        else:
            attn_bias += mask

    # operation
    attn_score = query @ key.transpose(-2, -1) * scale_factor # QK^T / scale
    attn_score += attn_bias # masking value
    attn_prob = torch.softmax(attn_score, dim=-1)
    attn_prob = torch.dropout(attn_prob, dropout, train=True)

    return attn_prob @ value

In [11]:
standard_attention(query, key, value)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

## Pytorch Attention

In [12]:
import torch
import math

def pytorch_attention(query, key, value, mask=None, dropout=0.0, is_causal=False, scale=None):
    trg_len, src_len = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(trg_len, src_len, dtype=query.dtype)

    # lower-triangle만 계산하여 upper-triangle을 예측
    if is_causal:
        assert mask is None
        # upper traingle masking
        tmp_mask = torch.ones(trg_len, src_len, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(tmp_mask.logical_not())

    # masking
    if mask is not None:
        if mask.dtype == torch.bool:
            attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
        else:
            attn_bias += mask
    
    # operation
    attn_score = query @ key.transpose(-2, -1) * scale_factor # QK^T / scale
    attn_score += attn_bias # masking value
    attn_prob = torch.softmax(attn_score, dim=-1)
    attn_prob = torch.dropout(attn_prob, dropout, train=True)

    return attn_prob @ value

In [14]:
with torch.backends.cuda.sdp_kernel(enable_math=False):
    pytorch_attention(query, key, value)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!