In [1]:
import os

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

from pathlib import Path

import pandas as pd
from torch.utils.data import DataLoader

from prediction.data import MaskedAISDataset
from prediction.model import TrajectoryTrainer, TrajectoryTransformer

LOG_DIR = Path().resolve().parent.parent / "logs"

## Create Dataset if it does not exist yet, skip otherwise

In [None]:
from datetime import datetime

from prediction.data import vessel_groups
from prediction.preprocessing import load_and_build, remove_outliers_parallel

groups = vessel_groups()
df = load_and_build(datetime(2024, 1, 1), datetime(2024, 1, 1), 100, groups, verbose=True)
df = remove_outliers_parallel(
    df=df,
    threshold_partition_sog=5.0,
    threshold_association_sog=15.0,
    threshold_completeness=100.0,
    threshold_partition_distance=100.0,
    threshold_association_distance=100.0,
    additional_filter_columns=["orientations"]
)
df.to_pickle("ais_data.pkl")

Loading and building trajectories for 1 days using 11 processes
Loading and building trajectories for 2024-01-01 00:00:00


100%|██████████| 1/1 [00:06<00:00,  6.37s/it]
100%|██████████| 14868/14868 [02:12<00:00, 111.80it/s]


## Load the dataset if it exists

In [6]:
df = pd.read_pickle('ais_data.pkl')

In [7]:
MAX_SEQ_LEN = 100
dataset = MaskedAISDataset(df, MAX_SEQ_LEN, n_workers=1, normalize=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
device = "mps"

transformer = TrajectoryTransformer(
    d_model=128,
    nhead=4,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=512,
    dropout=0.1,
    max_seq_len=MAX_SEQ_LEN
)

100%|██████████| 12087/12087 [00:05<00:00, 2060.26it/s]


In [8]:
trainer = TrajectoryTrainer(transformer, loader, loader, LOG_DIR, device=device)

In [9]:
trainer.train(max_epochs=10)

Epoch 0: 100%|██████████| 378/378 [00:30<00:00, 12.28it/s, loss=0.0622]
  ) and not torch._nested_tensor_from_mask_left_aligned(



Epoch 1/10 - Time: 39.50s
train_loss: 0.0873
val_loss: 0.0565
Learning rate: 0.000098
--------------------------------------------------


Epoch 1: 100%|██████████| 378/378 [00:28<00:00, 13.06it/s, loss=0.0459]



Epoch 2/10 - Time: 36.65s
train_loss: 0.0588
val_loss: 0.0529
Learning rate: 0.000091
--------------------------------------------------


Epoch 2: 100%|██████████| 378/378 [00:28<00:00, 13.23it/s, loss=0.0629]



Epoch 3/10 - Time: 36.48s
train_loss: 0.0562
val_loss: 0.0531
Learning rate: 0.000080
--------------------------------------------------


Epoch 3: 100%|██████████| 378/378 [00:27<00:00, 13.78it/s, loss=0.0494]



Epoch 4/10 - Time: 34.88s
train_loss: 0.0544
val_loss: 0.0518
Learning rate: 0.000066
--------------------------------------------------


Epoch 4: 100%|██████████| 378/378 [00:27<00:00, 13.77it/s, loss=0.0558]



Epoch 5/10 - Time: 34.87s
train_loss: 0.0531
val_loss: 0.0505
Learning rate: 0.000051
--------------------------------------------------


Epoch 5: 100%|██████████| 378/378 [00:27<00:00, 13.73it/s, loss=0.0525]



Epoch 6/10 - Time: 34.95s
train_loss: 0.0528
val_loss: 0.0505
Learning rate: 0.000035
--------------------------------------------------


Epoch 6: 100%|██████████| 378/378 [00:27<00:00, 13.69it/s, loss=0.0406]



Epoch 7/10 - Time: 35.06s
train_loss: 0.0522
val_loss: 0.0506
Learning rate: 0.000021
--------------------------------------------------


Epoch 7: 100%|██████████| 378/378 [00:27<00:00, 13.60it/s, loss=0.0356]



Epoch 8/10 - Time: 35.42s
train_loss: 0.0520
val_loss: 0.0499
Learning rate: 0.000010
--------------------------------------------------


Epoch 8: 100%|██████████| 378/378 [00:27<00:00, 13.52it/s, loss=0.0556]



Epoch 9/10 - Time: 35.39s
train_loss: 0.0509
val_loss: 0.0496
Learning rate: 0.000003
--------------------------------------------------


Epoch 9: 100%|██████████| 378/378 [00:27<00:00, 13.60it/s, loss=0.0441]



Epoch 10/10 - Time: 35.44s
train_loss: 0.0511
val_loss: 0.0491
Learning rate: 0.000001
--------------------------------------------------


{'train_loss': 0.0510633668798224, 'val_loss': 0.04913883752844952}