## Model

In [4]:
import math
import struct
import inspect
import time

from model.LLMconfig import LLMConfig
from typing import Any, Optional, Tuple, List
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

ImportError: cannot import name 'LLMConfig' from 'model.LLMconfig' (e:\xiangmu\GITHUB_PROJS\Zero2LLM\Zer02LLM\Zer0LLM_all\model\LLMconfig.py)

In [2]:
#1 均方根层归一化(Root Mean Square Layer Normalization, RMSNorm)
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return self.weight * x * torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)

#2 旋转位置编码(Rotary Position Embedding, RoPE)
def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0):
    # 计算旋转位置编码的频率
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    # 生成时间步长
    t = torch.arange(end, dtype=freqs.dtype)  # type: ignore
    # 计算频率的外积
    freqs = torch.outer(t, freqs)  # type: ignore
    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # type: ignore
    return pos_cis

def apply_rotary_emb(xq, xk, pos_cis):
    def unite_shape(pos_cis, x):
        ndim = x.ndim # 获取输入张量 x 的维度
        assert 0 <= 1 < ndim # 检查 pos_cis 的维度是否正确
        assert pos_cis.shape == (x.shape[1], x.shape[-1]) # 检查 pos_cis 的形状是否正确
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        # 将 pos_cis 和 x 的形状调整为相同的形状
        return pos_cis.view(*shape)
    
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    pos_cis = unite_shape(pos_cis, xq_)
    # pos_cis shape: [batch_size, seq_len, n_embd//2]
    # 保留实部
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    # xq_out shape: [batch_size, seq_len, n_embd]
    # xk_out shape: [batch_size, seq_len, n_embd]
    return xq_out.type_as(xq), xk_out.type_as(xk)
    


In [3]:
xq,  xk = torch.randn((2,  16,  4,  64)), torch.randn((2,  16,  4,  64)) # (batch_size,  sequence_length,  num_heads,  head_dim)
pos_cis = precompute_pos_cis(64,  16) # 计算旋转位置编码的旋转角（复数表示）
print(f'pos_cis 的形状为 {pos_cis.shape}, 其中 [0, 0] 下标元素为 {pos_cis[0,  0]}\n')

xq_rope,  xk_rope = apply_rotary_emb(xq,  xk,  pos_cis)
# original shape xq: torch.Size([2, 16, 4, 64])
# ajusted shape xq: torch.Size([2, 16, 4, 32])
# ajusted shape pos_cis： torch.Size([1, 16, 1, 32])
print(f'经过 RoPE 编码后的 Query 与 Key 的形状为 {xq_rope.shape},  {xk_rope.shape}\n')


pos_cis 的形状为 torch.Size([16, 32]), 其中 [0, 0] 下标元素为 (1+0j)

经过 RoPE 编码后的 Query 与 Key 的形状为 torch.Size([2, 16, 4, 64]),  torch.Size([2, 16, 4, 64])



In [23]:
# 3 Attention 层
# repeatkv
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, seqlen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
    )# 重复 Key 和 Value 的最后一个维度 n_rep 次

# attention
class Attention(nn.Module):
    def __init__(self,  args: LLMConfig):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        self.n_local_heads = args.n_heads
        self.n_local_kv_heads = args.n_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        # q k v o projection
        self.wq = nn.Linear(args.dim,  args.n_heads * self.head_dim,  bias=False)
        self.wk = nn.Linear(args.dim,  args.n_kv_heads * self.head_dim,  bias=False)
        self.wv = nn.Linear(args.dim,  args.n_kv_heads * self.head_dim,  bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim,  args.dim,  bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout
        self.flash = hasattr(torch.nn.functional,  'scaled_dot_product_attention') and args.flash_attn
        mask = torch.full((1,  1,  args.max_seq_len,  args.max_seq_len),  float("-inf"))
        mask = torch.triu(mask,  diagonal=1)
        self.register_buffer("mask",  mask,  persistent=False)

    def forward(self, 
               x: torch.Tensor, 
               pos_cis: torch.Tensor, 
               past_key_value: Optional[Tuple[torch.Tensor,  torch.Tensor]] = None, 
               use_cache=False):
        bsz,  seq_len,  _ = x.shape
        ############## Forward QKV & RoPE ##############
        xq,  xk,  xv = self.wq(x),  self.wk(x),  self.wv(x)
        xq = xq.view(bsz,  seq_len,  self.n_local_heads,  self.head_dim)
        xk = xk.view(bsz,  seq_len,  self.n_local_kv_heads,  self.head_dim)
        xv = xv.view(bsz,  seq_len,  self.n_local_kv_heads,  self.head_dim)
        xq,  xv = apply_rotary_emb(xq,  xk,  pos_cis)
        ################### KV Cache ###################
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0],  xk],  dim=1)
            xv = torch.cat([past_key_value[1],  xv],  dim=1)
        past_kv = (xk,  xv) if use_cache else None
        xq,  xk,  xv = (
            xq.transpose(1,  2), 
            repeat_kv(xk,  self.n_rep).transpose(1,  2), 
            repeat_kv(xv,  self.n_rep).transpose(1,  2)
        )
        ############ Scaled Dot Production #############
        if self.flash and seq_len != 1:
            dropout_p = self.dropout if self.training else 0.0
            output = F.scaled_dot_product_attention(
                xq,  xk,  xv, 
                attn_mask=None, 
                dropout_p=dropout_p, 
                is_causal=True
            )
        else:
            scores = (xq @ xk.transpose(-2,  -1)) / math.sqrt(self.head_dim)
            scores += self.mask[:,  :,  :seq_len,  :seq_len]
            scores = F.softmax(scores.float(),  dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = scores @ xv
        ################################################
        output = output.transpose(1,  2).reshape(bsz,  seq_len,  -1)
        output = self.resid_dropout(self.wo(output))
        return output,  past_kv





SyntaxError: incomplete input (4084081713.py, line 18)