In [None]:
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm
from wzh.transformer import Transformer

torch.manual_seed(0)
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataset = load_dataset("data/wikitext2", "wikitext-2-v1")
dataset = load_dataset(
    "data/wikitext2/",
    data_files={
        "train": "train-00000-of-00001.parquet",
    },
)

# tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("./gpt2-tokenizer")
vocab_size = len(tokenizer)


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        dim_model = 384
        self.embedding = nn.Embedding(vocab_size, dim_model)
        self.model = Transformer(
            nlayer=6,
            dim_model=dim_model,
            num_head=8,
            max_seq_len=1024,
            glu_attn=False,
        )
        self.output = nn.Linear(dim_model, vocab_size)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.model(x, mask)
        x = self.output(x)
        return x


class GLUAttention(nn.Module):
    def __init__(self):
        super().__init__()
        dim_model = 384
        self.embedding = nn.Embedding(vocab_size, dim_model)
        self.model = Transformer(
            nlayer=6,
            dim_model=dim_model,
            num_head=8,
            max_seq_len=1024,
            glu_attn=True,
        )
        self.output = nn.Linear(dim_model, vocab_size)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.model(x, mask)
        x = self.output(x)
        return x


def prepare_data(example):
    tokens = tokenizer(example["text"], truncation=True, max_length=1024)
    return {"input_ids": tokens["input_ids"], "labels": tokens["input_ids"]}


tokenized_dataset = dataset.map(
    prepare_data, remove_columns=dataset["train"].column_names
)


def collate_fn(examples):
    input_ids = [torch.tensor(x["input_ids"], dtype=torch.long) for x in examples]
    input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True)
    labels = input_ids.clone()
    return {"input_ids": input_ids, "labels": labels}


train_loader = DataLoader(
    tokenized_dataset["train"],
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    pin_memory=True,
)


def train(model, num_epochs):
    model.to(device)
    model.train()
    print(f"parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(model)
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()

    num_token_list = []
    loss_list = []
    ema_loss = 8
    total_tokens = 0

    for epoch in range(num_epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            seq_len = input_ids.size(1)
            if seq_len == 0:
                continue
            total_tokens += seq_len
            mask = torch.triu(
                torch.ones((seq_len, seq_len), dtype=torch.bool, device=device),
                diagonal=1,
            )
            optimizer.zero_grad()
            logits = model(input_ids, mask)
            loss = criterion(
                logits[:, :-1].view(-1, vocab_size), labels[:, 1:].view(-1)
            )
            loss.backward()
            optimizer.step()
            ema_loss = 0.999 * ema_loss + 0.001 * loss.item()
            progress_bar.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "ema loss": f"{ema_loss:.4f}",
                }
            )
            num_token_list.append(total_tokens)
            loss_list.append(loss.item())
        scheduler.step()
    return num_token_list, loss_list


import numpy as np


def split_and_average(list, num_splits=100):
    split_indices = np.linspace(0, len(list), num_splits + 1, dtype=int)
    avg = []

    for i in range(len(split_indices) - 1):
        start_idx = split_indices[i]
        end_idx = split_indices[i + 1]
        avg.append(np.mean(list[start_idx:end_idx]))

    return avg

In [None]:
token_list, loss_list = train(Baseline(), 10)
token_list = split_and_average(token_list, 100)
loss_list = split_and_average(loss_list, 100)
print(token_list)
print(loss_list)

In [None]:
token_list, loss_list = train(GLUAttention(), 10)
token_list = split_and_average(token_list, 100)
loss_list = split_and_average(loss_list, 100)
print(token_list)
print(loss_list)