In [34]:
import numpy as np

def flash_attention_tiled(Q, K, V, B_r, B_c,lam):
    """
    Implements the FlashAttention algorithm with tiling.
    
    Args:
        Q, K, V: numpy arrays of shape (N, d)
        B_r: row block size
        B_c: column block size
    
    Returns:
        O: output attention result of shape (N, d)
    """
    N, d = Q.shape
    # Initialize accumulators in HBM
    O = np.zeros((N, d), dtype=Q.dtype)
    m = np.full(N, -np.inf, dtype=Q.dtype)
    l = np.zeros(N, dtype=Q.dtype)
    
    # Number of blocks
    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))
    
    # Loop over K, V blocks
    for j in range(T_c):
        k_start = j * B_c
        k_end = min((j+1) * B_c, N)
        K_j = K[k_start:k_end]  # shape (B_c, d)
        V_j = V[k_start:k_end]  # shape (B_c, d)
        
        # Loop over Q blocks
        for i in range(T_r):
            q_start = i * B_r
            q_end = min((i+1) * B_r, N)
            Q_i = Q[q_start:q_end]       # shape (B_r, d)
            O_i = O[q_start:q_end]       # shape (B_r, d)
            m_i = m[q_start:q_end]       # shape (B_r,)
            l_i = l[q_start:q_end]       # shape (B_r,)
            
            # 1. Compute raw scores S_ij = Q_i K_j^T (shape B_r x B_c)
            # print("Q_i: ",Q_i,"\nK_j:\n",K_j.T)
            S_ij = Q_i @ K_j.T
            # S_ij = np.sum(np.abs(Q_i - K_j), axis=1, keepdims=True) 
            # print("\nS_ij\n",S_ij)
            
            # 2. Block-wise row max for numerical stability
            m_tilde = np.max(S_ij, axis=1)
            
            # 3. Exponentiate shifted scores
            P_tilde = np.exp(S_ij - m_tilde[0])
            
            # 4. Block-wise row sum of exponentials
            l_tilde = np.sum(P_tilde, axis=1)
            
            # 5. Update running max and sum
            m_new = np.maximum(m_i, m_tilde)
            l_new = np.exp(m_i - m_new) * l_i + np.exp(m_tilde - m_new) * l_tilde
            
            # 6. Accumulate partial outputs
            #    (diag(l_new)^{-1}) [ diag(l_i) e^{m_i-m_new} O_i + e^{m_tilde-m_new} P_tilde V_j ]
            term1 = (np.exp(m_i - m_new) * l_i)[:, None] * O_i
            term2 = np.exp(m_tilde - m_new)[:, None] * (P_tilde @ V_j)
            O_update = (term1 + term2) / l_new[:, None]
            
            # Write-back
            O[q_start:q_end] = O_update
            m[q_start:q_end] = m_new
            l[q_start:q_end] = l_new
    
    return O

# Example usage with 3x3 matrices
Q = np.array([[1, 0, 0],
              [0, 1, 0],
              [0, 0, 1]], dtype=float)

K = Q.copy()
V = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]], dtype=float)

# Choose block sizes
B_r, B_c = 1, 1

O_fa = flash_attention_tiled(Q, K, V, B_r, B_c,1)

# Display results
print("FlashAttention output:\n", O_fa)


FlashAttention output:
 [[2.90747402 3.90747402 4.90747402]
 [4.         5.         6.        ]
 [5.09252598 6.09252598 7.09252598]]


In [35]:
def attention(Q, K, V):
    d_k = Q.shape[-1]
    # scores = (Q @ K.T) / np.sqrt(d_k)            # [seq_q, seq_k]
    scores = (Q @ K.T)
    weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
    weights /= weights.sum(axis=-1, keepdims=True)  # softmax
    return weights @ V   

In [36]:
O_naive = attention(Q, K, V) 
print("\nNaive attention output:\n", O_naive)


Naive attention output:
 [[2.90747402 3.90747402 4.90747402]
 [4.         5.         6.        ]
 [5.09252598 6.09252598 7.09252598]]


FLASHATTENTION 2

In [37]:
import numpy as np
import math

def flash_attention2(Q: np.ndarray, K: np.ndarray, V: np.ndarray,
                     block_rows: int, block_cols: int):
    """
    FlashAttention-2 forward pass in NumPy.
    
    Args:
        Q: Queries, shape (N, D)
        K: Keys,    shape (N, D)
        V: Values,  shape (N, D_v)
        block_rows: Block size for Q (B_r)
        block_cols: Block size for K, V (B_c)
    
    Returns:
        O: Output, shape (N, D_v)
        L: Log-sum-exp per query, shape (N,)
    """
    N, D = Q.shape
    _, D_v = V.shape
    
    # Number of tiles
    T_r = math.ceil(N / block_rows)
    T_c = math.ceil(N / block_cols)
    
    O = np.zeros((N, D_v), dtype=Q.dtype)
    L = np.zeros(N, dtype=Q.dtype)
    
    for i in range(T_r):
        start_r = i * block_rows
        end_r = min((i + 1) * block_rows, N)
        Qi = Q[start_r:end_r]                     # (B_r_i, D)
        B_r_i = Qi.shape[0]
        
        # Initialize online-softmax accumulators
        m = np.full(B_r_i, -np.inf, dtype=Q.dtype)  # running max
        l = np.zeros(B_r_i, dtype=Q.dtype)          # running sum exp
        O_tilde = np.zeros((B_r_i, D_v), dtype=Q.dtype)
        
        for j in range(T_c):
            start_c = j * block_cols
            end_c = min((j + 1) * block_cols, N)
            Kj = K[start_c:end_c]                 # (B_c_j, D)
            Vj = V[start_c:end_c]                 # (B_c_j, D_v)
            
            # 1) Raw attention scores
            S = Qi @ Kj.T                          # (B_r_i, B_c_j)
            
            # 2) Update running max
            row_max = np.max(S, axis=1)            # (B_r_i,)
            new_m = np.maximum(m, row_max)
            
            # 3) Compute shifted exp
            P = np.exp(S - new_m[:, None])         # (B_r_i, B_c_j)
            
            # 4) Update running sum of exp
            l = np.exp(m - new_m) * l + np.sum(P, axis=1)
            
            # 5) Accumulate unnormalized output
            O_tilde = (np.exp(m - new_m)[:, None] * O_tilde) + (P @ Vj)
            
            # Commit new max
            m = new_m
        
        # 6) Final normalization for this block
        O[start_r:end_r] = O_tilde / l[:, None]
        L[start_r:end_r] = m + np.log(l)
    
    return O, L


# Example usage with 3x3 matrices
Q = np.array([[1, 0, 0],
              [0, 1, 0],
              [0, 0, 1]], dtype=float)

K = Q.copy()
V = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]], dtype=float)

# Choose block sizes
B_r, B_c = 1, 1

O_fa = flash_attention2(Q, K, V, B_r, B_c)

# Display results
print("FlashAttention 2 output:\n", O_fa)

FlashAttention 2 output:
 (array([[2.90747402, 3.90747402, 4.90747402],
       [4.        , 5.        , 6.        ],
       [5.09252598, 6.09252598, 7.09252598]]), array([1.55144471, 1.55144471, 1.55144471]))
