In [None]:
import pandas as pd
import numpy as np
from scipy import stats
import seaborn as sns
import pickle
import plotly.express as px
import statsmodels.formula.api as smf
import plotly.graph_objects as go
from scripts.python.routines.manifest import get_manifest
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout
from statsmodels.stats.multitest import multipletests
from scipy.stats import chi2_contingency
from scipy.stats import kruskal, mannwhitneyu
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=False)
from scipy.stats import mannwhitneyu, median_test
import matplotlib.pyplot as plt
import pathlib
from matplotlib.patches import Rectangle
from tqdm import tqdm
import plotly.colors
from src.utils.plot.bioinfokit import mhat, volcano
import gseapy as gp
import mygene
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, IncrementalPCA, KernelPCA, TruncatedSVD
from sklearn.decomposition import MiniBatchDictionaryLearning, FastICA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap, TSNE, LocallyLinearEmbedding
import upsetplot as upset
import missingno as msno
from pyod.models.lunar import LUNAR
from matplotlib_venn import venn2, venn2_circles
from glob import glob
from hydra import compose, initialize
from omegaconf import OmegaConf
import omegaconf
import os
import ast
import json
from sklearn.preprocessing import MinMaxScaler
from scripts.python.routines.plot.colorscales import get_continuous_color
from src.datamodules.cross_validation import RepeatedStratifiedKFoldCVSplitter
from src.datamodules.tabular import TabularDataModule
from lifelines import CoxPHFitter
from lifelines.statistics import proportional_hazard_test
from sklearn.metrics import mean_absolute_error
from scipy.stats import wilcoxon, friedmanchisquare
from suffix_trees import STree
from itertools import combinations

# 0. Setup

In [None]:
path_dataset = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
path_load = f"{path_dataset}/data/covid/treatment"
path_save = f"{path_dataset}/special/041_covid_treatment"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

color_cont = 'deepskyblue'

# 1. Empty features

In [None]:
path_local = f"001_empty_features"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path_load}/data_0.xlsx", index_col="patient_id")

series_n_nan = df.isna().sum()
df_nan = pd.DataFrame({'n_nans': series_n_nan.values}, index=series_n_nan.index)
df_nan.sort_values([f"n_nans"], ascending=[False], inplace=True)

hist_min = df_nan.loc[:, f"n_nans"].min()
hist_max = df_nan.loc[:, f"n_nans"].max()
hist_width = hist_max - hist_min
hist_n_bins = df_nan.loc[:, f"n_nans"].max()
hist_bin_width = hist_width / hist_n_bins

plt.figure()
sns.set_theme(style='whitegrid')
hist = sns.histplot(
    data=df_nan,
    x=f"n_nans",
    bins=hist_n_bins,
    binrange=(hist_min, hist_max),
    binwidth=hist_bin_width,
    discrete=True,
    edgecolor='k',
    linewidth=1
)
hist.set_xlabel("Number of missing values")
hist.set_ylabel("Number of features")
hist.set_title(f"Total features: {df.shape[1]}\nTotal samples: {df.shape[0]}")
plt.savefig(f"{path_save}/{path_local}/hist_n_nans.png", bbox_inches='tight')
plt.savefig(f"{path_save}/{path_local}/hist_n_nans.pdf", bbox_inches='tight')
plt.clf()

## Update features with nan info

In [None]:
df_feats = pd.read_excel(f"{path_load}/features.xlsx", index_col="feature")
df_feats.loc[df_feats.index, 'n_nans'] = df_nan.loc[df_feats.index, 'n_nans']
df_feats.loc[df_feats.index, 'percentage_nans'] = df_nan.loc[df_feats.index, 'n_nans'] / df.shape[0]
df_feats.to_excel(f"{path_load}/features.xlsx")

## Save filtered data

In [None]:
lim_exclude = 1675
feats_exclude = df_nan.index[df_nan["n_nans"] > lim_exclude].values
df.drop(feats_exclude, axis=1, inplace=True)
df['n_nans'] = df.isnull().sum(axis=1)
df.to_excel(f"{path_load}/data_exclude({lim_exclude}).xlsx")

# 2. Forms features

In [None]:
path_local = f"002_forms_features"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path_load}/data_0.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_load}/features.xlsx", index_col="feature")

forms = df_feats['form'].unique()

for form in forms:
    pathlib.Path(f"{path_save}/{path_local}/{form}").mkdir(parents=True, exist_ok=True)
    df_feats_form = df_feats.loc[(df_feats["form"] == form) & (df_feats["type"].isin(["cat", "cont"])), :]

    df_form = df.loc[:, df_feats_form.index]
    df_form.rename(columns=dict(zip(df_feats_form.index.values, df_feats_form["eng_title"].values)), inplace=True)
    feats_form = df_feats_form["eng_title"].values
    df_form['Missed features'] = df_form.isnull().sum(axis=1)

    msno.bar(
        df=df_form.loc[:, feats_form],
        label_rotation=90
    )
    plt.savefig(f"{path_save}/{path_local}/{form}/msno_bar.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{path_local}/{form}/msno_bar.pdf", bbox_inches='tight')
    plt.close()

    fig = plt.figure(figsize=(12, 0.4 * df_feats_form['eng_title'].value_counts(dropna=True).shape[0]))
    sns.set_theme(style='whitegrid', font_scale=1)
    bar = sns.barplot(
        data=df_feats_form,
        y='eng_title',
        x='percentage_nans',
        edgecolor='black',
        orient='h',
        palette=px.colors.qualitative.Alphabet,
        dodge=True
    )
    bar.set_xlabel("Part of NaNs")
    bar.set_ylabel("")
    bar.set_title(f"Features' missing values")
    plt.savefig(f"{path_save}/{path_local}/{form}/bar.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{path_local}/{form}/bar.pdf", bbox_inches='tight')
    plt.close()

    msno.matrix(
        df=df_form.loc[:, feats_form],
        label_rotation=90
    )
    plt.savefig(f"{path_save}/{path_local}/{form}/msno_mtx.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{path_local}/{form}/msno_mtx.pdf", bbox_inches='tight')
    plt.close()

    msno.matrix(
        df=df_form.sort_values([f"Missed features"], ascending=[False]).loc[:, feats_form],
        label_rotation=90
    )
    plt.savefig(f"{path_save}/{path_local}/{form}/msno_mtx_sorted.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{path_local}/{form}/msno_mtx_sorted.pdf", bbox_inches='tight')
    plt.close()

    for feat, row in df_feats_form.iterrows():
        if row['type'] == 'cat':

            pathlib.Path(f"{path_save}/{path_local}/{form}/cat").mkdir(parents=True, exist_ok=True)

            if not pd.isna(row['eng_values']):
                dict_values = ast.literal_eval(row['eng_values'])
                df_form.replace({row['eng_title']: dict_values}, inplace=True)
                palette = {x: px.colors.qualitative.Dark24[x_id] for x_id, x in enumerate(dict_values.values())}
                order = dict_values.values()
            else:
                palette = px.colors.qualitative.Dark24
                order = df_form[row['eng_title']].unique()

            fig = plt.figure(figsize=(12, 0.4 * df_form[row['eng_title']].value_counts(dropna=True).shape[0]))
            sns.set_theme(style='whitegrid', font_scale=1)
            countplot = sns.countplot(
                data=df_form,
                y=row['eng_title'],
                edgecolor='black',
                orient='h',
                palette=palette,
                order=order
            )
            countplot.bar_label(countplot.containers[0])
            countplot.set_xlabel("Count")
            countplot.set_ylabel("")
            countplot.set_title(f"{row['eng_title']} ({df_form[row['eng_title']].count()})")
            plt.savefig(f"{path_save}/{path_local}/{form}/cat/{feat}.png", bbox_inches='tight', dpi=400)
            plt.savefig(f"{path_save}/{path_local}/{form}/cat/{feat}.pdf", bbox_inches='tight')
            plt.close(fig)

        elif row['type'] == 'cont' and df_form[row['eng_title']].count() > 5:

            print(feat)

            pathlib.Path(f"{path_save}/{path_local}/{form}/cont").mkdir(parents=True, exist_ok=True)

            sns.set_theme(style='whitegrid')

            fig, (ax_box, ax_hist) = plt.subplots(2, sharex=True, gridspec_kw={"height_ratios": (.15, .85)})

            box = sns.boxplot(df_form[row['eng_title']].values, orient='h', flierprops={"marker": "x"}, ax=ax_box)
            box.set_xlabel("")
            box.set_yticks([])
            box.set_title(f"Total samples: {df_form[row['eng_title']].count()}")
            sns.despine(ax=ax_box, left=False, right=False, bottom=False, top=False)

            if not pd.isna(row['hist_bins']):
                hist_bins_raw = list(map(float, json.loads(row['hist_bins'])))
                hist_bins = np.linspace(hist_bins_raw[0], hist_bins_raw[1], int(hist_bins_raw[2]))

                hist = sns.histplot(
                    data=df_form,
                    x=row['eng_title'],
                    bins=hist_bins,
                    edgecolor='k',
                    linewidth=1,
                    ax=ax_hist
                )
            else:
                hist_n_bins = 20
                hist_min = df_form.loc[:, row['eng_title']].min()
                hist_max = df_form.loc[:, row['eng_title']].max()
                hist_width = hist_max - hist_min
                hist_bin_width = hist_width / hist_n_bins
                hist = sns.histplot(
                    data=df_form,
                    x=row['eng_title'],
                    bins=hist_n_bins,
                    binrange=(hist_min, hist_max),
                    binwidth=hist_bin_width,
                    discrete=False,
                    edgecolor='k',
                    linewidth=1,
                    ax=ax_hist
                )

            plt.savefig(f"{path_save}/{path_local}/{form}/cont/{feat}.png", bbox_inches='tight', dpi=400)
            plt.savefig(f"{path_save}/{path_local}/{form}/cont/{feat}.pdf", bbox_inches='tight')
            plt.close(fig)

# 3. Data filtering

In [None]:
path_local = f"003_data_filtering"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_load}/data_0.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_load}/features.xlsx", index_col="feature")

# Samples only with positive PCR
df = df.loc[df["f04v1_pcr_id"] == 2, :]

# Samples only without 'deterioration' and 'no changes'
df = df.loc[df["f07_patient_status_id"].isin([1,2,3]), :]

# Samples with correct diastolic and systolic
df = df.loc[df["f01_ads"] > df["f01_add"], :]
df = df.loc[df["f01_ads"] > 50, :]
df = df.loc[df["f01_add"] > 40, :]

# NaNs preprocessing
feats_drop_rows = df_feats.index[df_feats["preprocessing"] == "drop_rows_with_na"].values
df = df.dropna(subset=feats_drop_rows)

# Calculate missing values parts for filtered data
series_n_nan = df.isna().sum()
df_nan = pd.DataFrame({'n_nans': series_n_nan.values}, index=series_n_nan.index)
df_nan.sort_values([f"n_nans"], ascending=[False], inplace=True)
df_feats.loc[df_feats.index, 'n_nans'] = df_nan.loc[df_feats.index, 'n_nans']
df_feats.loc[df_feats.index, 'percentage_nans'] = df_nan.loc[df_feats.index, 'n_nans'] / df.shape[0]
df_feats.to_excel(f"{path_save}/{path_local}/feats.xlsx")

# Include feratures
feats_to_include = df_feats.index[df_feats["include"] == "yes"].values
df = df.loc[:, feats_to_include]
df.to_excel(f"{path_save}/{path_local}/data.xlsx")
print(df.shape[0])

# 4. Features plot

In [None]:
path_local = f"004_features"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

feat_groups = df_feats.loc[df_feats["include"] == "yes", "feat_group"].value_counts().index.values
for feat_group in feat_groups:
    pathlib.Path(f"{path_save}/{path_local}/{feat_group}").mkdir(parents=True, exist_ok=True)

    # Categorical features
    df_feats_group_cat = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["feat_group"] == feat_group) & (df_feats["type"].isin(["cat"])), :]
    df_group_cat = df.loc[:, df_feats_group_cat.index]
    df_group_cat.rename(columns=dict(zip(df_feats_group_cat.index.values, df_feats_group_cat["eng_title"].values)), inplace=True)

    feats = []
    height_ratios = []
    for feat in df_feats_group_cat.index.values:
        feat_title = df_feats_group_cat.at[feat, 'eng_title']
        n_cats = len(df_group_cat[feat_title].unique())
        if n_cats > 1:
            height_ratios.append(n_cats)
            feats.append(feat)

    fig, axs = plt.subplots(
        nrows=len(feats),
        ncols=1,
        sharex=True,
        figsize=(18, 0.7 * sum(height_ratios)),
        gridspec_kw={'height_ratios': height_ratios}
    )

    for feat_id, feat in enumerate(feats):
        feat_title = df_feats_group_cat.at[feat, 'eng_title']
        feat_values = df_feats_group_cat.at[feat, 'eng_values']
        if not pd.isna(feat_values):
            dict_values = ast.literal_eval(feat_values)
            df_group_cat.replace({feat_title: dict_values}, inplace=True)

            colors_str = df_feats_group_cat.at[feat, 'colors']
            if not pd.isna(colors_str):
                palette = ast.literal_eval(colors_str)
                order = list(palette.keys())
            else:
                palette = {x: px.colors.qualitative.Dark24[x_id] for x_id, x in enumerate(dict_values.values())}
                order = dict_values.values()

            # If some categories not exist
            cats_not_exist = list(set(palette.keys()) - set(df_group_cat[feat_title].unique()))
            if len(cats_not_exist) > 0:
                for cat in cats_not_exist:
                    palette.pop(cat, None)
                order = palette.keys()
        else:
            palette = px.colors.qualitative.Dark24
            order = df_group_cat[feat_title].unique()

        sns.set_theme(style='whitegrid', font_scale=1)
        countplot = sns.countplot(
            data=df_group_cat,
            y=feat_title,
            edgecolor='black',
            orient='h',
            palette=palette,
            order=order,
            ax=axs[feat_id]
        )
        countplot.bar_label(countplot.containers[0])
        if feat_id == len(feats) - 1:
            countplot.set_xlabel("Count", fontsize=20)
        else:
            countplot.set_xlabel("")
        countplot.set_ylabel(f"")
        countplot.set_title(f"{feat_title}", fontsize=20)

    fig.tight_layout()
    plt.savefig(f"{path_save}/{path_local}/{feat_group}/cat.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{path_local}/{feat_group}/cat.pdf", bbox_inches='tight')
    plt.close(fig)

    # Continious features
    df_feats_group_cont = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["feat_group"] == feat_group) & (df_feats["type"].isin(["cont"])), :]
    df_group_cont = df.loc[:, df_feats_group_cont.index]
    df_group_cont.rename(columns=dict(zip(df_feats_group_cont.index.values, df_feats_group_cont["eng_title"].values)), inplace=True)

    if df_feats_group_cont.shape[0] > 0:
        pathlib.Path(f"{path_save}/{path_local}/{feat_group}/cont").mkdir(parents=True, exist_ok=True)

    for feat, row in df_feats_group_cont.iterrows():
        sns.set_theme(style='whitegrid', font_scale=1)
        fig, (ax_box, ax_hist) = plt.subplots(2, sharex=True, gridspec_kw={"height_ratios": (.15, .85)})

        box = sns.boxplot(
            data=df_group_cont[row['eng_title']].values,
            orient='h',
            flierprops={"marker": "x"},
            ax=ax_box,
            color=color_cont
        )
        box.set_xlabel("")
        box.set_yticks([])
        sns.despine(ax=ax_box, left=False, right=False, bottom=False, top=False)

        if not pd.isna(row['hist_bins']):
            hist_bins_raw = list(map(float, json.loads(row['hist_bins'])))
            hist_bins = np.linspace(hist_bins_raw[0], hist_bins_raw[1], int(hist_bins_raw[2]))

            hist = sns.histplot(
                data=df_group_cont,
                x=row['eng_title'],
                bins=hist_bins,
                edgecolor='k',
                linewidth=1,
                ax=ax_hist,
                kde=True,
                color=color_cont
            )
        else:
            hist_n_bins = 50
            hist_min = df_group_cont.loc[:, row['eng_title']].min()
            hist_max = df_group_cont.loc[:, row['eng_title']].max()
            hist_width = hist_max - hist_min
            hist_bin_width = hist_width / hist_n_bins
            hist = sns.histplot(
                data=df_group_cont,
                x=row['eng_title'],
                bins=hist_n_bins,
                binrange=(hist_min, hist_max),
                binwidth=hist_bin_width,
                discrete=False,
                edgecolor='k',
                linewidth=1,
                ax=ax_hist,
                kde=True,
                color=color_cont
            )

        plt.subplots_adjust(hspace=0.03)

        plt.savefig(f"{path_save}/{path_local}/{feat_group}/cont/{feat}.png", bbox_inches='tight', dpi=400)
        plt.savefig(f"{path_save}/{path_local}/{feat_group}/cont/{feat}.pdf", bbox_inches='tight')
        plt.close(fig)

# 5. Continuous features correlation

In [None]:
path_local = f"005_cont_stat"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats_target = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization', 'Final'])) & (df_feats["type"].isin(["cont"])), :]
df_target = df.loc[:, df_feats_target.index]
df_target.rename(columns=dict(zip(df_feats_target.index.values, df_feats_target["eng_title"].values)), inplace=True)

feats = df_feats_target["eng_title"].values
df_corr_mtx = pd.DataFrame(data=np.zeros(shape=(len(feats), len(feats))), index=feats, columns=feats)
for f_id_1 in range(len(feats)):
    for f_id_2 in range(f_id_1, len(feats)):
        f_1 = feats[f_id_1]
        f_2 = feats[f_id_2]
        if f_id_1 != f_id_2:
            vals_1 = df_target.loc[:, f_1].values
            vals_2 = df_target.loc[:, f_2].values
            corr, pval = stats.pearsonr(vals_1, vals_2)
            df_corr_mtx.at[f_2, f_1] = pval
            df_corr_mtx.at[f_1, f_2] = corr
        else:
            df_corr_mtx.at[f_2, f_1] = np.nan
selection = np.tri(df_corr_mtx.shape[0], df_corr_mtx.shape[1], -1, dtype=np.bool)
df_fdr = df_corr_mtx.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_fdr['pval_fdr_bh_log'] = -np.log10(df_fdr.loc[:, 'pval_fdr_bh'].values)
df_fdr['color'] = MinMaxScaler().fit_transform(df_fdr.loc[:, 'pval_fdr_bh_log'].values.reshape(-1, 1))
df_mtx_fdr = df_corr_mtx.copy()
for line_id in range(df_fdr.shape[0]):
    df_mtx_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

sns.set_theme(style='whitegrid')
df_to_plot = df_mtx_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()

mtx_triu = np.triu(mtx_to_plot, +1)
max_corr = np.max(mtx_triu)
min_corr = np.min(mtx_triu)
mtx_triu_mask = np.ma.masked_array(mtx_triu, mtx_triu == 0)
cmap_triu = plt.get_cmap("bwr").copy()

mtx_tril = np.tril(mtx_to_plot, -1)
max_pval = np.max(mtx_tril)
min_pval = np.min(mtx_tril)
mtx_tril_mask = np.ma.masked_array(mtx_tril, mtx_tril == 0)
cmap_tril = plt.get_cmap("viridis").copy()
cmap_tril.set_under('black')

fig, ax = plt.subplots()

im_triu = ax.imshow(mtx_triu_mask, cmap=cmap_triu, vmin=-1, vmax=1)
cbar_triu = ax.figure.colorbar(im_triu, ax=ax, location='right', shrink=0.80)
cbar_triu.ax.tick_params(labelsize=8)
cbar_triu.set_label(r"$\mathrm{Correlation\:coefficient}$", horizontalalignment='center', fontsize=8)

im_tril = ax.imshow(mtx_tril_mask, cmap=cmap_tril, vmin=-np.log10(0.05))
cbar_tril = ax.figure.colorbar(im_tril, ax=ax, location='right', shrink=0.80)
cbar_tril.ax.tick_params(labelsize=8)
cbar_tril.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=8)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.ptp(mtx_tril.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = "black"
        if i > j:
            color = textcolors[int(mtx_tril[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]) or np.isnan(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=5)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=5)
fig.tight_layout()
plt.savefig(f"{path_save}/{path_local}/corr_mtx.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/{path_local}/corr_mtx.pdf", bbox_inches='tight', dpi=400)
plt.clf()
df_save = df_mtx_fdr
df_save.to_excel(f"{path_save}/{path_local}/corr_mtx.xlsx", index=True)


def plot_regression(x, y, **kwargs):
    df = pd.DataFrame({"x": x, "y": y})
    formula = "y ~ x"
    model = smf.ols(formula=formula, data=df).fit()
    df_line = pd.DataFrame({"x": [min(x), max(x)]})
    df_line["y"] = model.predict(df_line)
    plt.gca().plot(df_line['x'].values, df_line['y'].values, color='dimgrey', marker=None, linestyle='-', linewidth=4.0)
    plt.gca().plot(df_line['x'].values, df_line['y'].values, color=kwargs['color'], marker=None, linestyle='-', linewidth=2.0)


def corr(x, y, **kwargs):
    dict_row_id = kwargs['dict_id']
    row_id = dict_row_id['row_id']
    df_fdr = kwargs['df_fdr']
    pval = df_fdr['pval_fdr_bh'].values[row_id]
    corr, _ = stats.pearsonr(x, y)
    kwargs['dict_id']['row_id'] += 1

    ax = plt.gca()

    colors, _ = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Pinkyl)
    colorscale = plotly.colors.make_colorscale(colors)
    color_raw = get_continuous_color(colorscale, intermed=df_fdr['color'].values[row_id])
    color_tuple = color_raw[4:-1].split(', ')
    color = [float(x) / 255.0 for x in color_tuple]
    if df_fdr['pval_fdr_bh'].values[row_id] > 0.05:
        color = 'white'
    ax.set_facecolor(color)

    label = f"p-value:"
    ax.annotate(label, xy=(0.5, 0.72), size=25, xycoords=ax.transAxes, ha='center')
    label = f"{pval:0.2e}"
    ax.annotate(label, xy=(0.5, 0.52), size=25, xycoords=ax.transAxes, ha='center')
    label = r'$\rho$: ' + f"{corr:0.2f}"
    ax.annotate(label, xy = (0.5, 0.20), size=25, xycoords=ax.transAxes, ha='center')

sns.set_theme(style="whitegrid", font_scale=1.5)
pair_grid = sns.PairGrid(df_target, vars=feats)
pair_grid.map_upper(sns.scatterplot, color='deepskyblue', s=25, alpha=0.5, edgecolor='k', linewidth=0.2)
pair_grid.map_diag(sns.kdeplot, color='deepskyblue')
pair_grid.map_upper(plot_regression, color='blue')
pair_grid.map_lower(corr, dict_id={'row_id': 0}, df_fdr=df_fdr)
for x_axis_id in range(pair_grid.axes.shape[0]):
    for y_axis_id in range(pair_grid.axes.shape[1]):
        pair_grid.axes[x_axis_id, y_axis_id].spines[['right', 'top']].set_visible(True)
        if x_axis_id != y_axis_id:
            pass
        if x_axis_id > y_axis_id:
            pair_grid.axes[x_axis_id, y_axis_id].grid(False)

for ax in pair_grid.axes.flatten():
    # rotate x axis labels
    ax.set_xlabel(ax.get_xlabel(), rotation = 90, fontsize=30)
    # rotate y axis labels
    ax.set_ylabel(ax.get_ylabel(), rotation = 0, fontsize=30)
    # set y labels alignment
    ax.yaxis.get_label().set_horizontalalignment('right')

plt.savefig(f"{path_save}/{path_local}/scatter_mtx.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/{path_local}/scatter_mtx.pdf", bbox_inches='tight')
plt.clf()

# 6. Categorical features statistics

In [None]:
path_local = f"006_cat_stat"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats_t0 = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization'])) & (df_feats["type"].isin(["cat"])), :]
df_feats_t2 = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Final'])) & (df_feats["type"].isin(["cat"])), :]
feats_t0_t2 = list(set(df_feats_t0["eng_title"].values).intersection(set(df_feats_t2["eng_title"].values)))
df_feats_t0["eng_title"].replace({f"{x}": f"{x}" for x in feats_t0_t2}, inplace=True)
df_feats_t2["eng_title"].replace({f"{x}": f"{x} (Final)" for x in feats_t0_t2}, inplace=True)
df.rename(columns=dict(zip(df_feats_t0.index.values, df_feats_t0["eng_title"].values)), inplace=True)
df.rename(columns=dict(zip(df_feats_t2.index.values, df_feats_t2["eng_title"].values)), inplace=True)

feats_t0 = df_feats_t0["eng_title"].values
feats_t2 = df_feats_t2["eng_title"].values
df_stat_mtx_t0_t0 = pd.DataFrame(data=np.zeros(shape=(len(feats_t0), len(feats_t0))), index=feats_t0, columns=feats_t0)
for f_id_1 in range(len(feats_t0)):
    for f_id_2 in range(f_id_1, len(feats_t0)):
        f_1 = feats_t0[f_id_1]
        f_2 = feats_t0[f_id_2]
        if f_id_1 != f_id_2:
            df_cross = pd.crosstab(df[f_1], df[f_2])
            res = chi2_contingency(df_cross, correction=True)
            df_stat_mtx_t0_t0.at[f_2, f_1] = res.pvalue
            if res.statistic == 0:
                df_stat_mtx_t0_t0.at[f_1, f_2] = 1e-100
            else:
                df_stat_mtx_t0_t0.at[f_1, f_2] = res.statistic
        else:
            df_stat_mtx_t0_t0.at[f_2, f_1] = np.nan
selection = np.tri(df_stat_mtx_t0_t0.shape[0], df_stat_mtx_t0_t0.shape[1], -1, dtype=np.bool)
df_fdr = df_stat_mtx_t0_t0.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_fdr['pval_fdr_bh'].replace({1.0: 0.999999}, inplace=True)
df_mtx_fdr = df_stat_mtx_t0_t0.copy()
for line_id in range(df_fdr.shape[0]):
    df_mtx_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

sns.set_theme(style='whitegrid')
df_to_plot = df_mtx_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()

mtx_triu = np.triu(mtx_to_plot, +1)
max_corr = np.max(mtx_triu)
min_corr = np.min(mtx_triu)
mtx_triu_mask = np.ma.masked_array(mtx_triu, mtx_triu == 0)
cmap_triu = plt.get_cmap("hot").copy()

mtx_tril = np.tril(mtx_to_plot, -1)
max_pval = np.max(mtx_tril)
min_pval = np.min(mtx_tril)
mtx_tril_mask = np.ma.masked_array(mtx_tril, mtx_tril == 0)
cmap_tril = plt.get_cmap("viridis").copy()
cmap_tril.set_under('black')

fig, ax = plt.subplots(figsize=(8,6))

im_triu = ax.imshow(mtx_triu_mask, cmap=cmap_triu)
cbar_triu = ax.figure.colorbar(im_triu, ax=ax, location='right', shrink=0.80)
cbar_triu.ax.tick_params(labelsize=8)
cbar_triu.set_label(r'$\chi^2$', horizontalalignment='center', fontsize=8)

im_tril = ax.imshow(mtx_tril_mask, cmap=cmap_tril, vmin=-np.log10(0.05))
cbar_tril = ax.figure.colorbar(im_tril, ax=ax, location='right', shrink=0.80)
cbar_tril.ax.tick_params(labelsize=8)
cbar_tril.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=8)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.min(mtx_tril.flatten()) + np.ptp(mtx_tril.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = "black"
        if i > j:
            color = textcolors[int(mtx_tril[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]) or np.isnan(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=2.3)
            text = ax.text(i, j, f"", ha="center", va="center", color=color, fontsize=2.3)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=2.3)
            text = ax.text(i, j, f"{mtx_to_plot[j, i]:0.2f}", ha="center", va="center", color=color, fontsize=2.3)
fig.tight_layout()
plt.savefig(f"{path_save}/{path_local}/corr_mtx_t0_t0.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/{path_local}/corr_mtx_t0_t0.pdf", bbox_inches='tight', dpi=400)
plt.clf()
df_save = df_mtx_fdr
df_save.to_excel(f"{path_save}/{path_local}/corr_mtx_t0_t0.xlsx", index=True)


df_stat_mtx_t0_t2 = pd.DataFrame(data=np.zeros(shape=(len(feats_t2), len(feats_t0))), index=feats_t2, columns=feats_t0)
for f_id_1 in range(len(feats_t2)):
    for f_id_2 in range(len(feats_t0)):
        f_1 = feats_t2[f_id_1]
        f_2 = feats_t0[f_id_2]
        df_cross = pd.crosstab(df[f_1], df[f_2])
        res = chi2_contingency(df_cross, correction=True)
        df_stat_mtx_t0_t2.at[f_1, f_2] = res.pvalue
selection = np.ones((df_stat_mtx_t0_t2.shape[0], df_stat_mtx_t0_t2.shape[1]), dtype=np.bool)
df_fdr = df_stat_mtx_t0_t2.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_fdr['pval_fdr_bh'].replace({1.0: 0.999999}, inplace=True)
df_mtx_fdr = df_stat_mtx_t0_t2.copy()
for line_id in range(df_fdr.shape[0]):
    df_mtx_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

sns.set_theme(style='whitegrid')
df_to_plot = df_mtx_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()
cmap = plt.get_cmap("viridis").copy()
cmap.set_under('black')

fig, ax = plt.subplots(figsize=(8,5))
im = ax.imshow(mtx_to_plot, cmap=cmap, vmin=-np.log10(0.05))
cbar = ax.figure.colorbar(im, ax=ax, location='right', shrink=0.4)
cbar.ax.tick_params(labelsize=8)
cbar.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=8)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.min(mtx_to_plot.flatten()) + np.ptp(mtx_to_plot.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = textcolors[int(mtx_to_plot[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=2.75)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=2.75)
fig.tight_layout()
plt.savefig(f"{path_save}/{path_local}/corr_mtx_t0_t2.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/{path_local}/corr_mtx_t0_t2.pdf", bbox_inches='tight', dpi=400)
plt.clf()
df_save = df_mtx_fdr
df_save.to_excel(f"{path_save}/{path_local}/corr_mtx_t0_t2.xlsx", index=True)

# 7. Categorical and continuous features statistics

In [None]:
pval_lim = 1e-10

path_local = f"007_cat_cont_stat"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}/examples").mkdir(parents=True, exist_ok=True)

df_feats_cat_t0 = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization'])) & (df_feats["type"].isin(["cat"])), :]
df_feats_cat_t2 = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Final'])) & (df_feats["type"].isin(["cat"])), :]
df_feats_cont = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization', 'Final'])) & (df_feats["type"].isin(["cont"])), :]
df.rename(columns=dict(zip(df_feats_cont.index.values, df_feats_cont["eng_title"].values)), inplace=True)
feats_cat_t0_t2 = list(set(df_feats_cat_t0["eng_title"].values).intersection(set(df_feats_cat_t2["eng_title"].values)))
df_feats_cat_t0["eng_title"].replace({f"{x}": f"{x}" for x in feats_cat_t0_t2}, inplace=True)
df_feats_cat_t2["eng_title"].replace({f"{x}": f"{x} (Final)" for x in feats_cat_t0_t2}, inplace=True)
df.rename(columns=dict(zip(df_feats_cat_t0.index.values, df_feats_cat_t0["eng_title"].values)), inplace=True)
df.rename(columns=dict(zip(df_feats_cat_t2.index.values, df_feats_cat_t2["eng_title"].values)), inplace=True)
df_feats_now = pd.concat([df_feats_cat_t0, df_feats_cat_t2, df_feats_cont])
df_feats_now['index_old'] = df_feats_now.index.values
df_feats_now.set_index("eng_title", verify_integrity=True, inplace=True)

feats_cat = list(df_feats_cat_t2["eng_title"].values) + list(df_feats_cat_t0["eng_title"].values)

for feat_cat in feats_cat:
    feat_values = df_feats_now.at[feat_cat, 'eng_values']
    if not pd.isna(feat_values):
        dict_values = ast.literal_eval(feat_values)
        df.replace({feat_cat: dict_values}, inplace=True)

feats_cont = list(df_feats_cont["eng_title"].values)
df_stat_mtx = pd.DataFrame(data=np.zeros(shape=(len(feats_cont), len(feats_cat))), index=feats_cont, columns=feats_cat)
for f_id_1, f_1 in enumerate(feats_cont):
    for f_id_2, f_2 in enumerate(feats_cat):
        cats = df[f_2].unique()
        vals_dict = {}
        for cat in cats:
            vals_cat = df.loc[df[f_2] == cat, f_1].values
            vals_dict[cat] = vals_cat
        if len(cats) > 2:
            stat, pval = kruskal(*vals_dict.values())
        elif len(cats) == 2:
            stat, pval = mannwhitneyu(*vals_dict.values(), alternative='two-sided')
        else:
            pval = 1.0
        df_stat_mtx.at[f_1, f_2] = pval
selection = np.ones((df_stat_mtx.shape[0], df_stat_mtx.shape[1]), dtype=np.bool)
df_fdr = df_stat_mtx.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_fdr['pval_fdr_bh'].replace({1.0: 0.999999}, inplace=True)
df_mtx_fdr = df_stat_mtx.copy()
for line_id in range(df_fdr.shape[0]):
    df_mtx_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

sns.set_theme(style='whitegrid')
df_to_plot = df_mtx_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()
cmap = plt.get_cmap("viridis").copy()
cmap.set_under('black')

fig, ax = plt.subplots(figsize=(9,4))
im = ax.imshow(mtx_to_plot, cmap=cmap, vmin=-np.log10(0.05))
cbar = ax.figure.colorbar(im, ax=ax, location='right', shrink=0.5)
cbar.ax.tick_params(labelsize=8)
cbar.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=8)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.min(mtx_to_plot.flatten()) + np.ptp(mtx_to_plot.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = textcolors[int(mtx_to_plot[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=2.75)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=2.75)
fig.tight_layout()
plt.savefig(f"{path_save}/{path_local}/corr_mtx.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/{path_local}/corr_mtx.pdf", bbox_inches='tight', dpi=400)
plt.clf()
df_save = df_mtx_fdr
df_save.to_excel(f"{path_save}/{path_local}/corr_mtx.xlsx", index=True)

df_vio = df_fdr.sort_values(by=['pval_fdr_bh'], ascending=[True])
df_vio = df_vio.loc[df_vio['pval_fdr_bh'] < pval_lim, :]
for vio_id, (index, row) in enumerate(df_vio.iterrows()):
    feat_cat = row['col']
    feat_cont = row['row']
    pval = row['pval_fdr_bh']

    filename = f"{vio_id:03d}_{df_feats_now.at[feat_cont, 'index_old']}_{df_feats_now.at[feat_cat, 'index_old']}"

    feat_cat_values_raw = df_feats_now.at[feat_cat, 'eng_values']

    if not pd.isna(feat_cat_values_raw):
        feat_cat_values = ast.literal_eval(feat_cat_values_raw)

        colors_str = df_feats_now.at[feat_cat, 'colors']
        if not pd.isna(colors_str):
            palette = ast.literal_eval(colors_str)
            order = list(palette.keys())
        else:
            palette = {x: px.colors.qualitative.Dark24[x_id] for x_id, x in enumerate(feat_cat_values.values())}
            order = feat_cat_values.values()

        # If some categories not exist
        cats_not_exist = list(set(palette.keys()) - set(df[feat_cat].unique()))
        if len(cats_not_exist) > 0:
            for cat in cats_not_exist:
                palette.pop(cat, None)
            order = palette.keys()
    else:
        palette = px.colors.qualitative.Dark24
        order = df[feat_cat].unique()

    dist_num_bins = 20
    fig = go.Figure()
    for cat_val in order:
        vals = df.loc[df[feat_cat] == cat_val, feat_cont].values
        fig.add_trace(
            go.Violin(
                y=vals,
                name=cat_val,
                box_visible=True,
                meanline_visible=True,
                showlegend=False,
                line_color='black',
                fillcolor=palette[cat_val],
                marker=dict(color=palette[cat_val], line=dict(color='black', width=0.3), opacity=0.8),
                points='all',
                bandwidth=np.ptp(vals) / dist_num_bins,
                opacity=0.8
            )
        )
    add_layout(fig, f"{feat_cat}", f"{feat_cont}", f"p-value: {pval:0.2e}")
    fig.update_layout(title_xref='paper')
    fig.update_layout(
        width=120 + 300 * len(order),
        height=800,
        margin=go.layout.Margin(
            l=100,
            r=20,
            b=100,
            t=40,
            pad=0,
        )
    )
    fig.update_xaxes(autorange=False, range=[-0.5, len(order) - 0.7])
    save_figure(fig, f"{path_save}/{path_local}/examples/{filename}")

# 8. Status day prediction

## Generate data

In [None]:
path_local = f"008_status_day_prediction"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats_selected = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization', 'Final'])) & (df_feats["type"].isin(["cat", "cont"])), :]
df = df.loc[:, df_feats_selected.index.values]
df.loc[df['f07_patient_status_id'] != 3, "Split"] = "trn_val"
df.loc[df['f07_patient_status_id'] == 3, "Split"] = "tst"
df.to_excel(f"{path_save}/{path_local}/data.xlsx", index=True, index_label='index')

df_feats_cont = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization'])) & (df_feats["type"].isin(["cont"])), :]
df_feats_cont.to_excel(f"{path_save}/{path_local}/feats_cont.xlsx", index=True)
df_feats_cat = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization'])) & (df_feats["type"].isin(["cat"])), :]
df_feats_cat.to_excel(f"{path_save}/{path_local}/feats_cat.xlsx", index=True)

## Collect ML results

In [None]:
dataset = "GSEUNN"
path = f"D:/YandexDisk/Work/pydnameth/datasets"
datasets_info = pd.read_excel(f"{path}/datasets.xlsx", index_col='dataset')
platform = datasets_info.loc[dataset, 'platform']

model = 'lightgbm_trn_val_tst'

path_runs = f"{path}/{platform}/{dataset}/special/041_covid_treatment/008_status_day_prediction/models/{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics_val_best_*.xlsx")

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)
for file in files:
    # Validation
    df_val = pd.read_excel(file, index_col="metric")
    for metric in df_val.index.values:
        df_res.at[file, metric + "_val"] = df_val.at[metric, "val"]

    # Train
    head, tail = os.path.split(file)
    tail = tail.replace('val', 'trn')
    df_trn = pd.read_excel(f"{head}/{tail}", index_col="metric")
    for metric in df_trn.index.values:
        df_res.at[file, metric + "_trn"] = df_trn.at[metric, "trn"]

    # Test
    head, tail = os.path.split(file)
    tail = tail.replace('val', f'tst')
    df_tst = pd.read_excel(f"{head}/{tail}", index_col="metric")
    for metric in df_trn.index.values:
        df_res.at[file, metric + "_tst"] = df_tst.at[metric, "tst"]

    # Params
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val

first_columns = [
    'mean_absolute_error_trn',
    'mean_absolute_error_cv_mean_trn',
    'mean_absolute_error_val',
    'mean_absolute_error_cv_mean_val',
    'mean_absolute_error_tst'
]
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")

# 9. Status prediction

## Generate data

In [None]:
path_local = f"009_status_prediction"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats_selected = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization', 'Final'])) & (df_feats["type"].isin(["cat", "cont"])), :]
df = df.loc[:, df_feats_selected.index.values]
dict_values = {
    0: 'Survived',
    1: 'Survived',
    2: 'Survived',
    3: 'Lethal',
    4: 'Survived',
}
df.replace({'f07_patient_status_id': dict_values}, inplace=True)
df.loc[:, "Split"] = "trn_val"
df.to_excel(f"{path_save}/{path_local}/data.xlsx", index=True, index_label='index')

df_status = pd.DataFrame({'f07_patient_status_id': ['Survived', 'Lethal']})
df_status.to_excel(f"{path_save}/{path_local}/statuses.xlsx", index=False)

df_feats_cont = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization'])) & (df_feats["type"].isin(["cont"])), :]
df_feats_cont.to_excel(f"{path_save}/{path_local}/feats_cont.xlsx", index=True)
df_feats_cat = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization'])) & (df_feats["type"].isin(["cat"])), :]
df_feats_cat.to_excel(f"{path_save}/{path_local}/feats_cat.xlsx", index=True)

## Collect ML data

In [None]:
dataset = "GSEUNN"
path = f"D:/YandexDisk/Work/pydnameth/datasets"
datasets_info = pd.read_excel(f"{path}/datasets.xlsx", index_col='dataset')
platform = datasets_info.loc[dataset, 'platform']

model = 'widedeep_tab_resnet_trn_val_tst'

path_runs = f"{path}/{platform}/{dataset}/special/041_covid_treatment/009_status_prediction/models/{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics_val_best_*.xlsx")

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)
for file in files:
    # Validation
    df_val = pd.read_excel(file, index_col="metric")
    for metric in df_val.index.values:
        df_res.at[file, metric + "_val"] = df_val.at[metric, "val"]

    # Train
    head, tail = os.path.split(file)
    tail = tail.replace('val', 'trn')
    df_trn = pd.read_excel(f"{head}/{tail}", index_col="metric")
    for metric in df_trn.index.values:
        df_res.at[file, metric + "_trn"] = df_trn.at[metric, "trn"]

    # Params
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val

first_columns = [
    'f1_score_macro_trn',
    'f1_score_macro_cv_mean_trn',
    'f1_score_macro_val',
    'f1_score_macro_cv_mean_val'
]
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")

# 10. Survival analysis

## Generate data

In [None]:
path_local = f"010_survival_analysis"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats_selected = df_feats.loc[(df_feats["include"] == "yes") & (df_feats["time"].isin(['Hospitalization', 'Final'])) & (df_feats["type"].isin(["cat", "cont"])), :]

# Filter some features
feats_to_del = [
    "f01_is_skin",
    "f01_is_vac_grip_2019",
    "f01_is_vac_pncoc",
    "f03_is_diabets1",
    "f03_is_onco_g",
    "f03_is_onco_now",
    "f03_is_onco_chem",
]
df_feats_selected.drop(feats_to_del, inplace=True)

df = df.loc[:, df_feats_selected.index.values]
dict_values = {
    0: 'Survived',
    1: 'Survived',
    2: 'Survived',
    3: 'Lethal',
    4: 'Survived',
}
df.replace({'f07_patient_status_id': dict_values}, inplace=True)
df.loc[:, "Split"] = "trn_val"
df.to_excel(f"{path_save}/{path_local}/data.xlsx", index=True, index_label='index')

df_status = pd.DataFrame({'f07_patient_status_id': ['Survived', 'Lethal']})
df_status.to_excel(f"{path_save}/{path_local}/statuses_f07_patient_status_id.xlsx", index=False)

df_status = pd.DataFrame({'f07_is_citokin': ['No', 'Yes']})
df_status.to_excel(f"{path_save}/{path_local}/statuses_f07_is_citokin.xlsx", index=False)

df_status = pd.DataFrame({'f07_is_ords': ['No', 'Yes']})
df_status.to_excel(f"{path_save}/{path_local}/statuses_f07_is_ords.xlsx", index=False)

df_status = pd.DataFrame({'f07_is_opp': ['No', 'Yes']})
df_status.to_excel(f"{path_save}/{path_local}/statuses_f07_is_opp.xlsx", index=False)

df_status = pd.DataFrame({'f07_is_bac_pneumo': ['No', 'Yes']})
df_status.to_excel(f"{path_save}/{path_local}/statuses_f07_is_bac_pneumo.xlsx", index=False)

df_status = pd.DataFrame({'f07_is_sepsis': ['No', 'Yes']})
df_status.to_excel(f"{path_save}/{path_local}/statuses_f07_is_sepsis.xlsx", index=False)

df_feats_cont = df_feats_selected.loc[(df_feats_selected["include"] == "yes") & (df_feats_selected["time"].isin(['Hospitalization'])) & (df_feats_selected["type"].isin(["cont"])), :]
df_feats_cont.to_excel(f"{path_save}/{path_local}/feats_cont.xlsx", index=True)
df_feats_cat = df_feats_selected.loc[(df_feats_selected["include"] == "yes") & (df_feats_selected["time"].isin(['Hospitalization'])) & (df_feats_selected["type"].isin(["cat"])), :]
df_feats_cat.to_excel(f"{path_save}/{path_local}/feats_cat.xlsx", index=True)

## Collect ML data

In [None]:
target = "f07_patient_status_id"

dataset = "GSEUNN"
path = f"D:/YandexDisk/Work/pydnameth/datasets"
datasets_info = pd.read_excel(f"{path}/datasets.xlsx", index_col='dataset')
platform = datasets_info.loc[dataset, 'platform']

model = 'pycoxph_trn_val_tst'

path_runs = f"{path}/{platform}/{dataset}/special/041_covid_treatment/010_survival_analysis/{target}/models/{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics.xlsx")

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)
for file in files:

    # metrics
    df_metrics = pd.read_excel(file, index_col="metric")
    for metric in df_metrics.index.values:
        df_res.at[file, metric + "_trn"] = df_metrics.at[metric, "trn"]
        df_res.at[file, metric + "_trn_mean"] = df_metrics.at[metric, "trn_mean"]
        df_res.at[file, metric + "_val"] = df_metrics.at[metric, "val"]
        df_res.at[file, metric + "_val_mean"] = df_metrics.at[metric, "val_mean"]

    # Params
    head, tail = os.path.split(file)
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val

first_columns = [
    'ci_val',
    'ci_val_mean',
]
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")

# 11. Repeated measures

In [None]:
path_local = f"011_rep_meas"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path_save}/003_data_filtering/data.xlsx", index_col="patient_id")
df_feats = pd.read_excel(f"{path_save}/003_data_filtering/feats.xlsx", index_col="feature")

df_feats_group_cat = pd.read_excel(f"{path_save}/{path_local}/feats_cat.xlsx", index_col="feature")
df_feats_group_cont = pd.read_excel(f"{path_save}/{path_local}/feats_cont.xlsx", index_col="feature")

pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

feats_cat = df_feats.index[(df_feats["preprocessing"] == "repeated_measures") & (df_feats["type"] == "cat")].values
feats_cat_labels = df_feats.loc[feats_cat, "eng_title"].values
feats_con = df_feats.index[(df_feats["preprocessing"] == "repeated_measures") & (df_feats["type"] == "cont")].values
feats_con_labels = df_feats.loc[feats_con, "eng_title"].values

forms_t0 = ["f01", "f02", "f03", "f04v1", "f05v1", "calc"]
forms_t1 = ["f04v2", "f05v2", "f06"]

# NaNs preprocessing
feats_drop_rows = df_feats.index[df_feats["preprocessing"] == "repeated_measures"].values
df = df.dropna(subset=feats_drop_rows)

df_stat = pd.DataFrame()
df_stat_group_diff = pd.DataFrame()

for feat in feats_con_labels:
    feats_t0 = df_feats.index[(df_feats["eng_title"] == feat) & (df_feats["form"].isin(forms_t0))].values
    if len(feats_t0) == 1:
        feat_t0 = feats_t0[0]
    else:
        raise ValueError(f"Too many t0 features: {feats_t0}")
    feats_t1 = df_feats.index[(df_feats["eng_title"] == feat) & (df_feats["form"].isin(forms_t1))].values
    if len(feats_t1) == 1:
        feat_t1 = feats_t1[0]
    else:
        raise ValueError(f"Too many t0 features: {feats_t1}")

    feat_common_str = STree.STree([feat_t0, feat_t1]).lcs()[1::]

    for feat_group in df_feats_group_cat.index.values:
        df_fig = df.loc[:, [feat_t0, feat_t1, feat_group]].copy()
        feat_group_label = df_feats_group_cat.loc[feat_group, "eng_title"]
        feat_group_label_dict = ast.literal_eval(df_feats_group_cat.loc[feat_group, "eng_values"])

        pvals_log10 = []
        for group in feat_group_label_dict:
            res = wilcoxon(
                x=df_fig.loc[df_fig[feat_group] == group, feat_t0].values,
                y=df_fig.loc[df_fig[feat_group] == group, feat_t1].values,
                alternative='two-sided'
            )
            df_stat.at[f"{feat_common_str}_{feat_group}_{feat_group_label_dict[group]}", "stat"] = res.statistic
            df_stat.at[f"{feat_common_str}_{feat_group}_{feat_group_label_dict[group]}", "pval"] = res.pvalue
            pvals_log10.append(-np.log10(res.pvalue))
        res = wilcoxon(
                x=df_fig.loc[:, feat_t0].values,
                y=df_fig.loc[:, feat_t1].values,
                alternative='two-sided'
            )
        df_stat.at[f"{feat_common_str}", "stat"] = res.statistic
        df_stat.at[f"{feat_common_str}", "pval"] = res.pvalue

        pvals_log10_diffs = []
        for pv1, pv2 in combinations(pvals_log10, 2):
            pvals_log10_diffs.append(abs(pv2 - pv1))
        df_stat_group_diff.at[f"{feat}", f"{feat_group_label}"] = max(pvals_log10_diffs)

df_stat.to_excel(f"{path_save}/{path_local}/cont/stat.xlsx")
df_stat_group_diff.to_excel(f"{path_save}/{path_local}/cont/stat_group_diff.xlsx")

sns.set_theme(style='whitegrid')
df_to_plot = df_stat_group_diff.copy()
mtx_to_plot = df_to_plot.to_numpy()
cmap = plt.get_cmap("viridis").copy()

fig, ax = plt.subplots(figsize=(6,4))
im = ax.imshow(mtx_to_plot, cmap=cmap)
cbar = ax.figure.colorbar(im, ax=ax, location='right', shrink=1.0)
cbar.ax.tick_params(labelsize=7)
cbar.set_label(r"Max differences in p-values between the groups", horizontalalignment='center', fontsize=6)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.min(mtx_to_plot.flatten()) + np.ptp(mtx_to_plot.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=7)
ax.tick_params(axis='both', which='minor', labelsize=7)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = textcolors[int(mtx_to_plot[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=3.75)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=3.75)
fig.tight_layout()
plt.savefig(f"{path_save}/{path_local}/cont/stat_group_diff.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/{path_local}/cont/stat_group_diff.pdf", bbox_inches='tight', dpi=400)
plt.clf()

diffs_order = np.dstack(np.unravel_index(np.argsort(df_stat_group_diff.values, axis=None), df_stat_group_diff.values.shape))[0][::-1]
n_plots = 50

for diff_id, diff in enumerate(diffs_order[0:n_plots]):
    feat = df_stat_group_diff.index.values[diff[0]]
    feats_t0 = df_feats.index[(df_feats["eng_title"] == feat) & (df_feats["form"].isin(forms_t0))].values
    if len(feats_t0) == 1:
        feat_t0 = feats_t0[0]
    else:
        raise ValueError(f"Too many t0 features: {feats_t0}")
    feats_t1 = df_feats.index[(df_feats["eng_title"] == feat) & (df_feats["form"].isin(forms_t1))].values
    if len(feats_t1) == 1:
        feat_t1 = feats_t1[0]
    else:
        raise ValueError(f"Too many t0 features: {feats_t1}")
    feat_common_str = STree.STree([feat_t0, feat_t1]).lcs()[1::]

    feat_group = df_feats_group_cat.index.values[diff[1]]

    df_fig = df.loc[:, [feat_t0, feat_t1, feat_group]].copy()
    feat_group_label = df_feats_group_cat.loc[feat_group, "eng_title"]
    feat_group_label_dict = ast.literal_eval(df_feats_group_cat.loc[feat_group, "eng_values"])
    feat_group_color_dict = ast.literal_eval(df_feats_group_cat.loc[feat_group, "colors"])

    for group in feat_group_label_dict:
        pval = df_stat.at[f"{feat_common_str}_{feat_group}_{feat_group_label_dict[group]}", "pval"]
        group_name = feat_group_label_dict[group] + f" ({df_fig.loc[df_fig[feat_group] == group, :].shape[0]}), p-value: {pval:0.2e}"
        feat_group_color_dict[group_name] = feat_group_color_dict[feat_group_label_dict[group]]
        del feat_group_color_dict[feat_group_label_dict[group]]
        feat_group_label_dict[group] = group_name

    df_fig.rename(columns={feat_t0: "T0", feat_t1: "T1", feat_group: feat_group_label}, inplace=True)
    df_fig[feat_group_label].replace(feat_group_label_dict, inplace=True)

    df_fig["index"] = df_fig.index.values
    df_fig = df_fig.melt(id_vars=["index", feat_group_label], var_name='Time', value_name=feat)
    fig = plt.figure()
    sns.set_theme(style='whitegrid', font_scale=1)
    catplot = sns.catplot(
        data=df_fig,
        x="Time",
        y=feat,
        hue=feat_group_label,
        orient='v',
        palette=feat_group_color_dict,
        kind="box",
        legend_out=False
    )
    sns.despine(left=False, right=False, bottom=False, top=False)
    plt.legend(title=feat_group_label, bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=1)
    plt.savefig(f"{path_save}/{path_local}/cont/{diff_id:02d}_{feat_common_str}_by_{feat_group}.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path_save}/{path_local}/cont/{diff_id:02d}_{feat_common_str}_by_{feat_group}.pdf", bbox_inches='tight')
    plt.close(fig)