In [1]:
import torch
from sklearn import metrics
from tabulate import tabulate
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import TrackModel
from utils.dataset import JSONTrackDataset

In [2]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
print('Using device:', device)
print()

# Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')

Using device: cpu



In [3]:
train_dataset = JSONTrackDataset('data/train.json')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=50, num_workers=7, pin_memory=True,
                              drop_last=True)

val_dataset = JSONTrackDataset('data/val.json')
val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=50, num_workers=7, pin_memory=True, drop_last=True)

test_dataset = JSONTrackDataset('data/test.json')
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=50, num_workers=7, pin_memory=True, drop_last=True)

In [4]:
model = TrackModel(5, 4).to(device)
EPOCHS_COUNT = 20

optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss(reduction='mean').to(device)

In [5]:
for epoch in range(EPOCHS_COUNT):
    # training
    loss_sum = 0
    ep_preds = []
    ep_labels = []
    for data_row in tqdm(train_dataloader, desc='Epoch {}/{}'.format(epoch, EPOCHS_COUNT - 1), ncols=80):
        data = data_row['data'].to(device)
        label = data_row['label'].to(device)

        optimizer.zero_grad()

        y_pred = model(data)
        loss = criterion(y_pred, label)
        loss_sum += loss.item()

        np_label = label.argmax(dim=1).data.numpy()
        np_preds = y_pred.argmax(dim=1).data.numpy()

        for i in np_label:
            ep_labels.append(i)
        for i in np_preds:
            ep_preds.append(i)

        loss.backward()
        optimizer.step()

    accuracy = metrics.accuracy_score(ep_labels, ep_preds)
    recall = metrics.recall_score(ep_labels, ep_preds, average='micro')
    precision = metrics.precision_score(ep_labels, ep_preds, average='micro')
    f1 = metrics.f1_score(ep_labels, ep_preds, average='micro')
    roc_auc = metrics.roc_auc_score(ep_labels, ep_preds, average='micro')

    print('')
    print(tabulate([
        ['Train', loss_sum / len(train_dataloader), accuracy, recall, precision, f1, roc_auc]
    ], headers=['Stage', 'Loss', 'Accuracy', 'Recall', 'Precision', 'F1', 'ROC-AUC'], tablefmt='orgtbl'))
    print('')

    # validation
    loss_sum = 0
    ep_preds = []
    ep_labels = []
    for val_row in tqdm(val_dataloader, desc='Val', ncols=80):
        data = val_row['data'].to(device)
        label = val_row['label'].to(device)

        y_pred = model(data)

        np_label = label.argmax(dim=1).data.numpy()
        np_preds = y_pred.argmax(dim=1).data.numpy()

        for i in np_label:
            ep_labels.append(i)
        for i in np_preds:
            ep_preds.append(i)

        loss = criterion(y_pred, label)
        loss_sum += loss.item()

    accuracy = metrics.accuracy_score(ep_labels, ep_preds)
    recall = metrics.recall_score(ep_labels, ep_preds, average='micro')
    precision = metrics.precision_score(ep_labels, ep_preds, average='micro')
    f1 = metrics.f1_score(ep_labels, ep_preds, average='micro')
    roc_auc = metrics.roc_auc_score(ep_labels, ep_preds, average='micro')

    print('')
    print(tabulate([
        ['Validation', loss_sum / len(train_dataloader), accuracy, recall, precision, f1, roc_auc]
    ], headers=['Stage', 'Loss', 'Accuracy', 'Recall', 'Precision', 'F1', 'ROC-AUC'], tablefmt='orgtbl'))
    print('')

Epoch 0/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 429.80it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |   Loss |   Accuracy |    Recall |   Precision |        F1 |   ROC-AUC |
|---------+--------+------------+-----------+-------------+-----------+-----------|
| Train   | 960.28 |  0.0225933 | 0.0225933 |   0.0225933 | 0.0225933 |  0.499913 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 552.22it/s]
Epoch 1/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |     Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+----------+------------+----------+-------------+---------+-----------|
| Validation | 0.170668 |    0.02192 |  0.02192 |     0.02192 | 0.02192 |       0.5 |



Epoch 1/19: 100%|██████████████████████████| 1126/1126 [00:01<00:00, 592.37it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |    Recall |   Precision |        F1 |   ROC-AUC |
|---------+----------+------------+-----------+-------------+-----------+-----------|
| Train   | 0.461187 |  0.0211723 | 0.0211723 |   0.0211723 | 0.0211723 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 700.00it/s]
Epoch 2/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |     Loss |   Accuracy |    Recall |   Precision |        F1 |   ROC-AUC |
|------------+----------+------------+-----------+-------------+-----------+-----------|
| Validation | 0.137668 |  0.0218133 | 0.0218133 |   0.0218133 | 0.0218133 |       0.5 |



Epoch 2/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 520.77it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+----------+------------+----------+-------------+----------+-----------|
| Train   | 0.372366 |   0.449876 | 0.449876 |    0.449876 | 0.449876 |   0.51092 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 554.37it/s]
Epoch 3/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+----------+------------+----------+-------------+----------+-----------|
| Validation | 0.111319 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 3/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 550.78it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+----------+------------+----------+-------------+----------+-----------|
| Train   | 0.301456 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 589.30it/s]
Epoch 4/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0902909 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 4/19: 100%|██████████████████████████| 1126/1126 [00:01<00:00, 580.49it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+----------+------------+----------+-------------+----------+-----------|
| Train   | 0.244846 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 525.07it/s]
Epoch 5/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0735033 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 5/19: 100%|██████████████████████████| 1126/1126 [00:01<00:00, 586.11it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+----------+------------+----------+-------------+----------+-----------|
| Train   | 0.199651 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 685.00it/s]
Epoch 6/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0601044 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 6/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 517.67it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+----------+------------+----------+-------------+----------+-----------|
| Train   | 0.163571 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 440.54it/s]
Epoch 7/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0494087 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 7/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 524.02it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |     Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+----------+------------+----------+-------------+----------+-----------|
| Train   | 0.134758 |   0.978845 | 0.978845 |    0.978845 | 0.978845 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 650.22it/s]
Epoch 8/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0408715 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 8/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 529.74it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |    Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+---------+------------+----------+-------------+----------+-----------|
| Train   | 0.11176 |   0.978845 | 0.978845 |    0.978845 | 0.978845 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 651.88it/s]
Epoch 9/19:   0%|                                      | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+----------+-----------|
| Validation | 0.0340482 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 9/19: 100%|██████████████████████████| 1126/1126 [00:02<00:00, 552.32it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0934089 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 646.60it/s]
Epoch 10/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+----------+-----------|
| Validation | 0.0286086 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 10/19: 100%|█████████████████████████| 1126/1126 [00:02<00:00, 551.71it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0787521 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 616.83it/s]
Epoch 11/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+----------+-----------|
| Validation | 0.0242671 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 11/19: 100%|█████████████████████████| 1126/1126 [00:02<00:00, 555.25it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0670509 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 688.08it/s]
Epoch 12/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0208132 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 12/19: 100%|█████████████████████████| 1126/1126 [00:01<00:00, 589.47it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0576975 |   0.978845 | 0.978845 |    0.978845 | 0.978845 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 603.27it/s]
Epoch 13/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+----------+-----------|
| Validation | 0.0180361 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 13/19: 100%|█████████████████████████| 1126/1126 [00:02<00:00, 521.62it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0502259 |   0.978863 | 0.978863 |    0.978863 | 0.978863 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 682.88it/s]
Epoch 14/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+----------+-----------|
| Validation | 0.0158289 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 14/19: 100%|█████████████████████████| 1126/1126 [00:01<00:00, 576.71it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0442957 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 539.82it/s]
Epoch 15/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0140803 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 15/19: 100%|█████████████████████████| 1126/1126 [00:02<00:00, 520.58it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0395289 |   0.978845 | 0.978845 |    0.978845 | 0.978845 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 601.52it/s]
Epoch 16/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |     Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+----------+------------+----------+-------------+---------+-----------|
| Validation | 0.012675 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 16/19: 100%|█████████████████████████| 1126/1126 [00:01<00:00, 584.54it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0357469 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 653.28it/s]
Epoch 17/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+---------+-----------|
| Validation | 0.0115539 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 17/19: 100%|█████████████████████████| 1126/1126 [00:02<00:00, 518.61it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0327174 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 637.82it/s]
Epoch 18/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|------------+-----------+------------+----------+-------------+----------+-----------|
| Validation | 0.0106454 |   0.978133 | 0.978133 |    0.978133 | 0.978133 |       0.5 |



Epoch 18/19: 100%|█████████████████████████| 1126/1126 [00:01<00:00, 563.83it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0302988 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 629.41it/s]
Epoch 19/19:   0%|                                     | 0/1126 [00:00<?, ?it/s]


| Stage      |       Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+------------+------------+----------+-------------+---------+-----------|
| Validation | 0.00994577 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



Epoch 19/19: 100%|█████████████████████████| 1126/1126 [00:01<00:00, 575.79it/s]
Val:   0%|                                              | 0/375 [00:00<?, ?it/s]


| Stage   |      Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+-----------+------------+----------+-------------+----------+-----------|
| Train   | 0.0283681 |   0.978828 | 0.978828 |    0.978828 | 0.978828 |       0.5 |



Val: 100%|███████████████████████████████████| 375/375 [00:00<00:00, 554.38it/s]



| Stage      |       Loss |   Accuracy |   Recall |   Precision |      F1 |   ROC-AUC |
|------------+------------+------------+----------+-------------+---------+-----------|
| Validation | 0.00937654 |    0.97808 |  0.97808 |     0.97808 | 0.97808 |       0.5 |



In [6]:
loss_sum = 0
ep_preds = []
ep_labels = []
for data_row in tqdm(test_dataloader, desc='Test', ncols=80):
    data = data_row['data'].to(device)
    label = data_row['label'].to(device)

    optimizer.zero_grad()

    y_pred = model(data)
    loss = criterion(y_pred, label)
    loss_sum += loss.item()

    np_label = label.argmax(dim=1).data.numpy()
    np_preds = y_pred.argmax(dim=1).data.numpy()

    for i in np_label:
        ep_labels.append(i)
    for i in np_preds:
        ep_preds.append(i)

accuracy = metrics.accuracy_score(ep_labels, ep_preds)
recall = metrics.recall_score(ep_labels, ep_preds, average='micro')
precision = metrics.precision_score(ep_labels, ep_preds, average='micro')
f1 = metrics.f1_score(ep_labels, ep_preds, average='micro')
roc_auc = metrics.roc_auc_score(ep_labels, ep_preds, average='micro')

print('')
print(tabulate([
    ['Test', loss_sum / len(train_dataloader), accuracy, recall, precision, f1, roc_auc]
], headers=['Stage', 'Loss', 'Accuracy', 'Recall', 'Precision', 'F1', 'ROC-AUC'], tablefmt='orgtbl'))
print('')

Test: 100%|██████████████████████████████████| 375/375 [00:00<00:00, 660.99it/s]



| Stage   |       Loss |   Accuracy |   Recall |   Precision |       F1 |   ROC-AUC |
|---------+------------+------------+----------+-------------+----------+-----------|
| Test    | 0.00895241 |   0.979627 | 0.979627 |    0.979627 | 0.979627 |       0.5 |

