In [2]:
import torch
import numpy as np 
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F
from transformers import PretrainedConfig


In [7]:

class LMConfig(PretrainedConfig):
    def __init__(
        self,
        
        n_layers: int =8, # 多少层

        max_seq_len: int = 512,#最大seq长度
        dim:int = 512, #embedding 维度  atten输入总维度


        n_heads: int  = 16, # 16个头 #每个head dim为 dim//n_heads
        n_kv_heads: int = 8, # 8个 kv头 ，（这里也就是说有8个组，每两个Query head共享一个kv）
        
        hidden_dim: int=None,  # mlp中间的 hidden_dim

        norm_eps: float=1e-5, # RmsNorm
        dropout: float=0.0,
        
        flash_atten: bool=True, # 是否使用flash attention

        **kwargs,
    ):
        super().__init__(**kwargs)

        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.dim = dim
        self.n_heads =  n_heads
        self.n_kv_heads = n_kv_heads 
        self.hidden_dim=hidden_dim 
        self.norm_eps = norm_eps 
        self.dropout=dropout

        self.flash_atten = flash_atten 




In [None]:
#torch.outer?
torch.polar?

In [1]:
# 这里实现的很神奇  利用的是复数极坐标形式 实现的ROPE
def precompute_pos_cis(dim:int, end:int, base:float = 10000.0):
    theta = 1.0 / (base **(torch.arange(0,dim,2)[:dim//2].float / dim))
    idx = torch.arange(end, device=theta.device)
    idx_theta = torch.outer(idx, idx_theta).float() # seq_len, d//2
    pos_cis = torch.polar(torch.ones_like(idx_theta),idx_theta) # complex
    return pos_cis

                   

# 旋转位置编码   相对位置 qi,kj = (xi,xj,i-j)
def apply_rotary_emb(xq,xk,pos_cis):
    # xq  (bs,seq,n_local_heads,head_dim)
    def unite_shape(pos_cis, x):
        ndim = x.ndim
        assert 0<=1<ndim 
        assert pos_cis.shape == (x.shape[1],x.shape[-1])
        shape = [d if i==1 or i==ndim -1 else 1 for i,d in enumerate(x.shape)]
        return pos_cis.view(*shape)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2))
    pos_cis = unite_shape(pos_cis, xq_)
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)



In [25]:
xq = torch.ones((2,4))
print(xq)
torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
# torch.view_as_complex?

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]])


tensor([[1.+1.j, 1.+1.j],
        [1.+1.j, 1.+1.j]])

In [14]:
class Attention(nn.Module):
    def __init__(self, args:LMConfig):
        super().__init__()

        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_heads = args.n_heads 
        assert args.n_heads % args.n_kv_heads ==0 , "group error"

        self.n_local_heads = self.n_heads
        self.n_local_kv_heads = self.n_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads 
        self.head_dim = args.dim // args.n_heads #每个head dim为 dim//n_heads

        self.wq = nn.Linear(args.dim, self.n_heads * self.head_dim,bias = False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim,bias = False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim,bias = False)
        self.wo = nn.Linear(args.dim, args.dim,bias = False)

        self.k_cache, self.v_cache = None, None
        self.atten_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout) 
        self.dropout = args.dropout 
        #  Flash Attention requires PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_atten 

        mask = torch.full((1,1,args.max_seq_len,args.max_seq_len),float("-inf"))
        mask = torch.triu(mask,diagonal=1)
        # buffer中的tensor可以理解为模型的常数
        # 只有buffers() 和 parameters()中的属性可以被state_dict保存
        # persistent=False， 不需要保存到state_dict中去
        self.register_buffer("mask", mask, persistent=False)
        
    def forward(self,x, poc_cis, kv_cache=False):
        # x (bs,sq_len,dim)
        bsz, seqlen, _ = x.shape

        xq,xk,xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # 旋转位置编码  相对位置
        xq, xk = apply_rotary_emb(xq,xk,poc_cis)

        # kv_cache
        if kv_cache and self.eval():
            if seqlen == 1 and all(cache is not None for cache in (self.k_vavhe,self.v_cache)):
                xk = torch.cat(())
        




In [17]:
nn.Module.register_buffer?

[1;31mSignature:[0m
[0mnn[0m[1;33m.[0m[0mModule[0m[1;33m.[0m[0mregister_buffer[0m[1;33m([0m[1;33m
[0m    [0mself[0m[1;33m,[0m[1;33m
[0m    [0mname[0m[1;33m:[0m [0mstr[0m[1;33m,[0m[1;33m
[0m    [0mtensor[0m[1;33m:[0m [0mOptional[0m[1;33m[[0m[0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m][0m[1;33m,[0m[1;33m
[0m    [0mpersistent[0m[1;33m:[0m [0mbool[0m [1;33m=[0m [1;32mTrue[0m[1;33m,[0m[1;33m
[0m[1;33m)[0m [1;33m->[0m [1;32mNone[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m
Adds a buffer to the module.

This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is tha