In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

torch.manual_seed(1024)

<torch._C.Generator at 0x1d960b96850>

### 0. 设置Config

In [5]:
@dataclass
class GPTConfig:
    block_size: int = 512   # 文本的最大长度 max_seq
    batch_size: int = 12
    n_layer: int = 8
    n_head: int = 8
    # 设置成一样为了tie_embedding_weight，共享词向量层（embedding layer）和输出层（output layer）的权重矩阵
    n_embd: int = 256   # 也是 hidden_dim/hidden_size的数值
    hidden_dim: int = n_embd
    dropout: float = 0.1 
    head_size: int = n_embd // n_head
    # vocab_size
    # gpt2 的官方的tokenizer
    vocab_size: int = 50257

In [6]:
class MOEConfig:
    def __init__(self, 
                 hidden_dim = 256, 
                 expert_number = 8, 
                 top_k = 2, 
                 shared_experts_numbers=2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_numbers

### 1. Multihead_Attention

In [7]:
# 1. single head attention
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.head_size = config.head_size
        self.query = nn.Linear(config.hidden_dim, config.head_size)
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)

        # 新的写法 attention_mask 通过 register_buffer 注册
        # 不用计算梯度，节约内存显存，速度更快
        self.register_buffer(
            "attention_mask",
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )   # 注意这里是一个block_size x block_size的矩阵，用的时候要用[:seq_len, :seq_len]截取一部分
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        weight = q @ k.transpose(-2, -1)
        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float("-inf")
        )
        # 注意计算weight的时候除以根号d_k
        weight = F.softmax(weight / math.sqrt(self.head_size), dim = -1)
        weight = self.dropout(weight)
        
        output = weight @ v
        return output

In [8]:
# multi head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                SingleHeadAttention(config)
                for _ in range(config.n_head)
            ]
        )
        self.proj = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads],
            dim = -1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output

### 2. MOE架构替换FFN

In [9]:
class BasicExpert(nn.Module):
    # 一个 Expert 可以是一个最简单的， linear 层即可
    # 也可以是 MLP 层
    # 也可以是 更复杂的 MLP 层（active function 设置为 swiglu）
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(feature_in, 4 * feature_in),
            nn.GELU(),
            nn.Linear(4 * feature_in, feature_out),
            nn.Dropout(0.1)
        )
    
    def forward(self, x):
        return self.ffn(x) 

In [10]:
class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_number)
        # 后面只是选后面 topk个专家
        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self, x):
        # 假设 expert_num = 8, top_k = 2
        router_logits = self.gate(x)   # (batch_size * seq_len, expert_number)

        # 计算每个专家的概率
        router_probs = F.softmax(router_logits, dim = -1, dtype = torch.float)

        # router_weights 的 ij 就是第 i 个 token 的 top_j 的专家的权重, 至于专家是哪个 indice 就是指出哪几个专家属于该token的topk
        # top_k 是可以反向传播的 
        router_weights, selected_experts_indices = torch.topk(
            router_probs,
            self.top_k,
            dim = -1
        )   # router_weight, selected_experts_indices 的shape都是 (batch_size * seq_len, top_k)

        # 再次softmax一下，重新归一化
        router_weights = router_weights / router_weights.sum(
            dim = -1, keepdim = True
        )           # (batch_size * seq_len, top_k)
        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_experts_indices,
            num_classes = self.expert_number,
        )   # 输出 (batch_size, * seq_len, top_k, expert_number)

        expert_mask = expert_mask.permute(2, 1, 0)  # 希望shape (expert_number, top_k, batch_size * seq_len)

        return router_logits, router_weights, selected_experts_indices, expert_mask
        
        
        # router_logits shape (batch_size * seq_len, expert_number)
        # router_weights shape (batch_size * seq_len, top_k)
        # selected_experts_indices shape (batch_size * seq_len, top_k)
        # expert_mask shape (expert_number, top_k, batch_size * seq_len)

In [11]:
class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

        # 初始化专家
        self.experts = nn.ModuleList(
            [
                BasicExpert(
                    config.hidden_dim,
                    config.hidden_dim
                )   for _ in range(config.expert_number)
            ]
        )
        self.router = MOERouter(config)

    def forward(self, x):
        # x shape (batch_size, seq_len, hidden_dim)
        
        batch_size, seq_len, hidden_dim = x.size()

        # token 维度计算, x reshape (batch_size * seq_len, hidden_dim)
        hidden_states = x.view(-1, hidden_dim)

        # 做相关专家计算
        router_logits, router_weights, selected_experts_indices, expert_masks = self.router(hidden_states)

        # expert_masks shape (expert_number, top_k, batch_size * seq_len)
        # 最终 final_hidden_states 是 (batch_size * seq_len, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype = hidden_states.dtype,
            device = hidden_states.device
        )

        # 遍历专家，把选中该专家的token的hidden_states加入final_hidden_states中
        # 如 expert_0 可能有 100 个token选中了，token总数是batch * seq_len
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_masks (expert_num, top_k, batch_size * seq_len)
            # current_expert_mask (top_k, batch_size * seq_len)
            current_expert_mask = expert_masks[expert_idx]

            # top_idx 是 0 or 1 如果假设top_k=2，表示当前 token 是作为当前专家的top1 or top2
            # token_x 是 token 在 batch_size * seq_len 的索引位置, 如 b * s = 8，那么就是 0-7
            # top_idx 和 token_x 都是一个1维的值，有一个一一对应的关系！
            # top_idx为了选expert的哪个weight， token_x 是为了选 hidden_states （到底是哪个token）
            top_idx, token_x = torch.where(
                current_expert_mask
            )
            # hidden_states shape (batch_size * seq_len, hidden_dim)
            # unsqueeze的 shape (1, batch_size * seq_len, hidden_dim)

            # current_state shape (selected_token_number, hidden_dim)
            current_state = hidden_states.unsqueeze(0)[:, token_x, :].reshape(-1, hidden_dim)

            current_state = expert_layer(current_state)

            # current_token_router_weight shape (selected_token_number, )
            # token_x, top_idx 是两个长度相同的一维索引张量，PyTorch 会按位置一一对应提取元素，最终形成一个长度等于索引数量的一维张量。
            current_token_router_weight = router_weights[token_x, top_idx]
            # shape (selected_token_number, 1)
            # 虽然这里要知道哪个专家是该token的top_k要查indices，但其实已经得到了该专家的weight，具体是哪个专家就不需要在这一步搞了
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            
            # current_state (selected_token_number, hidden_dim) 再* (selected_token_number, 1)广播
            current_hidden_states = current_state * current_token_router_weight

            # index_add_加下划线代表原地操作
            final_hidden_states.index_add_(
                dim = 0,    # 沿着哪个维度进行索引和累加
                index = token_x,  # 要累加token索引
                source = current_hidden_states.to(hidden_states.dtype),   # 提供累加值的源张量
            )
        # 把 final_hidden_states 还原
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

        # router_logits shape (batch_size * seq_len, expert_number)，是为了算loss用的
        return final_hidden_states, router_logits   

In [12]:
class SharedExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.routed_experts_moe = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(self.config.hidden_dim, 
                            self.config.hidden_dim)
                        for _ in range(self.config.shared_experts_number)
            ]
        )

    def forward(self, x):
        # x shape (b, s, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # concat要先unsqueeze一下，stack可以直接创建新的维度
        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]
        shared_expert_output = torch.stack(
            shared_experts_output_list,
            dim = 0
        )   # shape (shared_experts_number,batch_size, seq_len, hidden_dim)

        # shape (b, s, hidden_dim)
        shared_expert_out = shared_expert_output.sum(dim=0, keepdim=False)

        # shape (b, s, hidden_dim)
        sparse_moe_out, router_logits = self.routed_experts_moe(
            x
        )

        output = shared_expert_out + sparse_moe_out
        return output, router_logits    

### 3. 定义Block

In [13]:
# block
class Block(nn.Module):
    def __init__(self, gpt_config, moe_config):
        super().__init__()
        self.att = MultiHeadAttention(gpt_config)
        self.moe = SharedExpertMOE(moe_config)
        self.att_ln = nn.LayerNorm(gpt_config.hidden_dim, eps = 1e-6)
        self.moe_ln = nn.LayerNorm(gpt_config.hidden_dim, eps = 1e-6)

    def forward(self, x):
        x = x + self.att(self.att_ln(x))
        tmp_x, router_logits = self.moe(self.moe_ln(x))
        x = x + tmp_x
        return x, router_logits

### 4. 构建GPT，加入负载平衡loss

In [14]:
# GPT
class GPT(nn.Module):
    def __init__(self, config, moe_config):
        super().__init__()
        # 主要就是(embedding, position, norm, mlp, block)
        # position embedding从0，1，xxx升级到rope
        # norm从layer norm升级到了RMS norm
        # mlp -> swiglu
        # mha -> gqa
        self.config = config
        self.moe_config = moe_config
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        self.block = Block(config, moe_config)
        self.last_ln = nn.LayerNorm(config.hidden_dim)
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias = False)
        # 现在的slm，会用tie_weight来减少参数
        # 非常重要
        # linear 层的weight有一个转置的操作
        self.token_embedding_table.weight = self.lm_head.weight
        self.block_size = 256

        self.apply(self._init_weights)  # 遍历所有的子模块，更优雅

    def switch_load_balancing_loss(self, router_logits, num_experts = 8) -> torch.Tensor:
        """
        计算 Switch Transformers 的负载均衡损失，确保所有专家得到的token数量差不多
        
        Args:
            router_logits: shape [batch_size * sequence_length, num_experts]
            num_experts: 专家数量
        
        Returns:
            total_loss: 总损失 = auxiliary_loss + z_loss
        """
        # 计算路由概率
        router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
        
        # 获取每个token的最优专家
        _, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s]
        
        # 创建one-hot矩阵表示选中的专家
        mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()  # [b*s, num_experts]
        
        # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
        expected_load = torch.ones_like(router_probs) / num_experts
        
        # 计算实际负载 (每个专家处理的token数量除以总token数量)
        # 在batch维度上计算平均值
        actual_load = mask.mean(dim=0)  # [num_experts]
        
        # 计算auxiliary loss
        # 这会惩罚负载分布与期望负载的差异
        aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
        
        # 计算z_loss (可选)
        # 这会惩罚过大的路由logits
        z_loss = torch.mean(torch.square(router_logits))
        z_loss_weight = 0.001  # 可调整的超参数
        
        # 总损失
        total_loss = aux_loss + z_loss * z_loss_weight
    
        return total_loss

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # 初始化为正态分布
            torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean = 0.0, std = 0.02)
        
    def forward(self, idx, targets=None):
        # idx 输入的是token_ids
        # targets 是目标的token ids (batch, seq_len) 就是词表里面最终选择的词！
        # shape 要一样
        batch, seq_len = idx.size() # (batch, seq_len)
        token_emb = self.token_embedding_table(idx) # (batch, seq_len, n_embd)
        pos_emb = self.position_embedding_table(
            # 要确保位置编码和输入的idx在同一个设备上
            torch.arange(seq_len, device = idx.device)
        )   # shape(seq_len, n_embd)

        loss = 0
        x = token_emb + pos_emb    # 这里其实是广播相加 (batch, seq_len, n_embd)
        for _ in range(self.config.n_layer):
            x, router_logits = self.block(x)
            aux_loss = self.switch_load_balancing_loss(router_logits, self.moe_config.expert_number)
            loss +=  0.01 * aux_loss
        x = self.last_ln(x)
        logits = self.lm_head(x)    # shape (batch, seq_len, vocab_size)
    
        if targets is None:
            loss = None
        else:
            batch, seq_len, vocab_size = logits.size()
            logits = logits.view(batch * seq_len, vocab_size)
            targets = targets.view(batch * seq_len)
            loss = loss + F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # 如果序列太长，只取最后 block_size 个token
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # 获取预测
            logits, _ = self(idx_cond)  # 等价于self.forward(idx_cond)
            # 只关注最后一个时间步的预测, shape (batch_size, seq_len, vocab_size)
            logits = logits[:, -1, :]  # becomes (batch_size, vocab_size)
            # 应用softmax获取概率
            probs = F.softmax(logits, dim=-1)
            # 从概率分布中采样下一个token（而非贪心选择最大值）
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
            # 附加到序列上
            idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, seq_len + 1)
        return idx

### 5. 构建输入Dataset

In [15]:
# 写一个dataset,为Dataloader准备
class MyDataset(Dataset):
    def __init__(self, path, block_size = 512):
        import tiktoken
        # gpt专用的tokenizer
        self.enc = tiktoken.get_encoding('gpt2')
        self.block_size = block_size    # pos 最大长度

        self.encoded_data = []
        # 特殊符号分割不同的训练文本
        # <|endoftext|> # [50256]，即它在vocab里面放在最后一个50256位置
        self.eos_token = self.enc.encode(
            "<|endoftext|>",
            allowed_special={"<|endoftext|>"}
        )[0]

        self.max_lines = 1000
        import json

        raw_data = []   # 为了pad长度不一样的data
        with open(path, 'r', encoding = 'utf-8') as f:
            for i, line in enumerate(f):
                if i >= self.max_lines:
                    break
                try:
                    text = json.loads(line.strip())['text']
                    raw_data.append(text)
                except Exception as e:
                    continue
        
        full_encoded = []
        for text in raw_data:
            encoded_text = self.enc.encode(text)    # list
            # 将所有text放在一行然后eos_token做分割
            full_encoded.extend(encoded_text + [self.eos_token])

        # block_size = 512
        # 长 -> 短 512
        for i in range(0, len(full_encoded), self.block_size):
            # 注意！在这的数据有一个移位的操作
            chunk = full_encoded[i:i+self.block_size+1] # 512 每一行实际是 513
            if len(chunk) < self.block_size + 1:
                chunk = chunk + [self.eos_token] * (self.block_size + 1 - len(chunk))
            self.encoded_data.append(chunk)                  

    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        # 完成了移位的操作
        chunk = self.encoded_data[idx]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y
    
    def encode(self, text):
        return self.enc.encode(text)
    
    def decode(self, ids):
        return self.enc.decode(ids)


### 6. 运行相关的函数

In [16]:
model = GPT(GPTConfig(), MOEConfig())
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# print 模型共计的参数

total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params / 1e6} M")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# 设置 cosine 学习率，余弦退火
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

Total Parameters: 18.519304 M


In [17]:
# train data
train_dataset = MyDataset(r'E://llm/data/mobvoi_seq_monkey_general_open_corpus.jsonl')

# split traindataset to train and val
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=False)

In [18]:
import os
def train(model, optimizer, scheduler, train_loader, val_loader, device, epoch):
    model.train()
    total_loss = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        # 将数据移动设备上
        x, y = x.to(device), y.to(device)

        # 前向传播
        logits, loss = model(x, targets=y)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 调整学习率
        scheduler.step()

        total_loss += loss.item()

        if batch_idx % 10 == 0:
            print(f"Epoch: {epoch}, Batch: {batch_idx}, loss:{loss.item():.4f}")
    
    return total_loss

def eval(model, val_loader, device):
    # 验证
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, targets=y)
            val_loss += loss.item()
    return val_loss

for epoch in range(2):
    train_loss = train(model, optimizer, scheduler, train_loader, val_loader, device, epoch)
    val_loss = eval(model, val_loader, device)
    print(f'Epoch: {epoch}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}')

    # 保存模型
    avg_val_loss = val_loss / len(val_loader)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': avg_val_loss,
    }
    # 确保保存目录存在
    os.makedirs('./checkpoints', exist_ok=True)  # exist_ok=True 表示如果目录已存在不报错
    # 保存每个epoch的模型
    torch.save(checkpoint, f'./checkpoints/model_epoch_{epoch}.pt')


Epoch: 0, Batch: 0, loss:11.0330
Epoch: 0, Batch: 10, loss:9.5344
Epoch: 0, Batch: 20, loss:8.1548
Epoch: 0, Batch: 30, loss:6.8227
Epoch: 0, Batch: 40, loss:5.8295
Epoch: 0, Batch: 50, loss:5.3388
Epoch: 0, Batch: 60, loss:5.0983
Epoch: 0, Batch: 70, loss:4.8970
Epoch: 0, Batch: 80, loss:4.9168
Epoch: 0, Batch: 90, loss:4.7920
Epoch: 0, Batch: 100, loss:4.5867
Epoch: 0, Batch: 110, loss:4.4916
Epoch: 0, Batch: 120, loss:4.3504
Epoch: 0, Batch: 130, loss:4.3763
Epoch: 0, Batch: 140, loss:4.2993
Epoch: 0, Batch: 150, loss:4.2175
Epoch: 0, Batch: 160, loss:4.0668
Epoch: 0, Batch: 170, loss:4.0668
Epoch: 0, Batch: 180, loss:4.0724
Epoch: 0, Batch: 190, loss:4.1092
Epoch: 0, Batch: 200, loss:4.0818
Epoch: 0, Batch: 210, loss:3.9188
Epoch: 0, Batch: 220, loss:3.9561
Epoch: 0, Batch: 230, loss:3.9254
Epoch: 0, Batch: 240, loss:3.9603
Epoch: 0, Batch: 250, loss:3.8319
Epoch: 0, Batch: 260, loss:3.8373
Epoch: 0, Batch: 270, loss:3.8459
Epoch: 0, Train Loss: 4.9027, Val Loss: 3.8177
Epoch: 1, B

### 6. 调用模型

In [19]:
def generate_text(prompt, max_new_tokens=100, temperature=0.8):
    # 编码输入文本
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
    
    # 生成文本
    with torch.no_grad():
        generated = model.generate(
            input_tensor, 
            max_new_tokens=max_new_tokens
        )
    
    # 解码结果
    output_ids = generated[0].cpu().tolist()
    return tokenizer.decode(output_ids)

In [20]:
import tiktoken
# 初始化配置
gpt_config = GPTConfig(
    block_size=512,
    batch_size=12,
    n_layer=8,
    n_head=8,
    n_embd=256
)

moe_config = MOEConfig(
    hidden_dim=256,
    expert_number=8,
    top_k=2,
    shared_experts_numbers=2
)

# 创建模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT(gpt_config, moe_config).to(device)

# 加载权重
checkpoint = torch.load('./checkpoints/model_epoch_1.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 获取tokenizer
tokenizer = tiktoken.get_encoding('gpt2')

# 生成文本
prompt = "你好啊"
output_text = generate_text(prompt, max_new_tokens=150)
print(output_text)

你好啊试�辎 sinners iterationYesterday affirmed式还品法重弽带�甋、杶在法 disinfect有�逡机诤��等 兡口？不部，1.水��单。��矎��对罦俗优了匽劵��让��� Gamma平�衢�平�颫玁常目�：的重绿兞�（
