# Imports

In [None]:
import pandas as pd
import numpy as np
import scipy
from sklearn.linear_model import ElasticNet, ElasticNetCV
from sklearn.model_selection import RepeatedKFold, GridSearchCV
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scripts.python.routines.betas import betas_drop_na
from plotly.subplots import make_subplots
from scipy import stats
import pickle
import random
import plotly.express as px
import copy
import statsmodels.formula.api as smf
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scripts.python.pheno.datasets.filter import filter_pheno
from scripts.python.pheno.datasets.features import get_column_name, get_status_dict, get_sex_dict
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
import pathlib
from scripts.python.routines.manifest import get_manifest
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from scripts.python.routines.plot.p_value import add_p_value_annotation
from statsmodels.stats.multitest import multipletests
from sklearn.metrics import mean_absolute_error
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
from functools import reduce
from scipy.stats import kruskal, mannwhitneyu
from glob import glob
import os
import matplotlib.pyplot as plt
import pathlib

# Init data

In [None]:
path_save = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/033_immuno_ml_draft_figures"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

path_load = "E:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/021_ml_data/immuno"

df = pd.read_excel(f"{path_load}/260_imp(fast_knn)_replace(quarter).xlsx", index_col="index")
feats = pd.read_excel(f"{path_load}/feats_con.xlsx", index_col="features").index.values

# Data description (Participants) figures

In [None]:
cat_feat_colors = {"Sex": [("F", "red"), ("M", "blue")]}
feat_x = "Age"
bin_size = 5
for feat, fields in cat_feat_colors.items():
    fig = go.Figure()
    for val, color in fields:
        xs = df.loc[df[feat] == val, feat_x].values
        fig.add_trace(
            go.Histogram(
                x=xs,
                name=f"{val} ({len(xs)})",
                showlegend=True,
                marker=dict(
                    color=color,
                    opacity=0.75,
                    line=dict(
                        width=1,
                        color="black"
                    ),
                ),
                xbins=dict(size=bin_size)
            )
        )
    add_layout(fig, f"{feat_x}", "Count", "")
    fig.update_layout(
        margin=go.layout.Margin(l=90, r=20, b=75, t=50, pad=0),
        legend_font_size=20,
        legend={'itemsizing': 'constant'},
        barmode='overlay'
    )
    pathlib.Path(f"{path_save}/data_description_participants").mkdir(parents=True, exist_ok=True)
    save_figure(fig, f"{path_save}/data_description_participants/Histogram_cont({feat_x})_cat({feat})")

# Data description (Features) figures

## Generate data for figure

In [None]:
feats_plot = ["Age"] + list(feats)
df_corr = pd.DataFrame(data=np.zeros(shape=(len(feats_plot), len(feats_plot))), index=feats_plot, columns=feats_plot)
for f_id_1 in range(len(feats_plot)):
    for f_id_2 in range(f_id_1, len(feats_plot)):
        f_1 = feats_plot[f_id_1]
        f_2 = feats_plot[f_id_2]
        if f_id_1 != f_id_2:
            vals_1 = df.loc[:, f_1].values
            vals_2 = df.loc[:, f_2].values
            corr, pval = stats.pearsonr(vals_1, vals_2)
            df_corr.at[f_2, f_1] = pval
            df_corr.at[f_1, f_2] = corr
        else:
            df_corr.at[f_2, f_1] = np.nan
selection = np.tri(df_corr.shape[0], df_corr.shape[1], -1, dtype=np.bool)
df_fdr = df_corr.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_corr_fdr = df_corr.copy()
for line_id in range(df_fdr.shape[0]):
    df_corr_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

## Plot correlation matrix

In [None]:
df_to_plot = df_corr_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()

mtx_triu = np.triu(mtx_to_plot, +1)
max_corr = np.max(mtx_triu)
min_corr = np.min(mtx_triu)
mtx_triu_mask = np.ma.masked_array(mtx_triu, mtx_triu==0)
cmap_triu = plt.get_cmap("bwr").copy()

mtx_tril = np.tril(mtx_to_plot, -1)
mtx_tril_mask = np.ma.masked_array(mtx_tril, mtx_tril==0)
cmap_tril = plt.get_cmap("viridis").copy()
cmap_tril.set_under('black')

fig, ax = plt.subplots()

im_triu = ax.imshow(mtx_triu_mask, cmap=cmap_triu, vmin=-1, vmax=1)
cbar_triu = ax.figure.colorbar(im_triu, ax=ax, location='right')
cbar_triu.set_label(r"$\mathrm{Correlation\:coefficient}$", horizontalalignment='center', fontsize=10)

im_tril = ax.imshow(mtx_tril_mask, cmap=cmap_tril, vmin=-np.log10(0.05))
cbar_tril = ax.figure.colorbar(im_tril, ax=ax, location='right')
cbar_tril.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=10)

ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.ptp(mtx_tril.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = "black"
        if i > j:
            color = textcolors[int(mtx_tril[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]) or np.isnan(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=1.3)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=1.3)
fig.tight_layout()
pathlib.Path(f"{path_save}/data_description_features").mkdir(parents=True, exist_ok=True)
plt.savefig(f"{path_save}/data_description_features/corr_mtx_fdr.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/data_description_features/corr_mtx_fdr.pdf", bbox_inches='tight', dpi=400)
plt.clf()