In [None]:
# attention

import torch
import torch.nn as nn

import torch.nn.functional as F

d_model = 6
d_k = 4
batch_size = 2
seq_len = 3
x = torch.rand(batch_size, seq_len, d_model)

print("input shape: ", x.shape)

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_k)

    def forward(self, input):   # [B, S, d_model]
        Q = self.W_q(input)     # [B, S, d_k]
        K = self.W_k(input)     # [B, S, d_k]
        V = self.W_v(input)     # [B, S, d_k]

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (Q.size(-1) ** 0.5)     # [B, S, S]
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)

        return output, weights

attention_layer = SelfAttention(d_model, d_k)

output, weights = attention_layer(x)

print(output)



print(weights)



input shape:  torch.Size([2, 3, 6])
tensor([[[ 0.3468,  0.3706, -0.2624,  0.2491],
         [ 0.3442,  0.3662, -0.2685,  0.2541],
         [ 0.3465,  0.3706, -0.2625,  0.2492]],

        [[ 0.4045,  0.2085, -0.3606,  0.2901],
         [ 0.4062,  0.2053, -0.3626,  0.2926],
         [ 0.4031,  0.2112, -0.3585,  0.2880]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[0.3342, 0.3263, 0.3395],
         [0.3281, 0.3408, 0.3310],
         [0.3327, 0.3262, 0.3411]],

        [[0.3364, 0.3301, 0.3335],
         [0.3289, 0.3408, 0.3303],
         [0.3396, 0.3216, 0.3388]]], grad_fn=<SoftmaxBackward0>)


In [None]:
# mask

import torch
import torch.nn as nn

import torch.nn.functional as F

d_model = 6
d_k = 4
batch_size = 2
seq_len = 3
x = torch.rand(batch_size, seq_len, d_model)

print("input shape: ", x.shape)

def causal_mask(seq_len):
    return torch.tril(torch.ones(seq_len, seq_len)).bool()  # [L, L]

mask = causal_mask(seq_len)  # [L, L]
print(mask)


class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_k)

    def forward(self, input, attn_mask=None):   # [B, S, d_model]
        Q = self.W_q(input)     # [B, S, d_k]
        K = self.W_k(input)     # [B, S, d_k]
        V = self.W_v(input)     # [B, S, d_k]

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (Q.size(-1) ** 0.5)     # [B, S, S]

        if attn_mask != None:
            scores = scores.masked_fill(attn_mask==0, float('-inf'))
            print("socres:", scores)

        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)

        return output, weights

attention_layer = SelfAttention(d_model, d_k)

output, weights = attention_layer(x, mask)

print("output: ", output)

print("weights: ", weights)



input shape:  torch.Size([2, 3, 6])
tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])
socres: tensor([[[0.0288,   -inf,   -inf],
         [0.3989, 0.3232,   -inf],
         [0.1370, 0.1270, 0.1868]],

        [[0.4134,   -inf,   -inf],
         [0.1390, 0.0844,   -inf],
         [0.1108, 0.0873, 0.0983]]], grad_fn=<MaskedFillBackward0>)
tensor([[[-0.2365, -0.4558, -0.2359, -0.0013],
         [-0.2038, -0.4263, -0.2420, -0.1248],
         [-0.1586, -0.4470, -0.2827, -0.1519]],

        [[-0.3216, -0.5624, -0.3853, -0.1773],
         [-0.2795, -0.5421, -0.4857, -0.2161],
         [-0.2835, -0.5046, -0.4398, -0.1511]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[1.0000, 0.0000, 0.0000],
         [0.5189, 0.4811, 0.0000],
         [0.3288, 0.3256, 0.3456]],

        [[1.0000, 0.0000, 0.0000],
         [0.5137, 0.4863, 0.0000],
         [0.3373, 0.3295, 0.3332]]], grad_fn=<SoftmaxBackward0>)


In [None]:
# Multi Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_head"
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input, attn_mask=None):
        B, S, E = input.size()
        Q = self.W_q(input).view(B, S, self.num_heads, self.d_k).transpose(1, 2) # [B, S, E] -> [B, S, n, d_k] -> [B, n, S, d_k]
        K = self.W_k(input).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(input).view(B, S, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2,-1)) / (self.d_k ** 0.5)    # [B, n, S, S]

        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask==0, float('-inf'))
        
        weights = F.softmax(scores, dim=-1)

        output = torch.matmul(weights, V)   # [B, n, S, d_k]

        output = output.transpose(1,2).contiguous().view(B, S, self.d_model) # [B, S, d_model]

        output = self.dropout(self.W_o(output))
    

        return output, weights


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
    
    def forward(self, input):
        output = self.ff(input)

        return output
    
class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_ff, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.attention_layer = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask):   # [B, S, E]
        attn_output, _ = self.attention_layer(x, attn_mask)   # [B, S, E]
        x = self.norm1(x + self.dropout_1(attn_output))

        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout_2(ff_output))

        return x



In [8]:
import torch
import torch.nn as nn

import torch.nn.functional as F

d_model = 6
num_heads = 2
d_k = 4
batch_size = 2
seq_len = 3
x = torch.rand(batch_size, seq_len, d_model)

print("input shape: ", x.shape)

def causal_mask(seq_len):
    return torch.tril(torch.ones(seq_len, seq_len)).bool()  # [L, L]

mask = causal_mask(seq_len)  # [L, L]
print(mask)

multi_attention_layer = MultiHeadAttention(d_model, num_heads)

output, weights = multi_attention_layer(x, mask)

print("output: ", output)

print("weights: ", weights)

input shape:  torch.Size([2, 3, 6])
tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])
output:  tensor([[[-0.1975, -0.0850, -0.1854,  0.1128,  0.1626,  0.0000],
         [-0.1770, -0.1456, -0.3501,  0.0945,  0.1552,  0.7728],
         [-0.1121, -0.1763, -0.4046,  0.1056,  0.1187,  0.8214]],

        [[-0.0216, -0.1186, -0.6042,  0.1005,  0.1399,  0.9678],
         [-0.0302, -0.1038, -0.4704,  0.1165,  0.1151,  0.8514],
         [-0.0285, -0.0904, -0.5128,  0.1219,  0.1011,  0.8995]]],
       grad_fn=<MulBackward0>)
weights:  tensor([[[[1.0000, 0.0000, 0.0000],
          [0.4953, 0.5047, 0.0000],
          [0.3346, 0.3354, 0.3300]],

         [[1.0000, 0.0000, 0.0000],
          [0.4939, 0.5061, 0.0000],
          [0.3271, 0.3307, 0.3422]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.4991, 0.5009, 0.0000],
          [0.3393, 0.3337, 0.3270]],

         [[1.0000, 0.0000, 0.0000],
          [0.4998, 0.5002, 0.0000],
          [0.3451, 0.3

In [17]:
import torch
import torch.nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model//n_heads

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads."

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, attn_mask=None):
        B, S, E = x.shape # [B, S, d_model]

        Q = self.W_q(x).view(B, S, self.n_heads, self.d_k).transpose(1,2)   # [B, S, d_model] -> [B, S, n_heads, d_k] -> [B, n_heads, S, d_k]
        K = self.W_k(x).view(B, S, self.n_heads, self.d_k).transpose(1,2)
        V = self.W_v(x).view(B, S, self.n_heads, self.d_k).transpose(1,2)

        scores = torch.matmul(Q, K.transpose(-2,-1)) / (self.d_k**0.5)  # [B, n_heads, S, d_k] * [B, n_heads, d_k, S] -> [B, n_heads, S, S]

        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))
            
        att_weights = F.softmax(scores, dim=-1)

        output = torch.matmul(att_weights, V)   # [B, n_heads, S, S] * [B, n_heads, S, d_k] -> [B, n_heads, S, d_k]
        output = output.transpose(1,2).contiguous() # [B, n_heads, S, d_k] -> [B, S, n_heads, d_k] contiguous的作用是让output再内存上连续，接下来可以用view
        output = output.view(B, S, E)   # [B, S, n_heads, d_k] -> [B, S, d_model]

        output = self.W_o(output)

        return output, weights
    
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.ff(x)
    

class TransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads

        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.ff = FeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_output, _ = self.self_attn(x, mask)
        x = self.norm1(x + self.dropout1(attn_output))

        ffn_output = self.ff(x)
        x = self.norm2(x + self.dropout2(ffn_output))

        return x, _

    

batch_size = 2
seq_len = 3
d_model = 4
n_heads = 2
d_ff = 8
x = torch.randn(batch_size, seq_len, d_model)   # 
print(x)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
print("causal mask: ", causal_mask)
transformer_layer = TransformerLayer(d_model=d_model, n_heads=n_heads, d_ff=d_ff)

output, attn_score = transformer_layer(x, causal_mask)

print("output: ", output)
print("attm_score: ", attn_score)


tensor([[[ 0.1341,  0.3455,  2.0616,  0.7853],
         [ 0.5331, -0.0438,  0.1966, -1.7723],
         [ 1.6760, -0.1478,  0.9186,  1.5295]],

        [[-0.0292,  0.2269,  1.4519, -0.1068],
         [-1.3315, -2.5104, -0.3037,  0.4920],
         [ 0.3821, -1.8754, -0.1821, -0.5776]]])
causal mask:  tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
output:  tensor([[[-0.2300, -0.8331,  1.6908, -0.6278],
         [ 1.0556, -0.3513,  0.7690, -1.4734],
         [ 1.0836, -1.5699,  0.5917, -0.1054]],

        [[ 0.2094, -0.5744,  1.5166, -1.1516],
         [ 0.0330, -1.6348,  0.9485,  0.6533],
         [ 1.1670, -1.3824,  0.7010, -0.4856]]],
       grad_fn=<NativeLayerNormBackward0>)
attm_score:  tensor([[[[1.0000, 0.0000, 0.0000],
          [0.5217, 0.4783, 0.0000],
          [0.3498, 0.3204, 0.3298]],

         [[1.0000, 0.0000, 0.0000],
          [0.4930, 0.5070, 0.0000],
          [0.3249, 0.3154, 0.3596]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.4969, 0.503

In [4]:
# MQA, GQA
import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_heads):   # kv头的数量
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.d_model = d_model
        self.d_k = d_model // n_heads
        self.n_rep = n_heads // n_kv_heads  

        self.W_q = nn.Linear(d_model, self.d_k * n_heads)
        self.W_k = nn.Linear(d_model, self.d_k * n_kv_heads)
        self.W_v = nn.Linear(d_model, self.d_k * n_kv_heads)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, S, E = x.shape
        Q = self.W_q(x).view(B, S, self.n_heads, self.d_k).transpose(1,2) # [B, S, d_model] -> [B, S, n_heads, d_k] -> [B, n_heads, S, d_k]
        K = self.W_k(x).view(B, S, self.n_kv_heads, self.d_k).transpose(1,2)  # [B, S, d_model] -> [B, S, n_kv_heads, d_k] -> [B, n_kv_heads, S, d_k]
        V = self.W_v(x).view(B, S, self.n_kv_heads, self.d_k).transpose(1,2)    # [B, S, d_model] -> [B, S, n_kv_heads, d_k] -> [B, n_kv_heads, S, d_k]

        # 复制KV头以匹配Q的头数
        if self.n_rep > 1:
            K = K.repeat_interleave(self.n_rep, dim = 1)    # [B, n_kv_heads, S, d_k] -> # [B, n_kv_heads * n_rep = n_heads, S, d_k]
            V = V.repeat_interleave(self.n_rep, dim = 1)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)   # [B, n_heads, S, S]

        if mask is not None:
            scores = scores.masked_fill(mask==0, float('-inf'))

        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)   # [B, n_heads, S, S] * [B, n_heads, S, d_k] -> [B, n_heads, S, d_k]

        output = output.transpose(1,2).contiguous().view(B, S, E)   # [B, S, d_model]

        output = self.W_o(output)
        
        return output, scores

batch_size = 2
seq_len = 3
d_model = 8
n_heads = 4
n_kv_heads = 2

x = torch.randn(batch_size, seq_len, d_model)   # 
print(x)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
print("causal mask: ", causal_mask)
transformer_layer = GroupedQueryAttention(d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads)

output, attn_score = transformer_layer(x, causal_mask)

print("output: ", output)
print("attm_score: ", attn_score)


tensor([[[ 0.2095,  1.9240,  0.3394,  1.5518,  2.2554,  2.3029, -0.2029,
          -0.5393],
         [ 1.9780, -0.5150, -2.1081,  1.3318,  0.4494,  0.2211, -2.2074,
           0.5719],
         [ 0.8145, -0.9975,  0.5430,  0.1208, -0.1950, -0.3852,  0.0622,
          -0.3854]],

        [[-0.3090, -1.9795,  0.6845,  1.5346, -1.5454,  0.8559,  1.2202,
          -1.2872],
         [ 0.8155,  0.3763,  1.3528, -1.0328,  0.1628, -0.3964,  1.3299,
          -0.7186],
         [-0.2249,  0.3653,  1.0429,  0.9632, -1.9671, -0.3240, -0.3628,
           0.4104]]])
causal mask:  tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
output:  tensor([[[ 7.9113e-02, -2.6621e-01,  2.1510e-01, -1.3914e-01,  1.0455e+00,
          -8.1278e-01,  7.2820e-02, -5.4629e-01],
         [ 3.5965e-01, -2.3883e-01,  3.4103e-01,  1.6718e-01,  3.9212e-01,
          -5.2973e-01, -2.8044e-02,  1.4039e-02],
         [ 2.4854e-01,  2.9665e-02,  1.7839e-01,  5.8713e-02,  3.9517e-01,
          -2.8310e-01, 

In [21]:
# 带KV cache的MHA
import torch
import torch.nn as nn
import torch.nn.functional as F

class CachedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        assert d_model % n_heads == 0

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, past_kv=None, mask=None):
        
        B, S, E = x.shape    # 推理时，通常B = 1, S = 1

        # 1. 投影新输入
        # Q, K, V shape: [Batch, n_heads, seq_len (通常为1), d_k]
        Q = self.W_q(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, S, self.n_heads, self.d_k).transpose(1, 2)

        if past_kv is not None:
            past_k, past_v = past_kv
            K = torch.cat([past_k, K], dim=2)
            V = torch.cat([past_v, V], dim=2)

        current_kv = (K, V)

        scores = torch.matmul(Q, K.transpose(-2,-1)) / (self.d_k ** 0.5)

        if mask is not None:
            # 在 Decoding 阶段，如果 Q_len 为 1，通常不需要 causal mask，
            # 因为它天生只能看到过去。但在 Prefill 阶段需要。
            # 这里为了通用性，简单处理：如果 mask 尺寸匹配就应用。
            if mask.shape[-1] == scores.shape[-1] and mask.shape[-2] == scores.shape[-2]:
                 scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)    

        output = torch.matmul(attn_weights, V)

        output = output.transpose(1, 2).contiguous().view(B, S, self.d_model)
        output = self.W_o(output)

        return output, current_kv
    
# --- 设置 ---
d_model = 4
n_heads = 2

model = CachedMultiHeadAttention(d_model, n_heads)
model.eval() # 推理模式

# 阶段 1: Prefill (处理输入的 Prompt)
prompt_seq_len = 3
input_prompt = torch.randn(1, prompt_seq_len, d_model) # [1, 3, 4]
causal_mask = torch.tril(torch.ones(prompt_seq_len, prompt_seq_len)).bool()

print("--- Phase 1: Prefill ---")

output, kv_cache = model(input_prompt, mask=causal_mask, past_kv=None)

print("Output:", output) # [1, 3, 4]
print("Cache K shape", kv_cache[0].shape) # [1, 2, 3, 2] (Batch, n_heads, Seq, d_k)


# 阶段 2: 想去check一下mask的逻辑有没有写错
print("\n--- Phase 2: Decoding Step 1 ---")
input_new_token = input_prompt[:, :1, :] # 取最后一个时间步, shape: [1, 1, 4]
print(input_new_token.shape)

output_1, kv_cache = model(input_new_token, mask=None, past_kv=None)

print("Output", output_1) # [1, 1, 4] -> 生成了 1 个新 token 的表示
print("Cache K shape:", kv_cache[0].shape) 


input_next_token = input_prompt[:, 1:2, :] # 假设它就是下一个输入, shape [1, 1, 4]
print(input_new_token.shape)
output_2, kv_cache = model(input_next_token, mask=None, past_kv=kv_cache)

print("Output", output_2) # [1, 1, 4] -> 生成了 1 个新 token 的表示
print("Cache K shape:", kv_cache[0].shape) 

--- Phase 1: Prefill ---
Output: tensor([[[ 0.0478,  0.1587,  0.3308, -0.2052],
         [ 0.2319,  0.0765,  0.2512, -0.1282],
         [ 0.2933,  0.0798,  0.3320, -0.1028]]], grad_fn=<ViewBackward0>)
Cache K shape torch.Size([1, 2, 3, 2])

--- Phase 2: Decoding Step 1 ---
torch.Size([1, 1, 4])
Output tensor([[[ 0.0478,  0.1587,  0.3308, -0.2052]]], grad_fn=<ViewBackward0>)
Cache K shape: torch.Size([1, 2, 1, 2])
torch.Size([1, 1, 4])
Output tensor([[[ 0.2319,  0.0765,  0.2512, -0.1282]]], grad_fn=<ViewBackward0>)
Cache K shape: torch.Size([1, 2, 2, 2])


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UnifiedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask = None, past_kv=None, use_cache=False):
        B, S_q, _ = x.shape # S_q 是 Query 的序列长度

        Q = self.W_q(x).view(B, S_q, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, S_q, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, S_q, self.n_heads, self.d_k).transpose(1, 2)

        if past_kv is not None:
            past_k, past_v = past_kv
            K = torch.cat([past_k, K], dim=2)
            V = torch.cat([past_v, V], dim=2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)

        # 4. Mask 处理 (自动适配两种模式)
        if mask is not None:
            # 我们需要确保 mask 的形状能和 scores 广播匹配
            # 训练时 scores 是 [B, H, S, S], mask 通常是 [S, S] -> 完美匹配
            # 推理时 scores 是 [B, H, 1, S_past+1], mask 通常为 None (或需要特殊处理 padding)
            # 这里做一个简单的防呆兼容
            if mask.shape[-2] == scores.shape[-2] and mask.shape[-1] == scores.shape[-1]:
                 scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        output = output.transpose(1, 2).contiguous().view(B, S_q, self.d_model)
        output = self.W_o(output)

        if use_cache:
            current_kv = (K, V)
            return output, current_kv
        else:
            return output, None


d_model = 4
n_heads = 2
seq_len = 3
model = UnifiedMultiHeadAttention(d_model, n_heads)


# 模式 A: 训练模式 (Training Mode)
# 特点：并行、无 Cache、必须有 Mask
print("--- Mode A: Training (Parallel) ---")
train_input = torch.randn(2, seq_len, d_model) # Batch=2
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()

# 调用方式：不传 past_kv，use_cache=False
train_output, _ = model(train_input, mask=causal_mask, use_cache=False)

print("Training output shape:", train_output.shape) # [2, 5, 8]


# 模式 B: 推理模式 (Inference Mode)
# 特点：串行、使用 Cache
print("\n--- Mode B: Inference (Sequential) ---")
model.eval() # 切换到评估模式 (虽然这里 MHA 没有 dropout/BN，但好习惯要有)

# 1. Prefill 阶段 (处理 Prompt "I like eating")
prompt = torch.randn(1, 3, d_model) # 假设 Prompt 长度为 3
prompt_mask = torch.tril(torch.ones(3, 3)).bool()

# 调用方式：开启 use_cache=True
output_prefill, kv_cache = model(prompt, mask=prompt_mask, use_cache=True)
print("Prefill output shape:", output_prefill.shape) # [1, 3, 8]
print("Initial KV Cache len:", kv_cache[0].shape[2]) # 应该是 3

# 2. Decoding 循环阶段
next_token = output_prefill[:, -1:, :] # 假装这是采样出的新 token

for i in range(3): # 再生成 3 个词
    print(f"Decoding step {i+1}...")
    # 调用方式：传入上一步的 next_token 和 kv_cache
    # 注意：这里 mask=None，因为 Q长度=1，它默认可以看所有过去的 K
    output_step, kv_cache = model(next_token, past_kv=kv_cache, use_cache=True, mask=None)
    
    # 假装采样
    next_token = output_step 
    print("KV cache shape: ", kv_cache[0].shape)
    
print("Final KV Cache len:", kv_cache[0].shape[2]) # 3(prefill) + 3(decoding) = 6

--- Mode A: Training (Parallel) ---
Training output shape: torch.Size([2, 3, 4])

--- Mode B: Inference (Sequential) ---
Prefill output shape: torch.Size([1, 3, 4])
Initial KV Cache len: 3
Decoding step 1...
KV cache shape:  torch.Size([1, 2, 4, 2])
Decoding step 2...
KV cache shape:  torch.Size([1, 2, 5, 2])
Decoding step 3...
KV cache shape:  torch.Size([1, 2, 6, 2])
Final KV Cache len: 6
