In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

from practice.cnn_for_mnist.model import train_loader

torch.manual_seed(1024)


@dataclass
class GPTConfig:
    block_size: int = 512  # 文本的最大长度是512
    batch_size: int = 12
    n_layers: int = 12
    n_heads: int = 12
    n_embd: int = 768  # hidden_dim hidden_size
    hidden_dim: int = n_embd
    # 为了可以tie_embedding_weight
    dropout: float = 0.1
    head_size: int = n_embd // n_heads
    # vocab_size
    # gpt2官方的tokenzier
    vocab_size: int = 50257


# 定义一个单头注意力层
class SingleHeadAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.query = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)

        # attention_mask 通过 register_buffer注册
        # 因为不用计算梯度,所以节约内存和显存,速度也更快
        self.register_buffer("attention_mask",
                             # tril是下三角的意思
                             # block_size是512
                             torch.tril(
                                 torch.ones(config.block_size, config.block_size),
                             ))
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.size()
        q = self.key(x)
        k = self.query(x)
        v = self.value(x)
        weight = q @ k.transpose(-2, -1)  # @就是torch.matmul的简化写法
        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0, float('-inf')
        )
        # 要注意计算Weight的时候要除以根号d_k
        weight = weight / math.sqrt(self.head_size)
        weight = F.softmax(weight, dim=-1)
        # dropout要放到weight后面
        weight = self.dropout(weight)
        output = weight @ v
        return output


# 多头注意力写法
class MultiHeadAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.heads = nn.ModuleList([
            SingleHeadAttention(config)
            for _ in range(config.n_heads)
        ])
        self.proj = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads],
            dim=-1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output


# 3.feed forward(MLP)
class FeedForward(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.hidden_dim, 4 * config.hidden_dim),  # swiglu 8/3
            nn.GELU(),
            nn.Linear(4 * config.hidden_dim, config.hidden_dim),
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        return self.net(x)


# 4.block
class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.att = MultiHeadAttention(config)  # mha
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.hidden_dim)
        self.ln2 = nn.LayerNorm(config.hidden_dim)

    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


# 5.GPT
class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        # (embedding,position,norm,mlp,block)
        # position embedding从0，1，***embedding 升级到 rope
        # norm layer norm -> rms norm
        # mlp -> swiglu
        # mha - gqa
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layers)]
        )
        self.ln_final = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # 现在的SLM模型会用tie_weight来减少参数
        # linear(4->8),weight实际上的shape是8*4
        self.token_embedding_table.weight = self.lm_head.weight

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # 初始化为正态分布
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, idx, targets=None):
        # idx是输入token ids
        # targets是目标token ids
        # shape要一样
        batch_size, seq_len = idx.size()
        token_emb = self.token_embedding_table(idx)
        position_emb = self.position_embedding_table(
            torch.arange(seq_len, device=idx.device)
        )
        x = token_emb + position_emb
        x = self.blocks(x)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            batch, seq_len, vocab_size = logits.size()
            logits = logits.view(batch * seq_len, vocab_size)
            targets = targets.view(batch * seq_len)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_token):
        # ids shape(batch,seq_len)
        for _ in range(max_new_token):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            # shape,(batch,seq_len,vocab_size)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            # 随机采样
            idx_next = torch.multinomi(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)  # shape(batch,seq_len+1)
        return idx


class MyDataset(Dataset):
    def __init__(self, path, block_size=512):
        import tiktoken
        self.enc = tiktoken.get_encoding("gpt2")
        self.block_size = block_size
        self.encoded_data = []
        # 特殊符号分割不同的训练文本
        # <|endoftext|>
        self.eos_token = self.enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]
        import json
        self.max_lines = 1000
        raw_data = []
        with open(path, 'r') as f:
            for i, line in enumerate[f]:
                if i >= self.max_lines:
                    break
                try:
                    text = json.loads(line)["text"]
                    raw_data.append(text)
                except Exception as e:
                    continue
        full_encoded = []
        for text in raw_data:
            encode_text = self.enc.encode(text)  # list
            full_encoded.append(encode_text)

        # block_size是512
        # 长->短(512)
        for i in range(0, len(full_encoded), self.block_size):
            chunk = full_encoded[i:i + self.block_size]  # 512每一行实际是513
            if len(chunk) < self.block_size + 1:
                chunk = chunk + [self.eos_token] * (self.block_size + 1 - len(chunk))
            self.encoded_data.append(chunk)

    def __len__(self):
        return len(self.encoded_data)

    def __getitem__(self, idx):
        chunk = self.encoded_data[idx]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

    def encode(self, text):
        return self.enc.encode(text)

    def decode(self, text):
        return self.enc.decode(text)


# 5.运行相关的函数
model = GPT(GPTConfig())
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# 打印模型一共有多少参数
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 设置学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

# train_data
train_dataset = MyDataset('mobvoi_seq_monkey_general_open_corpus.jsonl')

# split traindataset to train and val
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=True)


def train(model, optimizer, scheduler, train_loader, valid_loader, device):
    model.train()
    total_loss = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        # 将数据迁移到设备上
        x, y = x.to(device), y.to(device)
        # 前向传播
        logits, loss = model(x, targets=y)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 调整学习率
        scheduler.step()
        total_loss += loss.item()
        if (batch_idx % 10 == 0):
            print(f'Epoch: 10,Batch:{batch_idx},Loss:{loss.item():.4f}')
        return total_loss


def eval(model, valid_loader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, targets=y)
            val_loss += loss.item()
    return val_loss


for epoch in range(2):
    train_loss = train(model, optimizer, scheduler, train_loader, val_loader, device)
    val_loss = eval(model, val_loader, device)
    print(f'Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}')
    avg_val_loss = val_loss / len(val_loader)
    # 保存模型
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': val_loss,
    }
    # 保存每个epoch的模型
    torch.save(checkpoint, f'checkpoints/model_epoch_{epoch}.pt')
