In [1]:
# 导入必要的库
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

定义GPT参数

In [None]:
@dataclass
class GPTConfig:
    block_size: int = 1024#上下文长度
    vocab_size: int = 50257
    n_layer: int = 12#层数
    n_head: int = 12#头数
    n_embd: int = 768#嵌入维度
    head_size: int = n_embd//n_head#每个头的维度
    dropout: float = 0.1
    head_dim:int=n_embd
    mlp_dim:int=4*n_embd


定义GPT结构


![这是GPT2的结构](./image.png)


In [None]:
#multi head attention
class MultiHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.head_size=config.head_size
        self.dropout=nn.Dropout(config.dropout)
        self.qkv=nn.Linear(config.n_embd,config.head_dim*3)
        self.n_head = config.n_head # 头数
        self.register_buffer('attention_mask',torch.tril(torch.ones(config.block_size,config.block_size)).view(1,1,config.block_size,config.block_size)) 
        self.proj=nn.Linear(self.head_dim,self.head_dim)      
    def forward(self,x):
        B,T,C=x.size()#B:batch_size,T:sequence_length,C:channel_size or head_dim
        qkv=self.qkv(x)
        q,k,v=qkv.chunk(3,dim=-1)
        q=q.view(B,T,self.n_head,self.head_size).transpose(1,2)
        k=k.view(B,T,self.n_head,self.head_size).transpose(1,2)#B,n_head,T,head_size
        v=v.view(B,T,self.n_head,self.head_size).transpose(1,2)
        #注意力分数
        attn_scores=q@k.transpose(-2,-1)/(self.head_size**0.5)#B,n_head,T,T
        #注意力权重
        weights=attn_scores.masked_fill(self.attention_mask[:,:,:T,:T]==0,float('-inf'))
        weights=weights.softmax(dim=-1)#B,n_head,T,T
        weights=self.dropout(weights)#dropout需要放在softmax后面 
        #注意力输出
        attn_output=weights@v#B,n_head,T,head_size
        attn_output=attn_output.transpose(1,2).contiguous().view(B,T,self.head_dim)
        #过投影层
        attn_output=self.proj(attn_output)#B,T,head_dim
        attn_output=self.dropout(attn_output)


MLP层

In [None]:
class MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(config.head_dim,config.mlp_dim),
            nn.GELU(),
            nn.Linear(config.mlp_dim,config.head_dim),
            nn.Dropout(config.dropout)
        )
    def forward(self,x):
        return self.net(x)

block块

In [None]:
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.ln_1=nn.LayerNorm(config.head_dim)
        self.attn=MultiHeadAttention(config)
        self.ln_2=nn.LayerNorm(config.head_dim)
        self.mlp=MLP(config)
    def forward(self,x):
        attn_output=self.attn(self.ln_1(x))
        x=x+attn_output
        x=self.ln_2(x)
        mlp_output=self.mlp(x)
        x=x+mlp_output
        return x

IndentationError: expected an indented block (1450040370.py, line 1)

GPT整体实现

In [None]:
class GPT(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config
        self.token_embedding=nn.Embedding(config.vocab_size,config.n_embd)
        self.position_embedding=nn.Embedding(config.block_size,config.n_embd)
        self.dropout=nn.Dropout(config.dropout)
        self.blocks=nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f=nn.LayerNorm(config.n_embd)
        self.lm_head=nn.Linear(config.n_embd,config.vocab_size,bias=False)
        self.lm_head.weight=self.token_embedding.weight
    
    def _init_weights(self,module):
        if isinstance(module,nn.Linear):
            torch.nn.init.normal_(module.weight,mean=0.0,std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
            torch.nn.init.normal_(module.weight,mean=0.0,std=0.02)
    
    def forward(self,idx,targets=None):
        B,T=idx.size()
        #token_embedding
        token_embedding=self.token_embedding(idx)#B,T,n_embd
        #position_embedding
        position_embedding=self.position_embedding(torch.arange(T,device=idx.device))#T,n_embd
        #embedding
        x=self.dropout(token_embedding+position_embedding)#B,T,n_embd
        #blocks
        for block in self.blocks:
            x=block(x)#B,T,n_embd
        #ln_f
        x=self.ln_f(x)#B,T,n_embd
        #head
        logits=self.lm_head(x)#B,T,vocab_size
        #loss
        if targets is not None:
            loss=F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1))
        else:
            loss=None

        return logits,loss
