In [19]:
import numpy as np
import pandas as pd

def flash_attention_l1(Q, K, V, B_r, B_c, lam):
    """
    FlashAttention with L1‐distance scoring:
      S_ij[r,c] = -lam * sum_m |Q_i[r,m] - K_j[c,m]|
    
    Args:
        Q, K, V: np arrays of shape (N, d)
        B_r: row block size
        B_c: column block size
        lam: scaling factor for L1 distance

    Returns:
        O: output attention result of shape (N, d)
    """
    N, d = Q.shape
    scale=lam / np.sqrt(d)
    # Initialize outputs & softmax accumulators
    O = np.zeros((N, d), dtype=Q.dtype)
    m = np.full(N, -np.inf, dtype=Q.dtype)
    l = np.zeros(N, dtype=Q.dtype)

    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))

    for j in range(T_c):
        k0, k1 = j * B_c, min((j+1) * B_c, N)
        K_j = K[k0:k1]        # (B_c, d)
        V_j = V[k0:k1]        # (B_c, d)

        for i in range(T_r):
            q0, q1 = i * B_r, min((i+1) * B_r, N)
            Q_i = Q[q0:q1]    # (B_r, d)
            O_i = O[q0:q1]    # (B_r, d)
            m_i = m[q0:q1]    # (B_r,)
            l_i = l[q0:q1]    # (B_r,)

            # 1. Compute L1‐distance based scores (B_r, B_c)
            #    diff has shape (B_r, B_c, d)
            diff = np.abs(Q_i[:, None, :] - K_j[None, :, :])
            S_ij = -scale * np.sum(diff, axis=2)

            # 2. Row‐wise max for numerical stability
            m_tilde = S_ij.max(axis=1)        # (B_r,)

            # 3. Exponentiate shifted scores
            P_tilde = np.exp(S_ij - m_tilde[:, None])  # (B_r, B_c)

            # 4. Row‐sum of exponentials
            l_tilde = P_tilde.sum(axis=1)     # (B_r,)

            # 5. Update running max and sum
            m_new = np.maximum(m_i, m_tilde)
            l_new = (np.exp(m_i - m_new) * l_i) + (np.exp(m_tilde - m_new) * l_tilde)

            # 6. Accumulate partial outputs
            term1 = (np.exp(m_i - m_new) * l_i)[:, None] * O_i                  # (B_r, d)
            term2 = (np.exp(m_tilde - m_new)[:, None]) * (P_tilde @ V_j)        # (B_r, d)
            O_update = (term1 + term2) / l_new[:, None]                         # (B_r, d)

            # Write-back
            O[q0:q1] = O_update
            m[q0:q1] = m_new
            l[q0:q1] = l_new

    return O

# Example usage with 3x3 matrices
Q = np.array([[1, 0, 0],
              [0, 1, 0],
              [0, 0, 1]], dtype=float)

K = Q.copy()
V = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]], dtype=float)

# Choose block sizes
B_r, B_c = 1, 1

O_fa = flash_attention_l1(Q, K, V, B_r, B_c,1)
O_naive = np.exp(Q @ K.T)
O_naive = (O_naive / np.sum(O_naive, axis=1, keepdims=True)) @ V

# Display results
print("FlashAttention output:\n", O_fa)
print("\nNaive attention output:\n", O_naive)


FlashAttention output:
 [[2.74 3.74 4.74]
 [4.   5.   6.  ]
 [5.26 6.26 7.26]]

Naive attention output:
 [[2.907 3.907 4.907]
 [4.    5.    6.   ]
 [5.093 6.093 7.093]]


In [20]:
def pretty_print(arr: np.ndarray,
                 name: str = None,
                 precision: int = 3,
                 suppress: bool = True):
    """
    Nicely print any NumPy array:
      - 1-D → as a single-column DataFrame
      - 2-D → as a DataFrame
      - ND (N>2) → iterate slices on axis-0
    
    Args:
        arr:       the array to print
        name:      optional name/title
        precision: decimal places for floats
        suppress:  suppress small values (numpy.set_printoptions)
    """
    # configure numpy prints
    np.set_printoptions(precision=precision, suppress=suppress)
    
    if name is not None:
        print(f"\n=== {name} ===")
    
    ndim = arr.ndim
    if ndim == 0:
        # scalar
        print(arr.item())
    elif ndim == 1:
        # vector
        df = pd.DataFrame(arr, columns=["value"])
        print(df)
    elif ndim == 2:
        # matrix
        df = pd.DataFrame(arr)
        print(df)
    else:
        # ND: iterate over first axis
        for idx in range(arr.shape[0]):
            slice_name = f"{name+' ' if name else ''}slice[{idx}]"
            pretty_print(arr[idx], name=slice_name,
                         precision=precision, suppress=suppress)

FLASHATTENTION 2

In [None]:
import numpy as np
import math

def flash_attention2(Q: np.ndarray, K: np.ndarray, V: np.ndarray,
                     block_rows: int, block_cols: int,lam):
    """
    FlashAttention-2 forward pass in NumPy.
    
    Args:
        Q: Queries, shape (N, D)
        K: Keys,    shape (N, D)
        V: Values,  shape (N, D_v)
        block_rows: Block size for Q (B_r)
        block_cols: Block size for K, V (B_c)
    
    Returns:
        O: Output, shape (N, D_v)
        L: Log-sum-exp per query, shape (N,)
    """
    N, D = Q.shape
    _, D_v = V.shape
    
    # Number of tiles
    T_r = math.ceil(N / block_rows)
    T_c = math.ceil(N / block_cols)
    
    O = np.zeros((N, D_v), dtype=Q.dtype)
    L = np.zeros(N, dtype=Q.dtype)

    scale=lam / np.sqrt(D)
    
    for i in range(T_r):
        start_r = i * block_rows
        end_r = min((i + 1) * block_rows, N)
        Qi = Q[start_r:end_r]                     # (B_r_i, D)
        B_r_i = Qi.shape[0]
        
        # Initialize online-softmax accumulators
        m = np.full(B_r_i, -np.inf, dtype=Q.dtype)  # running max
        l = np.zeros(B_r_i, dtype=Q.dtype)          # running sum exp
        O_tilde = np.zeros((B_r_i, D_v), dtype=Q.dtype)
        print("Qi\n",Qi)

        for j in range(T_c):
            start_c = j * block_cols
            end_c = min((j + 1) * block_cols, N)
            Kj = K[start_c:end_c]                 # (B_c_j, D)
            Vj = V[start_c:end_c]                 # (B_c_j, D_v)

            # print("Kj\n",pd.DataFrame(Kj))
            # print("Vj\n",pd.DataFrame(Vj))
            pretty_print(Kj,"Kj")
            pretty_print(Vj,"Vj")

            # 1) Raw attention scores
            # S = Qi @ Kj.T                          # (B_r_i, B_c_j)
            diff = np.abs(Qi[:, None, :] - Kj[None, :, :])
            # print("Sub:\n",pd.DataFrame(Qi[:, None, :]),"\n-\n",pd.DataFrame(Kj[None, :, :]),"\nDiff:\n",pd.DataFrame(diff))
            print("Sub")
            pretty_print(Qi[:, None, :],"Q")
            print("-")
            pretty_print(Kj[None, :, :],"K")
            print("=")
            pretty_print(diff,"Diff")

            S = -scale * np.sum(diff, axis=2)
            # print("S\n",pd.DataFrame(S))
            pretty_print(S,"S")

            
            # 2) Update running max
            row_max = np.max(S, axis=1)            # (B_r_i,)
            new_m = np.maximum(m, row_max)
            pretty_print(m,"m")
            pretty_print(row_max,"row_max")
            pretty_print(new_m,"new_m")
            # 3) Compute shifted exp
            P = np.exp(S - new_m[:, None])         # (B_r_i, B_c_j)
            # print("P\n",pd.DataFrame(P))
            pretty_print(P,"P")
            
            # 4) Update running sum of exp
            l = np.exp(m - new_m) * l + np.sum(P, axis=1)
            pretty_print(l,"l")
            
            # 5) Accumulate unnormalized output
            O_tilde = (np.exp(m - new_m)[:, None] * O_tilde) + (P @ Vj)
            # print("O_tilde\n",pd.DataFrame(O_tilde))
            pretty_print(O_tilde,"O_tilde")
            
            # Commit new max
            m = new_m
            print("-"*90)
        
        # 6) Final normalization for this block
        O[start_r:end_r] = O_tilde / l[:, None]
        L[start_r:end_r] = m + np.log(l)
    
    return O


# Example usage with 3x3 matrices
Q = np.array([[1,0,0,2,0,1],
              [0,1,4,0,0,0],
              [0,0,1,0,0,0],
              [0,7,3,1,0,0],
              [0,4,0,0,1,0],
              [0,0,0,5,0,1]], dtype=float)

K = Q.copy()
V = np.array([[1, 2, 3, 4, 5, 6],
              [4, 5, 6, 7, 8, 9],
              [7, 8, 9,10,11,12],
              [10,11,12,13,14,15],
              [13,14,15,16,17,18],
              [16,17,18,19,20,21]], dtype=float)

# Choose block sizes
B_r, B_c = 3, 3

O_fa = flash_attention2(Q, K, V, B_r, B_c,1)

# Display results
print("FlashAttention 2 output:\n", O_fa)

Qi
 [[1. 0. 0. 2. 0. 1.]
 [0. 1. 4. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]]

=== Kj ===
     0    1    2    3    4    5
0  1.0  0.0  0.0  2.0  0.0  1.0
1  0.0  1.0  4.0  0.0  0.0  0.0
2  0.0  0.0  1.0  0.0  0.0  0.0

=== Vj ===
     0    1    2     3     4     5
0  1.0  2.0  3.0   4.0   5.0   6.0
1  4.0  5.0  6.0   7.0   8.0   9.0
2  7.0  8.0  9.0  10.0  11.0  12.0
Sub

=== Q ===

=== Q slice[0] ===
     0    1    2    3    4    5
0  1.0  0.0  0.0  2.0  0.0  1.0

=== Q slice[1] ===
     0    1    2    3    4    5
0  0.0  1.0  4.0  0.0  0.0  0.0

=== Q slice[2] ===
     0    1    2    3    4    5
0  0.0  0.0  1.0  0.0  0.0  0.0
-

=== K ===

=== K slice[0] ===
     0    1    2    3    4    5
0  1.0  0.0  0.0  2.0  0.0  1.0
1  0.0  1.0  4.0  0.0  0.0  0.0
2  0.0  0.0  1.0  0.0  0.0  0.0
=

=== Diff ===

=== Diff slice[0] ===
     0    1    2    3    4    5
0  0.0  0.0  0.0  0.0  0.0  0.0
1  1.0  1.0  4.0  2.0  0.0  1.0
2  1.0  0.0  1.0  2.0  0.0  1.0

=== Diff slice[1] ===
     0    1    2    3 