In [2]:
# part 1: 导入相关的 package
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

import math

torch.manual_seed(1024)

<torch._C.Generator at 0x7f895f3ad210>

## 2.GPT参数

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 512 #文本最大长度， max_seq
    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768 #hidden_dim, hidden_size
    hidden_dim: int=n_embd
    droupout: float = 0.1
    head_size: int = n_embd//n_head
    # vocab_size 
    # gpt2 官方tokenizer
    vocab_size: int = 50257 
    

## 3.GPT结构

In [4]:
#1. single head attention
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.key=nn.Linear(config.hidden_dim, config.head_size)
        self.query=nn.Linear(config.hidden_dim, config.head_size)
        self.value=nn.Linear(config.hidden_dim, config.head_size)
        
        # attention_mask 用register_buffer注册
        # 不用计算**梯度**，节省内存和显存，速度更快
        
        self.register_buffer(
            "attention_mask",
            torch.tril(
                torch.ones((config.block_size, config.block_size))
            )
        )
        
        self.dropout==nn.Dropout(config.droupout)
        
    def forward(self, x):
        batch_size, seq_len, hidden_dim=x.size()
        k=self.key(x)
        q=self.query(x)
        v=self.value(x)
        
        weight=q@k.transpose(-2,-1)
        weight=weight.masked_fill(
            self.attention_mask[:seq_len,:seq_len]==0,
            float("-inf")
        )
        weight=F.softmax(weight,dim=-1)/math.sqrt(self.head_size)
        
        #dropout 放weight之后
        weight=self.dropout(weight)

        return weight@v
        

In [5]:
#2. MultiheadAttention
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads=nn.ModuleList(
            [
                SingleHeadAttention(config)
                for _ in range(config.n_head)
            ]
        )
        self.proj =nn.Linear(config.hidden_dim, config.hidden_dim)
        self.dropout=nn.Dropout(config.droupout)

    def forward(self, x):
        batch_size, seq_len, hidden_dim=x.size()
        # [batch_size, seq_len, n_head, head_size]
        x=torch.cat(
            [h(x) for h in self.heads],
            dim=-1
        )
        output=self.proj(x) 
        output=self.dropout(output)
        return output
    

In [6]:
#3. feed forward (MLP)
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(config.hidden_dim, 4*config.hidden_dim),
            nn.ReLU(),
            nn.Linear(4*config.hidden_dim, config.hidden_dim),
            nn.Dropout(config.droupout)
        )
    def forward(self, x):
        return self.net(x)


In [7]:
#4. TransformerBlock
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att=MultiHeadAttention(config)
        self.ffn=FeedForward(config)
        self.ln1=nn.LayerNorm(config.hidden_dim)
        self.ln2=nn.LayerNorm(config.hidden_dim)
        
    def forward(self, x):
        x=x+self.att(self.ln1(x))
        x=x+self.ffn(self.ln2(x))
        return x
    

In [None]:
#.5. GPT
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # (embedding, postion, norm, mlp, block)
        # position -> rope
        # norm -> rms norm
        # mlp -> swiglu
        # mhx -> gqa
        
        self.token_embedding_table=nn.Embedding(config.vocab_size,config.n_embd)
        self.position_embedding_table=nn.Embedding(config.block_size,config.n_embd)
        self.blocks=nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        self.ln_final=nn.LayerNorm(config.n_embd)
        self.lm_head=nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # 现在的SLM模型会用tie weight减少参数
        
        self.token_embedding_table.weight=self.lm_head.weight

    def _init_weight(self, module):
        if isinstance(module,nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

            
    def forward(self, idx, targets=None):
        # idx 输入token ids，
        # targets 目标 token ids
        # idx and targets are both [batch_size, seq_len]
        batch, seq_len = idx.size()
        tok_emb = self.token_embedding_table(idx) # [batch_size, seq_len, n_embd] 
        pos_emb = self.position_embedding_table(
            torch.arange(seq_len, device=idx.device)
        ) # [seq_len, n_embd]
        x=tok_emb + pos_emb
        x=self.blocks(x)
        x=self.ln_final(x)

        logits=self.lm_head(x)

        if targets is None:
            loss=None
        else:
            batch, seq_len, vocab_size=logits.size()
            logits=logits.view(batch*seq_len,vocab_size)
            targets=targets.view(batch*seq_len)
            loss=F.cross_entropy(logits,targets)
            loss=F.cross_entropy(logits,targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # 如果序列太长，只取最后 block_size 个token
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # 获取预测
            logits, _ = self(idx_cond)
            # 只关注最后一个时间步的预测
            logits = logits[:, -1, :]  # becomes (B, vocab_size)
            # 应用softmax获取概率
            probs = F.softmax(logits, dim=-1)
            # 采样下一个token
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # 附加到序列上
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx 
     
    

## 4. Dataset

In [None]:
class MyDataset(Dataset):
    def __init__(self,path,block_size=512):
        import tiktoken
        self.enc=tiktoken.get_encoding("gpt2")
        self.block_size=block_size
        
        self.encoded_data=[]
        
        # <|endoftext|>
        self.eos_token=self.enc_encode(
            "<|endoftext|>",
            allowed_special={"<|endoftext|>"}
        )[0]
        
        self.max_lines=1000
        import json
        
        raw_data=[]
        
        with open(path,'r') as f:
            for i,line in enumerate(f):
                if i >=self.max_lines:
                    break
                try:
                    text=json.load(line.strip())["text"]
                    raw_data.append(text)
                except Exception as e:
                    continue
                    #encoded_text=self.enc.encode(text)
        full_encoded=[]
        for text in raw_data:
            encoded_text=self.enc.encode(text)
            full_encoded.extend(encoded_text+[self.eos_token])
        
        # block size 512
        
        for i in range(0,len(full_encoded),self.block_size)
            chunk=full_encoded[i:i+self.block_size+1]
            if len(chunk) < self.block_size+1:
                chunk= chunk+[self.eos_token]*(self.block_size+1-len(chunk))
            self.encoded_data.append(chunk)
            
    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        chunk=self.encoded_data[idx]
        x=torch.tensor(chunk[:-1],dtype=torch.long)
        y=torch.tensor(chunk[1:],dtype=torch.long)
        return x,y
    
    def encode(self,text):
        return self.enc.encode(text)
    
    def decode(self, ids):
        """将token IDs解码为文本"""
        return self.enc.decode(ids)
            

In [None]:
# train data
train_dataset = MyDataset('/root/fs/mobvoi_seq_monkey_general_open_corpus.jsonl')

# split traindataset to train and val
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=False)

In [8]:
model = GPT(GPTConfig())
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# 打印模型一共有多少参数

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6} M")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# 设置 cosine 学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

NameError: name 'GPT' is not defined

In [None]:
# 训练循环
def train(model, optimizer, scheduler, train_loader, val_loader, device):
    model.train()
    total_loss = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        # 将数据移到设备上
        x, y = x.to(device), y.to(device)
        
        # 前向传播
        logits, loss = model(x, targets=y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 调整学习率
        scheduler.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
    return total_loss

def eval(model, val_loader, device):
    # 验证
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, targets=y)
            val_loss += loss.item()
    return val_loss

for epoch in range(2):
    train_loss = train(model, optimizer, scheduler, train_loader, val_loader, device)
    val_loss = eval(model, val_loader, device)
    print(f'Epoch: {epoch}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}')

    # 保存模型
    avg_val_loss = val_loss / len(val_loader)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': avg_val_loss,
    }
    # 保存每个epoch的模型
    torch.save(checkpoint, f'checkpoints/model_epoch_{epoch}.pt')
    