In [1]:
import torch
import torch.nn as nn
import datasets
import transformers
import torch_utils as tu
import experiment_utils as eu
from typing import Callable

In [2]:
seq_len = 128
batch_size = 16
lr = 1e-4
train_steps = 5000
val_steps = 1000
log_steps = 10
warmup_steps = train_steps // 10

vocab_size = 256
num_layers = 4
hidden_dim = 256
num_heads = 8
head_dim = 32
mlp_dim = 1024

In [3]:
dataset = datasets.load_dataset("karpathy/tiny_shakespeare")
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

In [4]:
tokenizer: transformers.PreTrainedTokenizerFast = (
    transformers.AutoTokenizer.from_pretrained("../tokenizers/bytelevel")
)


def tokenize_fn(examples):
    outputs = tokenizer(
        examples["text"],
        max_length=seq_len + 1,
        stride=1,
        truncation=True,
        return_overflowing_tokens=True,
        return_length=True,
    )

    input_ids = []
    target_ids = []
    for length, ids in zip(outputs["length"], outputs["input_ids"]):
        if length - 1 == seq_len:
            input_ids.append(ids[:-1])
            target_ids.append(ids[1:])
    return {"input_ids": input_ids, "target_ids": target_ids}


tokenized_dataset = dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

In [5]:
data_collator = transformers.DefaultDataCollator()
train_loader = torch.utils.data.DataLoader(
    tokenized_dataset["train"],
    batch_size=batch_size,
    collate_fn=data_collator,
    shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
    tokenized_dataset["validation"],
    batch_size=batch_size,
    collate_fn=data_collator,
)
test_loader = torch.utils.data.DataLoader(
    tokenized_dataset["test"],
    batch_size=batch_size,
    collate_fn=data_collator,
)

In [6]:
class ModelWrapper(nn.Module):
    metrics = {"perplexity": eu.compare_fns.min}

    def __init__(
        self,
        vocab_size: int,
        num_layers: int,
        hidden_dim: int,
        num_heads: int,
        head_dim: int,
        mlp_dim: int,
        norm: nn.Module = nn.RMSNorm,
        activation: Callable = nn.functional.silu,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = tu.Transformer(
            vocab_size=vocab_size,
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            head_dim=head_dim,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            norm=norm,
            activation=activation,
        )

    def forward(self, batch):
        input_ids = batch["input_ids"]
        target_ids = batch["target_ids"].view(-1)
        output_logits = self.model(input_ids).view(-1, self.vocab_size)

        loss = nn.functional.cross_entropy(output_logits, target_ids)
        ppl = torch.exp(loss)

        output = {
            "output_logits": output_logits,
            "loss": loss,
            "perplexity": ppl,
        }
        return output

In [7]:
model = ModelWrapper(
    vocab_size=vocab_size,
    num_layers=num_layers,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    head_dim=head_dim,
    mlp_dim=mlp_dim,
)

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = tu.get_lr_scheduler(optimizer, train_steps, warmup_steps)

In [9]:
logger = eu.Logger("../.logs/llm")
logger.start_experiment()

In [10]:
tu.train(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    # val_loader=val_loader,
    train_steps=train_steps,
    logger=logger,
    log_steps=log_steps,
    val_steps=val_steps,
)

Training:   0%|          | 0/5000 [00:00<?, ?it/s]

0.0
2.0000000000000002e-07
4.0000000000000003e-07
6.000000000000001e-07
8.000000000000001e-07
1.0000000000000002e-06
1.2000000000000002e-06
1.4000000000000001e-06
1.6000000000000001e-06
1.8e-06
2.0000000000000003e-06
2.2e-06
2.4000000000000003e-06
2.6e-06
2.8000000000000003e-06
3e-06
3.2000000000000003e-06
3.4000000000000005e-06
3.6e-06
3.8e-06
4.000000000000001e-06
4.2000000000000004e-06
4.4e-06
4.6e-06
4.800000000000001e-06
5e-06
5.2e-06
5.4e-06
5.600000000000001e-06
5.8e-06
6e-06
6.2e-06
6.4000000000000006e-06
6.6e-06
6.800000000000001e-06
7.000000000000001e-06
7.2e-06
7.4e-06
7.6e-06
7.8e-06
8.000000000000001e-06
8.200000000000001e-06
8.400000000000001e-06
8.599999999999999e-06
8.8e-06
9e-06
9.2e-06
9.4e-06
9.600000000000001e-06
9.800000000000001e-06
1e-05
1.02e-05
1.04e-05
1.06e-05
1.08e-05
1.1000000000000001e-05
1.1200000000000001e-05
1.1400000000000001e-05
1.16e-05
1.18e-05
1.2e-05
1.22e-05
1.24e-05
1.2600000000000001e-05
1.2800000000000001e-05
1.3000000000000001e-05
1.32e-05
1.

In [11]:
logger.end_experiment()