# Generation without kv_cache

![NO_KVCACHE](./image/without_kv_cache.png)

In [64]:
from rich import print
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
import time

config = LlamaConfig(
    vocab_size=100,
    hidden_size=256,
    intermediate_size=512,
    num_hidden_layers=2,
    num_attention_heads=4,
    num_key_value_heads=4,
)
model = LlamaForCausalLM(config)
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(100, 256)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=256, out_features=256, bias=False)
          (k_proj): Linear(in_features=256, out_features=256, bias=False)
          (v_proj): Linear(in_features=256, out_features=256, bias=False)
          (o_proj): Linear(in_features=256, out_features=256, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=256, out_features=512, bias=False)
          (up_proj): Linear(in_features=256, out_features=512, bias=False)
          (down_proj): Linear(in_features=512, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (rotary_emb): LlamaRotaryEmbedding()
 

In [65]:
X = torch.randint(0, 100, (1, 10))  # 创建数据、不使用tokenizer
print(X.shape)
print(X)

In [66]:
idx = {}
idx['input_ids'] = X
output = model(**idx)

for i in range(4):
    print(f"\nGeneration第{i}个时的输入{idx['input_ids'].shape}：")
    print(f"Generation第{i}个时的输入{idx['input_ids']}：")
    output = model(**idx)
    logits = output['logits'][:,-1,:]
    idx_next = torch.argmax(logits, dim=-1)
    time.sleep(1)
    
    idx['input_ids'] = torch.cat((idx['input_ids'], idx_next.unsqueeze(0)), dim=-1)

# Generation With KV-Cache

![KVCACHE](./image/with_kv_cache.png)

In [25]:
import torch
import torch.nn.functional as F

D = 128  # single-head-dim
V = 64  # vocab_size

class kv_cache(torch.nn.Module):
    def __init__(self, D, V):
        super().__init__()
        self.D = D
        self.V = V
        self.Embedding = torch.nn.Embedding(V, D)
        self.Wq = torch.nn.Linear(D, D)
        self.Wk = torch.nn.Linear(D, D)
        self.Wv = torch.nn.Linear(D, D)
        self.lm_head = torch.nn.Linear(D, V)
        self.cache_K = self.cache_V = None  # initial
    
    def forward(self, X):
        X = self.Embedding(X)
        Q, K, V = self.Wq(X), self.Wk(X), self.Wv(X)
        print(f"input_Q:{Q.shape}")
        print(f"input_K:{K.shape}")
        print(f"input_V:{V.shape}")
        
        # Easy KV_Cache
        if self.cache_K == None:
            self.cache_K = K
            self.cache_V = V
        else:
            self.cache_K = torch.cat((self.cache_K, K), dim=1)
            self.cache_V = torch.cat((self.cache_V, V), dim=1)
            K = self.cache_K
            V = self.cache_V
            
        print(f"cache_K:{self.cache_K.shape}")
        print(f"cache_V:{self.cache_K.shape}")
        
        # ignore proj/MLP/scaled/mask/multi-head
        attn = Q @ K.transpose(1,2) @ V
        
        # output
        output = self.lm_head(attn)
        return output
    

In [28]:
model = kv_cache(D, V) # 创建decode模型

# 创建数据、不使用tokenizer
X = torch.randint(0, 64, (1, 10))
print(X.shape)

for i in range(3):
    print(f"\nGeneration {i} step input_shape: {X.shape}：")
    output = model.forward(X)
    next_token = torch.argmax(F.softmax(output, dim=-1), -1)[:, -1]
    print(f'next_token预测:{next_token}')
    X = next_token.unsqueeze(0)

In [29]:
# GPT

In [30]:
import torch

In [31]:
batch_size = 1 # 一句话

x = torch.randn(batch_size, 4, 512)  # input :  batch_size, length, embd_dim
# “我，很，开，心”

y = torch.randn(batch_size, 4) # output: length
# “很，开，心， EOS”

print(x.shape)
print(y.shape)

In [37]:
# x = torch.randn(1, 64, 512)
q = torch.randn(512, 512)
k = torch.randn(512, 512)
v = torch.randn(512, 512)

mlp_up = torch.randn(512, 1024)
mlp_down = torch.randn(1024, 512)

Q, K, V = x @ q, x @ k, x @ v

In [38]:
mask = torch.tril(torch.ones(1, 4, 4))
print(mask)

In [39]:
attn = Q @ K.transpose(1, 2) @ V
# mask
attn.shape

torch.Size([1, 4, 512])

In [42]:
# 没有swiglu，mlp层升维再降维
mlp = attn @ mlp_up @ mlp_down
mlp.shape

torch.Size([1, 4, 512])

In [46]:
lm_head = torch.randn(512, 32000)
result = mlp @ lm_head
print(result.shape)
result

tensor([[[-6.1981e+09,  1.4613e+09,  1.2173e+09,  ...,  7.5367e+09,
           2.8965e+09,  2.9435e+09],
         [ 7.2909e+09,  6.5999e+09,  1.4065e+09,  ...,  1.5567e+10,
           1.5449e+10, -7.2056e+09],
         [ 1.5447e+09,  2.4404e+09, -9.7456e+08,  ..., -1.7298e+09,
           3.1604e+09, -8.9028e+08],
         [-4.4650e+09, -1.7091e+10,  6.4039e+09,  ..., -3.4217e+09,
          -2.6782e+10,  3.5392e+09]]])

In [57]:
pred = torch.argmax(F.softmax(result, dim=-1), -1)
pred[:, -1]

tensor([31996])

In [62]:
result[0,3,31996]

tensor(5.3672e+10)

In [82]:
probs = F.softmax(result, dim=-1)
loss = F.cross_entropy(probs.view(-1, probs.size(-1)), y.view(-1).long())
loss

tensor(10.3736)

In [83]:
y

tensor([[-0.2147, -0.5427,  1.4803,  0.8184]])

In [84]:
F.softmax(result, dim=-1)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])