In [None]:
import pandas as pd
import numpy as np
import scipy
from sklearn.linear_model import ElasticNet, ElasticNetCV
from sklearn.model_selection import RepeatedKFold, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scripts.python.routines.betas import betas_drop_na
from plotly.subplots import make_subplots
from numpy.ma import masked_array
from scipy import stats
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle
import random
import plotly.express as px
import copy
import statsmodels.formula.api as smf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scripts.python.pheno.datasets.filter import filter_pheno
from scripts.python.pheno.datasets.features import get_column_name, get_status_dict, get_sex_dict
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import pathlib
from scripts.python.routines.manifest import get_manifest
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from scripts.python.routines.plot.p_value import add_p_value_annotation
from statsmodels.stats.multitest import multipletests
from sklearn.metrics import mean_absolute_error
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
from pathlib import Path
from functools import reduce
from scipy.stats import chi2_contingency
from scipy.stats import kruskal, mannwhitneyu
from impyute.imputation.cs import fast_knn, mean, median, random, mice, mode, em

# Prepare data

In [None]:
path = "E:/YandexDisk/EEG/experiments"

exp_type = '1st_day'
exp_sub_type = 'quasi'
model = 'lightgbm'
exp_date = '2022-06-26_05-56-54'

path_load = f"{path}/{exp_type}/models/{exp_type}_{exp_sub_type}_trn_val_{model}/runs/{exp_date}"
path_save = f"{path}/special/001_32_channels_highlight"

df_data = pd.read_excel(f"{path_load}/feature_importances.xlsx", index_col='feature')

groups = {
    'Alpha_psd': ('Alpha PSD', px.colors.sequential.Reds[1:-1]),
    'Alpha_trp': ('Alpha TRP', px.colors.sequential.Reds[1:-1]),
    'paf': ('Peak Alpha Frequency', px.colors.sequential.Reds[1:-1]),
    'iaf': ('Individual Alpha Frequency', px.colors.sequential.Reds[1:-1]),
    'Beta_psd': ('Beta PSD', px.colors.sequential.Blues[1:-1]),
    'Beta_trp': ('Beta TRP', px.colors.sequential.Blues[1:-1]),
    'Gamma_psd': ('Gamma PSD', px.colors.sequential.Greens[1:-1]),
    'Gamma_trp': ('Gamma TRP', px.colors.sequential.Greens[1:-1]),
    'Theta_psd': ('Theta PSD', px.colors.sequential.Purples[1:-1]),
    'Theta_trp': ('Theta TRP', px.colors.sequential.Purples[1:-1]),
}



# Plotting

In [None]:
coordinates = {
    'Cz': (1.0, 1.0),
    'C3': (0.65, 1.0),
    'C4': (1.35, 1.0),
    'T7': (0.31, 1.0),
    'T8': (1.69, 1.0),
    'Fz': (1.0, 1.365),
    'Pz': (1.0, 0.635),
    'FC1': (0.835, 1.175),
    'FC2': (1.165, 1.175),
    'CP1': (0.835, 0.825),
    'CP2': (1.165, 0.825),
    'F3': (0.712, 1.40),
    'F4': (1.288, 1.40),
    'P3': (0.712, 0.60),
    'P4': (1.288, 0.60),
    'FC5': (0.505, 1.195),
    'FC6': (1.495, 1.195),
    'CP5': (0.505, 0.805),
    'CP6': (1.495, 0.805),
    'F7': (0.445, 1.445),
    'F8': (1.555, 1.445),
    'P7': (0.445, 0.555),
    'P8': (1.555, 0.555),
    'Fp1': (0.80, 1.69),
    'Fp2': (1.20, 1.69),
    'O1': (0.80, 0.31),
    'O2': (1.20, 0.31),
    'OZ': (1.0, 0.28),
    'FT9': (0.165, 1.245),
    'FC10': (1.835, 1.245),
    'TP9': (0.165, 0.755),
    'CP10': (1.835, 0.755),
}
titles = [v[0] for k, v in groups.items()]
fig = make_subplots(rows=5, cols=2, shared_yaxes=False, shared_xaxes=False)

n_rows = 5
n_cols = 2

colorbar_xs = [0.44, 0.99]
colorbar_ys = [0.925, 0.712, 0.4995, 0.2887, 0.0765]

for g_id, g in enumerate(groups):

    r_id, c_id = divmod(g_id, n_cols)

    df_group = df_data.loc[df_data.index.str.contains(g), :]
    df_group["electrode"] = df_group.index.str.replace(f"_{g}", "").values

    xs = [c[0] for c in coordinates.values()]
    ys = [c[1] for c in coordinates.values()]
    elecs = [c for c in coordinates]
    colors = df_group.loc[df_group["electrode"] + f"_{g}", 'importance'].values

    fig.add_shape(
        type="circle",
        xref="x", yref="y",
        x0=0.3, y0=0.28, x1=1.7, y1=1.72,
         line={
            'color': "black",
            'dash': 'dot',
            'width': 0.5
        },
        layer="below",
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_shape(
        type="circle",
        xref="x", yref="y",
        x0=0.13, y0=0.09, x1=1.87, y1=1.91,
        line={
            'color': "black",
            'dash': 'solid',
            'width': 0.5
        },
        layer="below",
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_shape(
        type="line",
        xref="x", yref="y",
        x0=0.13, y0=1, x1=1.87, y1=1,
        line={
            'color': "black",
            'dash': 'dot',
            'width': 0.5
        },
        layer="below",
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_shape(
        type="line",
        xref="x", yref="y",
        x0=1, y0=0.09, x1=1, y1=1.91,
        line={
            'color': "black",
            'dash': 'dot',
            'width': 0.5
        },
        layer="below",
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_annotation(
        x=1, y=1.99,
        text="NASION",
        showarrow=False,
        font=dict(color='black', size=20),
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_annotation(
        x=1, y=2.2,
        text=f"{groups[g][0]}",
        showarrow=False,
        font=dict(color='black', size=40),
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_annotation(
        x=1, y=1.815,
        text="FRONTAL",
        showarrow=False,
        font=dict(color='black', size=15),
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_annotation(
        x=1, y=0,
        text="INION",
        showarrow=False,
        font=dict(color='black', size=20),
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_annotation(
        x=0.07, y=1.43,
        text="LEFT",
        showarrow=False,
        font=dict(color='black', size=20),
        row=r_id + 1,
        col=c_id + 1
    )
    fig.add_annotation(
        x=1.93, y=1.43,
        text="RIGHT",
        showarrow=False,
        font=dict(color='black', size=20),
        row=r_id + 1,
        col=c_id + 1
    )

    fig.add_trace(
        go.Scatter(
            x=xs,
            y=ys,
            showlegend=False,
            mode='markers+text',
            marker=dict(
                size=37,
                opacity=1,
                line=dict(
                    width=1
                ),
                color=colors,
                colorscale=groups[g][1],
                showscale=True,
                colorbar=dict(
                    title=dict(text="", font=dict(size=20)), tickfont=dict(size=20),
                    x=colorbar_xs[c_id],
                    y=colorbar_ys[r_id],
                    len=0.115
                )
            ),
            text=elecs,
            textposition="bottom center",
            textfont=dict(
                family="arial",
                size=18,
                color="Black"
            ),
        ),
        row=r_id + 1,
        col=c_id + 1
    )

    fig.update_xaxes(
        row=r_id + 1,
        col=c_id + 1,
        autorange=False,
        range=[-0.2, 2.2],
        visible=False,
        title_text=f"",
        showgrid=False,
        zeroline=False,
        linecolor='black',
        showline=False,
        gridcolor='gainsboro',
        gridwidth=0.05,
        mirror=False,
        ticks='outside',
        titlefont=dict(
            color='black',
            size=20
        ),
        showticklabels=False,
        tickangle=0,
        tickfont=dict(
            color='black',
            size=20
        ),
        exponentformat='e',
        showexponent='all'
    )
    fig.update_yaxes(
        row=r_id + 1,
        col=c_id + 1,
        autorange=False,
        range=[-0.2, 2.2],
        visible=False,
        title_text="",
        showgrid=False,
        zeroline=False,
        linecolor='black',
        showline=False,
        gridcolor='gainsboro',
        gridwidth=0.05,
        mirror=False,
        ticks='outside',
        titlefont=dict(
            color='black',
            size=20
        ),
        showticklabels=False,
        tickangle=0,
        tickfont=dict(
            color='black',
            size=20
        ),
        exponentformat='e',
        showexponent='all'
    )
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.01,
        xanchor="center",
        x=0.5
    ),
    title=dict(
        text="",
        font=dict(size=60)
    ),
    template="none",
    autosize=False,
    width=1500,
    height=3500,
    margin=go.layout.Margin(
        l=100,
        r=100,
        b=100,
        t=100,
        pad=0
    )
)
fig.update_layout(legend_font_size=50)
fig.update_layout(legend={'itemsizing': 'constant'})
save_figure(fig, f"{path_save}/feature_importance")


In [None]:

for g in groups:
    Path(f"{path_save}/{g}").mkdir(parents=True, exist_ok=True)
    df_group = df_data.loc[df_data.index.str.contains(g), :]
    df_group["electrode"] = df_group.index.str.replace(f"_{g}", "").values

    xs = [c[0] for c in coordinates.values()]
    ys = [c[1] for c in coordinates.values()]
    elecs = [c for c in coordinates]
    colors = df_group.loc[df_group["electrode"] + f"_{g}", 'importance'].values
    fig = go.Figure()
    fig.add_shape(
        type="circle",
        xref="x", yref="y",
        x0=0.3, y0=0.28, x1=1.7, y1=1.72,
         line={
            'color': "black",
            'dash': 'dot',
            'width': 1
        },
        layer="below"
    )
    fig.add_shape(
        type="circle",
        xref="x", yref="y",
        x0=0.13, y0=0.09, x1=1.87, y1=1.91,
        line={
            'color': "black",
            'dash': 'solid',
            'width': 1
        },
        layer="below"
    )
    fig.add_shape(
        type="line",
        xref="x", yref="y",
        x0=0.13, y0=1, x1=1.87, y1=1,
        line={
            'color': "black",
            'dash': 'dot',
            'width': 1
        },
        layer="below"
    )
    fig.add_shape(
        type="line",
        xref="x", yref="y",
        x0=1, y0=0.09, x1=1, y1=1.91,
        line={
            'color': "black",
            'dash': 'dot',
            'width': 1
        },
        layer="below"
    )
    fig.add_annotation(
        x=1, y=1.99,
        text="NASION",
        showarrow=False,
        font=dict(color='black', size=20),
    )
    fig.add_annotation(
        x=1, y=1.815,
        text="FRONTAL",
        showarrow=False,
        font=dict(color='black', size=15),
    )
    fig.add_annotation(
        x=1, y=0,
        text="INION",
        showarrow=False,
        font=dict(color='black', size=20),
    )
    fig.add_annotation(
        x=0.07, y=1.43,
        text="LEFT",
        showarrow=False,
        font=dict(color='black', size=20),
    )
    fig.add_annotation(
        x=1.93, y=1.43,
        text="RIGHT",
        showarrow=False,
        font=dict(color='black', size=20),
    )

    fig.add_trace(
        go.Scatter(
            x=xs,
            y=ys,
            showlegend=False,
            mode='markers+text',
            marker=dict(
                size=37,
                opacity=1,
                line=dict(
                    width=1
                ),
                color=colors,
                colorscale=px.colors.sequential.Hot,
                showscale=True,
                colorbar=dict(title=dict(text="", font=dict(size=20)), tickfont=dict(size=20))
            ),
            text=elecs,
            textposition="bottom center",
            textfont=dict(
                family="arial",
                size=18,
                color="Black"
            ),
        )
    )

    add_layout(fig, "", f"", f"", font_size=20)
    fig.update_layout(legend_font_size=20)
    fig.update_xaxes(autorange=False)
    fig.update_xaxes(visible=False)
    fig.update_xaxes(range=[0, 2])
    fig.update_yaxes(autorange=False)
    fig.update_yaxes(visible=False)
    fig.update_yaxes(range=[0, 2])

    fig.update_layout(
        autosize=False,
        width=630,
        height=500
    )

    fig.update_layout(
        margin=go.layout.Margin(
            l=20,
            r=20,
            b=20,
            t=20,
            pad=0
        )
    )

    save_figure(fig, f"{path_save}/{g}/feature_importance")