# Transformer Decoder

## Data Preparation

In [None]:
from torch.utils.data import Dataset
import os

os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HTTP_PROXY'] = "http://127.0.0.1:7897"
os.environ['HTTPS_PROXY'] = "http://127.0.0.1:7897"
os.environ['ALL_PROXY'] = "socks5://127.0.0.1:7897"


class Tokenizer:
    def __init__(self, datapath: str):
        with open(datapath, 'r', encoding='utf-8') as f:
            self.dataset = f.read()
        self.__gen_vocab()

    def __gen_vocab(self):
        # 开始符号 <CLS>, 分隔符号 <SEP>, 未知符号 <UNK>
        self.char2idx = {'<UNK>': 0, '<CLS>': 1, '<SEP>': 2}
        self.idx2char = {0: '<UNK>', 1: '<CLS>', 2: '<SEP>'}

        for idx, char in enumerate(set(self.dataset), start=1):
            self.char2idx[char] = idx
            self.idx2char[idx] = char

        self.vocab_size = len(self.char2idx)

    def encode(self, sentence):
        indices = [self.char2idx.get(char, 0) for char in sentence]
        return [1] + indices + [2]

    def decode(self, ids):
        chars = [self.idx2char.get(_id, 0) for _id in ids]
        return ''.join(chars[1:-1])


class ShakespeareDataset(Dataset):
    def __init__(self, datapath: str, tokenizer_mode: str, tokenizer, chunk_size: int):
        with open(datapath, 'r', encoding='utf-8') as f:
            self.dataset = f.read()
        self.chunk_size = chunk_size
        self.tokenizer_mode = tokenizer_mode
        self.tokenizer = tokenizer
        if tokenizer_mode == 'custom':
            self.vocab_size = tokenizer.vocab_size
        elif tokenizer_mode == 'bert':
            self.vocab_size = tokenizer.vocab_size
        elif tokenizer_mode == 'tiktoken':
            self.vocab_size = tokenizer.max_token_value + 1

        self.encoded_dataset = tokenizer.encode(self.dataset)

    def __len__(self):
        return len(self.encoded_dataset) - self.chunk_size

    def __getitem__(self, idx):
        chunk = self.encoded_dataset[idx:idx + self.chunk_size]
        label = self.encoded_dataset[idx + 1:idx + self.chunk_size + 1]
        # 转成 tensor
        chunk = torch.tensor(chunk, dtype=torch.long)
        label = torch.tensor(label, dtype=torch.long)
        return chunk, label

    def get_vocab_size(self):
        return self.vocab_size


def generate_tgt_mask(seq_len):
    """生成上三角的掩蔽矩阵，防止看到未来的词"""
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
    return mask


## Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, seq_len: int, embed_dim: int, hidden_dim: int):
        super(Attention, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.Q = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.K = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.V = nn.Linear(embed_dim, hidden_dim, bias=False)

    def forward(self, values, keys, queries, mask=None):
        values = self.V(values)
        keys = self.K(keys)
        queries = self.Q(queries)

        scaled_dot_product = torch.bmm(queries, keys.transpose(1, 2)) / (self.hidden_dim ** 0.5)

        if mask is not None:
            scaled_dot_product = scaled_dot_product.masked_fill(mask == 0, float('-inf'))

        attention = torch.softmax(scaled_dot_product, dim=-1)

        output = torch.bmm(attention, values)
        return output


class MultiHeadAttention(nn.Module):
    def __init__(self, seq_len: int, embed_dim: int, heads: int, dropout: float):
        super(MultiHeadAttention, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.heads = heads
        self.hidden_dim = embed_dim // heads

        assert self.hidden_dim * heads == embed_dim, "embed_dim 必须被注意力头整除"

        self.multi_head_attention_layers = nn.ModuleList([
            Attention(self.seq_len, self.embed_dim, self.hidden_dim)
            for _ in range(self.heads)
        ])

        self.out_linear = nn.Linear(self.hidden_dim * self.heads, self.embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, values, keys, queries, mask=None):
        attention_outputs = torch.cat([
            attention_layer(values, keys, queries, mask)
            for attention_layer in self.multi_head_attention_layers
        ], dim=-1)
        output = self.out_linear(attention_outputs)
        return output


class Expert(nn.Module):
    def __init__(self, embed_dim: int):
        super(Expert, self).__init__()
        self.embed_dim = embed_dim
        self.expert_layer = nn.Sequential(
            nn.Linear(self.embed_dim, 4 * self.embed_dim),
            nn.ReLU(),
            nn.Linear(4 * self.embed_dim, self.embed_dim)
        )

    def forward(self, x):
        return self.expert_layer(x)


class TopKRouter(nn.Module):
    def __init__(self, embed_dim: int, num_experts: int, active_experts: int):
        super(TopKRouter, self).__init__()
        self.embed_dim = embed_dim
        self.num_experts = num_experts
        self.active_experts = active_experts
        # 使用简单的 MLP 作为 Router
        self.top_k_router = nn.Sequential(
            nn.Linear(self.embed_dim, self.num_experts),
            nn.ReLU()
        )

    def forward(self, x):
        scores = self.top_k_router(x)  # (batch_size, seq_len, num_experts)

        top_k_values, top_k_indices = torch.topk(scores, self.active_experts,
                                                 dim=-1)  # (batch_size, seq_len, active_experts)

        mask = torch.zeros_like(scores).scatter(-1, top_k_indices, 1)  # (batch_size, seq_len, num_experts)

        # mask 中被选中的位置为 1，未被选中的位置为 0

        masked_scores = scores.masked_fill(mask == 0, float('-inf'))
        router_weight = torch.softmax(masked_scores, dim=-1)

        return router_weight, mask


class SparseMoE(nn.Module):
    def __init__(self, embed_dim: int, num_experts: int, active_experts: int):
        super(SparseMoE, self).__init__()
        self.embed_dim = embed_dim
        self.num_experts = num_experts
        self.active_experts = active_experts
        self.experts = nn.ModuleList([Expert(self.embed_dim) for _ in range(self.num_experts)])
        self.router = TopKRouter(self.embed_dim, self.num_experts, self.active_experts)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        router_output, mask = self.router(x)

        # 初始化一个全0的输出张量
        outputs = torch.zeros_like(x)

        # 遍历所有专家并将输出加权累加
        for i, expert in enumerate(self.experts):
            expert_output = expert(x)  # 获取当前专家的输出
            # 使用mask和router_output来加权输出
            weight = router_output[:, :, i:i + 1] * mask[:, :, i:i + 1]
            outputs += weight * expert_output

        return outputs


class PositionalEncoding(nn.Module):
    def __init__(self, seq_len: int, embed_dim: int, dropout: float):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)

        position = torch.arange(0, seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(torch.log(torch.tensor(10000.0)) / embed_dim))
        pe = torch.zeros(seq_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, seq_len, embed_dim)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe
        return self.dropout(x)


class TransformerDecoderLayer(nn.Module):
    def __init__(
            self,
            embed_dim: int,
            n_heads: int,
            seq_len: int,
            num_experts: int,
            active_experts: int,
            dropout: float
    ):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(seq_len, embed_dim, n_heads, dropout)
        self.encoder_decoder_attention = MultiHeadAttention(seq_len, embed_dim, n_heads, dropout)
        self.moe_ffn = SparseMoE(embed_dim, num_experts, active_experts)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, memory, src_mask, tgt_mask):
        # self-attention
        x = self.norm1(x)
        x = x + self.dropout1(self.self_attention(x, x, x, tgt_mask))

        # encoder-decoder attention
        x = self.norm2(x)
        x = x + self.dropout2(self.encoder_decoder_attention(memory, memory, x, src_mask))

        # moe ffn
        x = self.norm3(x)
        x = x + self.dropout3(self.moe_ffn(x))

        return x


class SparseMoETransformerDecoder(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            seq_len: int,
            embed_dim: int,
            n_layers: int,
            n_heads: int,
            num_experts: int,
            active_experts: int,
            dropout: float
    ):
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        super(SparseMoETransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(seq_len, embed_dim, dropout)
        self.transformer_layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, n_heads, seq_len, num_experts, active_experts, dropout)
            for _ in range(n_layers)
        ])
        self.out_linear = nn.Linear(embed_dim, vocab_size)

    def forward(self, x, src_mask=None, tgt_mask=None):
        x = self.embedding(x)
        x = self.pos_encoder(x)

        for layer in self.transformer_layers:
            x = layer(x, x, src_mask, tgt_mask)

        output = self.out_linear(x)
        return output

    def generate(self, input_tokens, max_new_tokens):
        device = next(self.parameters()).device
        input_tokens = input_tokens.to(device)

        if input_tokens.size(1) >= self.seq_len:
            input_tokens = input_tokens[:, :self.seq_len]
        else:
            input_tokens = F.pad(input_tokens, (0, self.seq_len - input_tokens.size(1)))

        for _ in range(max_new_tokens):
            if input_tokens.size(1) >= self.seq_len:
                input_tokens = input_tokens[:, -self.seq_len:]

            tgt_mask = generate_tgt_mask(input_tokens.size(1)).to(device)
            output = self(input_tokens, tgt_mask=tgt_mask)
            last_token_logits = output[:, -1, :]  # 取最后一个 token 的 logits
            probs = F.softmax(last_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_tokens = torch.cat([input_tokens, next_token], dim=-1)

            if next_token.item() == self.vocab_size - 1:  # Assuming the last vocab index is an EOS token
                break

        return input_tokens


## Training

In [None]:
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


class Trainer:
    def __init__(self, args, model, train_dataloader, val_dataloader, criterion, optimizer):
        self.args = args
        self.device = args.device
        self.epochs = args.epochs
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

        self.model.to(self.device)
        self.save_path = args.save_path
        self.model_path = args.model_path
        self.writer = SummaryWriter(log_dir=self.save_path)

    def train(self):
        best_val_loss = float('inf')
        for epoch in range(self.epochs):
            train_loss = self._train_single_epoch()
            val_loss = self._val_single_epoch()

            print(f"Epoch {epoch + 1}/{self.epochs}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            self.writer.add_scalar("Train Loss", train_loss, epoch)
            self.writer.add_scalar("Val Loss", val_loss, epoch)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), f"{self.model_path}/best_model.pth")

    def _train_single_epoch(self):
        self.model.train()
        epoch_loss = 0
        for batch in tqdm(self.train_dataloader, desc="Training"):
            inputs, targets = batch
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # tgt_mask
            tgt_mask = generate_tgt_mask(inputs.size(1)).to(self.device)

            self.optimizer.zero_grad()
            output = self.model(inputs, tgt_mask=tgt_mask)
            loss = self.criterion(output.view(-1, output.size(-1)), targets.view(-1))

            loss.backward()
            self.optimizer.step()

            epoch_loss += loss.item()

        return epoch_loss / len(self.train_dataloader)

    @torch.no_grad()
    def _val_single_epoch(self):
        self.model.eval()
        epoch_loss = 0
        for batch in tqdm(self.val_dataloader, desc="Validation"):
            inputs, targets = batch
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # tgt_mask
            tgt_mask = generate_tgt_mask(inputs.size(1)).to(self.device)

            output = self.model(inputs, tgt_mask=tgt_mask)
            loss = self.criterion(output.view(-1, output.size(-1)), targets.view(-1))

            epoch_loss += loss.item()

        return epoch_loss / len(self.val_dataloader)


@torch.no_grad()
def test(args, model, test_dataloader, criterion):
    device = args.device
    model.to(device)
    writer = SummaryWriter(log_dir=args.save_path)

    model.eval()
    test_loss = 0
    for batch in tqdm(test_dataloader, desc="Testing"):
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        tgt_mask = generate_tgt_mask(inputs.size(1)).to(device)

        output = model(inputs, tgt_mask=tgt_mask)
        loss = criterion(output.view(-1, output.size(-1)), targets.view(-1))

        test_loss += loss.item()

    test_loss = test_loss / len(test_dataloader)
    print(f"Test Loss: {test_loss:.4f}")
    writer.add_scalar("Test Loss", test_loss)


## Arguments

In [None]:
import argparse
import torch
import os


def parse_args():
    parser = argparse.ArgumentParser(description="Train a SparseMoE Transformer model for text generation.")
    parser.add_argument('--data_path', type=str, default='data/input.txt', help='Path to the input text file.')
    parser.add_argument('--chunk_size', type=int, default=50, help='Size of text chunks.')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.')
    parser.add_argument('--seq_len', type=int, default=50, help='Sequence length for the model.')
    parser.add_argument('--embed_dim', type=int, default=64, help='Embedding dimension for the model.')
    parser.add_argument('--n_layers', type=int, default=3, help='Number of layers in the Transformer.')
    parser.add_argument('--n_heads', type=int, default=4, help='Number of attention heads.')
    parser.add_argument('--num_experts', type=int, default=4, help='Number of experts in SparseMoE.')
    parser.add_argument('--active_experts', type=int, default=2, help='Number of active experts in SparseMoE.')
    parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs.')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate.')
    parser.add_argument('--save_path', type=str, default='results/', help='Path to save the model and results.')
    parser.add_argument('--model_path', type=str, default='models/', help='Path to save the model.')
    return parser.parse_args()


args = parse_args()

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.device = device

args.tokenizer_mode = 'bert'

args.save_path = f"{args.save_path}/{args.tokenizer_mode}"
args.model_path = f"{args.model_path}/{args.tokenizer_mode}"

# mkdir model path
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

## Main

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from transformers import BertTokenizer
import tiktoken


def main():
    # 加载分词器
    if args.tokenizer_mode == 'custom':
        tokenizer = Tokenizer(args.data_path)
    elif args.tokenizer_mode == 'bert':
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    elif args.tokenizer_mode == 'tiktoken':
        tokenizer = tiktoken.get_encoding("cl100k_base")
    else:
        raise ValueError("Invalid tokenizer mode. Choose from 'custom', 'bert', 'tiktoken'.")

    # 加载数据集
    dataset = ShakespeareDataset(args.data_path, args.tokenizer_mode, tokenizer, args.chunk_size)

    # 设置词汇表大小
    args.vocab_size = dataset.get_vocab_size()

    # 划分数据集，并创建 DataLoader
    train_size = int(0.7 * len(dataset))
    val_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    # 初始化模型
    model = SparseMoETransformerDecoder(
        vocab_size=args.vocab_size,
        seq_len=args.seq_len,
        embed_dim=args.embed_dim,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        num_experts=args.num_experts,
        active_experts=args.active_experts,
        dropout=args.dropout
    ).to(args.device)

    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # 训练和验证
    trainer = Trainer(args, model, train_dataloader, val_dataloader, criterion, optimizer)
    trainer.train()

    # 加载最佳模型并测试
    model.load_state_dict(torch.load(f"{args.model_path}/best_model.pth"))
    test(args, model, test_dataloader, criterion)


def generate_text(input_text: str, max_len: int = 100):
    if args.tokenizer_mode == 'custom':
        tokenizer = Tokenizer(args.data_path)
        encoded_text = torch.tensor(tokenizer.encode(input_text), dtype=torch.long).unsqueeze(0).to(args.device)
    elif args.tokenizer_mode == 'bert':
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        encoded_text = tokenizer.encode(input_text, add_special_tokens=True, return_tensors='pt').to(args.device)
    elif args.tokenizer_mode == 'tiktoken':
        tokenizer = tiktoken.get_encoding("cl100k_base")
        encoded_text = torch.tensor(tokenizer.encode(input_text), dtype=torch.long).unsqueeze(0).to(args.device)
    else:
        raise ValueError("Invalid tokenizer mode. Choose from 'custom', 'bert', 'tiktoken'.")

    # 加载数据集
    dataset = ShakespeareDataset(args.data_path, args.tokenizer_mode, tokenizer, args.chunk_size)

    # 设置词汇表大小
    args.vocab_size = dataset.get_vocab_size()

    # 加载模型
    model = SparseMoETransformerDecoder(
        vocab_size=args.vocab_size,
        seq_len=args.seq_len,
        embed_dim=args.embed_dim,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        num_experts=args.num_experts,
        active_experts=args.active_experts,
        dropout=0.1
    ).to(args.device)

    model.load_state_dict(torch.load(f"{args.model_path}/best_model.pth"))
    model.eval()

    if args.tokenizer_mode == 'custom':
        gen_tokens = model.generate(encoded_text, max_len)[0].tolist()
        gen_text = tokenizer.decode(gen_tokens)
    elif args.tokenizer_mode == 'bert':
        gen_tokens = model.generate(encoded_text, max_len)[0].tolist()
        gen_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    elif args.tokenizer_mode == 'tiktoken':
        gen_tokens = model.generate(encoded_text, max_len)[0].tolist()
        gen_text = tokenizer.decode(gen_tokens)
    else:
        raise ValueError("Invalid tokenizer mode. Choose from 'custom', 'bert', 'tiktoken'.")

    print("Input text:")
    print(input_text)
    print("Generated text:")
    print(gen_text, end='\n\n')


if __name__ == '__main__':
    # main()
    generate_text("To be or not to be, that is the question:", max_len=100)
    generate_text("I could pick my lance", max_len=100)

    origin_text = """
Would the nobility lay aside their ruth,
And let me use my sword, I'll make a quarry
With thousands of these quarter'd slaves, as high
As I could pick my lance."""
    generate_text(origin_text, max_len=100)
