In [None]:
import ipywidgets as widgets
from ipywidgets import HBox, VBox
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from statsmodels.tsa.seasonal import STL
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
from IPython.display import display, clear_output
import utils_plots as ut

In [3]:
# Define dictiona
n_clusters_dict = {'NEP': 9, 'NWP': 8, 'NA': 12, 'NI': 9, 'SI': 10, 'SP': 11}

run_name_dict = {
    'NEP': [f'test60_linreg_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'test78_linreg_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'test87_linreg_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'selfeat50_top20_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'selfeat60_top20_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'selfeat70_top20_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'selfeat75_top20_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'selfeat80_top20_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS',
            f'selfeat90_top20_nc{n_clusters_dict["NEP"]}_nv8_nd9_noTS'],
    'NWP': [f'test4_linreg_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'test25_linreg_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'test83_linreg_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'selfeat50_top20_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'selfeat60_top20_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'selfeat70_top20_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'selfeat75_top20_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'selfeat80_top20_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS',
            f'selfeat90_top20_nc{n_clusters_dict["NWP"]}_nv8_nd9_noTS'],
    'NA': [f'test3_linreg_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'test14_linreg_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'test61_linreg_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'selfeat50_top20_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'selfeat60_top20_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'selfeat70_top20_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'selfeat75_top20_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'selfeat80_top20_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS',
           f'selfeat90_top20_nc{n_clusters_dict["NA"]}_nv8_nd9_noTS'],
    'NI': [f'test26_linreg_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'test32_linreg_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'test45_linreg_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'selfeat50_top20_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'selfeat60_top20_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'selfeat70_top20_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'selfeat75_top20_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'selfeat80_top20_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS',
           f'selfeat90_top20_nc{n_clusters_dict["NI"]}_nv8_nd9_noTS'],
    'SI': [f'test12_linreg_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'test51_linreg_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'test82_linreg_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'selfeat50_top20_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'selfeat60_top20_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'selfeat70_top20_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'selfeat75_top20_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'selfeat80_top20_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS',
           f'selfeat90_top20_nc{n_clusters_dict["SI"]}_nv8_nd9_noTS'],
    'SP': [f'test8_linreg_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'test23_linreg_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'test98_linreg_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'selfeat50_top20_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'selfeat60_top20_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'selfeat70_top20_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'selfeat75_top20_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'selfeat80_top20_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS',
           f'selfeat90_top20_nc{n_clusters_dict["SP"]}_nv8_nd9_noTS'],
}

In [None]:
# Define widgets
widget_layout1 = widgets.Layout(width='300px')
widget_layout2 = widgets.Layout(width='170px')

basin_widget = widgets.Dropdown(
    options=['NEP', 'NWP', 'NA', 'NI', 'SP', 'SI'],
    value='NEP',
    style={'description_width': 'initial'},
    layout=widget_layout1,
)
label_basin_widget = widgets.HTML(value="<b>Basin:</b>", layout=widget_layout2)
basin_box = HBox([label_basin_widget, basin_widget])

run_name_widget = widgets.Dropdown(
    options=run_name_dict[basin_widget.value],  # initialized with first basin
    value=run_name_dict[basin_widget.value][0],
    layout=widget_layout1,
)
label_run_widget = widgets.HTML(value="<b>Run name:</b>", layout=widget_layout2)
run_box = HBox([label_run_widget, run_name_widget])

update_button = widgets.Button(
    description = 'Update',
    icon = "check",
    layout=widgets.Layout(height='auto'),
    style={'button_color': 'lightgreen'}
)

# Link Basin -> Run options
def update_run_name_options(change):
    new_basin = change['new']
    run_options = run_name_dict[new_basin]
    run_name_widget.options = run_options
    run_name_widget.value = run_options[0]

basin_widget.observe(update_run_name_options, names='value')

HBox(children=(HTML(value='<b>Basin:</b>', layout=Layout(width='170px')), Dropdown(layout=Layout(width='220px'…

HBox(children=(HTML(value='<b>Run name:</b>', layout=Layout(width='170px')), Dropdown(layout=Layout(width='220…

In [None]:
def update_output(basin, run_name):
    # Number of clusters 
    n_clusters_dict = {'NEP': 9, 'NWP': 8, 'NA': 12, 'NI': 9, 'SI': 10, 'SP': 11}
    n_clusters = n_clusters_dict[basin]
    # Set years range and number of folds
    years = np.arange(1980, 2022, 1) # from 1980 to 2021 included
    n_folds = 3
    # Set directories and file paths, then load file containing predictors and target
    project_dir = '/Users/huripari/Documents/PhD/TCs_Genesis'
    fs_dir = os.path.join(project_dir, 'xai-gpi')
    cluster_data = f'{basin}_{n_clusters}clusters_noTS'
    cluster_data_dir = os.path.join(fs_dir, 'data', cluster_data)
    # predictors
    predictor_file = f'predictors_1980-2022_{n_clusters}clusters_8vars_9idxs.csv'
    predictors_df = pd.read_csv(os.path.join(cluster_data_dir, predictor_file), index_col=0)
    predictors_df.index = pd.to_datetime(predictors_df.index)
    predictors_df = predictors_df.loc[predictors_df.index.year.isin(years)]
    # target
    target_file = 'target_residual_1980-2022_2.5x2.5.csv'
    seasonal_file = 'target_seasonality_1980-2022_2.5x2.5.csv'
    trend_file = 'target_trend_1980-2022_2.5x2.5.csv'
    target_df = pd.read_csv(os.path.join(cluster_data_dir, target_file), index_col=0)
    target_df.index = pd.to_datetime(target_df.index)
    target_df = target_df.loc[target_df.index.year.isin(years)]
    target_season_df = pd.read_csv(os.path.join(cluster_data_dir, seasonal_file), index_col=0)
    target_season_df.index = pd.to_datetime(target_season_df.index)
    target_season_df = target_season_df.loc[target_season_df.index.year.isin(years)]
    target_trend_df = pd.read_csv(os.path.join(cluster_data_dir, trend_file), index_col=0)
    target_trend_df.index = pd.to_datetime(target_trend_df.index)
    target_trend_df = target_trend_df.loc[target_trend_df.index.year.isin(years)]
    # gpis
    gpis_file = f'{basin}_2.5x2.5_gpis_time_series.csv'
    gpis_path = os.path.join(fs_dir, 'data', gpis_file)
    gpis_df = pd.read_csv(gpis_path, index_col=0)
    gpis_df.index = pd.to_datetime(gpis_df.index)
    gpis_df = gpis_df.loc[gpis_df.index.year.isin(years)]
    # Get the run info and data
    Y_pred, Y_pred_noFS, X_test_eval, X_test_eval_noFS, mlps, mlps_noFS, perm_importance_mlp, perm_importance_mlp_noFS, shap_values_mlp, shap_values_mlp_noFS = ut.runs_info(basin, run_name)
    # Convert list of dataframes to a single dataframe
    X_test = pd.concat(X_test_eval)
    X_test_noFS = pd.concat(X_test_eval_noFS)
    Y_pred_df = pd.concat(Y_pred)
    Y_pred_noFS_df = pd.concat(Y_pred_noFS)
    ## Time series Trajectories and Metrics ##
    # Predictions with trend and seasonality
    Y_pred_df_TS = Y_pred_df['resid'] + target_trend_df['trend'] + target_season_df['season']
    Y_pred_noFS_df_TS = Y_pred_noFS_df['resid'] + target_trend_df['trend'] + target_season_df['season']
    Y_pred_df_TS[Y_pred_df_TS < 0] = 0.0
    Y_pred_noFS_df_TS[Y_pred_noFS_df_TS < 0] = 0.0
    # Annual data without trend and seasonality
    target_df_annual = target_df.groupby(target_df.index.year).sum()
    Y_pred_df_annual = Y_pred_df.groupby(Y_pred_df.index.year).sum()
    Y_pred_noFS_df_annual = Y_pred_noFS_df.groupby(Y_pred_noFS_df.index.year).sum()
    # GPIs time series with trend and seasonality
    engpi_TS = gpis_df['engpi']
    ogpi_TS = gpis_df['ogpi']
    # GPIs time series without trend and seasonality
    decomp_engpi = STL(engpi_TS).fit()
    trend_engpi = decomp_engpi.trend
    seasonal_engpi = decomp_engpi.seasonal
    engpi = decomp_engpi.resid
    decomp_ogpi = STL(ogpi_TS).fit()
    trend_ogpi = decomp_ogpi.trend
    seasonal_ogpi = decomp_ogpi.seasonal
    ogpi = decomp_ogpi.resid
    # Annual data of the GPIs
    engpi_annual = engpi.groupby(engpi.index.year).sum()
    ogpi_annual = ogpi.groupby(ogpi.index.year).sum()
    # Compute the correlation coefficient and the MSE between the predictions and the test values
    # Monthly without trend and seasonality
    r, _ = pearsonr(target_df['resid'], Y_pred_df['resid'])
    r_noFS, _ = pearsonr(target_df['resid'], Y_pred_noFS_df['resid'])
    r_engpi, _ = pearsonr(target_df['resid'], engpi)
    r_ogpi, _ = pearsonr(target_df['resid'], ogpi)
    # Annual without trend and seasonality
    rY, _ = pearsonr(target_df_annual['resid'], Y_pred_df_annual['resid'])
    rY_noFS, _ = pearsonr(target_df_annual['resid'], Y_pred_noFS_df_annual['resid'])
    rY_engpi, _ = pearsonr(target_df_annual['resid'], engpi_annual)
    rY_ogpi, _ = pearsonr(target_df_annual['resid'], ogpi_annual)
    # Plotting the monthly time series detrended and deseasonalized
    fig_ts = ut.plot_monthly_time_series(target_df['resid'], Y_pred_df['resid'], Y_pred_noFS_df['resid'], engpi, ogpi, r, r_noFS, r_engpi, r_ogpi)
    # Plotting the annual time series detrended and deseasonalized
    fig_annual = ut.plot_annual_time_series(target_df_annual['resid'], Y_pred_df_annual['resid'], Y_pred_noFS_df_annual['resid'], engpi_annual, ogpi_annual, rY, rY_noFS, rY_engpi, rY_ogpi)
    ## Selected features ##
    # Determine selected features according to the run_name
    if 'selfeat' in run_name:
        perc = run_name.split('_top20')[0].split('selfeat')[1]
        csv_path = os.path.join(fs_dir, 'results', f'selected_features_best_models_{basin}_{n_clusters}_noTS.csv')
        df_perc_sel = pd.read_csv(csv_path, index_col=0)
        selected_features = df_perc_sel[str(perc)].dropna().to_list()
    elif 'test' in run_name:
        experiment_filename = f'1980-2022_{n_clusters}clusters_8vars_9idxs.csv'
        sol_filename = 'linreg_' + experiment_filename
        output_dir = os.path.join(fs_dir, 'results', basin, run_name)
        best_sol_path = os.path.join(output_dir, f'best_solution_{sol_filename}')
        best_solution = pd.read_csv(best_sol_path, sep=',', header=None)
        best_solution = best_solution.to_numpy().flatten()
        column_names = predictors_df.columns.tolist()
        final_sequence = best_solution[len(column_names):2*len(column_names)]
        sequence_length = best_solution[:len(column_names)]
        feat_sel = best_solution[2*len(column_names):]
        variable_selection = feat_sel.astype(int)
        time_sequences = sequence_length.astype(int)
        time_lags = final_sequence.astype(int)
        selected_features = []
        for c, col in enumerate(predictors_df.columns):
            if variable_selection[c] == 0 or time_sequences[c] == 0:
                continue
            for j in range(time_sequences[c]):
                selected_features.append(str(col))
    else:
        raise ValueError(f'Unknown run name: {run_name}')
    # Get the variables names and the selected clusters
    variables_with_cluster = [var for var in selected_features if 'cluster' in var]
    variables_without_cluster = [var for var in selected_features if 'cluster' not in var]
    variable_names_cluster = [var.split('_cluster')[0] for var in variables_with_cluster]
    variable_names_cluster = list(set(variable_names_cluster))
    variable_names_cluster.sort()
    # Plot the selected features
    fig_clusters = ut.plot_variables_clusters(basin, n_clusters, cluster_data_dir, variable_names_cluster, selected_features)
    ## SHAP values ##
    years_couples = []
    # Create a DataFrame with fold number corresponding to each year and also the couple of max and min years for each fold
    kfold = KFold(n_splits=n_folds)
    test_years_df = pd.DataFrame(0, index=years, columns=['fold'])
    for nf, (train_index, test_index) in enumerate(kfold.split(years)):
        test_years_df.loc[years[test_index], 'fold'] = nf
        Y_pred_df_annual_fold = Y_pred_df_annual.loc[years[test_index]]
        max_fold = Y_pred_df_annual_fold['resid'].idxmax()
        min_fold = Y_pred_df_annual_fold['resid'].idxmin()
        years_couples.append((max_fold, min_fold))
    # Plot shap values for each fold
    fig_shap = ut.plot_shap_values(shap_values_mlp)
    # Plot shap values with min max years
    fig_shap_minmax = ut.plot_minmax_shap_values(shap_values_mlp, years_couples, Y_pred, test_years_df)

# Merge layout of the widgets
widgets_layout = HBox([VBox([basin_box, run_box]), update_button])
display(widgets_layout)
# Update plot based on the selected basin and run name
outputs = widgets.Output()

def on_button_clicked(b):
    with outputs:
        clear_output(wait=True)
        update_output(basin_widget.value, run_name_widget.value)

update_button.on_click(on_button_clicked)
display(outputs)