In [None]:
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split

from classifier.data.dataset import ClassifierDataset
from classifier.data.transforms import ToEmbedding
from classifier.model import Classifier


In [None]:
# Hyper parameters

DATASET_PATH = './external/data/dataset.json'
EMBEDDING_MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 22
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

MAX_ITER = 500
BATCH_SIZE = 1024
LEARNING_RATE = 0.001
TRAIN_SIZE = 0.8


In [None]:
# Load data

dataset = ClassifierDataset(
    dataset_path=DATASET_PATH,
    transform=ToEmbedding(embedding_model_name=EMBEDDING_MODEL_NAME)
)

train_dataset, test_dataset = random_split(dataset, [TRAIN_SIZE, 1-TRAIN_SIZE])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
# Define model, loss function and optimizer

model = Classifier(
    context_size=dataset[0][0].size(0),
    grimoire_size=dataset[0][1].size(0),
    output_size=1
).to(DEVICE)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [None]:
# Training and evaluation loop

train_loss_history, train_acc_history = [], []
test_loss_history, test_acc_history = [], []

for epoch in range(MAX_ITER):
    # Training
    train_total, train_loss, train_correct = 0, 0, 0
    model.train()
    for context, grimoire, label in train_dataloader:
        context, grimoire, label = context.to(DEVICE), grimoire.to(DEVICE), label.to(DEVICE)
        # Forward pass
        outputs = model(context, grimoire)
        loss = criterion(outputs, label)
        
        # Record metrics
        train_total += label.size(0)
        train_loss += loss.item()
        train_correct += ((outputs.data > 0.5) == label).sum().item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluation
    model.eval()  # set the model to evaluation mode
    with torch.no_grad():  # disable gradient calculation
        test_total, test_loss, test_correct = 0, 0, 0
        for context, grimoire, label in test_dataloader:
            context, grimoire, label = context.to(DEVICE), grimoire.to(DEVICE), label.to(DEVICE)

            outputs = model(context, grimoire)
            loss = criterion(outputs, label)

            test_total += label.size(0)
            test_loss += loss.item()
            test_correct += ((outputs.data > 0.5) == label).sum().item()

    # Print and save metrics
    print(
        f'Epoch {epoch+1:>4}, '
        f'TrainLoss: {100 * train_loss / train_total:.4f}, TrainAcc: {100 * train_correct / train_total:.2f}%, '
        f'TestLoss: {100 * test_loss / test_total:.4f}, TestAcc: {100 * test_correct / test_total:.2f}%'
    )
    train_loss_history.append(train_loss / train_total)
    train_acc_history.append(train_correct / train_total)
    test_loss_history.append(test_loss / test_total)
    test_acc_history.append(test_correct / test_total)

print('Training complete.')


In [None]:
# Plot metric history

window_size = MAX_ITER // 100
def moving_average(data: list, window_size: int) -> np.ndarray:
    """Compute moving average."""
    return np.convolve(data, np.ones(window_size), 'valid') / window_size

plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.plot(moving_average(train_loss_history, window_size), label='train')
plt.plot(moving_average(test_loss_history, window_size), label='test')
plt.title('Loss history')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(122)
plt.plot(moving_average(train_acc_history, window_size), label='train')
plt.plot(moving_average(test_acc_history, window_size), label='test')
plt.title('Accuracy history')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
# Save model

import os

os.makedirs('./.cache', exist_ok=True)
torch.save(model.state_dict(), './.cache/model.pth')