In [4]:
from flash_attention_torch import flash_attn_torch_fn
from flash_attention_triton import flash_attn_triton_fn

In [7]:
import torch
import triton
import triton.language as tl
import triton.testing as tt
DEIVCE = 'cuda' if torch.cuda.is_available() else 'cpu'
# ====== Benchmarking Memory Bandwidth (GB/s)  =======
configs_memory = tt.Benchmark(
    x_names=["N"],
    x_vals=[2 ** i for i in range(5, 128)],
    line_arg="provider",
    line_vals=["triton", "torch"],
    line_names=["Triton", "Torch"],
    plot_name="flash-attention-bandwidth",
    args={},
)
@tt.perf_report(configs_memory)
def benchmark_fa_runtime(N, provider):
    device = DEIVCE
    dtype = torch.float32
    
    D = 128
    B = 1
    Q = torch.randn((B, N, D), device=device, dtype=dtype, requires_grad=True)
    K = torch.randn((B, N, D), device=device, dtype=dtype, requires_grad=True)
    V = torch.randn((B, N, D), device=device, dtype=dtype, requires_grad=True)


    ms_min, ms_med, ms_max = tt.do_bench(
        lambda: flash_attn_triton_fn(Q,K,V, True) if provider == "triton"
        else flash_attn_torch_fn(Q, K, V, True),
        quantiles=(0.0, 0.5, 1.0),
    )

    return ms_med, ms_min, ms_max

In [None]:
benchmark_fa_runtime.run(save_path='cs336_systems/triton_kernels/benchmarking', print_data=False)