In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from dataset_precompute import PrecomputedDataset
from model import GPTConfig, GPT

import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Model arguments

n_layer = 6 
n_head = 6
n_embd = 384
block_size = 1024
bias = False
dropout = 0.0

device = 'cuda'
dtype = 'float132' 

In [None]:
# DataLoaders
train_dataloader = DataLoader(PrecomputedDataset('data/nesymres/train_nc.pt'), batch_size=64, shuffle=True, drop_last=True)
val_dataloader = DataLoader(PrecomputedDataset('data/nesymres/val_nc.pt'), batch_size=64, shuffle=False, drop_last=True)

In [None]:
# Model
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, bias=bias, vocab_size=14, dropout=dropout)
gptconf = GPTConfig(**model_args)
model = GPT(gptconf).to(device)
model = torch.compile(model)

In [None]:
# Hyperparameters
learning_rate = 1e-3
batch_size = 128
epochs = 100

In [None]:
# Loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Accuracy and loss lists
train_accuracy_list = []
test_accuracy_list = []
train_loss_list = []
test_loss_list = []

In [None]:
def train_loop(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device, dtype=torch.float32, non_blocking=True)
        y = y.to(device, dtype=torch.float32, non_blocking=True)

        logits, loss = model(X, y)

        # backprop
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        loss, current = loss.item(), (batch + 1) * len(X)
        print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device, dtype=torch.float32, non_blocking=True)
            y = y.to(device, dtype=torch.float32, non_blocking=True)

            logits, loss = model(X, y)
            pred = (torch.sigmoid(logits) > 0.5).type(torch.float16)


            test_loss += loss.item()
            # multihot prediction in pred, shape is (batch_size, 14)
            # multihot ground truth in y, shape is (batch_size, 14)
            # correct only if all is correct
            correct += (pred == y).all(dim=1).type(torch.float16).sum().item()

    test_loss /= num_batches
    correct /= size
    if dataloader == train_dataloader:
        print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        train_accuracy_list.append(100*correct)
        train_loss_list.append(test_loss)
    else:
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        test_accuracy_list.append(100*correct)
        test_loss_list.append(test_loss)

In [None]:
if __name__ == '__main__':
    print(f"Epoch 0\n-------------------------------")
    test_loop(val_dataloader, model)
    test_loop(train_dataloader, model)

    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, optimizer)
        test_loop(val_dataloader, model)
        test_loop(train_dataloader, model)

    plt.figure(figsize=(10, 5))
    epochs_range = range(epochs + 1)

    # Training accuracy
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, train_accuracy_list, label='Train Accuracy', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy')
    plt.legend()

    # Training loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, train_loss_list, label='Train Loss', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()
    print("Done!")