In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
from torch import nn

from pathlib import Path
import os

In [None]:
from src.bpe import train_or_load_bpe
from src.dataloader import get_dataloader
from src.models.transformer import Transformer
from src.train import train

from torch.optim.lr_scheduler import LambdaLR

DATASET_DIR = os.environ.get("DATASET_DIR", "./datasets")
DATA_PATH = Path(DATASET_DIR) / "tiny_shakespeare" / "data.txt"
BPE_PATH = Path("./bpe_cache")
os.makedirs("./bpe_cache", exist_ok=True)

VOCAB_SIZE = 10000
PREPROCESS = False

### Get byte-pair encoder ###
with DATA_PATH.open("r") as f:
    text = f.read()

bpe = train_or_load_bpe(
    BPE_PATH,
    text,
    vocab_size=VOCAB_SIZE,
    preprocess=PREPROCESS,
)

### Transformer model parameters ###
h = 8
d_model = 512
d_ff = 2048
num_attention_layers = 6
dropout_p = 0.1
num_epochs = 3000
starting_epoch = 0
split_length = 100
batch_size = 32
shuffle = True
device = "mps" if torch.backends.mps.is_available() else "cpu"

MODEL_NAME = f"transformer_vocab={VOCAB_SIZE}"
DO_PROFILE = False

os.makedirs("./checkpoints", exist_ok=True)
checkpoint_file = Path("./checkpoints") / MODEL_NAME

### Load relevant transformer objects and train model ###
dataloader = get_dataloader(
    DATA_PATH,
    bpe,
    split_length=split_length,
    batch_size=batch_size,
    shuffle=shuffle,
    preprocess=PREPROCESS,
)

model = Transformer(
    vocab_size=VOCAB_SIZE+1, # include sentinel
    h=h,
    d_model=d_model,
    d_ff=d_ff,
    num_attention_layers=num_attention_layers,
    dropout_p=dropout_p,
)

loss_fn = nn.CrossEntropyLoss()

WARMUP_STEPS = 10
def lr_fn(epoch: int) -> float:
    epoch += 1 + starting_epoch
    lr = d_model**(-0.5) * min((epoch+1)**(-0.5), (epoch+1) * WARMUP_STEPS**(-1.5))
    return lr

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = LambdaLR(optimizer, [lr_fn])


def do_train():
    return train(
        dataloader,
        model,
        loss_fn,
        optimizer,
        scheduler,
        num_epochs,
        checkpoint_file,
        VOCAB_SIZE+1, # include sentinel
        device=device,
        starting_epoch=starting_epoch,
    )

if DO_PROFILE:
    from torch.profiler import ProfilerActivity

    os.makedirs("./profiles", exist_ok=True)
    with torch.profiler.profile(
        activities=[
            ProfilerActivity.CPU,
            ProfilerActivity.CUDA,
            ProfilerActivity.MTIA,
        ],
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
        record_shapes=True,
        on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiles"),
    ), torch.mps.profiler.profile():
        num_epochs = 1
        do_train()
else:
    losses = do_train()


In [None]:
import matplotlib.pyplot as plt

os.makedirs("./plots", exist_ok=True)
loss_file = Path("./plots") / MODEL_NAME

plt.plot(losses)
plt.title("Loss over time")
plt.savefig(loss_file)
plt.show()


In [None]:
from src.inference import generate_text

prompt = "\n"
num_tokens = 100

result = generate_text(
    prompt=prompt,
    num_tokens=num_tokens,
    model=model,
    vocab_size=VOCAB_SIZE,
    bpe=bpe,
    device=device,
)

