### 无kv-cache生成代码

In [None]:
''' 
相当于每次生成output的时候都要全量过一遍idx, 然后取最后一个的argmax作为idx
'''


import torch
import torch.nn.functional as F 
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM

config = LlamaConfig(
    vocab_size = 100,           # 词汇表大小
    hidden_size = 256,          # 隐藏层维度
    intermediate_size = 512,    # 前馈网络中间层维度
    num_hidden_layers = 2,      # Transformer 层数
    num_attention_heads = 4,    # 注意力头的数量
    num_key_value_heads = 4,    # 键/值对的头数量（用于 GQA）
)
model = LlamaForCausalLM(config)

X = torch.randint(0, 100, (1, 5))
print(X.shape)

idx = {}
idx['input_ids'] = X
for i in range(3):
    print(f"\nGeneration_{i} input_ids shape {idx['input_ids'].shape}:")
    print(f"input_ids {idx['input_ids']}:")
    output = model(**idx)
    print(f"logits shape {output['logits'].shape}")
    logits = output['logits'][:, -1, :]
    idx_next = torch.argmax(logits, dim=-1)
    print('idx_next: ', idx_next)
    # 注意这里是unsqueeze(1)，因为idx_next已经有一个[]了
    idx['input_ids'] = torch.cat((idx['input_ids'], idx_next.unsqueeze(1)), dim=-1)


torch.Size([1, 5])

Generation_0 input_ids shape torch.Size([1, 5]):
input_ids tensor([[12, 81, 92, 52, 65]]):
logits shape torch.Size([1, 5, 100])
idx_next:  tensor([66])

Generation_1 input_ids shape torch.Size([1, 6]):
input_ids tensor([[12, 81, 92, 52, 65, 66]]):
logits shape torch.Size([1, 6, 100])
idx_next:  tensor([66])

Generation_2 input_ids shape torch.Size([1, 7]):
input_ids tensor([[12, 81, 92, 52, 65, 66, 66]]):
logits shape torch.Size([1, 7, 100])
idx_next:  tensor([93])


In [None]:
import torch
import torch.nn.functional as F 
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM

D = 128       # single head dim
V = 64        # vocab size
B = 4         # batch size
L = 10        # seq len  

class kv_cache(torch.nn.Module):
    def __init__(self, D, V):
        super().__init__()
        self.D = D
        self.V = V
        self.embedding = torch.nn.Embedding(V, D)
        self.Wq = torch.nn.Linear(D, D)
        self.Wk = torch.nn.Linear(D, D)
        self.Wv = torch.nn.Linear(D, D)
        self.lm_head = torch.nn.Linear(D, V)
        self.cache_k = self.cache_v = None

    def forward(self, X):
        X = self.embedding(X)   # (B, L, V) -> (B, L, D)
        Q, K, V = self.Wq(X), self.Wk(X), self.Wv(X)   # (B, L, D)
        # 单层，每次增加 2 x batch_size x L x D x 4/2 bits的内存
        if self.cache_k is None:
            self.cache_k = K
            self.cache_v = V
        else:
            self.cache_k = torch.cat([self.cache_k, K], dim= 1)   # (B, L_cache + L, D)
            self.cache_v = torch.cat([self.cache_v, V], dim= 1)   # (B, L_cache + L, D)
            K = self.cache_k
            V = self.cache_v
        print('K shape: ', K.shape)
        print('V shape: ', V.shape)

        attn = Q @ K.transpose(-1, -2) / (self.D ** 0.5)    # (B, 1, L+L_cache)
        print('attn shape: ', attn.shape)
        output = attn @ V       # (B, 1, D)
        print(output.shape)
        output = self.lm_head(output)   # (B, L, V)
        return output

model = kv_cache(D, V)
X = torch.randint(0, V, (B, L))

for i in range(4):
    print(f"\nGeneration {i} step input_shape: {X.shape}：")
    output = model.forward(X) 
    next_token = torch.argmax(F.softmax(output, dim = -1),dim=-1)[:,-1]
    print('next_token.shape:', next_token.shape)
    X = next_token.unsqueeze(1)
        


Generation 0 step input_shape: torch.Size([4, 10])：
K shape:  torch.Size([4, 10, 128])
V shape:  torch.Size([4, 10, 128])
attn shape:  torch.Size([4, 10, 10])
torch.Size([4, 10, 128])
next_token.shape: torch.Size([4])

Generation 1 step input_shape: torch.Size([4, 1])：
K shape:  torch.Size([4, 11, 128])
V shape:  torch.Size([4, 11, 128])
attn shape:  torch.Size([4, 1, 11])
torch.Size([4, 1, 128])
next_token.shape: torch.Size([4])

Generation 2 step input_shape: torch.Size([4, 1])：
K shape:  torch.Size([4, 12, 128])
V shape:  torch.Size([4, 12, 128])
attn shape:  torch.Size([4, 1, 12])
torch.Size([4, 1, 128])
next_token.shape: torch.Size([4])

Generation 3 step input_shape: torch.Size([4, 1])：
K shape:  torch.Size([4, 13, 128])
V shape:  torch.Size([4, 13, 128])
attn shape:  torch.Size([4, 1, 13])
torch.Size([4, 1, 128])
next_token.shape: torch.Size([4])
