Configuration

In [1]:
import torch

DEVICE = torch.device('cuda:0')
EPOCHS = 2000
BATCH_SIZE = 8

Load dataset

In [2]:
from torch.utils.data import DataLoader, random_split
import numpy as np
from dataset import MyDataset

raw_data = np.load('dataset.npy')
dataset = MyDataset(raw_data, device=DEVICE)
train_set, validation_set = random_split(dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_set, batch_size=len(validation_set))

Create model

In [3]:
from model import Model

model = Model(mode='train').to(DEVICE)

Train

In [4]:
from torchmetrics.classification import BinaryF1Score, BinaryAccuracy, BinaryPrecision, BinaryRecall
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

metric_accuracy = BinaryAccuracy().to(DEVICE)
metric_precision = BinaryPrecision().to(DEVICE)
metric_recall = BinaryRecall().to(DEVICE)
metric_f1 = BinaryF1Score().to(DEVICE)
writer = SummaryWriter()

for i in tqdm(range(EPOCHS), desc='Training epochs'):
    model.train()
    ep_loss = []
    for mb_x, mb_y in train_dataloader:
        model.optim.zero_grad()
        out = model(mb_x)
        loss = model.loss_fn(out, mb_y)
        loss.backward()
        model.optim.step()

        ep_loss.append(loss.cpu().item())
        metric_accuracy(out, mb_y)
        metric_f1(out, mb_y)
        metric_precision(out, mb_y)
        metric_recall(out, mb_y)

    if (i + 1) % 100 == 0:
        torch.save(model.state_dict(), f"./models/{i + 1}_model.pt")

    v_loss, v_acc, v_pre, v_rec, v_f1 = model.validate(validation_dataloader)
    writer.add_scalars('Loss', {'train': np.mean(ep_loss), 'validation': v_loss}, i)
    writer.add_scalars('Metric/Accuracy', {'train': metric_accuracy.compute(), 'validation': v_acc}, i)
    writer.add_scalars('Metric/Precision', {'train': metric_precision.compute(), 'validation': v_pre}, i)
    writer.add_scalars('Metric/Recall', {'train': metric_recall.compute(), 'validation': v_rec}, i)
    writer.add_scalars('Metric/F1', {'train': metric_f1.compute(), 'validation': v_f1}, i)

Training epochs:  71%|███████   | 1419/2000 [07:35<03:06,  3.12it/s]


KeyboardInterrupt: 