In [None]:
import torch
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau, CyclicLR

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

In [None]:
%load_ext autoreload
%autoreload 2

from llm.data import load_dataset, text_to_tensor, tensor_to_text
from llm.training import TransformerLMTrainingLoop, TrainingLogs
from llm.transformer import CharGenerativeTransformer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cpu = torch.device('cpu')

In [None]:
(train_dataset, valid_dataset), ctoi, itoc = load_dataset(
    Path("dataset", "lotr.txt"), slice(0, 2_000_000), slice(2_000_000, None)
)

print(f"Train dataset length:      {len(train_dataset)} characters")
print(f"Validation dataset length: {len(valid_dataset)} characters")

In [None]:
sequence_length = 256

model = CharGenerativeTransformer(
    vocab_size=len(ctoi),
    seq_length=sequence_length,
    embedding_dim=64,
    latent_dim=256,
    n_heads=8,
    n_layers=4
)

print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, factor=0.8, patience=25)

print(optimizer)

In [None]:
n_epochs = 1000
batch_size = 32
# batch_size = 16

training_loop = TransformerLMTrainingLoop(
    device, n_epochs, sequence_length, batch_size,
    train_dataset, valid_dataset,
    model, optimizer, scheduler
)

with TrainingLogs() as logs:
    training_loop.run(logs)

In [None]:
def display_training_review(train_losses, valid_losses=None, learning_rates=None):
    fig, ax = plt.subplots()

    ax.set_title("Cross entropy over the epochs")
    ax.plot(train_losses, label="train loss")
    if valid_losses is not None:
        ax.plot(valid_losses, label="valid loss")
        best_valid_epoch = np.argmin(valid_losses)
        best_valid = valid_losses[best_valid_epoch]
        ax.plot(best_valid_epoch, best_valid, 'or')
        ax.text(best_valid_epoch, best_valid, f"({best_valid_epoch}, {best_valid:.6f})", ha='center', va='top', color='red')
    ax.legend(loc='upper right')
    ax.set_xlabel("epoch")
    ax.set_ylabel("loss")
    ax.grid()

    if learning_rates is not None:
        lr_ax = ax.twinx()
        lr_ax.plot(learning_rates, ':r', label='learning rate')
        lr_ax.set_xlabel("epoch")
        lr_ax.set_ylabel("learning rate")
        lr_ax.legend(loc='lower left')

    fig.tight_layout()
    plt.show()

In [None]:
print(f"Training time: {logs.ellapsed_time:.0f} seconds")
display_training_review(logs.train_loss, logs.valid_loss, logs.lr)

In [None]:
def print_prediction_examples(model, test_contexts, n_predictions=80, repeat=4):
    model.eval()

    with torch.no_grad():
        print("Deterministic predictions:")
        for context in test_contexts:
            context_tokens = text_to_tensor(context, ctoi)
            predicted_tokens = model.predict_argmax(context_tokens, n_predictions)
            predicted_text = tensor_to_text(predicted_tokens, itoc)
            print(f"{repr(context)} -> {repr(predicted_text)}")

        print("\nProbabilistic predictions:")
        for context in test_contexts:
            for _ in range(repeat):
                predicted_tokens = model.predict_proba(context_tokens, n_predictions)
                predicted_text = tensor_to_text(predicted_tokens, itoc)
                print(f"{repr(context)} -> {repr(predicted_text)}")

In [None]:
test_contexts = (
    "The throne",
    "Aragorn son of",
    "He paused, ",
    "the ring of ",
    "suddenly",
    "you cannot ",
    "I am a servant of the Secret Fire",
)
model.to(cpu)
model.eval()
print_prediction_examples(model, test_contexts, 80)