In [2]:
# part 1: 导入相关的 package
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

import math

torch.manual_seed(1024)

<torch._C.Generator at 0x7f895f3ad210>

## 2.GPT参数

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 512 #文本最大长度， max_seq
    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768 #hidden_dim, hidden_size
    hidden_dim: int=n_embd
    droupout: float = 0.1
    head_size: int = n_embd//n_head
    # vocab_size 
    # gpt2 官方tokenizer
    vocab_size: int = 50257 
    

## 3.GPT结构

In [4]:
#1. single head attention
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.key=nn.Linear(config.hidden_dim, config.head_size)
        self.query=nn.Linear(config.hidden_dim, config.head_size)
        self.value=nn.Linear(config.hidden_dim, config.head_size)
        
        # attention_mask 用register_buffer注册
        # 不用计算**梯度**，节省内存和显存，速度更快
        
        self.register_buffer(
            "attention_mask",
            torch.tril(
                torch.ones((config.block_size, config.block_size))
            )
        )
        
        self.dropout==nn.Dropout(config.droupout)
        
    def forward(self, x):
        batch_size, seq_len, hidden_dim=x.size()
        k=self.key(x)
        q=self.query(x)
        v=self.value(x)
        
        weight=q@k.transpose(-2,-1)
        weight=weight.masked_fill(
            self.attention_mask[:seq_len,:seq_len]==0,
            float("-inf")
        )
        weight=F.softmax(weight,dim=-1)/math.sqrt(self.head_size)
        
        #dropout 放weight之后
        weight=self.dropout(weight)

        return weight@v
        

In [None]:
#2. MultiheadAttention
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads=nn.ModuleList(
            [
                SingleHeadAttention(config)
                for _ in range(config.n_head)
            ]
        )
        self.proj =nn.Linear(config.hidden_dim, config.hidden_dim)
        self.dropout=nn.Dropout(config.droupout)

    def forward(self, x):
        batch_size, seq_len, hidden_dim=x.size()
        # [batch_size, seq_len, n_head, head_size]
        x=torch.cat(
            [h(x) for h in self.heads],
            dim=-1
        )
        output=self.proj(x) 
        output=self.dropout(output)
        return output
    