# GPT（124M）实现
参考教程：https://github.com/bbruceyuan/LLMs-Zero-to-Hero


# 1. 导入依赖

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from dataclasses import dataclass
import math

torch.manual_seed(42)

<torch._C.Generator at 0x23faff1aab0>

# 2. 定义GPT相关参数

In [15]:
@dataclass
class GPTConfig:
    block_size: int = 512   # 上下文长度
    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    hidden_dim:int = n_embd
    dropout: float = 0.1
    head_size: int = hidden_dim // n_head
    vocab_size: int = 50257    # GPT-2的词表大小

# 3. 定义GPT结构

In [24]:
# 1. single head attention
class SingleHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.key=nn.Linear(config.hidden_dim,config.head_size)
        self.value=nn.Linear(config.hidden_dim,config.head_size)
        self.query=nn.Linear(config.hidden_dim,config.head_size)
        self.head_size=config.head_size
        
        # attention mask 通过 register_buffer 注册在模型参数中，因为不用计算梯度，节约显存
        self.register_buffer(
            "attention_mask",
            # 创建一个下三角矩阵
            torch.tril(torch.ones(config.block_size,config.block_size))
        )

        self.dropout=nn.Dropout(config.dropout)
    
    def forward(self, x):
        batch_size,seq_len,hidden_dim=x.size()
        k=self.key(x)
        v=self.value(x)
        q=self.query(x)
        weight=q @ k.transpose(-2,-1)
        weight=weight.masked_fill(      # 填充下三角矩阵为负无穷
            self.attention_mask[:seq_len,:seq_len] == 0,
            float("-inf")
        ) / math.sqrt(self.head_size)       # 缩放
        weight=F.softmax(weight,dim=-1)
        weight=self.dropout(weight)     # dropout要在softmax之后
        out=weight @ v
        return out
    
# 2. multi head attention
class MultiHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                SingleHeadAttention(config) for _ in range(config.n_head)
            ]
        )
        self.proj=nn.Linear(config.hidden_dim,config.hidden_dim)    # 投影
        self.dropout=nn.Dropout(config.dropout)

    def forward(self,x):
        output=torch.cat(
            [h(x) for h in self.heads],
            dim=-1
        )
        output=self.proj(output)
        output=self.dropout(output)
        return output 
    

# 3. feed forward
class FeedForward(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(config.hidden_dim,4*config.hidden_dim),
            nn.GELU(),
            nn.Linear(4*config.hidden_dim,config.hidden_dim),
            nn.Dropout(config.dropout)
        )
    
    def forward(self,x):
        return self.net(x)
    
# 4. transformer block
class Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.att=MultiHeadAttention(config)
        self.ffn=FeedForward(config)
        self.ln1=nn.LayerNorm(config.hidden_dim)
        self.ln2=nn.LayerNorm(config.hidden_dim)
    
    def forward(self,x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x
    
# 5. GPT model
class GPT(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)  # 词嵌入
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)  # 位置嵌入
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        self.ln_final = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # 现在的SLM采用tie weights，embedding和lm_head的权重共享，这样可以减少参数，加快训练
        self.lm_head.weight = self.token_embedding_table.weight

        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # 这里使用正态分布初始化
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx是输入的token id，shape为(B, T)
        batch, seq_len = idx.size()
        token_emb = self.token_embedding_table(idx)

        # seq_len是输入的最大长度
        pos_emb = self.position_embedding_table(
            torch.arange(seq_len, device=idx.device)    # 这里是一个长度为seq_len的数组，用于表示位置
        )

        x = token_emb + pos_emb  # 这里是把token embedding和position embedding相加
        x = self.blocks(x)
        x = self.ln_final(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            batch, seq_len, vocab_size = logits.size()
            logits = logits.view(batch*seq_len, vocab_size)
            targets = targets.view(batch*seq_len)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # 生成新的token
        for _ in range(max_new_tokens):
            # 如果输入的token长度超过block_size，则截断
            idx_cond = idx if idx.size(1) < self.block_size else idx[:, -self.block_size:]
            # 获取预测的token
            logits, _ = self(idx_cond)
            # 只关注最后一个时间步的预测结果
            logits = logits[:, -1, :]
            # 应用softmax获取概率分布
            probs = F.softmax(logits, dim=-1)
            # 采样下一个token
            idx_next = torch.multinomial(probs, num_samples=1)  # 这里使用multinomial方法进行采样，返回的是一个索引
            # 将预测的token添加到输入中
            idx = torch.cat((idx,idx_next),dim=1)
        return idx

# 4. 数据集处理

- https://github.com/mobvoi/seq-monkey-data

- 序列猴子数据集的下载链接如下：

http://share.mobvoi.com:5000/sharing/O91blwPkY

- 序列猴子数据集以 JSONL 类型文件提供。文件的每一行都是格式统一的 JSON 类型的文本。其中，JSON 的格式为：
```text
{"text": "<文档>"}
```

In [25]:
# Dataset
class SeqMonkeyDataset(Dataset):
    def __init__(self, path, block_size=512):
        import tiktoken
        self.enc = tiktoken.get_encoding("gpt2")    # 获取GPT2的编码器
        self.block_size = block_size

        self.eos_token = self.enc.encode(
            "<|endoftext|>",
            allowed_special={"<|endoftext|>"}
        )[0]

        import json
        self.encoded_data = []
        self.max_lines = 1000   # 读取前1000行数据
        raw_data = []
        with open(path, 'r', encoding='utf-8') as f:
            for i,line in enumerate(f):
                if i==0:
                    print(json.loads(line))
                if i>=self.max_lines:
                    break
                try:
                    text = json.loads(line.strip())['text']
                    raw_data.append(text)
                except json.JSONDecodeError:
                    continue
                except Exception as e:
                    continue
        full_encoded = []
        for text in raw_data:
            encoded_text = self.enc.encode(text)
            full_encoded.extend(encoded_text + [self.eos_token])
        
        # 将长文本分割成训练样本
        for i in range(0, len(full_encoded), self.block_size):
            chunck = full_encoded[i:i+self.block_size+1]    # 这里加1是因为要取target
            # 如果chunck的长度不足block_size，则填充eos_token
            if len(chunck) < self.block_size+1:
                chunck = chunck + [self.eos_token]*(self.block_size+1-len(chunck))
            self.encoded_data.append(chunck)

    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        chunk = self.encoded_data[idx]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y
    
    def encode(self, text):
        """将文本编码为token id"""
        return self.enc.encode(text)
    
    def decode(self,ids):
        """将token id解码为文本"""
        return self.enc.decode(ids)

In [26]:
# train data
train_dataset = SeqMonkeyDataset('./data/mobvoi_seq_monkey_general_open_corpus.jsonl')

{'text': '在查处虚开增值税专用发票案件中，常常涉及进项留抵税额和税款损失的认定和处理。在计算税款损失时，要不要将进项留抵税额包括在内？\n对此，实务中存在意见分歧。\n有人主张归并，即计算税款损失时包括进项留抵税额；\n有人主张剥离，即计算税款损失时剔除进项留抵税额。分析这个问题，需要确定进项留抵税额与税款损失之间是什么关系。\n理清这二者之间的关系，首先需要了解增值税的概念和其抵扣机制。增值税是以商品（货物、服务等）在流转过程中产生的增值额作为计税依据而征收的一种流转税。为避免重复征税，在增值税中存在抵扣链条机制。\n一般而言，交易上游企业缴纳的税额，交易下游企业可以对相应的税额进行抵扣。\n对增值税一般纳税人来说，其购进货物、服务等取得增值税专用发票，发票上的税额是进项税额。\n其出售货物、服务等，向购买方开具增值税专用发票，发票的税额是销项税额。\n一般情况下，销项税额减去进项税额的金额是应纳税额，企业根据应纳税额按期申报纳税。\n其次需要了解进项留抵税额的概念及产生原因。\n在计算销项税额和进项税额的差额时，有时会出现负数，即当期进项税额大于当期销项税额。这个差额在当期未实现抵扣，为进项留抵税额，在以后纳税人有销项税额时再进行抵扣。\n企业产生进项留抵税额的主要原因是其进项税额和销项税额时间上的不一致。\n例如，企业前期集中采购货物和服务，投资大，销项税率低于进项税率等。\n从税款抵扣的角度看，进项留抵税额只是购进的这部分进项税额参与到增值税应纳税额的计算过程中，但是其对应的进项税额抵扣还未真正实现，一般要等到其未来有相应的销项税额时，才能真正实现进项税额抵扣。\n可见，进项留抵税额处于不确定状态，能否抵扣受到很多因素影响，例如企业经营中断，没有销项税额，这时进项留抵税额就无法实现抵扣。但如果企业按照税收政策规定申请进项留抵退税，进项税额抵扣就随之实现。\n最后需要了解税款损失的概念。\n税款损失，通常是指因虚开增值税专用发票，导致国家税款被骗或者流失的金额。关于税款损失，实务中有多种表述。\n例如，北京大学法学院教授陈兴良曾谈到虚开行为本身不会造成国家税款损失，只有利用发票抵扣时才会造成国家税款损失。刘兵等编著的《虚开增值税专用发票案例司法观点和案例解析》一书中提到：“给国家税款造成损失的数额，实际上就是被骗取的国家税款在侦查终结以前无法追回的部

In [27]:
# # 划分训练集和验证集
# train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])

# train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=12, shuffle=False)

# 划分训练集和验证集
total_len = len(train_dataset)
train_len = int(total_len * 0.9)
val_len = total_len - train_len
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=False)

# 5. 运行相关的函数

In [28]:
# 创建模型
model = GPT(GPTConfig)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# 打印一下模型共有多少参数
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params/1e6}M")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

Total number of parameters: 124.046592M


In [29]:
import os

n_epochs = 10
output_dir = "./checkpoints"

# # 训练循环
# def train(model, optimizer, scheduler, train_loader, val_loader, device):
#     model.train()   # 进入训练模式
#     total_loss = 0
#     for batch_idx, (x,y) in enumerate(train_loader):
#         # 将数据移动到设备上
#         x,y = x.to(device),y.to(device)
#         # 前向传播
#         logits, loss = model(x, targets=y)
#         # 反向传播
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         # 调整学习率
#         scheduler.step()
#         total_loss += loss.item()

#         if batch_idx % 100 == 0:
#             print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {total_loss/100:.3f}")
#     return total_loss

# def eval(model, val_loader, device):
#     model.eval()    # 设置模型为评估模式
#     val_loss = 0
#     with torch.no_grad():
#         for x,y in val_loader:
#             x,y = x.to(device), y.to(device)
#             logits, loss = model(x, targets=y)
#             val_loss += loss.item()
#     return val_loss

# for epoch in range(n_epochs):
#     train_loss = train(model, optimizer, scheduler, train_loader, val_loader, device)
#     val_loss = eval(model, val_loader, device)
#     print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")

#     # 保存模型
#     avg_val_loss = val_loss/len(val_loader)
#     checkpoint = {
#         "model_state_dict": model.state_dict(),
#         "optimizer_state_dict": optimizer.state_dict(),
#         "scheduler_state_dict": scheduler.state_dict(),
#         "epoch": epoch,
#         "val_loss": avg_val_loss
#     }
#     # 保存每个epoch的模型
#     torch.save(checkpoint, f"{output_dir}/model_epoch_{epoch}.pth")


def train_with_logging(model, optimizer, scheduler, train_loader, val_loader, device, n_epochs=10):
    """
    带日志记录的训练函数
    """
    train_losses = []
    val_losses = []
    
    for epoch in range(n_epochs):
        # 训练阶段
        model.train()
        total_loss = 0
        batch_count = 0
        
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            
            logits, loss = model(x, targets=y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            batch_count += 1
            
            if batch_idx % 100 == 0:
                print(f"Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.3f}")
        
        # 计算平均训练损失
        avg_train_loss = total_loss / batch_count
        train_losses.append(avg_train_loss)
        
        # 验证阶段
        model.eval()
        val_loss = 0
        val_batch_count = 0
        
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, loss = model(x, targets=y)
                val_loss += loss.item()
                val_batch_count += 1
        
        avg_val_loss = val_loss / val_batch_count
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # 保存模型检查点
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss
        }
        
        os.makedirs(output_dir, exist_ok=True)
        torch.save(checkpoint, f"{output_dir}/model_epoch_{epoch}.pth")
    
    return train_losses, val_losses


# 6. 可视化

In [30]:
import matplotlib.pyplot as plt
import numpy as np

def plot_loss_curves(train_losses, val_losses, epochs=None, save_path=None):
    """
    绘制训练和验证损失曲线
    
    Args:
        train_losses (list): 训练损失列表
        val_losses (list): 验证损失列表  
        epochs (int, optional): 总epoch数，如果不提供则根据列表长度自动计算
        save_path (str, optional): 保存图片的路径
    """
    if epochs is None:
        epochs = len(train_losses)
    
    # 创建x轴数据
    x_epochs = range(1, epochs + 1)
    
    # 设置中文字体支持
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号
    
    # 创建图形
    plt.figure(figsize=(10, 6))
    
    # 绘制损失曲线
    plt.plot(x_epochs, train_losses, 'b-', linewidth=2, label='训练损失', marker='o')
    plt.plot(x_epochs, val_losses, 'r-', linewidth=2, label='验证损失', marker='s')
    
    # 设置图表属性
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('训练和验证损失曲线')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 添加最小值标注
    min_train_idx = np.argmin(train_losses)
    min_val_idx = np.argmin(val_losses)
    
    plt.annotate(f'最小训练损失: {train_losses[min_train_idx]:.4f}', 
                xy=(min_train_idx + 1, train_losses[min_train_idx]),
                xytext=(min_train_idx + 1, train_losses[min_train_idx] + 0.1),
                arrowprops=dict(arrowstyle='->', color='blue'))
    
    plt.annotate(f'最小验证损失: {val_losses[min_val_idx]:.4f}',
                xy=(min_val_idx + 1, val_losses[min_val_idx]),
                xytext=(min_val_idx + 1, val_losses[min_val_idx] + 0.1),
                arrowprops=dict(arrowstyle='->', color='red'))
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图片（如果指定了路径）
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"损失曲线已保存到: {save_path}")
    
    # 显示图形
    plt.show()
    
    # 打印统计信息
    print(f"\n=== 损失统计信息 ===")
    print(f"训练损失 - 最小值: {min(train_losses):.4f}, 最终值: {train_losses[-1]:.4f}")
    print(f"验证损失 - 最小值: {min(val_losses):.4f}, 最终值: {val_losses[-1]:.4f}")
    print(f"最佳验证损失出现在第 {np.argmin(val_losses) + 1} 个epoch")


In [None]:
# 使用示例
if __name__ == "__main__":
    # 假设您已经有了训练好的损失数据
    train_losses, val_losses = train_with_logging(model, optimizer, scheduler, train_loader, val_loader, device)
    
    # 或者使用已有的损失数据绘图
    # 示例数据
    # sample_train_losses = [2.5, 2.1, 1.8, 1.6, 1.4, 1.3, 1.2, 1.1, 1.05, 1.0]
    # sample_val_losses = [2.4, 2.0, 1.7, 1.5, 1.4, 1.35, 1.3, 1.25, 1.2, 1.15]
    
    # 绘制曲线
    plot_loss_curves(train_losses, val_losses, save_path="./loss_curve.png")