# Flash attention

In [1]:
import numpy as np

# Standard attention (naive implementation) - for reference
def attention_naive(Q, K, V):
    # Q, K: shape (N, d), V: shape (N, d_v)
    N, d = Q.shape
    # Compute all pairwise scores (N x N matrix)
    scores = Q @ K.T / np.sqrt(d)
    # Apply softmax to each row to get attention weights
    weights = np.exp(scores)
    weights = weights / weights.sum(axis=1, keepdims=True)
    # Multiply weights by V to get outputs
    return weights @ V

# FlashAttention-style attention (tiled softmax computation)
def attention_flash(Q, K, V, block_size):
    N, d = Q.shape
    _, d_v = V.shape
    output = np.zeros((N, d_v))
    max_scores = np.full(N, -np.inf)
    # 1. First pass: find max score per query across all key blocks
    for j in range(0, N, block_size):
        scores_block = (Q @ K[j:j+block_size].T) / np.sqrt(d)
        max_scores = np.maximum(max_scores, scores_block.max(axis=1))
    # 2. Second pass: accumulate exp(scores) and output
    exp_sums = np.zeros(N)
    for j in range(0, N, block_size):
        scores_block = (Q @ K[j:j+block_size].T) / np.sqrt(d)
        # subtract max for numerical stability
        scores_block -= max_scores[:, None]
        exp_block = np.exp(scores_block)
        exp_sums += exp_block.sum(axis=1)
        output += exp_block @ V[j:j+block_size]
    # 3. Normalize output by total sum of exponentials
    output = (output.T / exp_sums).T
    return output
    
# --- Example usage and output ---
if __name__ == "__main__":
    np.random.seed(42)  # for reproducibility

    N = 8      # number of tokens
    d = 4      # dimension of Q/K
    d_v = 6    # dimension of V
    block_size = 4

    Q = np.random.rand(N, d)
    K = np.random.rand(N, d)
    V = np.random.rand(N, d_v)

    out_naive = attention_naive(Q, K, V)
    out_flash = attention_flash(Q, K, V, block_size)

    print("Naive Attention Output:")
    print(out_naive)

    print("\nFlashAttention Output:")
    print(out_flash)

    print("\nMax difference between outputs:")
    print(np.max(np.abs(out_naive - out_flash)))

Naive Attention Output:
[[0.55259814 0.41700637 0.25999533 0.4921267  0.46558592 0.51640672]
 [0.55160996 0.40042111 0.25835317 0.47554091 0.52187624 0.51087415]
 [0.56089783 0.40232874 0.27235195 0.47382213 0.48810883 0.50043211]
 [0.5564651  0.41181309 0.25998126 0.4760774  0.50894969 0.4904798 ]
 [0.54653764 0.41196281 0.25553554 0.48323978 0.49929823 0.51021444]
 [0.55621186 0.41134108 0.25624623 0.47961047 0.51108311 0.49819644]
 [0.55028725 0.40710339 0.26557556 0.47792453 0.48964641 0.50576986]
 [0.55683063 0.41829481 0.24900266 0.48628646 0.50820815 0.50049716]]

FlashAttention Output:
[[0.55259814 0.41700637 0.25999533 0.4921267  0.46558592 0.51640672]
 [0.55160996 0.40042111 0.25835317 0.47554091 0.52187624 0.51087415]
 [0.56089783 0.40232874 0.27235195 0.47382213 0.48810883 0.50043211]
 [0.5564651  0.41181309 0.25998126 0.4760774  0.50894969 0.4904798 ]
 [0.54653764 0.41196281 0.25553554 0.48323978 0.49929823 0.51021444]
 [0.55621186 0.41134108 0.25624623 0.47961047 0.511083