In [1]:
import os

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

from pathlib import Path

import pandas as pd
from torch.utils.data import DataLoader

from prediction.data import MaskedAISDataset
from prediction.model import TrajectoryTrainer, TrajectoryTransformer

LOG_DIR = Path().resolve().parent.parent / "logs"

In [2]:
df = pd.read_pickle('ais_data.pkl')

In [3]:
MAX_SEQ_LEN = 100
dataset = MaskedAISDataset(df, MAX_SEQ_LEN, n_workers=1, normalize=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
device = "mps"

transformer = TrajectoryTransformer(
    d_model=128,
    nhead=4,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=512,
    dropout=0.1,
    max_seq_len=MAX_SEQ_LEN
)

100%|██████████| 8585/8585 [00:04<00:00, 2028.60it/s]


In [4]:
trainer = TrajectoryTrainer(transformer, loader, loader, LOG_DIR, device=device)

In [5]:
trainer.train(max_epochs=10)

Epoch 0: 100%|██████████| 269/269 [00:22<00:00, 11.78it/s, loss=0.0454]
  ) and not torch._nested_tensor_from_mask_left_aligned(



Epoch 1/10 - Time: 29.90s
train_loss: 0.0788
val_loss: 0.0450
Learning rate: 0.000098
--------------------------------------------------


Epoch 1: 100%|██████████| 269/269 [00:20<00:00, 13.05it/s, loss=0.0411]



Epoch 2/10 - Time: 26.89s
train_loss: 0.0497
val_loss: 0.0420
Learning rate: 0.000091
--------------------------------------------------


Epoch 2: 100%|██████████| 269/269 [00:20<00:00, 13.02it/s, loss=0.0862]



Epoch 3/10 - Time: 26.30s
train_loss: 0.0451
val_loss: 0.0427
Learning rate: 0.000080
--------------------------------------------------


Epoch 3: 100%|██████████| 269/269 [00:20<00:00, 13.34it/s, loss=0.0632]



Epoch 4/10 - Time: 25.54s
train_loss: 0.0443
val_loss: 0.0414
Learning rate: 0.000066
--------------------------------------------------


Epoch 4: 100%|██████████| 269/269 [00:20<00:00, 13.17it/s, loss=0.0659]



Epoch 5/10 - Time: 26.13s
train_loss: 0.0429
val_loss: 0.0407
Learning rate: 0.000051
--------------------------------------------------


Epoch 5: 100%|██████████| 269/269 [00:19<00:00, 13.77it/s, loss=0.0451]



Epoch 6/10 - Time: 24.80s
train_loss: 0.0426
val_loss: 0.0408
Learning rate: 0.000035
--------------------------------------------------


Epoch 6: 100%|██████████| 269/269 [00:19<00:00, 14.05it/s, loss=0.0555]



Epoch 7/10 - Time: 24.86s
train_loss: 0.0418
val_loss: 0.0397
Learning rate: 0.000021
--------------------------------------------------


Epoch 7: 100%|██████████| 269/269 [00:19<00:00, 13.72it/s, loss=0.0440]



Epoch 8/10 - Time: 24.78s
train_loss: 0.0418
val_loss: 0.0400
Learning rate: 0.000010
--------------------------------------------------


Epoch 8: 100%|██████████| 269/269 [00:19<00:00, 13.57it/s, loss=0.0480]



Epoch 9/10 - Time: 25.17s
train_loss: 0.0412
val_loss: 0.0398
Learning rate: 0.000003
--------------------------------------------------


Epoch 9: 100%|██████████| 269/269 [00:19<00:00, 13.55it/s, loss=0.0341]



Epoch 10/10 - Time: 24.98s
train_loss: 0.0413
val_loss: 0.0399
Learning rate: 0.000001
--------------------------------------------------


{'train_loss': 0.041837751740748996, 'val_loss': 0.039724161017506095}