In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm
from wzh.transformer import Transformer

torch.manual_seed(0)
learning_rate = 1e-5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataset = load_dataset("data/wikitext103", "wikitext-103-v1")
dataset = load_dataset(
    "data/wikitext103/",
    data_files={
        "train": ["train-00000-of-00002.parquet","train-00001-of-00002.parquet"],
    },
)

# tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("./gpt2-tokenizer")
vocab_size = len(tokenizer)


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        dim_model = 384
        self.embedding = nn.Embedding(vocab_size, dim_model)
        self.model = Transformer(
            nlayer=6,
            dim_model=dim_model,
            num_head=8,
            max_seq_len=1024,
            glu_attn=False,
        )
        self.output = nn.Linear(dim_model, vocab_size)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.model(x, mask)
        x = self.output(x)
        return x


class GLUAttention(nn.Module):
    def __init__(self):
        super().__init__()
        dim_model = 384
        self.embedding = nn.Embedding(vocab_size, dim_model)
        self.model = Transformer(
            nlayer=6,
            dim_model=dim_model,
            num_head=8,
            max_seq_len=1024,
            glu_attn=True,
        )
        self.output = nn.Linear(dim_model, vocab_size)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.model(x, mask)
        x = self.output(x)
        return x


def prepare_data(example):
    tokens = tokenizer(example["text"], truncation=True, max_length=1024)
    return {"input_ids": tokens["input_ids"], "labels": tokens["input_ids"]}


tokenized_dataset = dataset.map(
    prepare_data, remove_columns=dataset["train"].column_names
)


def collate_fn(examples):
    input_ids = [torch.tensor(x["input_ids"], dtype=torch.long) for x in examples]
    input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True)
    labels = input_ids.clone()
    return {"input_ids": input_ids, "labels": labels}


train_loader = DataLoader(
    tokenized_dataset["train"],
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    pin_memory=True,
)


def train(model, num_epochs):
    model.to(device)
    model.train()
    print(f"parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(model)
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()

    num_token_list = []
    loss_list = []
    ema_loss = 8
    total_tokens = 0

    for epoch in range(num_epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            seq_len = input_ids.size(1)
            if seq_len == 0:
                continue
            total_tokens += seq_len
            mask = torch.triu(
                torch.ones((seq_len, seq_len), dtype=torch.bool, device=device),
                diagonal=1,
            )
            optimizer.zero_grad()
            logits = model(input_ids, mask)
            loss = criterion(
                logits[:, :-1].view(-1, vocab_size), labels[:, 1:].view(-1)
            )
            loss.backward()
            optimizer.step()
            ema_loss = 0.999 * ema_loss + 0.001 * loss.item()
            progress_bar.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "ema loss": f"{ema_loss:.4f}",
                }
            )
            num_token_list.append(total_tokens)
            loss_list.append(loss.item())
        scheduler.step()
    return num_token_list, loss_list


import numpy as np


def split_and_average(list, num_splits=100):
    split_indices = np.linspace(0, len(list), num_splits + 1, dtype=int)
    avg = []

    for i in range(len(split_indices) - 1):
        start_idx = split_indices[i]
        end_idx = split_indices[i + 1]
        avg.append(np.mean(list[start_idx:end_idx]))

    return avg

In [2]:
token_list, loss_list = train(Baseline(), 1)
token_list = split_and_average(token_list, 100)
loss_list = split_and_average(loss_list, 100)
print(token_list)
print(loss_list)

parameters: 49,297,489
Baseline(
  (embedding): Embedding(50257, 384)
  (model): Transformer(
    (pe): PositionalEncoding()
    (layers): ModuleList(
      (0-5): 6 x TransformerLayer(
        (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadSelfAttention(
          (wq): Linear(in_features=384, out_features=384, bias=True)
          (wk): Linear(in_features=384, out_features=384, bias=True)
          (wv): Linear(in_features=384, out_features=384, bias=True)
          (wo): Linear(in_features=384, out_features=384, bias=True)
        )
        (ffn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (ffn): GLUFeedForward(
          (linear1): Linear(in_features=384, out_features=2048, bias=True)
          (linear2): Linear(in_features=1024, out_features=384, bias=True)
        )
      )
    )
  )
  (output): Linear(in_features=384, out_features=50257, bias=True)
)


Epoch 1: 100%|██████████| 1801350/1801350 [11:20:45<00:00, 44.10it/s, loss=5.4937, ema loss=4.1913]  


[592633.1344206008, 1775149.7129613734, 2968149.7804291844, 4153741.575487083, 5341142.065236052, 6515029.654077253, 7696866.1204188485, 8883754.95888412, 10068568.441373391, 11251683.641459227, 12428944.038709125, 13600162.233133048, 14777525.638540773, 15961294.550768174, 17137959.51287554, 18324709.603261802, 19498156.652961373, 20687730.848940004, 21862212.824034333, 23027673.436394848, 24209810.060767315, 25388208.0616309, 26558684.936223175, 27745798.69776824, 28931265.443309586, 30096726.689527895, 31272323.802918456, 32442116.755128317, 33622678.421630904, 34790034.54892704, 35959152.56798283, 37134680.75521415, 38314104.41862661, 39493853.57793991, 40670731.13037507, 41854448.36051502, 43038176.97133047, 44230061.0430006, 45417589.04214592, 46605326.939484976, 47775253.101030044, 48935971.257831946, 50097247.880515024, 51269839.88789699, 52456797.68955454, 53626359.88248927, 54799712.562918454, 55979029.2639485, 57150185.287958115, 58331884.7344206, 59508089.19484978, 60687455

In [2]:
token_list, loss_list = train(GLUAttention(), 1)
token_list = split_and_average(token_list, 100)
loss_list = split_and_average(loss_list, 100)
print(token_list)
print(loss_list)

parameters: 49,298,257
GLUAttention(
  (embedding): Embedding(50257, 384)
  (model): Transformer(
    (pe): PositionalEncoding()
    (layers): ModuleList(
      (0-5): 6 x TransformerLayer(
        (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadSelfAttention(
          (wq): Linear(in_features=384, out_features=384, bias=True)
          (wk): Linear(in_features=384, out_features=384, bias=True)
          (wv): Linear(in_features=384, out_features=512, bias=True)
          (wo): Linear(in_features=256, out_features=384, bias=True)
        )
        (ffn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (ffn): GLUFeedForward(
          (linear1): Linear(in_features=384, out_features=2048, bias=True)
          (linear2): Linear(in_features=1024, out_features=384, bias=True)
        )
      )
    )
  )
  (output): Linear(in_features=384, out_features=50257, bias=True)
)


Epoch 1: 100%|██████████| 1801350/1801350 [11:25:35<00:00, 43.79it/s, loss=4.3933, ema loss=4.1295]   


[596162.9970815451, 1777169.3688412018, 2951576.5273819743, 4117972.13320745, 5283359.27751073, 6458656.3527897, 7645470.674534375, 8813815.348927038, 9991150.129957082, 11182892.41665236, 12361578.900008583, 13548502.643175965, 14723169.877167381, 15903577.51231654, 17081740.72506438, 18249885.009785406, 19423068.302918456, 20595076.49343404, 21772772.33725322, 22959193.63776824, 24135818.6656081, 25323836.541545063, 26503390.407639485, 27682442.862832617, 28860831.480645437, 30050872.03527897, 31233208.808669526, 32417826.863702685, 33594352.2239485, 34772230.56377682, 35955628.710042916, 37112733.17174491, 38295110.33716738, 39475797.15793991, 40644346.893914685, 41828210.1551073, 43005617.9016309, 44178361.13003176, 45357282.39879829, 46531243.32643777, 47712569.941287555, 48899660.843017764, 50082577.42120171, 51258110.96060086, 52425548.410351045, 53605747.40532189, 54789417.777939916, 55971035.037854075, 57159731.88026779, 58346734.57321888, 59536176.34437768, 60723895.23131061,