In [None]:
import datetime
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import tensorflow as tf
import re, json

import plotly.express as px
from dash import Dash, dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from jupyter_dash import JupyterDash

import sys
sys.path.append('./rtaUtils')
from rtaUtils import data_loading, common, sort_vectors, paths, experiment, data_preparation


config = {
  'toImageButtonOptions': {
    'format': 'png', # one of png, svg, jpeg, webp
    'filename': 'newplot',
    'height': 1000,
    'width': 1000,
    'scale':1 # Multiply title/legend/axis/canvas sizes by this factor
  }
}

airport_list = ('LEBL','LPPT','LFPO','LFPG','EGLL','LPPR','EDDF','EBBR','EHAM','LEBB','LIRF',
                'EDDM','EGKK','LEVC','LSGG','LECO','LIMC','EIDW','LFML','LEZL','LOWW','LIPZ',
                'LIPE','LFLL','EDDL','EDDB','LTFM','LFMN','LFRS','LROP','EDDP','LGAV','EDDH',
                'LHBP','EKCH','EGCC','ELLX','LKPR','LIRN','LBSF','EPWA')

numeric_feat   = ['latitude', 'longitude', 'altitude'] # 'vspeed', 'speed', 'track', 'hav_distance'
categoric_feat = [] #'operator'      
objective      = ['latitude', 'longitude', 'altitude']

feat_dict = dict(
    numeric=numeric_feat,
    categoric=categoric_feat,
    objective=objective
)

In [None]:
#!pip install dash==2.6.0
#!pip install dash_bootstrap_components
#!pip install jupyter_dash

In [None]:
sampling = 15
months = '*'

In [None]:
models = [x.stem for x in paths.models_path.glob('*/')]

experimento = None
month_data = None
tray_data = None
prediction = None
windows = None
prediction_windows = None
true_windows = None

current_month = '202201'

## Inicialización

In [None]:
# Estilos de los ejemplos de Dash
external_stylesheets = [dbc.themes.BOOTSTRAP,'https://codepen.io/chriddyp/pen/bWLwgP.css']

app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

## Estructura

In [None]:
############### Model ###############

modelSelector = dcc.Dropdown(
    id = 'modelSelector',
    options={x:x.upper().replace('_', ' ') for x in models},
    value=models[0],
)

modelVersionSelector = dcc.Dropdown(
    id = 'modelVersionSelector',
)

############### Trajectories ###############

dateSelector = dcc.DatePickerSingle(
    id = 'dateSelector',
    placeholder = 'Select one...',
    date='2022-01-01',
#     initial_visible_month=data.flightDate.min(),
    min_date_allowed = '2022-01-01',
    max_date_allowed = '2022-09-30',
    display_format='DD/MM/YYYY',
#     disabled_days=dates
)

airportSelector = dcc.Dropdown(
    id = 'airportSelector',
    options = [dict(label=x, value=x) for x in sorted(airport_list)],
    placeholder='Select one...',
    multi=True,
)

trajectorySelector = dcc.Dropdown(
    id = 'trajectorySelector',
    placeholder='Select one...',
    multi=False,
)

windowSelector = dcc.Input(
    id = 'windowSelector',
    type = 'number',
    min=0, max=100,
    value=0,
    step=1
)

In [None]:
app.layout = html.Div(style=dict(display='flex'), children=[
    html.Div(style=dict(width='20%', padding=5,), children=[
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Data', style=dict(fontWeight='bold')),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col(dbc.Label('Months'), width=4),
                dbc.Col(months, width=7),
                dbc.Col(dbc.Label('Sampling'), width=4),
                dbc.Col(sampling, width=7),
            ]),
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Model', style=dict(fontWeight='bold')),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Model'), modelSelector], width=12),
            ]),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Version'), modelVersionSelector], width=12),
            ]),
            html.Div(id='placeholder', style={'display':'none'})
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Trajectory', style=dict(fontWeight='bold')),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Date'), dateSelector], width=6),
                dbc.Col([dbc.Label('Airport'), airportSelector], width=6),
            ]),
            dbc.Row(style=dict(paddingTop=5, paddingBottom=5), children=[
                dbc.Col([dbc.Label('Trajectory'), trajectorySelector], width=12),
            ]),
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Experiment', style=dict(fontWeight='bold')),
            dcc.RadioItems(
                id='experimentTypeSelector',
                style=dict(padding=5),
                options=dict(full=' Full trajectory', window=' Individual window'),
                value='full',
            ),
            html.Div(id='sliderDiv', children=[windowSelector]),
            html.Button('Run experiment', id='runButton', 
                        style=dict(margin='auto', marginTop=20, display='flex', alignItems='center')),
        ]),
        
        html.Fieldset(style=dict(width='99%', borderWidth=1, borderStyle='solid', 
                                 borderRadius=5, padding=5, margin=2), children=[
            html.Legend('Map type', style=dict(fontWeight='bold')),
            dcc.RadioItems(
                id='mapTypeSelector',
                style=dict(padding=5),
                options=['open-street-map',
                         'carto-positron', 
                         'carto-darkmatter', 
                         'stamen-terrain', 
                         'stamen-toner'],
                value='open-street-map',
            ),
        ]),
    ]),
    
    html.Div(id='mapDiv', style=dict(width='75%', margin=7, borderWidth=1, borderStyle='solid', borderRadius=5), children=[
        dcc.Graph(id='mapGraph', style=dict(width='100%',height='100%'),)
    ])
])

In [None]:
@app.callback(
    [Output(component_id='modelVersionSelector', component_property='options'),
     Output(component_id='modelVersionSelector', component_property='value'),
     Output(component_id='dateSelector', component_property='date')],
    [Input(component_id='modelSelector', component_property='value')],
    [State(component_id='dateSelector', component_property='date')] 
)
def load_model(model_name,current_date):
    global model
    global experimento
    model_path = paths.models_path
    
    if model_name:
        with open(model_path / (model_name + "/experiment_config.json"), 'r') as input_file:
            model_metadata = json.load(input_file)
        conf = model_name.split('_')
        experimento = experiment.ExperimentTrajectory(
            model_type = model_metadata['model_type'],
            # model_type = conf[0][2:],
            lookback = model_metadata['lookback'],
            lookforward = model_metadata['lookforward'],
            sampling = model_metadata['sampling'],
            model_config = dict(n_units = model_metadata['num_units'], 
                                act_function = model_metadata['activation_function'],
                                batch_size = model_metadata['batch_size']),
            months = model_metadata['months'], 
            airport = model_metadata['airport'],
            features = model_metadata['features']
        )
        experimento.load_model('best')
        
        model_versions = {x.stem:x.stem.replace('_', ' ') for x in (paths.models_path / model_name).glob('*.h5')}
    else: 
        model_versions={'best':'Best'}
    
    return (model_versions, 'best', current_date)


@app.callback(
    [Output(component_id='placeholder', component_property='children')],
    [Input(component_id='modelVersionSelector', component_property='value')],
    []
)
def load_model_version(model_version):
    global experimento
    
    if model_version:
        experimento.load_model(model_version)
    return ([],)


@app.callback(
    [Output(component_id='trajectorySelector', component_property='options'),
     Output(component_id='trajectorySelector', component_property='value')],
    [Input(component_id='dateSelector', component_property='date'),
     Input(component_id='airportSelector', component_property='value')],
    [State(component_id='trajectorySelector', component_property='value')]
)
def filter_trajectories(fecha, origen, current_trajectory):
    global month_data
    global current_month
    
    if fecha:
        m = fecha[:8].replace('-','')

        # if month_data is None or current_month is None or m != current_month:
        month_data = data_loading.load_final_data(m, 'test', sampling=sampling)
        current_month = m
        df = month_data[pd.to_datetime(month_data.timestamp, unit='s').dt.date.astype(str) == fecha].copy()
        if origen:
            df = df[df.aerodromeOfDeparture.isin(origen)]

        labels = df[['fpId','aerodromeOfDeparture']].drop_duplicates().sort_values(['aerodromeOfDeparture'])
        labels = [{'label': f'{x[1].aerodromeOfDeparture} {x[1].fpId}', 'value': x[1].fpId} 
                   for x in labels.iterrows()]
        
        return [labels, current_trajectory]
    else:
        return [{}, None]

    
@app.callback(
    [Output(component_id='windowSelector', component_property='max')],
    [Input(component_id='trajectorySelector', component_property='value')],
    []
)
def prepare_experiment(trajectory):
    global month_data
    global experimento
    global tray_data
    global prediction
    global windows
    global true_windows
    
    if experimento is None or month_data is None or trajectory is None:
        return [0]
    
    # Full predictions
    tray_data = month_data[month_data.fpId == trajectory].copy()
    prediction = experimento.predict_trajectory(tray_data)
    # Keep predictions to be displayed
    auxDF = pd.DataFrame()
    for i in range(0, len(prediction)-experimento.lookforward,experimento.lookforward*experimento.lookforward):
        auxDF = auxDF.append(prediction.iloc[list(range(i,i+experimento.lookforward))])
    
    prediction = auxDF.copy()
    prediction['point'] = 'predicted'
    
    # Predictions for individual windows
    windows = [tray_data.iloc[i:i+experimento.lookback].copy()
               for i in range(tray_data.shape[0]-experimento.lookback-experimento.lookforward-experimento.shift)]
    true_windows = [tray_data.iloc[i:i+experimento.lookforward].copy()
                    for i in range(experimento.lookback+experimento.shift+1, tray_data.shape[0])]

    if tray_data is not None:
        max_value = len(windows)-1

    return [max_value]


@app.callback(
    [Output(component_id='sliderDiv', component_property='style')],
    [Input(component_id='experimentTypeSelector', component_property='value')],
    [])
def change_experiment_type(exp_type):
    if exp_type == 'full':
        return ({'display':'none',},)
    else:
        return ({'display':'block',},)

In [None]:
@app.callback(
    [Output(component_id='mapGraph', component_property='figure')],
    [Input(component_id='runButton', component_property='n_clicks'),
     Input(component_id='windowSelector', component_property='value'),
     Input(component_id='mapTypeSelector', component_property='value'),
    ],
    [State(component_id='experimentTypeSelector', component_property='value'),
     State(component_id='mapGraph', component_property='figure')]
)
def run_experiment(clicks, selected_window, map_type, exp_type, current_fig):
    global experimento
    global tray_data
    global prediction
    global windows
    global true_windows
    
    # print("SELECTED WINDOW: ", selected_window)
    
    if clicks and experimento:
        current_zoom = current_fig['layout']['mapbox']['zoom']
        current_center = current_fig['layout']['mapbox']['center']
        df_viz = tray_data[experimento.objective_feat].copy()
        df_viz['point'] = 'real'
               
        if exp_type == 'window':
            # predictions = experimento.predict_trajectory(windows[selected_window])
            predictions = experimento.predict_trajectory_acumulado(windows[selected_window].copy(),
                                                                   experimento.lookforward)
            window = windows[selected_window][experimento.objective_feat].copy()
            window['point'] = 'window'
            predictions['point'] = 'predicted'
            true = true_windows[selected_window][experimento.objective_feat].copy()
            true['point'] = 'true'
            
            true_array = true[['latitude','longitude','altitude']].to_numpy()
            true_array = true_array.reshape(-1,len(experimento.objective_feat))
            #print(experimento.lookforward)
            true_array = true_array.reshape((-1, experimento.lookforward, len(experimento.objective_feat)))
            predictions_array = predictions[['latitude','longitude','altitude']].to_numpy()
            predictions_array = predictions_array.reshape(-1,len(experimento.objective_feat))
            predictions_array = predictions_array.reshape((-1, experimento.lookforward, len(experimento.objective_feat)))
            
            ##### WAVELET
            # latitud = predictions['latitude'].values
            # longitud = predictions['longitude'].values
            # altitud = predictions['altitude'].values
            #
            # import pywt
            ## Obtener los coeficientes de la descomposición wavelet
            # wavelet = 'db4'  # Tipo de wavelet utilizado
            # level = 3  # Número de niveles de descomposición
            # coeffs = pywt.wavedec(longitud, wavelet, level=level)

            ## Filtrar los coeficientes de alta frecuencia
            # threshold = np.std(coeffs[-1]) * 2  # Umbral de filtrado (ajusta este valor según tus datos)
            # coeffs_filt = [pywt.threshold(c, value=threshold, mode='soft') for c in coeffs]
            ## Reconstruir la señal filtrada
            # reconstructed_signal = pywt.waverec(coeffs_filt, wavelet)

            ## Crear un nuevo DataFrame con la señal suavizada
            # df_smoothed = pd.DataFrame({ 'latitude': latitud, 'longitude': reconstructed_signal})
            # df_smoothed['altitude'] = altitud
            # df_smoothed['point'] = 'predicted'
            
            
            ##### MINIMOS CUADRADOS
            latitud = predictions['latitude'].values
            longitud = predictions['longitude'].values
            altitud = predictions['altitude'].values
            
            #evaluo en puntos intermedios!!!
            medios = np.linspace(longitud[-1],longitud[1], num=12)
            
            polinomio_coefs = np.polyfit(longitud, latitud,4)
            polinomio = np.poly1d(polinomio_coefs)
            
            df_smoothed = pd.DataFrame({ 'latitude': polinomio(medios[1:11]), 'longitude': medios[1:11]})
            df_smoothed = pd.DataFrame({ 'latitude': polinomio(longitud), 'longitude': longitud})
            df_smoothed['altitude'] = altitud
            df_smoothed['point'] = 'smoothed'

            df_viz = pd.concat([df_viz.iloc[:selected_window], 
                                window,
                                true,
                                predictions,
                                df_viz.iloc[selected_window+experimento.lookback+experimento.lookforward+experimento.shift+1:]], axis=0)
        elif exp_type == 'full':
            predictions = prediction
            df_viz = pd.concat([df_viz, predictions], axis=0)
        
        fig = px.scatter_mapbox(
            df_viz, 'latitude', 'longitude', height=850, 
            zoom = current_zoom, center = current_center,
            mapbox_style=map_type, opacity = 1, title='Map',
            color ='point', hover_data = ['altitude']
        )
        # px.scatter_geo()
        return (fig,)
    else:
        return (px.scatter_mapbox(lat=[43.0],lon=[4.0], zoom = 4, 
                                  mapbox_style='open-street-map', title='Map'),)

## Callbacks

In [None]:
app.run_server(debug = True)
# mode='inline'