In [1]:
import os

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

from pathlib import Path

import pandas as pd
from torch.utils.data import DataLoader

from prediction.data import MaskedAISDataset
from prediction.model import TrajectoryTrainer, TrajectoryTransformer

LOG_DIR = Path().resolve().parent.parent / "logs"

## Create Dataset if it does not exist yet, skip otherwise

In [None]:
from datetime import datetime

from prediction.data import vessel_groups
from prediction.preprocessing import load_and_build, remove_outliers_parallel

groups = vessel_groups()
df = load_and_build(datetime(2024, 1, 1), datetime(2024, 1, 1), 100, groups, verbose=True)
df = remove_outliers_parallel(
    df=df,
    threshold_partition_sog=5.0,
    threshold_association_sog=15.0,
    threshold_completeness=100.0,
    threshold_partition_distance=100.0,
    threshold_association_distance=100.0,
    additional_filter_columns=["orientations"]
)
df.to_pickle("ais_data.pkl")

## Load the dataset if it exists

In [2]:
df = pd.read_pickle('ais_data.pkl')

In [None]:
MAX_SEQ_LEN = 100
dataset = MaskedAISDataset(df, MAX_SEQ_LEN, n_workers=1, normalize=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
device = "mps"

transformer = TrajectoryTransformer(
    d_model=128,
    nhead=4,
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=512,
    dropout=0.1,
    max_seq_len=MAX_SEQ_LEN
)

In [None]:
trainer = TrajectoryTrainer(transformer, loader, loader, LOG_DIR, device=device)

In [None]:
trainer.train(max_epochs=10)