In [1]:
import torch
import pandas as pd

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import train as tr
import dataset as ds
import models as ml

# Load Test Datasets

In [3]:
brest_delta_series = pd.read_pickle('./data/pkl/brest_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')
norway_delta_series = pd.read_pickle('./data/pkl/norway_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')
pireaus_delta_series = pd.read_pickle('./data/pkl/piraeus_dataset_window_1024_stride_1024_crs_3857__.traj_delta_windows.pickle')

brest_ix_series = pd.read_pickle('./data/pkl/brest_dataset_train_dev_test_indices_.pkl') 
norway_ix_series = pd.read_pickle('./data/pkl/norway_dataset_train_dev_test_indices_.pkl') 
piraeus_ix_series = pd.read_pickle('./data/pkl/piraeus_dataset_train_dev_test_indices_.pkl') 

In [4]:
brest_delta_series_train, brest_delta_series_test = brest_delta_series.iloc[brest_ix_series['train']].copy(), brest_delta_series.iloc[brest_ix_series['test']].copy() 
norway_delta_series_train, norway_delta_series_test = norway_delta_series.iloc[norway_ix_series['train']].copy(), norway_delta_series.iloc[norway_ix_series['test']].copy() 
pireaus_delta_series_train, pireaus_delta_series_test = pireaus_delta_series.iloc[piraeus_ix_series['train']].copy(), pireaus_delta_series.iloc[piraeus_ix_series['test']].copy() 

# Load Models

In [5]:
fedvrf_brest = torch.load('./data/pth/fl/lstm_1_350_fc_150window_1024_stride_1024_crs_3857___batchsize_1__brest_dataset_stratified.flwr_local.pth', map_location=torch.device('cpu'))
fedvrf_norway = torch.load('./data/pth/fl/lstm_1_350_fc_150window_1024_stride_1024_crs_3857___batchsize_1__norway_dataset_.flwr_local.pth', map_location=torch.device('cpu'))
fedvrf_pireaus = torch.load('./data/pth/fl/lstm_1_350_fc_150window_1024_stride_1024_crs_3857___batchsize_1__piraeus_dataset_.flwr_local.pth', map_location=torch.device('cpu'))
fedvrf_global = torch.load('./data/pth/fl/lstm_1_350_fc_150_window_1024_stride_1024_crs_3857___.flwr_global_epoch170.pth', map_location=torch.device('cpu'))

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [11]:
# FedVRF (Brest; Global Model)
test_set = ds.VRFDataset(brest_delta_series_test, scaler=fedvrf_brest['scaler'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, collate_fn=test_set.pad_collate)

fedvrf_global_model = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], scale=dict(
        mu=torch.tensor(fedvrf_brest['scaler'].mean_[:2]), 
        sigma=torch.tensor(fedvrf_brest['scaler'].scale_[:2])
    )
)
fedvrf_global_model.load_state_dict(fedvrf_global['model_state_dict'])
fedvrf_global_model.eval()

tr.evaluate_model(fedvrf_global_model, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=test_loader)

self.eps=0.0001


                                                                          

Loss: 111.02232 |  Accuracy: 157.00923 | 53.31057; 626.14830; 1233.42963; 1722.14176; 2858.34507; 2936.93127; 3262.54171 m


(tensor(111.0223), 157.0092313226689)

In [12]:
# FedVRF (Norway; Global Model)
test_set = ds.VRFDataset(norway_delta_series_test, scaler=fedvrf_norway['scaler'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, collate_fn=test_set.pad_collate)

fedvrf_global_model_nor = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], scale=dict(
        mu=torch.tensor(fedvrf_norway['scaler'].mean_[:2]), 
        sigma=torch.tensor(fedvrf_norway['scaler'].scale_[:2])
    )
)
fedvrf_global_model_nor.load_state_dict(fedvrf_global['model_state_dict'])
fedvrf_global_model_nor.eval()

tr.evaluate_model(fedvrf_global_model_nor, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=test_loader)

self.eps=0.0001


                                                                           

Loss: 5.51781 |  Accuracy: 7.80332 | 7.02033; 276.47297; 618.59717; 188.85668; 453.65445; 1306.63910; nan m


(tensor(5.5178), 7.803319095115764)

In [13]:
# FedVRF (Piraeus; Global Model)
test_set = ds.VRFDataset(pireaus_delta_series_test, scaler=fedvrf_pireaus['scaler'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, collate_fn=test_set.pad_collate)

fedvrf_global_model_pir = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], scale=dict(
        mu=torch.tensor(fedvrf_pireaus['scaler'].mean_[:2]), 
        sigma=torch.tensor(fedvrf_pireaus['scaler'].scale_[:2])
    )
)
fedvrf_global_model_pir.load_state_dict(fedvrf_global['model_state_dict'])
fedvrf_global_model_pir.eval()

tr.evaluate_model(fedvrf_global_model_pir, torch.device('cpu'), criterion=tr.RMSELoss(eps=1e-4), test_loader=test_loader)

self.eps=0.0001


                                                                           

Loss: 69.31537 |  Accuracy: 98.02672 | 54.24953; 504.80842; 1137.09654; 1016.52809; 2775.92908; 3898.66991; nan m


(tensor(69.3154), 98.02671942094705)