In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Load Models

In [2]:
vrf_brest = torch.load('lstm_1_350_fc_150_share_all_window_1024_stride_1024_crs_3857___batchsize_1__brest_dataset.pth', map_location=torch.device('cpu'))
vrf_norway = torch.load('lstm_1_350_fc_150_share_all_window_1024_stride_1024_crs_3857___batchsize_1__norway_dataset_.pth', map_location=torch.device('cpu'))
vrf_piraeus = torch.load('lstm_1_350_fc_150_share_all_window_1024_stride_1024_crs_3857___batchsize_1__piraeus_dataset.pth', map_location=torch.device('cpu'))
vrf_mt = torch.load('lstm_1_350_fc_150_share_all_window_1024_stride_1024_crs_3857___batchsize_1__mt_dataset.pth', map_location=torch.device('cpu'))
vrf_share_all = torch.load('lstm_1_350_fc_150_share_all_awindow_1024_stride_1024_crs_3857___batchsize_1__share_all.pth', map_location=torch.device('cpu'))

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [None]:
print(vrf_brest['scaler'].mean_, vrf_brest['scaler'].scale_, sep='\t')
print(vrf_norway['scaler'].mean_, vrf_norway['scaler'].scale_, sep='\t')
print(vrf_piraeus['scaler'].mean_, vrf_piraeus['scaler'].scale_, sep='\t')
print(vrf_mt['scaler'].mean_, vrf_mt['scaler'].scale_, sep='\t')
print(vrf_share_all['scaler'].mean_, vrf_share_all['scaler'].scale_, sep='\t')

## Load Test set of Brest, Norway, and Piraeus Datasets

In [5]:
import pandas as pd
import dataset as ds

In [6]:
nor_ix, nor_data = pd.read_pickle('./pkl/norway_dataset_train_dev_test_indices.pkl'), pd.read_pickle('./pkl/norway_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')
bre_ix, bre_data = pd.read_pickle('./pkl/brest_dataset_train_dev_test_indices.pkl'), pd.read_pickle('./pkl/brest_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')
pir_ix, pir_data = pd.read_pickle('./pkl/piraeus_dataset_train_dev_test_indices.pkl'), pd.read_pickle('./pkl/piraeus_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')

In [11]:
nor_test_dataset, bre_test_dataset, pir_test_dataset = ds.VRFDataset(
    data=nor_data.iloc[nor_ix['test']].copy(), 
    scaler=vrf_norway['scaler']
), ds.VRFDataset(
    data=bre_data.iloc[bre_ix['test']].copy(), 
    scaler=vrf_brest['scaler']
), ds.VRFDataset(
    data=pir_data.iloc[pir_ix['test']].copy(), 
    scaler=vrf_piraeus['scaler']
)

nor_test_loader, bre_test_loader, pir_test_loader = ds.DataLoader(
    nor_test_dataset, 
    batch_size=1, 
    collate_fn=nor_test_dataset.pad_collate
), ds.DataLoader(
    bre_test_dataset, 
    batch_size=1, 
    collate_fn=bre_test_dataset.pad_collate
), ds.DataLoader(
    pir_test_dataset, 
    batch_size=1, 
    collate_fn=pir_test_dataset.pad_collate
)

## Model Inference

In [14]:
import models as ml
import train as tr

In [20]:
for name, loader, scaler in zip(
    ['Norway', 'Brest', 'Piraeus'],
    [nor_test_loader, bre_test_loader, pir_test_loader], 
    [vrf_norway['scaler'], vrf_brest['scaler'], vrf_piraeus['scaler']]
):
    model = ml.VesselRouteForecasting(
        hidden_size=350, fc_layers=[150,], scale=dict(mu=torch.tensor(scaler.mean_[:2]), sigma=torch.tensor(scaler.scale_[:2]))
    )
    model.load_state_dict(vrf_share_all['model_state_dict'])
    model.eval()

    print(f'Evaluating Share-all VRF on {name} Test Set...')
    tr.evaluate_model(model, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=loader)

Evaluating Share-all VRF on Norway Test Set...
self.eps=0.0001


                                                                           

Loss: 7.06312 |  Accuracy: 9.98874 | 8.91934; 416.72978; 531.82347; 958.30709; 956.40897; 1287.60952; nan m
Evaluating Share-all VRF on Brest Test Set...
self.eps=0.0001


                                                                          

Loss: 99.34281 |  Accuracy: 140.49193 | 43.08056; 557.39830; 1145.19652; 1594.11798; 3025.55032; 2440.73458; 3515.93491 m
Evaluating Share-all VRF on Piraeus Test Set...
self.eps=0.0001


                                                                           

Loss: 65.06891 |  Accuracy: 92.02131 | 39.50148; 475.13291; 927.04572; 1673.22440; 4090.95868; 3172.36473; 4648.08424 m
