In [None]:
import pandas as pd
import numpy as np
import scipy
from sklearn.linear_model import ElasticNet, ElasticNetCV
from sklearn.model_selection import RepeatedKFold, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scripts.python.routines.betas import betas_drop_na
from plotly.subplots import make_subplots
from numpy.ma import masked_array
from scipy import stats
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle
import random
import plotly.express as px
import copy
import statsmodels.formula.api as smf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scripts.python.pheno.datasets.filter import filter_pheno
from scripts.python.pheno.datasets.features import get_column_name, get_status_dict, get_sex_dict
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import pathlib
from scripts.python.routines.manifest import get_manifest
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from scripts.python.routines.plot.p_value import add_p_value_annotation
from statsmodels.stats.multitest import multipletests
from sklearn.metrics import mean_absolute_error
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
from pathlib import Path
from functools import reduce
from scipy.stats import chi2_contingency
from scipy.stats import kruskal, mannwhitneyu
from impyute.imputation.cs import fast_knn, mean, median, random, mice, mode, em
from sklearn.manifold import LocallyLinearEmbedding
from sklearn.decomposition import PCA
from glob import glob
import os
import functools
from tqdm.notebook import tqdm


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

# Plot distribution

In [None]:
path = "E:/YandexDisk/EEG/experiments"

exp_type = '1st_day' # ['1st_day', '2nd_day_sham', '2nd_day_tms']

exp_sub_type = 'im'

metric_thld = 0.70

path_load = f"{path}/{exp_type}"
path_save = f"{path}/special/003_subjects_distribution_in_val"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

df_data = pd.read_excel(f"{path_load}/data_new.xlsx", index_col='index')
subjects = sorted(df_data['subject'].unique(), key=lambda x: float(x[1::]))

files = glob(f"{path_load}/cv/{exp_type}_{exp_sub_type}_*/runs/*/cv_progress.xlsx")
dict_files = {}
for f in files:
    head, tail = os.path.split(f)
    dict_files[f] = f"{head}/cv_ids.xlsx"

metrics_dict = {
    'train_f1_score_weighted':metric_thld,
    'val_f1_score_weighted': metric_thld
}

dict_subjects = {s: 0 for s in subjects}
for fn_prog, fn_ids in tqdm(dict_files.items()):
    df_prog = pd.read_excel(f"{fn_prog}", index_col="fold")
    df_ids = pd.read_excel(f"{fn_ids}", index_col="index")

    conditions = [df_prog[metric] > threshold for metric, threshold in metrics_dict.items()]
    df_prog = df_prog[conjunction(conditions)]

    folds = df_prog.index.values

    for fold in folds:
        samples = df_ids.loc[df_ids[f"fold_{fold:04d}"]=='val', :].index.values
        subjects_passed = set(df_data.loc[samples, 'subject'].values)
        for subj in subjects_passed:
            dict_subjects[subj] += 1

fig = go.Figure()

for subj in dict_subjects:
    fig.add_trace(
        go.Bar(
            name=subj,
            x=[subj],
            y=[dict_subjects[subj]],
            text=f'{dict_subjects[subj]:d}',
            textposition='auto',
            orientation='v',
        )
    )
add_layout(fig, f"", "Times in Validation dataset", f"")
fig.update_layout({'colorway': px.colors.qualitative.Light24})
fig.update_layout(title_xref='paper')
fig.update_layout(
    autosize=False,
    margin=go.layout.Margin(
        l=100,
        r=20,
        b=50,
        t=20,
        pad=0
    )
)
fig.update_xaxes(tickfont_size=15)
fig.update_layout(showlegend=False)
fig.update_xaxes(showticklabels=True)
fig.update_traces(textposition='auto')
save_figure(fig, f"{path_save}/{exp_type}_{exp_sub_type}_{metric_thld}")
