In [None]:
import torch
import pandas as pd
import geopandas as gpd
import contextily as ctx

import numpy as np
import datetime 

import sys, os
import tqdm
import shapely
import warnings

In [None]:
import st_toolkit as stt

# Load Data

In [None]:
actual_points = pd.read_csv('./data/piraeus-dataset/datasetpr_split_trajectories_sets_shuffle.csv', parse_dates=['timestamp'])
actual_points = gpd.GeoDataFrame(actual_points, geometry=gpd.points_from_xy(actual_points['WGS84lon'], actual_points['WGS84lat']), crs=4326)

# Load VRF Models

In [None]:
import models as ml

In [None]:
vrf_piraeus_pth_sn_cml   = torch.load('./data/pth/lstm_1_350_fc_150_share_all_window_1024_stride_1024_crs_3857___batchsize_1__piraeus_dataset.pth', map_location=torch.device('cpu'))
vrf_piraeus_pth_sd_cml   = torch.load('./data/pth/lstm_1_350_fc_150_share_all_window_1024_stride_1024_crs_3857___batchsize_1__share_all.pth', map_location=torch.device('cpu'))
fedvrf_piraeus_pth       = torch.load('./data/pth/fl/lstm_1_350_fc_150_window_1024_stride_1024_crs_3857___.flwr_global_epoch170.pth', map_location=torch.device('cpu'))
perfl_fedvrf_piraeus_pth = torch.load('./data/pth/perfl/lstm_1_350_fc_150_window_1024_stride_1024_crs_3857___.flwr_global_epoch170.pth', map_location=torch.device('cpu'))

# Instantiate SN-CML VRF Model
sn_cml_model = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], 
    scale=dict(
        mu=torch.tensor(vrf_piraeus_pth_sn_cml['scaler'].mean_[:2]), 
        sigma=torch.tensor(vrf_piraeus_pth_sn_cml['scaler'].scale_[:2])
    )
)
sn_cml_model.load_state_dict(vrf_piraeus_pth_sn_cml['model_state_dict'])
sn_cml_model.eval()

# Instantiate SD-CML VRF Model
sd_cml_model = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], 
    scale=dict(
        mu=torch.tensor(vrf_piraeus_pth_sd_cml['scaler'].mean_[:2]), 
        sigma=torch.tensor(vrf_piraeus_pth_sd_cml['scaler'].scale_[:2])
    )
)
sd_cml_model.load_state_dict(vrf_piraeus_pth_sd_cml['model_state_dict'])
sd_cml_model.eval()


# Instantiate FL VRF Model
fl_model = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], 
    scale=dict(
        # Use the Same Statistics as the Piraeus Dataset Train Set (which is equal to the stats of SN-CML)
        mu=torch.tensor(vrf_piraeus_pth_sn_cml['scaler'].mean_[:2]),    
        sigma=torch.tensor(vrf_piraeus_pth_sn_cml['scaler'].scale_[:2])
    )
)
fl_model.load_state_dict(fedvrf_piraeus_pth['model_state_dict'])
fl_model.eval()

# Instantiate PerFL VRF Model
perfl_model = ml.VesselRouteForecasting(
    hidden_size=350, fc_layers=[150,], 
    scale=dict(
        # Use the Same Statistics as the Piraeus Dataset Train Set (which is equal to the stats of SN-CML)
        mu=torch.tensor(vrf_piraeus_pth_sn_cml['scaler'].mean_[:2]),    
        sigma=torch.tensor(vrf_piraeus_pth_sn_cml['scaler'].scale_[:2])
    )
)
perfl_model.load_state_dict(perfl_fedvrf_piraeus_pth['model_state_dict'])
perfl_model.eval()

# Get the date with maximal concurrent traffic

In [None]:
traffic_flow = actual_points.groupby([actual_points.timestamp.dt.date, actual_points.timestamp.dt.hour]).apply(lambda l: l.vessel_id.nunique()).sort_index()
traffic_flow.loc[traffic_flow == traffic_flow.max()]

In [None]:
max_traffic_slice = actual_points.loc[(actual_points.timestamp.dt.date == datetime.date(2019, 3, 28)) & (actual_points.timestamp.dt.hour == 5)].copy()

In [None]:
actual_traffic = max_traffic_slice.set_index('timestamp').resample(rule='15min').get_group('2019-03-28 05:15:00')

In [None]:
actual_traffic.vessel_id.nunique()

# Get Area of Interest

In [None]:
def create_area_grid(spatial_area, crs=4326, quadrat_width=1000):
    '''
        Segment a spatial area into (an equally spaced) square grid
        
        Input:
            * spatial_area: The area to segment
            * quadrat_width: The squares' width
            
        Output:
            * A GeoSeries containing the grid's squares
            
        Note: the unit of the quadrat_width is in accord to the CRS of the spatial area.
    '''
    # quadrat_width is in the units the geometry is in, so we'll do a tenth of a degree
    geometry_cut = stt.ox.utils_geo._quadrat_cut_geometry(spatial_area, quadrat_width=quadrat_width)
    grid_gdf = gpd.GeoDataFrame(geometry_cut.geoms, columns=['geom'], geometry='geom', crs=crs)
    return grid_gdf

In [None]:
def classify_area_proximity(points, geoms, predicate='contains', out_name='area_id'):
    sindex = points.sindex
    geoms_idx, points_idx = sindex.query_bulk(geoms.geometry, predicate=predicate)

    points.loc[:, out_name] = np.nan
    points.iloc[points_idx, points.columns.get_loc(out_name)] = geoms_idx
    
    return points

In [None]:
spatial_coverage = shapely.geometry.box(*[*[2.59e6, 4.53e6, 2.64e6, 4.59e6]])
spatial_coverage_grid = create_area_grid(spatial_coverage, crs=3857, quadrat_width=1852) # Meters
spatial_coverage_grid = spatial_coverage_grid.reset_index().rename({'index':'id'}, axis=1)

# Get Traffic Flow for Actual TimeFrame

In [None]:
actual_traffic_flow = classify_area_proximity(actual_traffic.copy(), spatial_coverage_grid.to_crs(4326)).groupby('area_id').apply(lambda l: l.vessel_id.nunique())
# actual_traffic_flow = classify_area_proximity(actual_traffic.copy(), spatial_coverage_grid.to_crs(4326)).groupby('area_id').apply(lambda l: len(l))

In [None]:
gpd.GeoDataFrame(
    pd.merge(actual_traffic_flow.rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).explore(column='traffic_flow')

# Get Predicted Traffic Flow for $\Delta t = 15$ min.

In [None]:
shapely_coords_numpy = lambda l: np.array(*list(l.coords))

def create_delta_dataset(segment, time_name, speed_name, course_name, crs=3857):
    segment.sort_values(time_name, inplace=True)
    
    delta_curr = segment.to_crs(crs)[segment.geometry.name].apply(lambda l: pd.Series(shapely_coords_numpy(l), index=['dlon', 'dlat'])).diff()
    delta_curr_feats = segment[[speed_name, course_name]].diff().rename({speed_name:'dspeed_curr', course_name:'dcourse_curr'}, axis=1)
    delta_next = delta_curr.shift(-1)
    delta_tau  = pd.merge(
        segment[time_name].diff().rename('dt_curr'),
        segment[time_name].diff().shift(-1).rename('dt_next'),
        right_index=True, 
        left_index=True
    )
    
    return delta_curr.join(delta_curr_feats).join(delta_tau).join(delta_next, lsuffix='_curr', rsuffix='_next').dropna(subset=['dt_curr', 'dt_next'])

In [None]:
def vrf_predict(traj, model, scaler, lookahead=900, step=30, crs=3857,
                time_name='date_time_utc', speed_name='sog', course_name='cog',
                feats_in=['dlon_curr', 'dlat_curr', 'dt_curr', 'dt_next']):
    model.eval()
    
    traj_delta = create_delta_dataset(traj, time_name=time_name, speed_name=speed_name, course_name=course_name, crs=crs)
    model_input = torch.tensor(traj_delta[feats_in].values)

    for _ in range(step, lookahead+step+1, step):
        model_input_sc = scaler.transform(model_input)

        dxy_next = model(
            torch.tensor(model_input_sc).unsqueeze(0).float(),
            torch.tensor([len(model_input_sc)])
        ).detach()

        new_delta = torch.cat((dxy_next, torch.tensor([[step, step]])), dim=1)
        model_input = torch.cat((model_input, new_delta), dim=0)

    traj_start = np.array([
        traj.iloc[[0]].to_crs(crs).geometry.x.values[0],
        traj.iloc[[0]].to_crs(crs).geometry.y.values[0],
        traj.iloc[[0]][time_name].values[0]
    ])
    
    traj_pred = pd.DataFrame(
        torch.cumsum(
            torch.cat(
                (torch.tensor(traj_start).unsqueeze(0), model_input[:, :3])
            ),
            dim=0
        ).numpy(),
        columns=['lon', 'lat', time_name]
    )

    traj_pred = gpd.GeoDataFrame(traj_pred, crs=crs, geometry=gpd.points_from_xy(traj_pred['lon'], traj_pred['lat'])).to_crs(4326)    
    return traj_pred.copy()

In [None]:
tqdm.tqdm.pandas()

tff_res = dict()

for name, model, scaler in zip(
    ['sn_cml', 'sd_cml', 'fl', 'per_fl'],
    [sn_cml_model, sd_cml_model, fl_model, perfl_model],
    [vrf_piraeus_pth_sn_cml['scaler'], vrf_piraeus_pth_sd_cml['scaler'], vrf_piraeus_pth_sn_cml['scaler'], vrf_piraeus_pth_sn_cml['scaler']]
):
    # for each quarter of hour... (Use 15 min. trajectories for 30 min. forecasts)
    max_traffic_slice_predictions = max_traffic_slice.set_index('timestamp').resample(rule='15min').apply(
        # For each ```vessel_id``` use [Fed]VRF in order to infer its future trajectory, up to 30 min. with 30 sec. step (60 steps per forecast)
        lambda sdf: sdf.groupby(['vessel_id']).progress_apply(
            lambda vessel_id_traj: vrf_predict(
                vessel_id_traj.sort_values('t'), 
                model=model, 
                scaler=scaler, 
                lookahead=1800, step=15, 
                time_name='t', speed_name='speed', course_name='course'
            ).iloc[len(vessel_id_traj):].copy() if len(vessel_id_traj) >= 3 else None
        )
    ).reset_index(level=1, drop=False)
    max_traffic_slice_predictions.t = pd.to_datetime(max_traffic_slice_predictions.t, unit='s')

    tff_res[name] = classify_area_proximity(
        # max_traffic_slice_predictions.query('t <= "2019-03-28 05:30:00"'), 
        max_traffic_slice_predictions.query('t < "2019-03-28 05:30:00" and t > "2019-03-28 05:15:00"').copy(), 
        spatial_coverage_grid.to_crs(4326)
    ).groupby('area_id').apply(lambda l: l.vessel_id.nunique())
    # ).groupby('area_id').apply(lambda l: len(l))

In [None]:
gpd.GeoDataFrame(
    pd.merge(tff_res['per_fl'].rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).explore(column='traffic_flow')

# Unifying to a Single Figure 

In [None]:
pd.concat((actual_traffic_flow, tff_res['sn_cml'], tff_res['sd_cml'], tff_res['fl'], tff_res['per_fl']), axis=1).agg([min, max])

In [None]:
import matplotlib.pyplot as plt

ax1 = gpd.GeoDataFrame(
    pd.merge(actual_traffic_flow.rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).plot(column='traffic_flow', cmap='YlOrRd', alpha=0.65)
ctx.add_basemap(ax=ax1, source=ctx.providers.CartoDB.Positron, attribution='')
ax1.axis('off')
plt.savefig('tff_actual.png', dpi=300, bbox_inches='tight')


ax2 = gpd.GeoDataFrame(
    pd.merge(tff_res['sn_cml'].rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).plot(column='traffic_flow', cmap='YlOrRd', alpha=0.65)
ax2.set_xlim(*ax1.get_xlim()); ax2.set_ylim(*ax1.get_ylim()); ax2.axis('off')
ctx.add_basemap(ax=ax2, source=ctx.providers.CartoDB.Positron, attribution='')
plt.savefig('tff_sn_cml.png', dpi=300, bbox_inches='tight')


ax3 = gpd.GeoDataFrame(
    pd.merge(tff_res['sd_cml'].rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).plot(column='traffic_flow', cmap='YlOrRd', alpha=0.65)
ax3.set_xlim(*ax1.get_xlim()); ax3.set_ylim(*ax1.get_ylim()); ax3.axis('off')
ctx.add_basemap(ax=ax3, source=ctx.providers.CartoDB.Positron, attribution='')
plt.savefig('tff_sd_cml.png', dpi=300, bbox_inches='tight')


ax4 = gpd.GeoDataFrame(
    pd.merge(tff_res['fl'].rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).plot(column='traffic_flow', cmap='YlOrRd', alpha=0.65)
ax4.set_xlim(*ax1.get_xlim()); ax4.set_ylim(*ax1.get_ylim()); ax4.axis('off')
ctx.add_basemap(ax=ax4, source=ctx.providers.CartoDB.Positron, attribution='')
plt.savefig('tff_fl.png', dpi=300, bbox_inches='tight')


ax5 = gpd.GeoDataFrame(
    pd.merge(tff_res['per_fl'].rename('traffic_flow'), spatial_coverage_grid, left_index=True, right_on='id'),
    geometry='geom',
    crs=3857
).plot(column='traffic_flow', cmap='YlOrRd', alpha=0.65)
ax5.set_xlim(*ax1.get_xlim()); ax5.set_ylim(*ax1.get_ylim()); ax5.axis('off')
ctx.add_basemap(ax=ax5, source=ctx.providers.CartoDB.Positron, attribution='')
plt.savefig('tff_perfl.png', dpi=300, bbox_inches='tight')

 
# draw a new figure and replot the colorbar there
fig, cax = plt.subplots(figsize=(7,3))
sm = plt.cm.ScalarMappable(cmap='YlOrRd', norm=plt.Normalize(vmin=1, vmax=10))
# sm = plt.cm.ScalarMappable(cmap='YlOrRd', norm=plt.Normalize(vmin=1, vmax=250))

# empty array for the data range
sm._A = []
# add the colorbar to the figure
cbar = fig.colorbar(sm, orientation='horizontal', shrink=0.94, pad=0.02, ax=cax)

#     cax = plt.gcf().get_axes()[1]
#and we can modify it, i.e.:
cbar.ax.tick_params(labelsize=10)
cbar.ax.set_xlabel('#Vessels', fontsize=10, labelpad=5)

cax.remove()
# cbar.ax.set_yticklabels([f'$10^{{\, {np.int64(label)} }}$' if label.is_integer() else '' for label in cbar.ax.get_yticks()], rotation=90)  # horizontal colorbar
cbar.ax.tick_params(labelsize=10)
cbar.ax.xaxis.set_ticks_position("bottom")
plt.savefig('plot_onlycbar.png', dpi=300, bbox_inches='tight')

