<a href="https://colab.research.google.com/github/23silicon/FlashAttention/blob/main/FlashAttentionBenchmark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install ninja



In [2]:
import torch
import math
import sys
from torch.utils.cpp_extension import load_inline
!rm -rf /root/.cache/torch_extensions/

In [3]:
if not torch.cuda.is_available():
  print("Cuda not available")
  exit()

cuda_flash_attention = """
#include <cuda_runtime.h>
#include <cmath>

#define Br 32 //canonical Q tile height name in FlashAttention paper
#define Bc 32 //canonical K/V tiles height

__global__ void flash_attention(const float* Q, const float* K, const float* V, float* output,
                                int M, int N, int d, int Tr, int Tc, float scale) {
    int id = blockDim.x * blockIdx.x + threadIdx.x;
    extern __shared__ float sram[];
    //pointers to the beginning of each tile's allocated region in shared memory
    float* Qtile = sram; //rows 0-Br are for Q tile, height Br
    float* Ktile = &sram[Br * d]; //rows Br-(Br+Bc) are for K tile, height Bc
    float* Vtile = &sram[(Br + Bc) * d]; //rows (Br+Bc) to end are for Vtile, height Bc


    //loop 1: load tiles of Q into sram.
    // ***Each thread is responsible for loading 1 full row of Q into its respective tile
    if (d % 4 == 0) {
        for (int i = 0; i < (d >> 2); i++) {
            /*
            GPU doesn't retrieve a single float per query, instead it retrieves up to 32 bytes.
            Therefore, it's very inefficient to load 4 bytes at a time into sram.
            */
            float4* Q4 = (float4*)Q;
            float4* Q4tile = (float4*)Qtile;
            Q4tile[threadIdx.x * (d >> 2) + i] = Q4[id * (d >> 2) + i];
        }
    } else {
        for (int i = 0; i < d; i++) {
            Qtile[threadIdx.x * d + i] = Q[id * d + i];
        }
    }
    __syncthreads();


    //callocs to 0, fixed size typical max for d_model is 128 and explodes to registers
    float acc[128] = {0.0f}; //O matrix accumulator
    float l = 0.0f; //running denominator sum for softmax
    float m = -1e30f; //starts with a very low number, basically -infinity

    //ENTERING MAIN LOOP: here we stream tiles of K and V
    for (int i = 0; i < Tc; i++) {
        //load K/V
        if (d % 4 == 0) {
            float4* K4 = (float4*)K;
            float4* V4 = (float4*)V;
            float4* K4tile = (float4*)Ktile;
            float4* V4tile = (float4*)Vtile;
            for (int j = 0; j < d >> 2; j++) {
                int tilerow_KV = Bc * i + threadIdx.x;
                if (tilerow_KV < N) {
                    K4tile[threadIdx.x * (d >> 2) + j] = K4[tilerow_KV * (d >> 2) + j];
                    V4tile[threadIdx.x * (d >> 2) + j] = V4[tilerow_KV * (d >> 2) + j];
                } else {
                    K4tile[threadIdx.x * (d >> 2) + j] = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
                    V4tile[threadIdx.x * (d >> 2) + j] = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
                }
            }
        } else {
            for (int j = 0; j < d; j++) {
                int tilerow_KV = Bc * i + threadIdx.x;
                Ktile[threadIdx.x * d + j] = (tilerow_KV < N) ? K[tilerow_KV * d + j] : 0.0f;
                //V tile loaded by row as well to compute streamed dot product at the end
                Vtile[threadIdx.x * d + j] = (tilerow_KV < N) ? V[tilerow_KV * d + j] : 0.0f;
            }
        }
        __syncthreads();

        //Next step: compute dot product of this thread's corresponding Q row and every Ktile row (dim: Bc x d)
        float attention_scores[Bc];
        float blockmax = -1e30f;
        #pragma unroll //compiler hint to unroll cus Bc is known at compile time
        for (int row = 0; row < Bc; row++) {
            float global_K_row = i * Bc + row;
            if (global_K_row < N) {
                float sum = 0.0f;
                for (int col = 0; col < d; col++) {
                    sum += Qtile[threadIdx.x * d + col] * Ktile[row * d + col];
                }
                sum *= scale;
                blockmax = max(blockmax, sum);
                attention_scores[row] = sum;
            } else {
                //check for padded rows to set score to -inf instead of 0 because e^0 is 1, not 0.
                attention_scores[row] = -1e30f;
            }
        }

        //Most technically beefy part of this kernel: online safe softmax

        //part 1: Adjust attention scores by subtracting blockmax from each element, find running sum
        float blocksum = 0.0f;
        for (int j = 0; j < Bc; j++) {
            attention_scores[j] = __expf(attention_scores[j]-blockmax);
            blocksum += attention_scores[j];
        }

        //part 2: calculate new global max and scaling factors
        float newmax = max(m, blockmax);
        float scale_f1 = __expf(m-newmax);
        float scale_fb = __expf(blockmax - newmax);

        m = newmax;
        l = (l * scale_f1) + (blocksum * scale_fb); //apply scaling factors

        //part 3: adjust and update accumulator
        for (int col = 0; col < d; col++) {
            acc[col] *= scale_f1;
        }
        for (int j = 0; j < Bc; j++) {
            float scaled_p = attention_scores[j] * scale_fb;
            for (int col = 0; col < d; col++) {
                acc[col] += scaled_p * Vtile[j * d + col];
            }
        }

        __syncthreads();
    }

    //divide each element by l to complete the softmax and write to output
    if (id < M) { //final matrix is M x d, this block write the entire vector of length d to row @id
        float divL = 1.0f/l; //avoid repeated division
        if (d % 4 == 0) {
             float4* O4 = (float4*)output;
             for (int col = 0; col < d / 4; col++) {
                float4 res;
                res.x = acc[col * 4 + 0] * divL;
                res.y = acc[col * 4 + 1] * divL;
                res.z = acc[col * 4 + 2] * divL;
                res.w = acc[col * 4 + 3] * divL;
                O4[id * (d / 4) + col] = res;
             }
        } else {
            for (int col = 0; col < d; col++) {
                output[id * d + col] = acc[col] * divL;
            }
        }
    }
}

// Q, K, V, output are device pointers
extern "C" void solve_flash(const float* Q, const float* K, const float* V, float* output, int M, int N,
                      int d) {
    //Q is Mxd, K is Nxd, V is Nxd
    //QK^T is MxN
    //output is MxN
    int Tr = (M + Br - 1) / Br, Tc = (N + Bc - 1) / Bc; //tile counts
    dim3 threadsPerBlock (Br);
    dim3 blocksPerGrid (Tr);
    int sram_size ((Br * d + Bc * d + Bc * d) * sizeof(float)); //1 tile of Q, 1 tile of K, 1 tile of V
    float scale = 1.0f/sqrtf(d); //scaling factor multiplied by each element before softmax

    flash_attention<<<blocksPerGrid, threadsPerBlock, sram_size>>>(Q, K, V, output, M, N, d, Tr, Tc, scale);
}

"""





cuda_self_attention = """
#include <cuda_runtime.h>
#include <stdio.h>
#include <math.h>
#include <cfloat>

const int TILE_SIZE = 16;

__global__ void matrix_multiplication_kernel(const float* A, const float* B, float* C, int M, int d, int N, const float factor, const bool transposed) {
    const int nx = blockIdx.x * blockDim.x + threadIdx.x;
    const int ny = blockIdx.y * blockDim.y + threadIdx.y;
    if (nx < N && ny < M) {
        int idx = ny * N + nx;
        float sum = 0;
        if (transposed) {
            for (int i = 0; i < d; i++) {
                sum += A[ny * d + i] * B[nx * d + i];
            }
        } else {
            for (int i = 0; i < d; i++) {
                sum += A[ny * d + i] * B[i * N + nx];
            }
        }
        C[idx] = sum * factor;
    }
}

__device__ float atomicMaxFloat(float* address, float val) {
    int* address_as_int = (int*)address;
    int old = *address_as_int, assumed;

    do {
        assumed = old;
        old = atomicCAS(address_as_int, assumed, __float_as_int(fmaxf(val, __int_as_float(assumed))));
    } while (assumed != old);

    return __int_as_float(old);
}

__global__ void row_max_kernel(const float* input, float* row_max, const int M, const int N) {
    const int nx = blockIdx.x * blockDim.x + threadIdx.x;
    const int ny = blockIdx.y * blockDim.y + threadIdx.y;

    const int idx = ny * N + nx;
    const int tid = threadIdx.x;
    const int bid = threadIdx.y;
    __shared__ float sd[TILE_SIZE][TILE_SIZE];
    sd[bid][tid] = (nx < N && ny < M) ? input[idx] : FLT_MIN;
    __syncthreads();

    for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
        if (tid < offset && sd[bid][tid] < sd[bid][tid + offset]) {
            sd[bid][tid] = sd[bid][tid + offset];
        }
        __syncthreads();
    }

    if (tid == 0) {
        atomicMaxFloat(&row_max[ny], sd[bid][0]);
    }
}

__global__ void row_sum_kernel(const float* input, float* row_max, float* row_sum, const int M, const int N) {
    const int nx = blockIdx.x * blockDim.x + threadIdx.x;
    const int ny = blockIdx.y * blockDim.y + threadIdx.y;

    const int idx = ny * N + nx;
    const int tid = threadIdx.x;
    const int bid = threadIdx.y;
    __shared__ float sd[TILE_SIZE][TILE_SIZE];
    sd[bid][tid] = (nx < N && ny < M) ? expf(input[idx] - row_max[ny]) : 0.0;
    __syncthreads();

    for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
        if (tid < offset) {
            sd[bid][tid] += sd[bid][tid + offset];
        }
        __syncthreads();
    }

    if (tid == 0) {
        atomicAdd(&row_sum[ny], sd[bid][0]);
    }
}

__global__ void softmax_kernel(const float* input, float* output, float* row_max, float* row_sum, const int M, const int N) {
    const int nx = blockIdx.x * blockDim.x + threadIdx.x;
    const int ny = blockIdx.y * blockDim.y + threadIdx.y;
    if (nx < N && ny < M) {
        const int idx = ny * N + nx;
        output[idx] = expf(input[idx] - row_max[ny]) / row_sum[ny];
    }
}

void softmax(const float* input, float* output, int M, int N) {
    dim3 threadsPerBlock(TILE_SIZE, TILE_SIZE);
    dim3 blocksPerGrid((N + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (M + threadsPerBlock.y - 1) / threadsPerBlock.y);

    float *row_max;
    cudaMalloc((void **)&row_max, sizeof(float) * M);
    row_max_kernel<<<blocksPerGrid, threadsPerBlock>>>(input, row_max, M, N);

    float *row_sum;
    cudaMalloc((void **)&row_sum, sizeof(float) * M);
    cudaMemset(row_sum, 0, sizeof(float) * M);
    row_sum_kernel<<<blocksPerGrid, threadsPerBlock>>>(input, row_max, row_sum, M, N);


    softmax_kernel<<<blocksPerGrid, threadsPerBlock>>>(input, output, row_max, row_sum, M, N);
    cudaFree(row_max);
    cudaFree(row_sum);
}

// Q, K, V, output are device pointers
extern "C" void solve_naive(const float* Q, const float* K, const float* V, float* output, int M, int N, int d) {
    const int SIZE = M * N;
    dim3 threadsPerBlock(TILE_SIZE, TILE_SIZE);

    float* qk;
    cudaMalloc((void **)&qk, sizeof(float) * SIZE);
    dim3 blocksPerGrid((N + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (M + threadsPerBlock.y - 1) / threadsPerBlock.y);
    const float factor = 1 / sqrt((float)d);
    matrix_multiplication_kernel<<<blocksPerGrid, threadsPerBlock>>>(Q, K, qk, M, d, N, factor, true);

    float* softMaxQK;
    cudaMalloc((void **)&softMaxQK, sizeof(float) * SIZE);
    softmax(qk, softMaxQK, M, N);
    cudaFree(qk);

    dim3 blocksPerGrid2((d + threadsPerBlock.x - 1) / threadsPerBlock.x,
                       (M + threadsPerBlock.y - 1) / threadsPerBlock.y);
    matrix_multiplication_kernel<<<blocksPerGrid2, threadsPerBlock>>>(softMaxQK, V, output, M, N, d, 1.0, false);
    cudaFree(softMaxQK);
}

"""

cpp_source = """
extern "C" void solve_flash(const float* Q, const float* K, const float* V, float* output, int M, int N, int d);
extern "C" void solve_naive(const float* Q, const float* K, const float* V, float* output, int M, int N, int d);

torch::Tensor run_flash(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
    auto M = Q.size(0);
    auto N = K.size(0);
    auto d = Q.size(1);
    auto output = torch::empty_like(Q);
    solve_flash(Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
                output.data_ptr<float>(), M, N, d);
    return output;
}

torch::Tensor run_naive(torch::Tensor Q, torch::Tensor K, torch::Tensor V) {
    auto M = Q.size(0);
    auto N = K.size(0);
    auto d = Q.size(1);
    auto output = torch::empty_like(Q);
    solve_naive(Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),
                output.data_ptr<float>(), M, N, d);
    return output;
}
"""

print("Compiling Kernels...")
kernels = load_inline(
    name='comparison_kernels',
    cpp_sources=cpp_source,
    cuda_sources=cuda_flash_attention + "\n" + cuda_self_attention,
    functions=['run_flash', 'run_naive'],
    extra_cuda_cflags=['-O3', '--use_fast_math'],
    verbose=True
)
print("Compilation Finished!")

Compiling Kernels...
Compilation Finished!


In [4]:
def profile(func, args, name):
    # Warmup
    for _ in range(3): func(*args)
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(10):
        func(*args)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / 10

def pytorch_sdpa(Q, K, V):
    # SDPA expects 4D input: (Batch, Heads, Seq, Dim)
    # We fake Batch=1, Heads=1
    return torch.nn.functional.scaled_dot_product_attention(
        Q.view(1, 1, Q.size(0), Q.size(1)),
        K.view(1, 1, K.size(0), K.size(1)),
        V.view(1, 1, V.size(0), V.size(1))
    ).view(Q.size(0), Q.size(1))

def run_benchmark():
    d = 128
    print(f"\n{'N':<8} | {'Naive (ms)':<15} | {'Flash (ms)':<15} | {'PyTorch (ms)':<15} | {'Speedup'}")
    print("-" * 85)

    for N in [2<<14, 2<<15, 2<<16]: # 32k, 64k, 128k

        Q = torch.randn(N, d, device='cuda')
        K = torch.randn(N, d, device='cuda')
        V = torch.randn(N, d, device='cuda')

        # 1. Naive Benchmark (Isolated)
        t_naive = "OOM"
        try:
            # Only run Naive if N is reasonable to avoid hanging the colab
            if N <= 32768:
                t_naive = profile(kernels.run_naive, (Q,K,V), "Naive")
                t_naive_val = t_naive
            else:
                t_naive = "OOM (Skip)"
                t_naive_val = float('inf')
        except Exception as e:
            t_naive = "OOM"
            t_naive_val = float('inf')

        # 2. Flash Benchmark (Always Run)
        try:
            t_flash = profile(kernels.run_flash, (Q,K,V), "Flash")
        except Exception as e:
            t_flash = "Error"
            print(e)

        # 3. SDPA Benchmark (Always Run)
        try:
            t_sdpa = profile(pytorch_sdpa, (Q,K,V), "SDPA")
        except:
            t_sdpa = "Error"

        # Format speedup
        if isinstance(t_naive, str):
            speedup = "Inf"
        else:
            speedup = f"{t_naive/t_flash:.2f}x"

        tn_str = f"{t_naive:.4f}" if not isinstance(t_naive, str) else t_naive
        tf_str = f"{t_flash:.4f}" if not isinstance(t_flash, str) else t_flash
        ts_str = f"{t_sdpa:.4f}" if not isinstance(t_sdpa, str) else t_sdpa

        print(f"{N:<8} | {tn_str:<15} | {tf_str:<15} | {ts_str:<15} | {speedup}")

run_benchmark()


N        | Naive (ms)      | Flash (ms)      | PyTorch (ms)    | Speedup
-------------------------------------------------------------------------------------
32768    | 3743.8746       | 6541.5773       | 198.0648        | 0.57x
65536    | OOM (Skip)      | 26291.5719      | 981.7722        | Inf
131072   | OOM (Skip)      | 104102.6187     | 4297.5219       | Inf
