## 1. KV cache 推理优化
- 没有 KV cache：每一步生成，都需要重新计算整个生成序列的 KV 矩阵，导致大量重复计算，效率极低
- 使用 KV cache：
    1. 预填充（Prefill）：处理整个输入的 prompt，计算所有 token 的 K 和 V 的矩阵，将其缓存起来
    2. 解码（Decoding）：对于每个新生成的 token，只需要计算它自己的 K 和 V 向量
    3. 然后与缓存的矩阵拼接起来，形成完整 KV 矩阵
    4. 模型计算量与新生成 token 数量（通常是 1）有关，与序列总长度无关

In [None]:
# KV cache inference optimization
'''
KV cache steps:
1. Pre-fill: compute K, V for all tokens in the input sequence and store them.
2. Decoding: for each new token, compute its Q, and use the cached K, V to compute attention.
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model / num_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None): # past_key_value: previous kv cache
        '''
        Args:
            x (torch.Tensor): (batch, seq_len, d_model)
            - prefill: seq_len > 1
            - decoding: seq_len == 1
            mask: (d_model, d_model)
            past_key_value: K_cache, V_cache

        Returns:
            tuple: (output, present_key_value)
        '''
        batch_size, seq_len, _ = x.size()

        # 1. Linear Projection Q, K, V and reshape
        Q = self.W_Q(Q).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 2. KV cache
        if past_key_value is not None:
            past_key, past_value = past_key_value

            # K shape: (batch, num_heads, past_seq_len + seq_len, d_k) 在第二维拼接
            K = torch.cat([past_key, K], dim = 2)
            V = torch.cat([past_value, V], dim = 2)

        present_key_value = (K, V)

        # 3. calculate attn score
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weight = F.softmax(scores)

        output = torch.matmul(attn_weight, V)

        # calculate output
        output = output.transpose(1, 2).contiguous.view(batch_size, seq_len, self.d_model)
        output = self.W_O(output)

        return output, present_key_value

  cpu = _conversion_method_template(device=torch.device("cpu"))
