# ⚡ FlashAttention vs Standard Attention (with KV Cache & Inference Demo)

In [1]:
import torch
import torch.nn.functional as F
import numpy as np

In [2]:
seq_len = 4
dim = 8

torch.manual_seed(42)
Q = torch.randn(seq_len, dim)
K = torch.randn(seq_len, dim)
V = torch.randn(seq_len, dim)

## 🧠 Standard Attention

In [3]:
def standard_attention(Q, K, V):
    attn_scores = Q @ K.T
    attn_weights = F.softmax(attn_scores / np.sqrt(Q.size(1)), dim=-1)
    output = attn_weights @ V
    return output, attn_weights

standard_out, attn_weights = standard_attention(Q, K, V)
standard_out

tensor([[ 0.2667,  0.2371, -0.0554,  0.1298,  0.3541, -0.1906, -0.6448, -0.0085],
        [ 0.1086,  0.2444, -0.2164,  0.3814,  0.0631, -0.5633, -1.1007, -0.3306],
        [ 0.4947, -0.1095, -0.5350,  0.3420, -0.6224, -0.4772,  0.3223,  0.2335],
        [ 0.5705, -0.0146, -0.2775,  0.3238, -0.4680, -0.4084,  0.1981,  0.3296]])

## ⚡ FlashAttention (Row-wise fused computation)

In [4]:
def flash_attention(Q, K, V):
    L, d = Q.size()
    output = torch.zeros_like(Q)
    for i in range(L):
        q_i = Q[i]
        scores = (q_i @ K.T) / np.sqrt(d)
        scores = scores - scores.max()
        weights = torch.exp(scores)
        weights_sum = weights.sum()
        softmax_weights = weights / weights_sum
        output[i] = softmax_weights @ V
    return output

flash_out = flash_attention(Q, K, V)
flash_out

tensor([[ 0.2667,  0.2371, -0.0554,  0.1298,  0.3541, -0.1906, -0.6448, -0.0085],
        [ 0.1086,  0.2444, -0.2164,  0.3814,  0.0631, -0.5633, -1.1007, -0.3306],
        [ 0.4947, -0.1095, -0.5350,  0.3420, -0.6224, -0.4772,  0.3223,  0.2335],
        [ 0.5705, -0.0146, -0.2775,  0.3238, -0.4680, -0.4084,  0.1981,  0.3296]])

## ✅ Compare Results

In [5]:
diff = torch.abs(standard_out - flash_out).max()
print("Max difference:", diff.item())

Max difference: 3.5762786865234375e-07


## 🧠 KV 缓存机制模拟

In [6]:
cached_K = K[:3]
cached_V = V[:3]
new_q = Q[3].unsqueeze(0)

def flash_attention_with_kv_cache(q, cached_K, cached_V):
    scores = (q @ cached_K.T) / np.sqrt(q.size(-1))
    scores = scores - scores.max()
    weights = torch.softmax(scores, dim=-1)
    output = weights @ cached_V
    return output

kv_output = flash_attention_with_kv_cache(new_q, cached_K, cached_V)
kv_output

tensor([[ 0.4389,  0.3399,  0.3086,  0.1112,  0.4843, -0.0900, -0.6649,  0.2127]])

## 🔁 Decoder 推理循环示例（KV 缓存 + FlashAttention 风格）

In [7]:
seq_len = 5
dim = 8
torch.manual_seed(42)

Q_seq = torch.randn(seq_len, dim)
K_seq = torch.randn(seq_len, dim)
V_seq = torch.randn(seq_len, dim)

K_all, V_all = [], []
outputs = []

def flash_attention_step(q, K_cache, V_cache):
    K_tensor = torch.stack(K_cache)
    V_tensor = torch.stack(V_cache)
    scores = (q @ K_tensor.T) / np.sqrt(q.size(-1))
    scores = scores - scores.max()
    weights = torch.softmax(scores, dim=-1)
    output = weights @ V_tensor
    return output.squeeze(0)

for t in range(seq_len):
    q_t = Q_seq[t].unsqueeze(0)
    k_t = K_seq[t]
    v_t = V_seq[t]
    K_all.append(k_t)
    V_all.append(v_t)
    out_t = flash_attention_step(q_t, K_all, V_all)
    outputs.append(out_t)

decoder_outputs = torch.stack(outputs)
decoder_outputs

tensor([[-1.4570, -0.1023, -0.5992,  0.4771,  0.7262,  0.0912, -0.3891,  0.5279],
        [-0.3008,  0.1724, -0.0134,  0.7069,  1.0214,  0.2903,  0.4987,  0.4347],
        [ 1.0657,  0.7124, -1.0994, -0.6095,  0.1675,  1.1947,  0.4632,  0.5482],
        [ 0.0714,  0.0161, -1.0562,  0.3643,  0.3505, -0.1191,  0.0595,  0.5450],
        [ 0.1705,  0.0327, -0.6371,  0.5708,  0.3913,  0.0094,  0.5202,  0.4017]])