In [None]:
# Check if Cuda is available
!nvidia-smi

Sun Apr 13 17:48:13 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   32C    P0             50W /  400W |    1083MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import triton
import triton.language as tl

@triton.autotune(
    configs = [
        triton.Config({'BLOCK_SIZE': 128}),
        triton.Config({'BLOCK_SIZE': 256}),
        triton.Config({'BLOCK_SIZE': 512}),
    ],
    key = ['seq_len']
)
@triton.jit
def _flash_attention_forward(q_ptr, k_ptr, v_ptr, o_ptr, batch_size, num_heads,
                             seq_len, head_dim:tl.constexpr,
                             q_stride_b, q_stride_h, q_stride_s, q_stride_d,
                             k_stride_b, k_stride_h, k_stride_s, k_stride_d,
                             v_stride_b, v_stride_h, v_stride_s, v_stride_d,
                             o_stride_b, o_stride_h, o_stride_s, o_stride_d,
                             scale: float,
                             BLOCK_SIZE: tl.constexpr):
  """
  Compute attention matrix using tiling to avoid materializing the full matrix

  Args:
     q_ptr, k_ptr, v_ptr, o_ptr: pointers to query, key, value and output matrices
     batch_size: number of samples in each batch
     num_heads: number of attention heads
     seq_len: length of tokens in a sequence
     head_dim: number of features in each token
     q_stride_b: number of steps for skipping to the next batch
     q_stride_h: number of steps for skipping to the next attention head
     q_stride_s: number of steps for skipping to the next token in a sequence
     q_stride_d: numer of steps for skipping to th next feature
     BLOCK_SIZE: a subset of attention matrix, also known as tile
     scale: scaling factor for attention scores(typically 1/sqrt(head_dim))
  """
  print(f"Compiled with BLOCK_SIZE = {BLOCK_SIZE}")
  # Execution grid is defined as (batch_size, num_heads, seq_len)
  # Each program instance computes one position in this 3D space
  batch_id = tl.program_id(0)
  head_id = tl.program_id(1)
  seq_id = tl.program_id(2)

  # Compute query offset as each program handles one row of data(one query vector)
  q_offset = q_stride_b * batch_id + q_stride_h * head_id + q_stride_s * seq_id

  # Compute corresponding output offset
  o_offset = q_stride_b * batch_id + q_stride_h * head_id + q_stride_s * seq_id

  # Load a query vector at position (batch, head, seq)
  q = tl.load(q_ptr + q_offset + tl.arange(0, head_dim) * q_stride_d)

  # Max attention score so far
  m_i = float('-inf')

  # Softmax denominator
  d_i = 0.0

  # Output accumulator
  acc = tl.zeros([head_dim], dtype=tl.float32)

  # Load key and value vectors in blocks for efficiency
  for seq_offset in range(0, seq_len, BLOCK_SIZE):
    k_block_offset = k_stride_b * batch_id + k_stride_h * head_id + seq_offset * k_stride_s
    v_block_offset = v_stride_b * batch_id + v_stride_h * head_id + seq_offset * v_stride_s

    # Handle egde case: the last block might be smaller
    k_seq = tl.arange(0, BLOCK_SIZE)
    curr_block_size = min(BLOCK_SIZE, seq_len - seq_offset)
    mask = k_seq < curr_block_size

    # Creates a 2D index space (block_size, head_dim) for loading several key vectors at once
    k_indices = k_block_offset + k_seq[:, None] * k_stride_s + tl.arange(0, head_dim)[None, :] * k_stride_d

    # Load key vector block
    k_block = tl.load(k_ptr + k_indices, mask=mask[:, None], other=0.0)

    # Compute attention score Q * K^T
    s_ij = tl.sum(q[None, :] * k_block, axis=1)   # Shape: [block_size]
    s_ij = tl.where(mask, s_ij, float('-inf'))

    # Compute max of current block
    m_ij = tl.max(s_ij, axis=0)

    if seq_offset == 0:
      m_i_new = m_ij
    else:
      m_i_new = tl.maximum(m_i, m_ij)

    # Update softmax denominator
    alpha = tl.exp(m_i - m_i_new)

    # Precise exp calculation with careful masking
    p_ij = tl.exp(s_ij - m_ij)

    d_i_new = d_i * alpha + tl.sum(tl.where(mask, p_ij, 0.0), axis=0)

    # Load value vector block
    v_indices = v_block_offset + k_seq[:, None] * v_stride_s + tl.arange(0, head_dim)[None, :] * v_stride_d
    v_block = tl.load(v_ptr + v_indices, mask=mask[:, None], other=0.0)

    # Update acc before applying normalization
    p_ij_masked = tl.where(mask, p_ij, 0.0)[:, None]  # Add dimension for broadcasting
    acc_update = tl.sum(p_ij_masked * v_block, axis=0)
    acc = acc * alpha + acc_update

    # Update state variable for the next block
    m_i = m_i_new
    d_i = d_i_new

  # Normalize output by the sum of exp
  o = acc / (d_i + 1e-5)

  # Store final result to output tensor after all blocks are processed
  tl.store(o_ptr + o_offset + tl.arange(0, head_dim) * o_stride_d, o)



In [None]:
import torch

class FlashAttention(torch.nn.Module):
  """
    Flash attention module that implements memory efficient attention computation
  """
  def __init__(self, dropout=0.0) -> None:
    """
    Args:
      dropout: probability of dropping a fraction of attention scores or outputs
    """
    super().__init__()
    if dropout > 0.0:
      print("Warning: dropout is not implemented in this version!")

    self.dropout = dropout

  def forward(self, q, k, v, scale=None):
    """
      Compute attention score using flash attention algorithm. Assumes sequence length is the same
      for query, key and value vectors.

      Args:
        q(torch.Tensor): query tensor of shape (batch_size, num_heads, seq_len, head_dim)
        k(torch.Tensor): key tensor of shape (batch_size, num_heads, seq_len, head_dim)
        v(torch.Tensor): value tensor of shape (batch_size, num_heads, seq_len, head_dim)
        scale(float): scaling factor for numerical stability, default to 1/sqrt(head_dim)

      Returns:
        torch.Tensor: output tensor of shape (batch_size, num_heads, seq_len, head_dim)
    """
    # Verify shape consistency in q, k, v
    (batch_size, num_heads, seq_len, head_dim) = q.shape
    assert k.shape[:3] == (batch_size, num_heads, seq_len), \
      f"Key tensor shape {k.shape[:3]} doesn't match query tensor."
    assert v.shape[:3] == (batch_size, num_heads, seq_len), \
      f"Value tensor shape {v.shape[:3]} doesn't match query tensor."

    # Last dimension can be different
    assert k.shape[3] == v.shape[3] == head_dim, \
      f"Key, value head dimension {k.shape[3]}, {v.shape[3]} don't match query tensor."

    if scale == None:
      scale = 1.0 / (head_dim ** 0.5)

    # Ensure tensors are contiguous in memory for efficient access
    q = q.contiguous().to(torch.float32)
    k = k.contiguous().to(torch.float32)
    v = v.contiguous().to(torch.float32)

    q = q * scale

    # Initialize an output tensor
    o = torch.empty_like(q)

    # Get the strides of each tensor
    q_stride_b, q_stride_h, q_stride_s, q_stride_d = q.stride()
    k_stride_b, k_stride_h, k_stride_s, k_stride_d = k.stride()
    v_stride_b, v_stride_h, v_stride_s, v_stride_d = v.stride()
    o_stride_b, o_stride_h, o_stride_s, o_stride_d = o.stride()

    # Create a 3D execution grid of shape (batch_size, num_heads, seq_len)
    grid = (batch_size, num_heads, seq_len)

    # Launch the Triton kernel
    _flash_attention_forward[grid](q, k, v, o,
                             batch_size, num_heads, seq_len, head_dim,
                             q_stride_b, q_stride_h, q_stride_s, q_stride_d,
                             k_stride_b, k_stride_h, k_stride_s, k_stride_d,
                             v_stride_b, v_stride_h, v_stride_s, v_stride_d,
                             o_stride_b, o_stride_h, o_stride_s, o_stride_d,
                             scale)

    return o


In [None]:
def std_attention(q, k, v):
  """
    Compute attention matrix with standard full kernel materialization.
  """
  scale = 1.0 / (q.shape[3] ** 0.5)

  # Step 1: Q * K^T * scale
  qk = torch.einsum('bnsd, bntd->bnst', q, k) * scale

  # Step 2: apply softmax along the last dim
  attn = torch.softmax(qk, dim=-1)

  # Step 3: multiply by value vector
  out = torch.einsum('bnst, bntd->bnsd', attn, v)

  return out

In [None]:
def benchmark_attention(q, k, v, flash_attn, iterations=10):
    """Benchmark standard attention vs flash attention"""
    import time

    print("\n----- Performance Benchmark -----")

    # Benchmark standard attention
    torch.cuda.synchronize()  # Wait for all CUDA operations to finish
    start = time.time()
    for _ in range(iterations):  # Run multiple iterations for reliable timing
        scale = 1.0 / (q.shape[-1] ** 0.5)
        qk = torch.einsum('bnsd,bntd->bnst', q, k) * scale
        attn = torch.softmax(qk, dim=-1)
        std_output = torch.einsum('bnst,bntd->bnsd', attn, v)
    torch.cuda.synchronize()
    std_time = (time.time() - start) / iterations

    # Benchmark flash attention
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(iterations):
        flash_output = flash_attn(q, k, v)
    torch.cuda.synchronize()
    flash_time = (time.time() - start) / iterations

    print(f"Standard attention time: {std_time:.4f} s")
    print(f"Flash attention time: {flash_time:.4f} s")
    print(f"Speedup: {std_time / flash_time:.2f}x")

In [None]:
# Example usage and testing functions
def test_flash_attention():
  """
    Test if output from standard attention and flash attention match. CUDA device only.
  """
  # Test dimensions
  batch_size = 4
  num_heads = 8
  seq_len = 512
  head_dim = 32

  print(f"Testing with a batch size of {batch_size}, number of attention heads {num_heads}, sequence length {seq_len} and head dimension {head_dim}.")

  # Create random input tensors on GPU
  q = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda")
  k = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda")
  v = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda")

  # Compute output with standard attention algorithm
  print("Computing Standard Attention")
  std_output = std_attention(q, k, v)

  # Compute output with flash attention algorithm
  print("Computing Flash Attention")
  flash_attn = FlashAttention()
  flash_output = flash_attn(q, k, v)

  # Compare results from both algorithms
  max_diff = (flash_output - std_output).abs().max().item()
  print(f"Max difference between flash attention and standard attention: {max_diff}")

  if max_diff < 1e-4:
    print("Results match within error tolerance.")
  else:
    print("Results don't match. Check your algorithms!")

  # ----- Benchmark performance -----
  benchmark_attention(q, k, v, flash_attn)

test_flash_attention()

Testing with a batch size of 4, number of attention heads 8, sequence length 512 and head dimension 32.
Computing Standard Attention
Computing Flash Attention
Max difference between flash attention and standard attention: 1.4901161193847656e-06
Results match within error tolerance.

----- Performance Benchmark -----
Standard attention time: 0.0004 s
Flash attention time: 4.3755 s
Speedup: 0.00x
