# 重新默写GPT架构

In [1]:
import torch
from torch import  nn
from torch.nn import functional as F
import math

In [2]:
# 模型参数设置位置！
class ModelArgs:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.block_size = 128 # 窗口大小GPT2为1024
        self.batch_size = 32 # 暂定，之后再看显存占用
        self.n_layer = 3
        self.vocab_size = 7000
        self.n_head = 6
        self.n_embed = 768
        self.bias = False
        self.dropout = 0.0
        # self.dataset_path = './data/sherlock'
        # self.init_from = 'scratch'# 'scratch' or 'resume' # 从头训练还是继续
        # self.checkpoint_save_dir = ''
        self.eval_step = 50 # 每n步eval和保存checkpoint一次
        self.flash_attn = False
        # 学习率衰减
        self.learning_rate = 6e-4
        # self.warmup_iters = 2000
        # self.lr_decay_iters = 8000
        # self.min_lr = 6e-5
        # 优化器参数
        self.max_epochs = 10 # 训练多少个epoch
        # self.weight_decay = 1e-1
        # self.betas = (0.9,0.95)
        # self.grad_clip = 1.0 # 梯度裁剪
args= ModelArgs()

class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.qkv = nn.Linear(args.n_embed, 3*args.n_embed, bias=args.bias)
        self.dropout = args.dropout
        self.dropout_attn = nn.Dropout(args.dropout)
        self.n_embed = args.n_embed
        self.n_head = args.n_head
        assert self.n_embed % self.n_head == 0
        self.head_dim = self.n_embed // self.n_head
        self.flash_attn = args.flash_attn
        self.attn_proj = nn.Linear(args.n_embed, args.n_embed, bias=args.bias)
        
    def forward(self, x):
        B,T,C = x.shape
        q, k, v = self.qkv(x).split(self.n_embed, dim=-1)
        q = q.reshape(B, T, self.n_head, self.head_dim).permute(0,2,1,3)
        k = k.reshape(B, T, self.n_head, self.head_dim).permute(0,2,1,3)
        v = v.reshape(B, T, self.n_head, self.head_dim).permute(0,2,1,3)
        
        if self.flash_attn:
            attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
                                                  dropout_p=self.dropout if self.training else 0,
                                                  is_causal=True)
        else:
            score = q @ k.permute(0,1,3,2)
            score = score / (math.sqrt(self.head_dim))
            mask = torch.tril(torch.ones(T,T,device=x.device)).reshape(1,1,T,T) == 0
            score = score.masked_fill(mask, float('-inf'))
            score = F.softmax(score, dim=-1)
            if self.training:
                score = self.dropout_attn(score)
            attn = score @ v
        
        attn = attn.permute(0,2,1,3).reshape(B,T,C)
        
        return self.dropout_attn(self.attn_proj(attn))
    

class MLP(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.up_proj = nn.Linear(args.n_embed, 4*args.n_embed, bias=args.bias)
        self.down_proj = nn.Linear(4*args.n_embed, args.n_embed, bias=args.bias)
        self.dropout = nn.Dropout(args.dropout)
        self.relu = nn.ReLU()
    def forward(self, x):
        return self.dropout(self.down_proj(self.relu(self.up_proj(x))))

class Block(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.attn = Attention(args)
        self.mlp = MLP(args)
        self.norm =  nn.LayerNorm(args.n_embed)
    def forward(self, x):
        x = x + self.attn(self.norm(x))
        return x + self.mlp(self.norm(x))
    
class GPT(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.args = args
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(args.vocab_size, args.n_embed),
            wpe = nn.Embedding(args.block_size, args.n_embed),
            drop=nn.Dropout(args.dropout),
            h = nn.ModuleList([Block(args) for _ in range(args.n_layer)]),
            norm = nn.LayerNorm(args.n_embed)
        ))
        self.lm_head = nn.Linear(args.n_embed, args.vocab_size, bias=False)
        self.param_nums = 0
        # 初始化
        self.lm_head.weight = self.transformer.wte.weight
        self.apply(self._init_weights)
        
        for pname, p in self.named_parameters():
            self.param_nums += p.numel()
            if pname.endswith('attn_proj.weight'):
                torch.nn.init.normal_(p, mean=0, std=0.02/math.sqrt(1*args.n_layer))
        
    def _init_weights(self, module):
        if isinstance(module ,nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0,std=0.02)
            if module.bias is not None:
                 torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0,std=0.02)
            
    def forward(self, idx, target=None):
        B,T = idx.shape
        device = idx.device
        pos = torch.arange(0,T,dtype=torch.long,device=device)
        
        embed_wte = self.transformer.wte(idx)
        embed_wpe = self.transformer.wpe(pos)
        x = self.transformer.dropout(embed_wte + embed_wpe)
        
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.norm(x)
        
        if target is not None:
            logits = self.lm_head(x) # logits [B,T,vocab_size]
            loss = F.cross_entropy(logits.reshape(-1,logits.shape[-1]), target.reshape(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x) # logits [B,T,vocab_size]
            loss = None
        
        for name, p in model.named_parameters():
            print('*'*100)
            print(name, '/', p.shape)
            print(p.requires_grad ,'/', p.grad)
            
        return logits, loss
    
    
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx = idx if idx.shape[-1] < self.args.block_size else idx[:, -self.args.block_size:] # 截断取最靠后的idx
            logits, _ = self(idx) # logits[b,t,vocab_size]
            logits = logits[:, -1, :] / temperature # logits[b,vocab_size]
            
            if top_k is not None:
                v, _ = torch.topk(logits, k=top_k) # v [b,top_k]
                logits[logits < v[:, [-1]]] = float('-inf')
                
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_token], dim=-1)

In [3]:
model = GPT(args)

In [5]:
optim = torch.optim.AdamW(model.parameters())

In [12]:
i = 0
optim.param_groups

[{'params': [Parameter containing:
   tensor([[ 1.0698e-02, -4.9127e-03,  5.3695e-03,  ..., -2.4461e-02,
            -1.1465e-02, -1.2741e-02],
           [ 2.9380e-02, -3.7819e-02,  1.0553e-02,  ...,  2.6756e-03,
             5.0633e-03, -2.0953e-02],
           [-2.2765e-02, -3.0707e-02, -2.0424e-03,  ..., -1.5170e-02,
            -6.7438e-03,  1.9553e-02],
           ...,
           [ 3.1154e-02, -7.2931e-05, -5.7943e-03,  ..., -3.1537e-02,
            -2.2255e-02,  4.2320e-02],
           [ 7.2942e-03, -2.8758e-02,  1.8473e-02,  ..., -6.4733e-03,
             2.0547e-02, -2.9996e-02],
           [ 4.0273e-02,  4.3315e-02,  1.0132e-02,  ..., -5.0604e-03,
            -2.1345e-02, -1.4830e-02]], requires_grad=True),
   Parameter containing:
   tensor([[ 0.0440,  0.0336,  0.0088,  ...,  0.0151,  0.0057,  0.0012],
           [-0.0339, -0.0020,  0.0474,  ...,  0.0140, -0.0073,  0.0079],
           [ 0.0065,  0.0028, -0.0087,  ...,  0.0086,  0.0014, -0.0075],
           ...,
           [ 

In [34]:
for name,p in model.named_parameters():
    print('*'*100)
    print(name, '/', p.shape)
    print(p.requires_grad ,'/', p.grad)

****************************************************************************************************
transformer.wte.weight / torch.Size([7000, 768])
True / None
****************************************************************************************************
transformer.wpe.weight / torch.Size([128, 768])
True / None
****************************************************************************************************
transformer.h.0.attn.qkv.weight / torch.Size([2304, 768])
True / None
****************************************************************************************************
transformer.h.0.attn.attn_proj.weight / torch.Size([768, 768])
True / None
****************************************************************************************************
transformer.h.0.mlp.up_proj.weight / torch.Size([3072, 768])
True / None
****************************************************************************************************
transformer.h.0.mlp.down_proj.weight / torch.Size([768,