In [None]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from matplotlib import pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
def scaled_dot_product_attention(query, key, value, attn_mask=None, is_causal=False, dropout_p=0.0, scale=None):
    """
    Computes the scaled dot product attention between query, key, and value tensors in PyTorch eager mode.

    Args:
        query (torch.Tensor): The query tensor of shape (batch_size, n_heads, seq_len, hidden_dim).
        key (torch.Tensor): The key tensor of shape (batch_size, n_heads, seq_len, hidden_dim).
        value (torch.Tensor): The value tensor of shape (batch_size, n_heads, seq_len, hidden_dim).
        attn_mask (torch.Tensor, optional): The attention mask tensor of shape (batch_size, n_heads, seq_len, seq_len). Defaults to None.
        is_causal (bool, optional): Whether to apply a causal attention mask. Defaults to False.
        dropout_p (float, optional): The dropout probability. Defaults to 0.
        scale (float, optional): The scale factor for the dot product. Defaults to None.

    Returns:
        torch.Tensor: The output tensor of shape (batch_size, n_heads, seq_len, hidden_dim).
    """

    # Calculate the scale factor
    scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
    attn_weight = (query @ key.transpose(-2, -1) * scale_factor)
    
    # Create the attention mask
    attn_mask = torch.ones(query.shape[0], query.shape[1], query.shape[2], query.shape[2], dtype=torch.bool, device=device).tril(diagonal=0) if is_causal else attn_mask
    attn_weight = attn_weight.masked_fill_(~attn_mask, -torch.inf) if attn_mask is not None else attn_weight
      
    # Compute the scaled dot product attention
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=False)

    return attn_weight @ value

In [2]:
batch_size = 1
seq_len = 64
num_heads = 32
embed_dim = 128
dtype = torch.float16
device = torch.device("cuda")

query = torch.rand(batch_size, num_heads, seq_len, embed_dim, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, seq_len, embed_dim, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, seq_len, embed_dim, device=device, dtype=dtype)
eager = scaled_dot_product_attention(query, key, value, is_causal=True)
flash = F.scaled_dot_product_attention(query, key, value, is_causal=True)
assert torch.allclose(eager, flash, rtol=1e-03,atol=1e-03)

In [3]:
def bench_attention(seq_len, flash=False, num_repeats=256):
    """
    Measures the average time (in ms) required to compute multi-head attention for sequences of a given length.

    Args:
        seq_len (int): The length of the input sequence.
        flash (bool, optional): Whether to use the FlashAttention algorithm. Defaults to False.
        num_repeats (int, optional): The number of times to repeat the attention computation for timing purposes. Defaults to 256.

    Returns:
        float: The average time (in ms) required to compute multi-head attention for sequences of length seq_len.
    """
    
    if flash:
        mha = F.scaled_dot_product_attention
    else:
        mha = scaled_dot_product_attention
        
    query = torch.rand(batch_size, num_heads, seq_len, embed_dim, device=device, dtype=dtype)
    key = torch.rand(batch_size, num_heads, seq_len, embed_dim, device=device, dtype=dtype)
    value = torch.rand(batch_size, num_heads, seq_len, embed_dim, device=device, dtype=dtype)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(4):
        _ = mha(query, key, value, is_causal=True)

    start.record()
    for _ in range(num_repeats):
        _ = mha(query, key, value, is_causal=True)   
    end.record()
    torch.cuda.synchronize()
    
    return start.elapsed_time(end) / num_repeats

In [None]:
context_len = np.arange(256,4096,64)
flash = np.zeros(context_len.shape)
standard = np.zeros(context_len.shape)

for i,l in enumerate(tqdm(context_len)):
    flash[i] = bench_attention(l,flash=True)
    standard[i] = bench_attention(l,flash=False)

In [None]:
plt.plot(context_len, standard/flash)
plt.xlabel('Sequence length')
plt.ylabel('Speedup')
plt.title('FlashAttention vs. Standard Attention, head_size=128, n_heads=32, bs=1') 
plt.show()

In [6]:
def bench_llm(seq_len, model_name, max_new_tokens=128, flash=False, num_repeats=256):
    """
    Benchmark the end-to-end latency of a large language model (LLM) in Hugging Face, with FlashAttention enabled or disabled.

    Args:
        seq_len (int): Length of the input sequence.
        model_name (str): Name of the pre-trained LLM to use.
        max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 128.
        flash (bool, optional):
        Whether to use flash attention mechanism (if supported by the model).
        num_repeats (int, optional):
        Number of times to repeat the inference for averaging. Defaults to 256.

    Returns:
        float: The average end-to-end latency in seconds.
    """

    if flash:
        mech = "flash_attention_2"
    else:
        mech = "eager"

    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.bfloat16, 
        attn_implementation=mech,
        device_map="cuda",
    )
    
    token_ids = {
                    'input_ids': torch.randint(1, 10000, size=(1, seq_len), device='cuda'),
                    'attention_mask': torch.ones(1, seq_len, device='cuda')
    }
    
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    pad_token_id = model.config.eos_token_id
    
    # Warmup
    for _ in range(4):
        _ = model.generate(**token_ids, max_new_tokens=max_new_tokens, pad_token_id=pad_token_id)

    torch.cuda.synchronize()
    start.record()
    for _ in range(num_repeats):
        _ = model.generate(**token_ids, max_new_tokens=max_new_tokens, pad_token_id=pad_token_id) 
    end.record()
    torch.cuda.synchronize()
    
    return start.elapsed_time(end) / num_repeats

In [None]:
seq_lens = np.array([512, 1024, 2048])
model_names = ['mistralai/Mistral-7B-v0.1', 'NousResearch/Meta-Llama-3-8B', 'microsoft/phi-2']
prefill_flash = np.zeros((len(model_names),len(seq_lens)))
prefill_standard = np.zeros((len(model_names),len(seq_lens)))
for j, model_name in enumerate(tqdm(model_names)):
    for i, seq_len in enumerate(seq_lens):
        prefill_flash[j,i] = bench_llm(seq_len, model_name=model_name, max_new_tokens=1, flash=True)
        prefill_standard[j,i] = bench_llm(seq_len, model_name=model_name, max_new_tokens=1, flash=False)

In [None]:
speedup_prefill = prefill_standard/prefill_flash

models = ["Mistral-7B", "Llama3-8B", "Phi-2"]
avg_speedup = {
    '512': speedup_prefill.T[0],
    '1024': speedup_prefill.T[1],
    '2048': speedup_prefill.T[2],
}

x = np.arange(len(avg_speedup))  
width = 0.25
multiplier = 0

fig, ax = plt.subplots(layout='constrained')

for attribute, measurement in avg_speedup.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, measurement, width, label=attribute)
    ax.bar_label(rects, fmt='%.2f', padding=3)
    multiplier += 1

ax.legend(loc='upper left', ncols=1)
ax.set_xticks(x + width, models)
ax.set_ylabel('Speedup')
ax.set_title('Flash Attention vs Standard Attention Prefill Latency')
plt.savefig('benchmark-llm.png')
plt.show()