In [1]:
import torch
import pandas as pd

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import train as tr
import dataset as ds
import models as ml

# Load Test Datasets

In [3]:
brest_delta_series = pd.read_pickle('./data/pkl/brest_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')
norway_delta_series = pd.read_pickle('./data/pkl/norway_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')
pireaus_delta_series = pd.read_pickle('./data/pkl/piraeus_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')

brest_ix_series = pd.read_pickle('./data/pkl/brest_dataset_train_dev_test_indices_stratified.pkl') 
norway_ix_series = pd.read_pickle('./data/pkl/norway_dataset_train_dev_test_indices_.pkl') 
piraeus_ix_series = pd.read_pickle('./data/pkl/piraeus_dataset_train_dev_test_indices_stratified.pkl') 

In [4]:
brest_delta_series_train, brest_delta_series_test = brest_delta_series.iloc[brest_ix_series['train']].copy(), brest_delta_series.iloc[brest_ix_series['test']].copy() 
norway_delta_series_train, norway_delta_series_test = norway_delta_series.iloc[norway_ix_series['train']].copy(), norway_delta_series.iloc[norway_ix_series['test']].copy() 
pireaus_delta_series_train, pireaus_delta_series_test = pireaus_delta_series.iloc[piraeus_ix_series['train']].copy(), pireaus_delta_series.iloc[piraeus_ix_series['test']].copy() 

# Load Models

In [5]:
fedvrf_brest = torch.load('./data/pth/perfl/lstm_1_350_fc_150window_1024_stride_1024_crs_3857___batchsize_1__brest_dataset_stratified.flwr_local.pth', map_location=torch.device('cpu'))
fedvrf_norway = torch.load('./data/pth/perfl/lstm_1_350_fc_150window_1024_stride_1024_crs_3857___batchsize_1__norway_dataset_.flwr_local.pth', map_location=torch.device('cpu'))
fedvrf_pireaus = torch.load('./data/pth/perfl/lstm_1_350_fc_150window_1024_stride_1024_crs_3857___batchsize_1__piraeus_dataset_stratified.flwr_local.pth', map_location=torch.device('cpu'))
fedvrf_global = torch.load('./data/pth/perfl/lstm_1_350_fc_150_window_1024_stride_1024_crs_3857___.flwr_global_epoch170.pth', map_location=torch.device('cpu'))

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [9]:
# FedVRF (Brest; Global Model)
test_set = ds.VRFDataset(brest_delta_series_test, scaler=fedvrf_brest['scaler'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, collate_fn=test_set.pad_collate)

fedvrf_global_model = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], scale=dict(
        mu=torch.tensor(fedvrf_brest['scaler'].mean_[:2]), 
        sigma=torch.tensor(fedvrf_brest['scaler'].scale_[:2])
    )
)
fedvrf_global_model.load_state_dict(fedvrf_global['model_state_dict'])
fedvrf_global_model.eval()

tr.evaluate_model(fedvrf_global_model, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=test_loader)

self.eps=0.0001


                                                                          

Loss: 93.33247 |  Accuracy: 131.99202 | 42.92935; 535.57526; 1123.35376; 1546.13069; 2477.82376; 2402.55439; 1934.39757 m


(tensor(93.3325), 131.99202332826044)

In [10]:
# FedVRF (Norway; Global Model)
test_set = ds.VRFDataset(norway_delta_series_test, scaler=fedvrf_norway['scaler'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, collate_fn=test_set.pad_collate)

fedvrf_global_model_nor = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], scale=dict(
        mu=torch.tensor(fedvrf_norway['scaler'].mean_[:2]), 
        sigma=torch.tensor(fedvrf_norway['scaler'].scale_[:2])
    )
)
fedvrf_global_model_nor.load_state_dict(fedvrf_global['model_state_dict'])
fedvrf_global_model_nor.eval()

tr.evaluate_model(fedvrf_global_model_nor, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=test_loader)

self.eps=0.0001


                                                                           

Loss: 5.76135 |  Accuracy: 8.14774 | 7.25030; 355.48074; 658.99588; 99.50950; 923.48178; 1095.47244; nan m


(tensor(5.7613), 8.147743728623084)

In [11]:
# FedVRF (Piraeus; Global Model)
test_set = ds.VRFDataset(pireaus_delta_series_test, scaler=fedvrf_pireaus['scaler'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, collate_fn=test_set.pad_collate)

fedvrf_global_model_pir = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], scale=dict(
        mu=torch.tensor(fedvrf_pireaus['scaler'].mean_[:2]), 
        sigma=torch.tensor(fedvrf_pireaus['scaler'].scale_[:2])
    )
)
fedvrf_global_model_pir.load_state_dict(fedvrf_global['model_state_dict'])
fedvrf_global_model_pir.eval()

tr.evaluate_model(fedvrf_global_model_pir, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=test_loader)

self.eps=0.0001


                                                                           

Loss: 59.94801 |  Accuracy: 84.77926 | 37.23819; 441.44862; 974.67588; 1545.55927; 3222.29111; 2636.96627; 2830.85819 m


(tensor(59.9480), 84.77926496393852)