In [32]:
import os
import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo
import plotly.io as pio
from plotly.subplots import make_subplots
from dash import html


results_path = "../../Results/Multivariate_all_against_all/"
fig_path = results_path + "Figs/"

# %%
markers = ['wSMI_1_mean', 'wSMI_1_std', 'wSMI_2_mean', 'wSMI_2_std', 'wSMI_4_mean',
       'wSMI_4_std', 'wSMI_8_mean', 'wSMI_8_std', 'p_e_1_mean', 'p_e_1_std',
       'p_e_2_mean', 'p_e_2_std', 'p_e_4_mean', 'p_e_4_std', 'p_e_8_mean',
       'p_e_8_std', 'k_mean', 'k_std', 'se_mean', 'se_std', 'msf_mean',
       'msf_std', 'sef90_mean', 'sef90_std', 'sef95_mean', 'sef95_std',
       'b_mean', 'b_std', 'b_n_mean', 'b_n_std', 'g_mean', 'g_std', 'g_n_mean',
       'g_n_std', 't_mean', 't_std', 't_n_mean', 't_n_std', 'd_mean', 'd_std',
       'd_n_mean', 'd_n_std', 'a_n_mean', 'a_n_std', 'a_mean', 'a_std',
       'CNV_mean', 'CNV_std', 'P1_mean', 'P1_std', 'P3a_mean', 'P3a_std',
       'P3b_mean', 'P3b_std']

# plotting parameters
grey = "#21201F"
green = "#9AC529"
lblue = "#42B9B2"
pink = "#DE237B"
orange = "#F38A31"

colors = [pink,  green, orange, lblue]

comparisons = ['on-task_vs_mw','on-task_vs_dMW', 'on-task_vs_sMW', 'dMW_vs_sMW']
# comparisons = ['on-task_vs_mw',]




In [35]:
for i, contrast in enumerate(comparisons):

    df = pd.read_csv(os.path.join(results_path, f'{contrast}_PC_K4_trim_opt_trials.csv'), index_col = 0).dropna().drop_duplicates()

    # Select the top 10 models based on their AUC values
    top_models = df.nlargest(10, 'value')

    # Function to parse the feature importances from string to list
    def parse_importances(importances_str):
        values = importances_str.replace('[', '').replace(']', '').split()
        return [float(val) for val in values]

    # Applying the parsing function to the feature importances
    parsed_importances = top_models['feature_importances'].apply(parse_importances).tolist()

    # Feature names (replace this list with your actual feature names)
    feature_names = ['wSMI_1_mean', 'wSMI_1_std', 'wSMI_2_mean', 'wSMI_2_std', 'wSMI_4_mean',
                    'wSMI_4_std', 'wSMI_8_mean', 'wSMI_8_std', 'p_e_1_mean', 'p_e_1_std',
                    'p_e_2_mean', 'p_e_2_std', 'p_e_4_mean', 'p_e_4_std', 'p_e_8_mean',
                    'p_e_8_std', 'k_mean', 'k_std', 'se_mean', 'se_std', 'msf_mean',
                    'msf_std', 'sef90_mean', 'sef90_std', 'sef95_mean', 'sef95_std',
                    'b_mean', 'b_std', 'b_n_mean', 'b_n_std', 'g_mean', 'g_std', 'g_n_mean',
                    'g_n_std', 't_mean', 't_std', 't_n_mean', 't_n_std', 'd_mean', 'd_std',
                    'd_n_mean', 'd_n_std', 'a_n_mean', 'a_n_std', 'a_mean', 'a_std',
                    'CNV_mean', 'CNV_std', 'P1_mean', 'P1_std', 'P3a_mean', 'P3a_std',
                    'P3b_mean', 'P3b_std']

    # Creating a dataframe suitable for boxplot visualization
    boxplot_data = pd.DataFrame(parsed_importances, columns=feature_names).melt(var_name='Feature', value_name='Importance')

    # Creating the vertical boxplot using Plotly
    fig = px.box(boxplot_data, y='Feature', x='Importance', title=f"""FI of Top 10 Models {contrast} <br >
Average AUC {top_models.value.mean().round(3)} <br>
Best AUC {np.round(top_models.value.max(), 3)}
                 """, orientation='h')
    fig.update_layout(yaxis_title='Feature', xaxis_title='Importance', yaxis={'categoryorder':'total ascending'})
    fig.update_traces(marker_color=colors[i])
    fig.update_traces(marker=dict(size = 8))

    fig.update_layout(
        width=650,
        height=1100,
        template='plotly_white',
        font=dict(
            family="Times new roman",
            size=20,
            color="black"
        ),
        xaxis=dict(
            visible=True,
            range = [0,0.25], 
            tickfont={"size": 20},
            title='Feature Importance'
        ),
        yaxis=dict(
            categoryorder='total ascending',
            tickfont={"size": 20},
            automargin=True,
            range=[-1, len(markers)],
            dtick=1
        ),
        showlegend=True,
        margin=dict(l=50, r=50, t=150, b=50)  # Adjust the top margin
    )
    fig.show()
    
    filename = os.path.join(fig_path, f'{contrast}_top_10_feat_importances_PC_K4.png')
    fig.write_image(filename)
    fig.write_image(filename)