## Load in data

In [94]:
%load_ext autoreload
%autoreload 2
from ais_dataloader import *
from gp_kernel_ship_classification_dataset import *
from gp_kernel_ship_classification_trainer import *
from gp_kernel_ship_classification_network import *

import ipywidgets as widgets
from IPython.display import display
from plotting_utils import *



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cpu


In [49]:
date_range = pd.date_range(start='2024-01-01', end='2024-01-01', freq='D')
gp_regression_dataset = AISTrajectoryRegressionDataset(date_range, device)


Loading cached dataframe from data/processed/processed_AIS_df_2024_01_01_2024_01_01.pkl


Scaling trajectories for each MMSI: 100%|██████████| 3453/3453 [00:01<00:00, 3213.55it/s]



===== Dataset Statistics =====
Total number of AIS messages: 2128288
Number of unique MMSIs: 3453
Date range: 2024-01-01 00:00:00 to 2024-01-01 23:59:59


## Fit GP Models

In [None]:
import torch
import gpytorch
from multioutput_gp import *

num_trajectories = 3453
models = {}
likelihoods = {}
losses = {}

for idx in range(num_trajectories):
    # mmsi, times, state_trajectory = gp_regression_dataset[idx]
    mmsi, times, state_trajectory = gp_regression_dataset[idx]
    
    print(f"\nFitting GP for trajectory {idx+1}/{num_trajectories} for MMSI {mmsi}")
    
    X = times.detach().unsqueeze(1).to(device)
    Y = state_trajectory.detach().to(device)
    
    num_outputs = Y.shape[1]
    
    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_outputs, noise_prior=gpytorch.priors.NormalPrior(loc=0.25, scale=0.25)).to(device)
    model = MultiOutputExactGPModel(X, Y, likelihood, num_outputs=num_outputs).to(device)

    with gpytorch.settings.cholesky_jitter(1e-3):
        loss, model, likelihood = train_model(model, likelihood, X, Y, num_epochs=100, lr=0.1, mmsi=mmsi)

    print(f"Loss: {loss.item()}")
    # models.append(model)
    # likelihoods.append(likelihood)
    # losses.append(loss.item())
    models[mmsi] = model
    likelihoods[mmsi] = likelihood
    losses[mmsi] = loss.item()


Fitting GP for trajectory 1/100 for MMSI 3660489


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 187.32it/s]


Loss: 1.1078234910964966

Fitting GP for trajectory 2/100 for MMSI 203661016


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 22.71it/s]


Loss: 1.0687633752822876

Fitting GP for trajectory 3/100 for MMSI 205691000


GP Training Progress: 100%|██████████| 100/100 [00:07<00:00, 12.75it/s]


Loss: 0.8039560317993164

Fitting GP for trajectory 4/100 for MMSI 205717000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 17.89it/s]


Loss: 0.9143831729888916

Fitting GP for trajectory 5/100 for MMSI 209156000


GP Training Progress: 100%|██████████| 100/100 [00:09<00:00, 10.66it/s]


Loss: 0.8577029705047607

Fitting GP for trajectory 6/100 for MMSI 209228000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 22.83it/s]


Loss: 0.9678011536598206

Fitting GP for trajectory 7/100 for MMSI 209425000


GP Training Progress: 100%|██████████| 100/100 [00:20<00:00,  4.77it/s]


Loss: 0.7417958974838257

Fitting GP for trajectory 8/100 for MMSI 209429000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 31.05it/s]


Loss: 0.9501253366470337

Fitting GP for trajectory 9/100 for MMSI 209444000


GP Training Progress: 100%|██████████| 100/100 [00:25<00:00,  3.94it/s]


Loss: 0.755733072757721

Fitting GP for trajectory 10/100 for MMSI 209470000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 85.02it/s]


Loss: 0.8734135031700134

Fitting GP for trajectory 11/100 for MMSI 209513000


GP Training Progress: 100%|██████████| 100/100 [00:26<00:00,  3.83it/s]


Loss: 0.7495151162147522

Fitting GP for trajectory 12/100 for MMSI 209550000


GP Training Progress: 100%|██████████| 100/100 [00:27<00:00,  3.69it/s]


Loss: 0.7088612914085388

Fitting GP for trajectory 13/100 for MMSI 209641000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 21.98it/s]


Loss: 0.9428096413612366

Fitting GP for trajectory 14/100 for MMSI 209705000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 63.29it/s]


Loss: 0.8735288381576538

Fitting GP for trajectory 15/100 for MMSI 209729000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 137.30it/s]


Loss: 1.080646276473999

Fitting GP for trajectory 16/100 for MMSI 209786000


GP Training Progress: 100%|██████████| 100/100 [00:31<00:00,  3.16it/s]


Loss: 0.7393811941146851

Fitting GP for trajectory 17/100 for MMSI 209888000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 30.57it/s]


Loss: 0.7613688111305237

Fitting GP for trajectory 18/100 for MMSI 209933000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 16.95it/s]


Loss: 1.1045637130737305

Fitting GP for trajectory 19/100 for MMSI 209997000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 227.77it/s]


Loss: 0.9791954159736633

Fitting GP for trajectory 20/100 for MMSI 210185000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 18.20it/s]


Loss: 0.9444566965103149

Fitting GP for trajectory 21/100 for MMSI 210478000


GP Training Progress: 100%|██████████| 100/100 [00:25<00:00,  3.89it/s]


Loss: 0.7994800209999084

Fitting GP for trajectory 22/100 for MMSI 210499000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 24.32it/s]


Loss: 0.9940799474716187

Fitting GP for trajectory 23/100 for MMSI 210568000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 133.97it/s]


Loss: 1.2477079629898071

Fitting GP for trajectory 24/100 for MMSI 210614000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 18.76it/s]


Loss: 1.226464033126831

Fitting GP for trajectory 25/100 for MMSI 210737000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 26.92it/s]


Loss: 0.9553667902946472

Fitting GP for trajectory 26/100 for MMSI 210959000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 203.78it/s]


Loss: 0.9057654142379761

Fitting GP for trajectory 27/100 for MMSI 212370000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 143.37it/s]


Loss: 0.9788849949836731

Fitting GP for trajectory 28/100 for MMSI 212482000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 293.73it/s]


Loss: 1.1359623670578003

Fitting GP for trajectory 29/100 for MMSI 212656000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 52.12it/s]


Loss: 1.3011102676391602

Fitting GP for trajectory 30/100 for MMSI 212719000


GP Training Progress: 100%|██████████| 100/100 [00:45<00:00,  2.21it/s]


Loss: 0.7306892275810242

Fitting GP for trajectory 31/100 for MMSI 212775000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 17.24it/s]


Loss: 1.027761459350586

Fitting GP for trajectory 32/100 for MMSI 215001000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 182.87it/s]


Loss: 1.0756016969680786

Fitting GP for trajectory 33/100 for MMSI 215105000


GP Training Progress: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s]


Loss: 0.7610818147659302

Fitting GP for trajectory 34/100 for MMSI 215126000


GP Training Progress: 100%|██████████| 100/100 [00:10<00:00,  9.36it/s]


Loss: 0.818392276763916

Fitting GP for trajectory 35/100 for MMSI 215139000


GP Training Progress: 100%|██████████| 100/100 [00:09<00:00, 10.13it/s]


Loss: 0.8420707583427429

Fitting GP for trajectory 36/100 for MMSI 215159000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 22.70it/s]


Loss: 0.9947741031646729

Fitting GP for trajectory 37/100 for MMSI 215193000


GP Training Progress: 100%|██████████| 100/100 [00:08<00:00, 11.46it/s]


Loss: 0.834104597568512

Fitting GP for trajectory 38/100 for MMSI 215217000


GP Training Progress: 100%|██████████| 100/100 [00:11<00:00,  8.97it/s]


Loss: 0.9177435636520386

Fitting GP for trajectory 39/100 for MMSI 215534000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 62.87it/s]


Loss: 0.8753194212913513

Fitting GP for trajectory 40/100 for MMSI 215561000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 19.29it/s]


Loss: 0.8951831459999084

Fitting GP for trajectory 41/100 for MMSI 215726000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 23.45it/s]


Loss: 0.9424835443496704

Fitting GP for trajectory 42/100 for MMSI 215754000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 114.25it/s]


Loss: 0.8866640329360962

Fitting GP for trajectory 43/100 for MMSI 215765000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 19.90it/s]


Loss: 0.9410210847854614

Fitting GP for trajectory 44/100 for MMSI 215785000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 26.06it/s]


Loss: 0.7361562252044678

Fitting GP for trajectory 45/100 for MMSI 215797000


GP Training Progress: 100%|██████████| 100/100 [00:10<00:00,  9.35it/s]


Loss: 0.8655849695205688

Fitting GP for trajectory 46/100 for MMSI 215804000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 28.96it/s]


Loss: 0.7659711241722107

Fitting GP for trajectory 47/100 for MMSI 215809000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 50.04it/s]


Loss: 0.7703059911727905

Fitting GP for trajectory 48/100 for MMSI 215854000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 17.25it/s]


Loss: 0.9785612225532532

Fitting GP for trajectory 49/100 for MMSI 215876000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 32.46it/s]


Loss: 1.0254721641540527

Fitting GP for trajectory 50/100 for MMSI 215879000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 61.69it/s]


Loss: 0.9773473143577576

Fitting GP for trajectory 51/100 for MMSI 215896000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 269.45it/s]


Loss: 0.6813071370124817

Fitting GP for trajectory 52/100 for MMSI 215939000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 24.64it/s]


Loss: 1.04905104637146

Fitting GP for trajectory 53/100 for MMSI 215940000


GP Training Progress: 100%|██████████| 100/100 [00:08<00:00, 11.79it/s]


Loss: 0.8212319016456604

Fitting GP for trajectory 54/100 for MMSI 215965000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 304.60it/s]


Loss: 0.5782517790794373

Fitting GP for trajectory 55/100 for MMSI 218292000


GP Training Progress: 100%|██████████| 100/100 [00:03<00:00, 27.61it/s]


Loss: 0.9038199186325073

Fitting GP for trajectory 56/100 for MMSI 218427000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 19.41it/s]


Loss: 0.9083132743835449

Fitting GP for trajectory 57/100 for MMSI 218643000


GP Training Progress: 100%|██████████| 100/100 [00:06<00:00, 15.62it/s]


Loss: 0.8961074948310852

Fitting GP for trajectory 58/100 for MMSI 218791000


GP Training Progress: 100%|██████████| 100/100 [00:19<00:00,  5.09it/s]


Loss: 0.7543748021125793

Fitting GP for trajectory 59/100 for MMSI 218851000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 94.57it/s]


Loss: 0.834257960319519

Fitting GP for trajectory 60/100 for MMSI 219000034


GP Training Progress: 100%|██████████| 100/100 [00:10<00:00,  9.62it/s]


Loss: 0.8394700288772583

Fitting GP for trajectory 61/100 for MMSI 219028422


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 82.25it/s]


Loss: 0.776862382888794

Fitting GP for trajectory 62/100 for MMSI 219029122


GP Training Progress: 100%|██████████| 100/100 [00:31<00:00,  3.14it/s]


Loss: 0.7589831948280334

Fitting GP for trajectory 63/100 for MMSI 219031008


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 109.27it/s]


Loss: 0.8772634863853455

Fitting GP for trajectory 64/100 for MMSI 219031231


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 22.08it/s]


Loss: 1.0391451120376587

Fitting GP for trajectory 65/100 for MMSI 219130000


GP Training Progress: 100%|██████████| 100/100 [00:51<00:00,  1.95it/s]


Loss: 0.7117136120796204

Fitting GP for trajectory 66/100 for MMSI 219256000


GP Training Progress: 100%|██████████| 100/100 [00:06<00:00, 15.37it/s]


Loss: 0.9552637934684753

Fitting GP for trajectory 67/100 for MMSI 219287000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 18.93it/s]


Loss: 0.9950878024101257

Fitting GP for trajectory 68/100 for MMSI 219310000


GP Training Progress: 100%|██████████| 100/100 [00:41<00:00,  2.40it/s]


Loss: 0.7803823351860046

Fitting GP for trajectory 69/100 for MMSI 219412000


GP Training Progress: 100%|██████████| 100/100 [00:21<00:00,  4.72it/s]


Loss: 0.7874492406845093

Fitting GP for trajectory 70/100 for MMSI 219432000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 61.57it/s]


Loss: 0.7957141399383545

Fitting GP for trajectory 71/100 for MMSI 219445000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 24.34it/s]


Loss: 0.9321490526199341

Fitting GP for trajectory 72/100 for MMSI 219454000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 21.07it/s]


Loss: 1.2953239679336548

Fitting GP for trajectory 73/100 for MMSI 219456000


GP Training Progress: 100%|██████████| 100/100 [00:48<00:00,  2.08it/s]


Loss: 0.8185746073722839

Fitting GP for trajectory 74/100 for MMSI 219558000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 22.01it/s]


Loss: 1.1869332790374756

Fitting GP for trajectory 75/100 for MMSI 219605000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 24.24it/s]


Loss: 1.0008972883224487

Fitting GP for trajectory 76/100 for MMSI 219671000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 56.27it/s]


Loss: 0.8872231841087341

Fitting GP for trajectory 77/100 for MMSI 219674000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 23.77it/s]


Loss: 0.9411225914955139

Fitting GP for trajectory 78/100 for MMSI 219840000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 344.94it/s]


Loss: 1.1010100841522217

Fitting GP for trajectory 79/100 for MMSI 220590000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 21.89it/s]


Loss: 0.9622023701667786

Fitting GP for trajectory 80/100 for MMSI 220636000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 267.63it/s]


Loss: 0.8870916366577148

Fitting GP for trajectory 81/100 for MMSI 228337900


GP Training Progress: 100%|██████████| 100/100 [00:37<00:00,  2.64it/s]


Loss: 0.7632695436477661

Fitting GP for trajectory 82/100 for MMSI 228391700


GP Training Progress: 100%|██████████| 100/100 [00:07<00:00, 13.67it/s]


Loss: 0.7980597615242004

Fitting GP for trajectory 83/100 for MMSI 228421700


GP Training Progress: 100%|██████████| 100/100 [00:08<00:00, 11.81it/s]


Loss: 0.7445303201675415

Fitting GP for trajectory 84/100 for MMSI 229074000


GP Training Progress: 100%|██████████| 100/100 [00:27<00:00,  3.59it/s]


Loss: 0.9287002086639404

Fitting GP for trajectory 85/100 for MMSI 229081000


GP Training Progress: 100%|██████████| 100/100 [00:50<00:00,  2.00it/s]


Loss: 0.821551501750946

Fitting GP for trajectory 86/100 for MMSI 229248000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 238.11it/s]


Loss: 0.9036690592765808

Fitting GP for trajectory 87/100 for MMSI 229321000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 90.04it/s]


Loss: 0.7707103490829468

Fitting GP for trajectory 88/100 for MMSI 229340000


GP Training Progress: 100%|██████████| 100/100 [00:06<00:00, 15.83it/s]


Loss: 0.8422638773918152

Fitting GP for trajectory 89/100 for MMSI 229347000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 270.60it/s]


Loss: 1.1983578205108643

Fitting GP for trajectory 90/100 for MMSI 229410000


GP Training Progress: 100%|██████████| 100/100 [00:05<00:00, 18.52it/s]


Loss: 1.067389726638794

Fitting GP for trajectory 91/100 for MMSI 229551000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 284.38it/s]


Loss: 0.8405502438545227

Fitting GP for trajectory 92/100 for MMSI 229624000


GP Training Progress: 100%|██████████| 100/100 [00:01<00:00, 50.64it/s]


Loss: 0.9216511845588684

Fitting GP for trajectory 93/100 for MMSI 229680000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 22.82it/s]


Loss: 1.1969425678253174

Fitting GP for trajectory 94/100 for MMSI 229726000


GP Training Progress: 100%|██████████| 100/100 [00:16<00:00,  5.91it/s]


Loss: 0.8346900343894958

Fitting GP for trajectory 95/100 for MMSI 229760000


GP Training Progress: 100%|██████████| 100/100 [00:35<00:00,  2.85it/s]


Loss: 0.7187902927398682

Fitting GP for trajectory 96/100 for MMSI 229779000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 21.58it/s]


Loss: 0.8376892805099487

Fitting GP for trajectory 97/100 for MMSI 229830000


GP Training Progress: 100%|██████████| 100/100 [00:04<00:00, 23.59it/s]


Loss: 1.1457406282424927

Fitting GP for trajectory 98/100 for MMSI 229869000


GP Training Progress: 100%|██████████| 100/100 [00:14<00:00,  6.93it/s]


Loss: 0.7765158414840698

Fitting GP for trajectory 99/100 for MMSI 229894000


GP Training Progress: 100%|██████████| 100/100 [00:00<00:00, 275.40it/s]


Loss: 1.3274036645889282

Fitting GP for trajectory 100/100 for MMSI 229903000


GP Training Progress: 100%|██████████| 100/100 [00:12<00:00,  7.81it/s]

Loss: 0.8263754844665527





In [63]:
model = next(iter(models.values()))  # Get one of the models for demonstration
for param_name, param in model.named_parameters():
    print(f'Parameter name: {param_name:42} value = {param.tolist()}')
    
print()
print(model.covar_module.data_covar_module.kernels[0].lengthscale.item())

Parameter name: likelihood.raw_task_noises                 value = [-0.6876325607299805, -0.581544041633606, -0.9424332976341248, -0.5385560393333435, -0.9444971680641174, -0.9455525875091553]
Parameter name: likelihood.raw_noise                       value = [-0.9224559664726257]
Parameter name: mean_module.base_means.0.raw_constant      value = 0.06738271564245224
Parameter name: mean_module.base_means.1.raw_constant      value = -0.10569565743207932
Parameter name: mean_module.base_means.2.raw_constant      value = -4.691566209658049e-05
Parameter name: mean_module.base_means.3.raw_constant      value = 0.0004817449953407049
Parameter name: mean_module.base_means.4.raw_constant      value = -0.0001320840819971636
Parameter name: mean_module.base_means.5.raw_constant      value = -0.0004329355724621564
Parameter name: covar_module.task_covar_module.covar_factor value = [[-0.3279779851436615], [-0.39416366815567017], [-0.001594276400282979], [-0.0971570685505867], [0.01094231568276882

In [64]:
# Set the model and likelihood to evaluation mode
model.eval()
likelihood.eval()

# Generate test inputs (e.g., evenly spaced time points)
test_times = torch.linspace(times.min(), times.max(), 200).unsqueeze(1).to(device)

# Make predictions
with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.cholesky_jitter(1e-3):
    predictions = likelihood(model(test_times))
    mean = predictions.mean
    lower, upper = predictions.confidence_region()





### Plot GP Solution

In [53]:
# model.eval()
# likelihood.eval()
# from plotting_utils import *

# for mmsi in models:
#     model = models[mmsi]
#     likelihood = likelihoods[mmsi]
#     # Get the corresponding data for this MMSI
#     # If you want to use the same train/test split as before:
#     times, state_trajectory = None, None
#     for entry in gp_regression_dataset:
#         if entry[0] == mmsi:
#             _, times, state_trajectory = entry
#             break
#     if times is None:
#         continue  # skip if MMSI not found

#     train_X = times.clone().detach().unsqueeze(1).cpu()
#     train_Y = state_trajectory.clone().detach().cpu()

#     test_X = torch.linspace(times.min(), times.max(), 500).unsqueeze(1).to(device)
    
#     test_Y = eval_model(model, likelihood, test_X)

#     plot_gp(train_X, train_Y, test_X, test_Y)
#     plot_single_ship_path(mmsi, times, state_trajectory)


In [54]:
# import ipywidgets as widgets
# from IPython.display import display, clear_output
# from plotting_utils import *


# def plot_for_mmsi(selected_mmsi):
#     clear_output(wait=True)
#     model = models[selected_mmsi]
#     likelihood = likelihoods[selected_mmsi]
#     # Get the corresponding data for this MMSI
#     times, state_trajectory = None, None
#     for entry in gp_regression_dataset:
#         if entry[0] == selected_mmsi:
#             _, times, state_trajectory = entry
#             break
#     if times is None:
#         print("No data for MMSI:", selected_mmsi)
#         return

#     train_X = times.clone().detach().unsqueeze(1).cpu()
#     train_Y = state_trajectory.clone().detach().cpu()
#     test_X = torch.linspace(times.min(), times.max(), 500).unsqueeze(1).to(device)
#     test_Y = eval_model(model, likelihood, test_X)

#     plot_gp(train_X, train_Y, test_X, test_Y)
#     plot_single_ship_path(selected_mmsi, times, state_trajectory)

    
# mmsi_dropdown = widgets.Dropdown(
#     options=list(models.keys()),
#     description='MMSI:',
#     disabled=False,
# )

# widgets.interact(plot_for_mmsi, selected_mmsi=mmsi_dropdown)
    
    
    

In [55]:
print(pd.unique(gp_regression_dataset.df['MMSI'].values))
gp_regression_dataset.get_vessel_group_by_mmsi(3660489)

[367669550 367118980 636018568 ... 367619000 309108000 368926390]


'Other'

## Create the kernel param to ship mmsi dataset


In [95]:
kernel_classification_dataset = GPKernelShipClassificationDataset(gp_regression_dataset, models, device)
unique_group_ids = kernel_classification_dataset.get_unique_group_ids()

train_classification_dataset, test_classification_dataset = torch.utils.data.random_split(kernel_classification_dataset, [0.8, 0.2])
train_loader = torch.utils.data.DataLoader(train_classification_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_classification_dataset, batch_size=32, shuffle=False)

## Train the classification model

In [100]:
from gp_kernel_ship_classification_network import GPKernelShipClassificationNetwork
model = GPKernelShipClassificationNetwork(input_dim=2, num_classes=max(unique_group_ids) + 1)  # +1 for background class

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

trainer = GPKernelShipClassificationTrainer(model, train_loader, test_loader, criterion, optimizer, device)
trainer.train(num_epochs=500)

Epoch 500/500 | Train Loss: 0.6812 | Train Acc: 0.7500 | Test Loss: 4.4817 | Test Acc: 0.4000
