In [1]:
import torch
import torch.nn as nn
import datasets
import transformers
import torch_utils as tu
import experiment_utils as eu
from jaxtyping import Num, Array

ModuleNotFoundError: No module named 'jax'

In [None]:
seq_len = 128
batch_size = 16
lr = 1e-3
train_steps = 5000
val_steps = 1000
log_steps = 100
warmup_steps = train_steps // 10

In [18]:
dataset = datasets.load_dataset("karpathy/tiny_shakespeare")
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})

In [166]:
tokenizer: transformers.PreTrainedTokenizerFast = (
    transformers.AutoTokenizer.from_pretrained("../tokenizers/bytelevel")
)


def tokenize_fn(examples):
    outputs = tokenizer(
        examples["text"],
        max_length=seq_len + 1,
        stride=1,
        truncation=True,
        return_overflowing_tokens=True,
        return_length=True,
    )

    input_ids = []
    target_ids = []
    for length, ids in zip(outputs["length"], outputs["input_ids"]):
        if length - 1 == seq_len:
            input_ids.append(ids[:-1])
            target_ids.append(ids[1:])
    return {"input_ids": input_ids, "target_ids": target_ids}


tokenized_dataset = dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [167]:
data_collator = transformers.DefaultDataCollator()
train_loader = torch.utils.data.DataLoader(
    tokenized_dataset["train"],
    batch_size=batch_size,
    collate_fn=data_collator,
    shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
    tokenized_dataset["validation"],
    batch_size=batch_size,
    collate_fn=data_collator,
)
test_loader = torch.utils.data.DataLoader(
    tokenized_dataset["test"],
    batch_size=batch_size,
    collate_fn=data_collator,
)

In [None]:
class Llama(nn.Module):
    def __init__(self):
        super().__init__()

In [None]:
class ModelWrapper(nn.Module):
    metrics = {"perplexity": eu.compare_fns.min}

    def __init__(self):
        super().__init__()
        self.model = Llama()

    def forward(self, batch):
        input_ids = batch["input_ids"]
        target_ids = batch["target_ids"].view(-1)
        output_logits = self.model(input_ids).view(-1, input_ids.size(-1))

        loss = nn.functional.cross_entropy(output_logits, target_ids)
        ppl = torch.exp(loss)

        output = {
            "output_logits": output_logits,
            "loss": loss,
            "perplexity": ppl,
        }
        return output

In [None]:
model = ModelWrapper()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = tu.get_lr_scheduler(optimizer, train_steps, warmup_steps)

In [None]:
logger = eu.Logger("../.logs/llm")
logger.start_experiment()

In [None]:
tu.train(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    train_steps=train_steps,
    logger=logger,
    log_steps=log_steps,
    val_steps=val_steps,
)

In [None]:
logger.end_experiment()