In [1]:
# Install flash-attn Package (About 20 Min)
!pip install flash-attn==1.0.9 --no-build-isolation -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone


In [2]:
import torch
!nvidia-smi
print("CUDA Usage :", torch.cuda.is_available())
print("GPU :", torch.cuda.get_device_name(0))

Wed Nov 26 15:49:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   39C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
import math
import json
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func

# ---------------------------
# FLOPs and Efficiency Calculation
# ---------------------------
def flops(batch_size, seq_len, head_dim, num_heads, causal, mode='fwd'):
    assert mode in ['fwd', 'bwd', 'fwd_bwd']
    f = 4 * batch_size * seq_len**2 * num_heads * head_dim // (2 if causal else 1)
    return f if mode == 'fwd' else (2.5 * f if mode == 'bwd' else 3.5 * f)

def efficiency(flop, time):
    return (flop / time / 10**12) if time > 0 else 0.0

# ---------------------------
# Custom Benchmark Function
# ---------------------------
def benchmark_fwd_bwd(func, qkv, cu_seqlens, dropout_p, causal, repeats=30):
    fwd_times, bwd_times = [], []
    for _ in range(repeats):
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        out = func(qkv, cu_seqlens, qkv.shape[0], dropout_p, causal=causal)
        end.record()
        torch.cuda.synchronize()
        fwd_times.append(start.elapsed_time(end) / 1000.0)

        grad = torch.randn_like(out)
        start.record()
        out.backward(grad, retain_graph=True)
        end.record()
        torch.cuda.synchronize()
        bwd_times.append(start.elapsed_time(end) / 1000.0)

    return fwd_times, bwd_times

# ---------------------------
# PyTorch baseline attention
# ---------------------------
def pytorch_attn_func(qkv, dropout_p=0.0, causal=True):
    batch_size, seq_len, _, num_heads, head_dim = qkv.shape
    q, k, v = qkv.unbind(dim=2)
    q = rearrange(q, 'b t h d -> (b h) t d')
    k = rearrange(k, 'b s h d -> (b h) d s')
    softmax_scale = 1.0 / math.sqrt(head_dim)

    scores = torch.empty(batch_size * num_heads, seq_len, seq_len, dtype=qkv.dtype, device=qkv.device)
    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
                       '(b h) t s -> b h t s', h=num_heads)
    if causal:
        causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
        scores = scores + causal_mask.to(dtype=scores.dtype)
    attention = torch.softmax(scores, dim=-1)
    attention_drop = F.dropout(attention, dropout_p)
    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    return output.to(dtype=qkv.dtype)

# ---------------------------
# Main Benchmark Function
# ---------------------------
def benchmark_attention(batch_size, seq_len, num_heads, emb_dim, impl, causal, repeats, output):
    assert impl in ['Pytorch', 'Flash1']
    device = 'cuda'
    dtype = torch.float16
    dropout_p = 0.0
    head_dim = emb_dim // num_heads

    qkv = torch.randn(
        batch_size, seq_len, 3, num_heads, head_dim,
        device=device, dtype=dtype, requires_grad=True
    )

    if impl == 'Flash1':
        total_len = batch_size * seq_len
        cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=device)
        qkv_unpad = rearrange(qkv, 'b s three h d -> (b s) three h d')
        attention_func = lambda x, cu, tot, dp, causal: flash_attn_unpadded_qkvpacked_func(
            x, cu, tot, dropout_p=dp, causal=causal
        )
        input_tensor = qkv_unpad
        cu_input = cu_seqlens
    else:
        attention_func = lambda x, cu, tot, dp, causal: pytorch_attn_func(x, dp, causal)
        input_tensor = qkv
        cu_input = None

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    fwd_times, bwd_times = benchmark_fwd_bwd(attention_func, input_tensor, cu_input, dropout_p, causal=causal, repeats=repeats)

    forward_time = sum(fwd_times) / len(fwd_times)
    backward_time = sum(bwd_times) / len(bwd_times)
    peak_memory_usage = torch.cuda.max_memory_allocated() / (1024**2)

    benchmark_result = {
        'forward': {
            'time(s)': forward_time,
            'FLOPS(TFLOPs/s)': efficiency(
                flops(batch_size, seq_len, head_dim, num_heads, causal, mode='fwd'),
                forward_time
            )
        },
        'backward': {
            'time(s)': backward_time,
            'FLOPS(TFLOPs/s)': efficiency(
                flops(batch_size, seq_len, head_dim, num_heads, causal, mode='bwd'),
                backward_time
            )
        },
        'forward_backward': {
            'time(s)': forward_time + backward_time,
            'FLOPS(TFLOPs/s)': efficiency(
                flops(batch_size, seq_len, head_dim, num_heads, causal, mode='fwd_bwd'),
                forward_time + backward_time
            )
        },
        'peak_memory_usage(MB)': peak_memory_usage,
    }

    with open(output, 'w') as json_file:
        json.dump(benchmark_result, json_file, indent=2)

    print(f"Benchmark completed. Results saved to {output}")
    print(json.dumps(benchmark_result, indent=2))


In [4]:
# Run benchmark for FlashAttention v1
benchmark_attention(
    batch_size=16,
    seq_len=512,
    num_heads=8,
    emb_dim=512,
    impl='Flash1',  # Choose between 'Pytorch' or 'Flash1'
    causal=True,
    repeats=30,
    output='flash1_benchmark.json'
)

Benchmark completed. Results saved to flash1_benchmark.json
{
  "forward": {
    "time(s)": 0.0019403253237406412,
    "FLOPS(TFLOPs/s)": 2.2135294754180603
  },
  "backward": {
    "time(s)": 0.007021610633532207,
    "FLOPS(TFLOPs/s)": 1.5291959068084304
  },
  "forward_backward": {
    "time(s)": 0.008961935957272849,
    "FLOPS(TFLOPs/s)": 1.6773591786047992
  },
  "peak_memory_usage(MB)": 112.0009765625
}
