In [None]:
import datetime
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import re

import plotly.express as px
from dash import Dash, dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from jupyter_dash import JupyterDash

import sys
sys.path.append('./rtaUtils')
from rtaUtils import data_loading, common, sort_vectors, paths, experiment, data_preparation


config = {
  'toImageButtonOptions': {
    'format': 'png', # one of png, svg, jpeg, webp
    'filename': 'newplot',
    'height': 1000,
    'width': 1000,
    'scale':1 # Multiply title/legend/axis/canvas sizes by this factor
  }
}

airport_list = ('LEBL','LPPT','LFPO','LFPG','EGLL','LPPR','EDDF','EBBR','EHAM','LEBB','LIRF',
                'EDDM','EGKK','LEVC','LSGG','LECO','LIMC','EIDW','LFML','LEZL','LOWW','LIPZ',
                'LIPE','LFLL','EDDL','EDDB','LTFM','LFMN','LFRS','LROP','EDDP','LGAV','EDDH',
                'LHBP','EKCH','EGCC','ELLX','LKPR','LIRN','LBSF','EPWA')

numeric_feat   = ['latitude', 'longitude', 'altitude'] # 'vspeed', 'speed', 'track', 'hav_distance'
categoric_feat = [] #'operator'      
objective      = ['latitude', 'longitude', 'altitude']

feat_dict = dict(
    numeric=numeric_feat,
    categoric=categoric_feat,
    objective=objective
)

In [None]:
sampling = 15
months = '*'

In [None]:
models = [x.stem for x in paths.models_path.glob('*/')]

experimento = None
month_data = None
tray_data = None
prediction = None
windows = None
prediction_windows = None
true_windows = None

current_month = '202201'

## Inicialización

In [None]:
# Estilos de los ejemplos de Dash
external_stylesheets = [dbc.themes.BOOTSTRAP,'https://codepen.io/chriddyp/pen/bWLwgP.css']

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

## Estructura

In [None]:
############### Model ###############

modelSelector = dcc.Dropdown(
    id = 'modelSelector',
    options={x:x.upper().replace('_', ' ') for x in models},
    value=models[0],
)

modelVersionSelector = dcc.Dropdown(
    id = 'modelVersionSelector',
)

############### Trajectories ###############

dateSelector = dcc.DatePickerSingle(
    id = 'dateSelector',
    placeholder = 'Select one...',
    date='2022-01-01',
#     initial_visible_month=data.flightDate.min(),
    min_date_allowed = '2022-01-01',
    max_date_allowed = '2022-09-30',
    display_format='DD/MM/YYYY',
#     disabled_days=dates
)

airportSelector = dcc.Dropdown(
    id = 'airportSelector',
    options = [dict(label=x, value=x) for x in sorted(airport_list)],
    placeholder='Select one...',
    multi=True,
)

trajectorySelector = dcc.Dropdown(
    id = 'trajectorySelector',
    placeholder='Select one...',
    multi=False,
)

windowSelector = dcc.Slider(
    id = 'windowSelector',
    min=0, max=100,
    value=0,
    tooltip={'placement': 'top', 'always_visible': False}
)

In [None]:
app.layout = html.Div(style=dict(display='flex'), children=[
    html.Div(style=dict(width='20%', padding=5,), children=[
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Data', style=dict(fontWeight='bold')),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col(dbc.Label('Months'), width=4),
                dbc.Col(months, width=7),
                dbc.Col(dbc.Label('Sampling'), width=4),
                dbc.Col(sampling, width=7),
            ]),
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Model', style=dict(fontWeight='bold')),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Model'), modelSelector], width=12),
            ]),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Version'), modelVersionSelector], width=12),
            ]),
            html.Div(id='placeholder', style={'display':'none'})
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Trajectory', style=dict(fontWeight='bold')),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Date'), dateSelector], width=6),
                dbc.Col([dbc.Label('Airport'), airportSelector], width=6),
            ]),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Trajectory'), trajectorySelector], width=12),
            ]),
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Experiment', style=dict(fontWeight='bold')),
            dcc.RadioItems(
                id='experimentTypeSelector',
                style=dict(padding=5),
                options=dict(full=' Full trajectory', window=' Individual window'),
                value='full',
            ),
            html.Div(id='sliderDiv', children=[windowSelector]),
            html.Button('Run experiment', id='runButton', 
                        style=dict(margin='auto', marginTop=20, display='flex', alignItems='center')),
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Map type', style=dict(fontWeight='bold')),
            dcc.RadioItems(
                id='mapTypeSelector',
                style=dict(padding=5),
                options=['open-street-map',
                         'carto-positron', 
                         'carto-darkmatter', 
                         'stamen-terrain', 
                         'stamen-toner'],
                value='open-street-map',
            ),
        ]),
    ]),
    
    html.Div(id='mapDiv', style=dict(width='75%', margin=7, borderWidth=1, borderStyle='solid', borderRadius=5), children=[
        dcc.Graph(id='mapGraph', style=dict(width='100%',height='100%'),)
    ])
])

In [None]:
@app.callback(
    [Output(component_id='modelVersionSelector', component_property='options'),
     Output(component_id='modelVersionSelector', component_property='value')],
    [Input(component_id='modelSelector', component_property='value')],
    [] 
)
def load_model(model_name):
    global model
    global experimento
    
    if model_name:
        conf = model_name.split('_')
        experimento = experiment.ExperimentTrajectory(
            lookback=int(conf[2][2:]),
            lookforward=int(conf[3][2:]),
            sampling=sampling,
            model_config=dict(n_units=int(conf[4][1:]), 
                              act_function = 'tanh',
                              batch_size   = 128),
            months=months, 
            airport='*',
            features=feat_dict
        )
        experimento.load_model('best')
        
        model_versions = {x.stem:x.stem.replace('_', ' ') for x in (paths.models_path / model_name).glob('*.h5')}
    else: 
        model_versions={'best':'Best'}
    
    return (model_versions,'best')


@app.callback(
    [Output(component_id='placeholder', component_property='children')],
    [Input(component_id='modelVersionSelector', component_property='value')],
    []
)
def load_model_version(model_version):
    global experimento
    
    if model_version:
        experimento.load_model(model_version)
    return ([],)


@app.callback(
    [Output(component_id='trajectorySelector', component_property='options')],
    [Input(component_id='dateSelector', component_property='date'),
     Input(component_id='airportSelector', component_property='value')]
)
def filter_trajectories(fecha, origen):
    global month_data
    global current_month
    
    if fecha:
        m = fecha[:8].replace('-','')

        if month_data is None or current_month is None or m != current_month:
            month_data = data_loading.load_final_data(m, 'test', sampling=sampling)
            current_month = m
        df = month_data[pd.to_datetime(month_data.timestamp, unit='s').dt.date.astype(str) == fecha].copy()
        if origen:
            df = df[df.aerodromeOfDeparture.isin(origen)]

        labels = df[['fpId','aerodromeOfDeparture']].drop_duplicates().sort_values(['aerodromeOfDeparture'])
        labels = [{'label': f'{x[1].aerodromeOfDeparture} {x[1].fpId}', 'value': x[1].fpId} 
                   for x in labels.iterrows()]
        
        return [labels]
    else:
        return [{},None]

    
@app.callback(
    [Output(component_id='windowSelector', component_property='max')],
    [Input(component_id='trajectorySelector', component_property='value')],
    []
)
def prepare_experiment(trajectory):
    global month_data
    global experimento
    global tray_data
    global prediction
    global windows
    global true_windows
    
    if experimento is None or month_data is None or trajectory is None:
        return [0]
    
    # Full predictions
    tray_data = month_data[month_data.fpId == trajectory].copy()
    prediction = experimento.predict_trajectory(tray_data)
    prediction['point'] = 'predicted'
    
    # Predictions for individual windows
    windows = [tray_data.iloc[i:i+experimento.lookback].copy()
               for i in range(tray_data.shape[0]-experimento.lookback-experimento.lookforward)]
    true_windows = [tray_data.iloc[i:i+experimento.lookforward].copy()
                    for i in range(experimento.lookback+1, tray_data.shape[0])]

    if tray_data is not None:
        max_value = len(windows)-1

    return [max_value]


@app.callback(
    [Output(component_id='sliderDiv', component_property='style')],
    [Input(component_id='experimentTypeSelector', component_property='value')],
    [])
def change_experiment_type(exp_type):
    if exp_type == 'full':
        return ({'display':'none',},)
    else:
        return ({'display':'block',},)

In [None]:
@app.callback(
    [Output(component_id='mapGraph', component_property='figure')],
    [Input(component_id='runButton', component_property='n_clicks'),
     Input(component_id='windowSelector', component_property='value'),
     Input(component_id='mapTypeSelector', component_property='value'),
    ],
    [State(component_id='experimentTypeSelector', component_property='value')]
)
def run_experiment(clicks, selected_window, map_type, exp_type, ):
    global experimento
    global tray_data
    global prediction
    global windows
    global true_windows
    
    if clicks and experimento:
        df_viz = tray_data[experimento.objective_feat].copy()
        df_viz['point'] = 'real'
               
        if exp_type == 'window':
            predictions = experimento.predict_trajectory(windows[selected_window])
            window = windows[selected_window][experimento.objective_feat].copy()
            window['point'] = 'window'
            predictions['point'] = 'predicted'
            true = true_windows[selected_window][experimento.objective_feat].copy()
            true['point'] = 'true'
            
            df_viz = pd.concat([df_viz.iloc[:selected_window], 
                                window,
                                true,
                                predictions,
                                df_viz.iloc[selected_window+experimento.lookback+experimento.lookforward+1:]], axis=0)
        elif exp_type == 'full':
            predictions = prediction
            df_viz = pd.concat([df_viz, predictions], axis=0)
        
        fig = px.scatter_mapbox(
            df_viz, 'latitude', 'longitude', height=850, zoom=4,
            mapbox_style=map_type, opacity = 1, title='Map',
            color ='point'
        )
        px.scatter_geo()
        return (fig,)
    else:
        return (px.scatter_mapbox(lat=[],lon=[], mapbox_style='open-street-map', title='Map'),)

## Callbacks

In [None]:
app.run_server(debug = True)
# mode='inline'