In [55]:
import torch
import math
from torch import nn
layer_num = 4
# config = "{'
# vocab_size': 128256, 
# 'max_position_embeddings': 8192, 
# 'hidden_size': 4096, 
# 'intermediate_size': 14336, 
# 'num_hidden_layers': 32, 
# 'num_attention_heads': 32, 
# 'num_key_value_heads': 8, 
# 'hidden_act': 'silu', 
# 'initializer_range': 0.02, 
# 'rms_norm_eps': 1e-05, 
# 'pretraining_tp': 1, 
# 'use_cache': True, 
# 'rope_theta': 500000.0, 
# 'rope_scaling': None, 
# 'attention_bias': False, 
# 'attention_dropout': 0.0, 

In [56]:
def repeatkv(x):
    # [4,8,125,32]->[4,8,4,125,32]->[4,32,125,32]
    shape = list(x.shape)
    shape[1] *= 4
    return x.unsqueeze(2).repeat(1,1,4,1,1).reshape(shape)
repeatkv(torch.randn(4,8,125,32)).shape

torch.Size([4, 32, 125, 32])

In [57]:
torch.arange(0, 32, 2)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30])

In [58]:
# 没看懂啊！！！
def llama_rotary_pos_embed(lens):
    inv_freq = torch.arange(0,32,2) / 32
    inv_freq = 1.0 / (50_0000.0 ** inv_freq)
    inv_freq = inv_freq.reshape(1, 16, 1)

    position_ids = torch.arange(lens).reshape(1,1,-1).float()
    freqs = inv_freq.matmul(position_ids).transpose(1,2)
    emb = torch.cat((freqs, freqs), 2)

    return emb.cos(), emb.sin()
cos, sin = llama_rotary_pos_embed(16)
cos.shape, sin.shape

(torch.Size([1, 16, 32]), torch.Size([1, 16, 32]))

In [59]:
def apply_RoPE(x, cos, sin):
    def rotate_half(x):
        left = x[...,:16]
        right = x[...,16:]
        return torch.cat((right, left), -1)
    
    cos = cos.unsqueeze(1)
    sin = sin.unsqueeze(1)

    x = (rotate_half(x)*sin) + (x*cos)
    return x


input = {
    'x': torch.randn(4, 32, 125, 32),
    'sin': torch.randn(1, 125, 32),
    'cos': torch.randn(1, 125, 32)
}
apply_RoPE(**input).shape

torch.Size([4, 32, 125, 32])

In [60]:
def get_mask(attention_mask):
    # attention_mask -> [4, 125]
    b, lens = attention_mask.shape

    min_value = -1e15
    casual_mask = torch.full((lens, lens), min_value).triu(diagonal=1) # [125,125]
    casual_mask = casual_mask.reshape(1, 1, lens, lens).repeat(b, 1, 1, 1)
    casual_mask = casual_mask.to(attention_mask.device)

    mask = attention_mask.reshape(b, 1, 1, lens) == 0
    casual_mask = casual_mask.masked_fill(mask, min_value)
    return casual_mask
get_mask(torch.ones(1, 5).long())

tensor([[[[ 0.0000e+00, -1.0000e+15, -1.0000e+15, -1.0000e+15, -1.0000e+15],
          [ 0.0000e+00,  0.0000e+00, -1.0000e+15, -1.0000e+15, -1.0000e+15],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+15, -1.0000e+15],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+15],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]])

In [61]:
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # [1, hidden_size]
        self.weight = nn.Parameter(torch.ones(hidden_size))
        print(self.weight.shape)
        self.eps = eps
    # x.shape = [4,125,1024]
    def forward(self, x):
        var = x.pow(2).mean(-1, keepdim=True)
        x = x * (var + self.eps).rsqrt()
        return self.weight * x

LlamaRMSNorm(1024)(torch.randn(4, 125, 1024)).shape

torch.Size([1024])


torch.Size([4, 125, 1024])

### MLP SiLU activation function , highlight the important info

In [62]:
class LlamaMLP(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.up_proj = nn.Linear(hidden_size, 4*hidden_size, bias=False)
        self.gate_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        self.down_proj = nn.Linear(4*hidden_size, hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        left = self.act_fn(self.gate_proj(x))
        right = self.up_proj(x)
        return self.down_proj(left * right)

LlamaMLP(1024)(torch.randn(4, 125, 1024)).shape

torch.Size([4, 125, 1024])

In [63]:
class LlamaAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Linear(1024, 1024, bias=False)
        self.k_proj = nn.Linear(1024, 256, bias=False) # 为什么是256，kv-cache query分为4组 每组共享kv权重，共8组
        self.v_proj = nn.Linear(1024, 256, bias=False)
        self.o_proj = nn.Linear(1024, 1024, bias=False)

    # hidden_states -> [4, 125, 1024]
    # attention_mask -> [4, 125]
    def forward(self, hidden_states, attention_mask):
        b, l, _ = hidden_states.shape

        q = self.q_proj(hidden_states).reshape(b, l, 32, 32).transpose(1, 2)
        k = self.k_proj(hidden_states).reshape(b, l, 8, 32).transpose(1, 2)
        v = self.v_proj(hidden_states).reshape(b, l, 8, 32).transpose(1, 2)
        
        # 加入位置编码
        cos, sin = llama_rotary_pos_embed(l)
        cos, sin = cos.to(hidden_states.device), sin.to(hidden_states.device)
        q = apply_RoPE(q, cos, sin)
        k = apply_RoPE(k, cos, sin)
        
        # 复制对其尺寸 # [b, 32, 125, 32]
        k = repeatkv(k) 
        v = repeatkv(v)
        # shape [4,32,125,125]
        attn = q.matmul(k.transpose(-2,-1)) / math.sqrt(32) # 除以sqrt（dk） 防止梯度爆炸

        # 加入attn_mask[4,1,125,125]
        attention_mask = get_mask(attention_mask)
        attn = (attn + attention_mask).softmax(3) # 为什么要做softmax(3)

        attn = attn.matmul(v) # [4,32,125,32]
        attn = attn.transpose(1,2).reshape(b,l,-1)
        attn = self.o_proj(attn)

        return attn
    
input = {
    'hidden_states' : torch.randn(4, 125, 1024),
    'attention_mask' : torch.rand(4, 125)
}
LlamaAttention()(**input).shape

torch.Size([4, 125, 1024])

In [64]:
class DecoderLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        def forward(self, x):
            return

In [65]:
class LlamaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        def forward(self, x):
            return