# Description
Immunologic clock based on combined data (Ð¡entral region + Yakutia).
Previously, the clock was built only on data from the central region, here the clock is built on data from both regions.

In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import seaborn as sns
from glob import glob
import pathlib
from sklearn.metrics import mean_absolute_error
from scipy import stats
import patchworklib as pw
import os
import functools
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
import shap
from slugify import slugify
from src.models.tabular.widedeep.ft_transformer import WDFTTransformerModel
from art.estimators.regression.pytorch import PyTorchRegressor
from art.estimators.regression.blackbox import BlackBoxRegressor
from art.attacks.evasion import LowProFool, ZooAttack, FastGradientMethod
import torch
from sklearn.model_selection import RepeatedStratifiedKFold
from scripts.python.routines.sections import get_sections
import upsetplot


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

In [None]:
path = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
path_save = f"{path}/special/049_immuno_age_regression_extended/filtered"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path_save}/data.xlsx", index_col=0)
ids_central = df.index[(df['Controls/Cases'] == "Controls") & (df["Region"] == "Central")].values
ids_yakutia = df.index[(df['Controls/Cases'] == "Controls") & (df["Region"] == "Yakutia")].values

## Create splits of controls

In [None]:
n_splits = 4

for split_id in range(n_splits):
    df[f"Split_{split_id}"] = "tst_other"
    df.loc[df['Controls/Cases'] == "COVID-19 Acute and Dynamics", f"Split_{split_id}"] = "tst_covid19"
    df.loc[df['Controls/Cases'] == "ESRD", f"Split_{split_id}"] = "tst_esrd"
    df.loc[df['Controls/Cases'] == "Down Syndrome", f"Split_{split_id}"] = "tst_downs"

geo_parts = {
    'Cental': df.loc[ids_central, ['Age']].copy(),
    'Yakutia': df.loc[ids_yakutia, ['Age']].copy(),
}
for part, df_part in geo_parts.items():
    trgt = df_part.loc[:, "Age"].values
    ids = df_part.index.values

    ptp = np.ptp(trgt)
    num_bins = 10
    bins = np.linspace(np.min(trgt) - 0.1 * ptp, np.max(trgt) + 0.1 * ptp, num_bins + 1)
    binned = np.digitize(trgt, bins) - 1
    unique, counts = np.unique(binned, return_counts=True)
    occ = dict(zip(unique, counts))

    k_fold = RepeatedStratifiedKFold(
        n_splits=n_splits,
        n_repeats=1,
        random_state=1337
    )
    splits = k_fold.split(X=ids, y=binned, groups=binned)

    for split_id, (ids_trn, ids_val) in enumerate(splits):
        print(split_id)
        df.loc[ids[ids_trn], f"Split_{split_id}"] = "trn_val"
        df.loc[ids[ids_val], f"Split_{split_id}"] = "tst_ctrl"

df.to_excel(f"{path_save}/data_w_splits.xlsx", index_label="index")

hist_bins = np.linspace(0, 110, 12)

pathlib.Path(f"{path_save}/splits").mkdir(parents=True, exist_ok=True)

for split_id in range(n_splits):
    palette = {'trn_val': 'deepskyblue', 'tst_ctrl': 'crimson'}
    df_fig = df.loc[ids_central, :]
    fig = plt.figure()
    brick_0 = pw.Brick(figsize=(3, 2))
    sns.set_theme(style='whitegrid')
    sns.histplot(
        data=df_fig,
        hue_order=['trn_val', 'tst_ctrl'],
        bins=hist_bins,
        x="Age",
        hue=f"Split_{split_id}",
        edgecolor='black',
        palette=palette,
        multiple="stack",
        ax=brick_0
    )
    brick_0.set(xlim=(0, 110))
    
    palette = {'trn_val': 'blue', 'tst_ctrl': 'darkred'}
    df_fig = df.loc[ids_yakutia, :]
    brick_1 = pw.Brick(figsize=(3, 2))
    sns.set_theme(style='whitegrid')
    hist = sns.histplot(
        data=df_fig,
        hue_order=['trn_val', 'tst_ctrl'],
        bins=hist_bins,
        x="Age",
        hue=f"Split_{split_id}",
        edgecolor='black',
        palette=palette,
        multiple="stack",
        ax=brick_1
    )
    brick_1.set(xlim=(0, 110))
    
    pw_fig = brick_0 | brick_1
    pw_fig.savefig(f"{path_save}/splits/split_{split_id}.png")
    pw_fig.savefig(f"{path_save}/splits/split_{split_id}.pdf")

# Check SImAge samples

In [None]:
df_simage = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/draft/06_small_immuno_clocks/df_mapping.xlsx", index_col=0)
df_simage.index = df_simage.index.map(str)
ids_simage = df_simage.index[df_simage['new_index'].str.contains(r'trn_val|ctrl')].values
ids_filter = df.index[(df['Controls/Cases'] == "Controls") & (df["Region"] == "Central")].values
pathlib.Path(f"{path_save}/simage").mkdir(parents=True, exist_ok=True)

sections = get_sections([set(ids_simage), set(ids_filter)])
for sec in sections:
    df_sec = pd.DataFrame(index=list(sections[sec]))
    df_sec.to_excel(f"{path_save}/simage/{sec}.xlsx", index_label='gene')
    
dict_upset_lists = {
    "SImAge": ids_simage,
    "Filter": ids_filter,
}
upset_all = list(set().union(*list(dict_upset_lists.values())))
df_upset = pd.DataFrame(index=upset_all)
for k, v in dict_upset_lists.items():
    df_upset[k] = df_upset.index.isin(v)
df_upset = df_upset.set_index(list(dict_upset_lists.keys()))
fig = plt.figure()
upset_fig = upsetplot.UpSet(
    df_upset,
    sort_categories_by='input',
    subset_size='count',
    show_counts=True,
    min_degree=0,
    element_size=None,
    totals_plot_elements=3,
    include_empty_subsets=False
)
upset_fig.plot(fig)
plt.savefig(f"{path_save}/simage/upset.png", bbox_inches='tight')
plt.savefig(f"{path_save}/simage/upset.pdf", bbox_inches='tight')
plt.close()

ids_simage_not_cmn = list(set(ids_simage) - set(ids_filter))
ids_filter_not_cmn = list(set(ids_filter) - set(ids_simage))

df_nans = pd.read_excel(f"{path}/special/053_proof_that_immunodata_is_shit/filtered/02_nan_analysis/All/df_nan_samples.xlsx", index_col=0)
df_fig = df_nans.loc[ids_simage_not_cmn, :]
df_fig.sort_values(["Features with NaNs"], ascending=[False], inplace=True)
df_fig["Color"] = 'pink'
df_fig.loc[df_fig["Features with NaNs"] > 6, "Color"] = 'red'
plt.figure(figsize=(6, 12))
sns.set_theme(style='whitegrid')
barplot = sns.barplot(
    data=df_fig,
    x=f"Features with NaNs",
    y=df_fig.index,
    dodge=False,
    orient="h",
    edgecolor='k',
    linewidth=1,
    palette=df_fig['Color'].values,
)
barplot.set_ylabel("Samples")
plt.savefig(f"{path_save}/simage/ids_simage_not_cmn_nan.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/simage/ids_simage_not_cmn_nan.pdf", bbox_inches='tight')
plt.clf()
    
df_outs_iqr = pd.read_excel(f"{path}/special/053_proof_that_immunodata_is_shit/filtered/03_outliers/IQR/Status/Controls/df.xlsx", index_col=0)
df_fig = df_outs_iqr.loc[df_outs_iqr.index.isin(ids_simage_not_cmn), :]
df_fig.sort_values(["n_iqr_outs"], ascending=[False], inplace=True)
df_fig["Color"] = 'pink'
df_fig.loc[df_fig["n_iqr_outs"] > 6, "Color"] = 'red'
plt.figure(figsize=(6, 12))
sns.set_theme(style='whitegrid')
barplot = sns.barplot(
    data=df_fig,
    x=f"n_iqr_outs",
    y=df_fig.index,
    dodge=False,
    orient="h",
    edgecolor='k',
    linewidth=1,
    palette=df_fig['Color'].values,
)
barplot.set_ylabel("Samples")
barplot.set_xlabel("Number of IQR outliers")
plt.savefig(f"{path_save}/simage/ids_simage_not_cmn_out_iqr.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/simage/ids_simage_not_cmn_out_iqr.pdf", bbox_inches='tight')
plt.clf()

df_outs_pyod = pd.read_excel(f"{path}/special/053_proof_that_immunodata_is_shit/filtered/03_outliers/pyod_contam(0.1)_epochs(500)/Status/Controls/scaled/df.xlsx", index_col=0)
df_fig = df_outs_pyod.loc[df_outs_pyod.index.isin(ids_simage_not_cmn), :]
df_fig.sort_values(["Detections"], ascending=[False], inplace=True)
df_fig["Color"] = 'pink'
df_fig.loc[df_fig["Detections"] > 6, "Color"] = 'red'
plt.figure(figsize=(6, 12))
sns.set_theme(style='whitegrid')
barplot = sns.barplot(
    data=df_fig,
    x=f"Detections",
    y=df_fig.index,
    dodge=False,
    orient="h",
    edgecolor='k',
    linewidth=1,
    palette=df_fig['Color'].values,
)
barplot.set_ylabel("Samples")
plt.savefig(f"{path_save}/simage/ids_simage_not_cmn_out_pyod.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/simage/ids_simage_not_cmn_out_pyod.pdf", bbox_inches='tight')
plt.clf()

feats_simage = pd.read_excel(f"{path}/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values
df_outs_iqr = pd.read_excel(f"{path}/special/053_proof_that_immunodata_is_shit/filtered/03_outliers/IQR/Status/Controls/df.xlsx", index_col=0)
df_fig = df_outs_iqr.loc[df_outs_iqr.index.isin(ids_filter_not_cmn), :]
df_fig.sort_values(["n_iqr_outs"], ascending=[False], inplace=True)
df_fig.drop('n_iqr_outs', axis=1, inplace=True)
df_fig.replace({False: 0, True: 1}, inplace=True)
df_fig.columns = df_fig.columns.str.replace(r'_iqr_out', '')
feats_filter = df_fig.columns.values
colors = {}
for x in feats_filter:
    if x in feats_simage:
        colors[x] = 'red'
    else:
        colors[x] = 'pink'
sns.set_theme(style='whitegrid')
barplot = df_fig.iloc[::-1].plot(
    figsize=(6, 12),
    width=1,
    kind='barh',
    stacked=True,
    color=colors,
    edgecolor='black',
)
barplot.set_xlabel("Number of detections as outlier in different methods")
barplot.set_ylabel("Samples")
sns.move_legend(barplot, "upper left", bbox_to_anchor=(1, 1))
plt.savefig(f"{path_save}/simage/ids_filter_not_cmn_out_iqr.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/simage/ids_filter_not_cmn_out_iqr.pdf", bbox_inches='tight')
plt.close()

df_all = pd.read_excel(f"{path}/data/immuno/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)

df_not_common = df_all.loc[list(set(ids_simage_not_cmn).union(set(ids_filter_not_cmn))), :]
df_not_common['Dataset'] = ''
feats_filter = pd.read_excel(f"{path_save}/simage/feats.xlsx", index_col=0).index.values
df_not_common.loc[ids_filter_not_cmn, 'Dataset'] = 'Filtered'
df_not_common.loc[ids_simage_not_cmn, 'Dataset'] = 'SImAge'
colors = {'Filtered': 'crimson', 'SImAge': 'dodgerblue'}

hist_bins = np.linspace(5, 115, 23)
fig = plt.figure(figsize=(6, 4))
sns.set_theme(style='whitegrid')
histplot = sns.histplot(
    data=df_not_common,
    hue_order=list(colors.keys())[::-1],
    bins=hist_bins,
    x="Age",
    hue="Dataset",
    edgecolor='black',
    palette=colors,
    multiple="stack"
)
sns.move_legend(
    histplot, "lower center",
    bbox_to_anchor=(.5, 1), ncol=2, title="Dataset", frameon=True,
)
plt.setp(histplot.get_legend().get_texts(), fontsize='7') # for legend text
plt.setp(histplot.get_legend().get_title(), fontsize='10')
plt.savefig(f"{path_save}/simage/ids_not_cmn_histplot.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/simage/ids_not_cmn_histplot.pdf", bbox_inches='tight')
plt.close(fig)

df_stat = pd.DataFrame(index=list(feats_filter))
for feat in list(feats_filter) + ['SImAge acceleration']:
    vals = {}
    for group in ['Filtered', 'SImAge']:
        vals[group] = df_not_common.loc[df_not_common['Dataset'] == group, feat].values
        df_stat.at[feat, f"mean_{group}"] = np.mean(vals[group])
        df_stat.at[feat, f"median_{group}"] = np.median(vals[group])
        df_stat.at[feat, f"q75_{group}"], df_stat.at[feat, f"q25_{group}"] = np.percentile(vals[group], [75 , 25])
        df_stat.at[feat, f"iqr_{group}"] = df_stat.at[feat, f"q75_{group}"] - df_stat.at[feat, f"q25_{group}"]
        if feat == 'SImAge acceleration':
            df_stat.at[feat, f"MAE_{group}"] = mean_absolute_error(df_not_common.loc[df_not_common['Dataset'] == group, 'Age'].values, df_not_common.loc[df_not_common['Dataset'] == group, 'SImAge'].values)
    _, df_stat.at[feat, "mw_pval"] = mannwhitneyu(vals['Filtered'], vals['SImAge'], alternative='two-sided')

_, df_stat.loc[feats_filter, "mw_pval_fdr_bh"], _, _ = multipletests(df_stat.loc[feats_filter, "mw_pval"], 0.05, method='fdr_bh')
df_stat.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_stat.to_excel(f"{path_save}/simage/ids_not_cmn_mw.xlsx", index_label='Features')

feat = 'SImAge acceleration'
plt.figure(figsize=(6, 4))
sns.set_theme(style='whitegrid')
violin = sns.violinplot(
    data=df_not_common,
    x='Dataset',
    y=feat,
    palette=colors,
    scale='width',
    order=list(colors.keys()),
    saturation=0.75,
)
violin.set_xlabel(f"Dataset")
mw_pval = df_stat.at[feat, "mw_pval"]
pval_formatted = [f'{mw_pval:.2e}']
annotator = Annotator(
    violin,
    pairs=[('SImAge', 'Filtered')],
    data=df_not_common,
    x='Dataset',
    y=feat,
    order=list(colors.keys())
)
annotator.set_custom_annotations(pval_formatted)
annotator.configure(loc='outside')
annotator.annotate()
plt.savefig(f"{path_save}/simage/ids_not_cmn {feat}.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/simage/ids_not_cmn {feat}.pdf", bbox_inches='tight')
plt.close()

feats_sorted = df_stat.index[df_stat.index.isin(feats_filter)].values
axs = {}
pw_rows = []
n_cols = 5
n_rows = int(np.ceil(len(feats_sorted) / n_cols))
for r_id in range(n_rows):
    pw_cols = []
    for c_id in range(n_cols):
        rc_id = r_id * n_cols + c_id
        if rc_id < len(feats_sorted):
            feat = feats_sorted[rc_id]
            axs[feat] = pw.Brick(figsize=(3, 2))
            sns.set_theme(style='whitegrid')
            sns.violinplot(
                data=df_not_common,
                x='Dataset',
                y=feat,
                palette=colors,
                scale='width',
                order=list(colors.keys()),
                saturation=0.75,
                ax=axs[feat]
            )
            axs[feat].set_ylabel(feat)
            axs[feat].set_xlabel(f"Dataset")
            mw_pval = df_stat.at[feat, "mw_pval_fdr_bh"]
            pval_formatted = [f'{mw_pval:.2e}']
            annotator = Annotator(
                axs[feat],
                pairs=[('SImAge', 'Filtered')],
                data=df_not_common,
                x='Dataset',
                y=feat,
                order=list(colors.keys()),
            )
            annotator.set_custom_annotations(pval_formatted)
            annotator.configure(loc='outside')
            annotator.annotate()
            pw_cols.append(axs[feat])
        else:
            empty_fig = pw.Brick(figsize=(3.6, 2))
            empty_fig.axis('off')
            pw_cols.append(empty_fig)
    pw_rows.append(pw.stack(pw_cols, operator="|"))
pw_fig = pw.stack(pw_rows, operator="/")
pw_fig.savefig(f"{path_save}/simage/ids_not_cmn_mw_feats.pdf")
pw_fig.savefig(f"{path_save}/simage/ids_not_cmn_mw_feats.png")

df_stat = pd.DataFrame(index=list(feats_filter))
for group in colors.keys():
    for feat in feats_filter:
        xs = df_not_common.loc[df_not_common['Dataset'] == group, feat].values
        ys = df_not_common.loc[df_not_common['Dataset'] == group, 'SImAge'].values
        df_stat.at[feat, f"{group}_corr"], df_stat.at[feat, f"{group}_pval"] = stats.pearsonr(xs, ys, alternative='two-sided')
    _, df_stat[f"{group}_pval_fdr_bh"], _, _ = multipletests(df_stat[f"{group}_pval"], 0.05, method='fdr_bh')
    df_stat[f"{group}_pval_fdr_bh_log"] = -np.log10(df_stat[f"{group}_pval_fdr_bh"].values)
    df_stat["Color"] = 'white'
    df_stat.loc[df_stat[f"{group}_pval_fdr_bh"] < 0.05, 'Color'] = colors[group]
    df_stat.loc[df_stat[f"{group}_pval_fdr_bh"] >= 0.05, 'Color'] = 'gray'
    df_stat.sort_values([f"{group}_pval_fdr_bh"], ascending=[True], inplace=True)
    plt.figure(figsize=(12, 6))
    plt.xticks(rotation=90)
    sns.set_theme(style='white')
    barplot = sns.barplot(
        data=df_stat,
        x=df_stat.index,
        y=f"{group}_pval_fdr_bh_log",
        edgecolor='black',
        palette=df_stat['Color'].values,
        dodge=False
    )
    barplot.set_ylabel(r'$-\log_{10}(\mathrm{p-value})$')
    barplot.set_title(f"{group} ({len(df_not_common.index[df_not_common['Dataset'] == group])})")
    plt.savefig(f"{path_save}/simage/ids_not_cmn_{group}_pearson.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/simage/ids_not_cmn_{group}_pearson.pdf", bbox_inches='tight')
    plt.close()
df_stat.to_excel(f"{path_save}/simage/ids_not_cmn_pearson.xlsx", index_label='Features')

## Transfer train val test from SImAge

In [None]:
df_base = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/data/immuno/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
df_base['Split'] = 'tst_filtered'
df_base.loc[df_simage.index[df_simage['new_index'].str.contains('trn_val')], 'Split'] = 'trn_val'
df_base.loc[df_simage.index[df_simage['new_index'].str.contains('tst_ctrl')], 'Split'] = 'tst_ctrl_central'
for part in ['tst_ctrl_yakutia', 'tst_covid19', 'tst_esrd', 'tst_downs', 'tst_other']:
    df_base.loc[df.index[df['Split'] == part].values, 'Split'] = part
df_base.to_excel(f"{path_save}/simage/data.xlsx", index_label='index')

# Generate results for cleanlab

# Collect ML results

In [None]:
model = 'widedeep_ft_transformer_trn_val_tst'

path_runs = f"{path_save}/models/{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics_all_best_*.xlsx")

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)

parts = [
    'trn',
    'val',
    'tst_ctrl_central',
    'tst_ctrl_yakutia',
    'tst_covid19',
    'tst_downs',
    'tst_esrd',
    'trn_val',
    'trn_val_tst_ctrl_central',
    'trn_val_tst_ctrl_yakutia',
    'val_tst_ctrl_central',
    'val_tst_ctrl_yakutia'
]

for file in files:
    head, tail = os.path.split(file)
    df_preds = pd.read_excel(f"{head}/predictions.xlsx", index_col=0)
    # Metrics
    df_metrics = pd.read_excel(file, index_col="metric")
    for metric in df_metrics.index.values:
        for part in parts:
            df_res.at[file, metric + f"_{part}"] = df_metrics.at[metric, part]
    
    # Params
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val
        

df_res["train_more_val"] = False
df_res["selected"] = False
df_res.loc[df_res["mean_absolute_error_trn"] > df_res["mean_absolute_error_val"], "train_more_val"] = True

first_columns = [
    'selected',
    'train_more_val',
    'mean_absolute_error_trn',
    'mean_absolute_error_val',
    'mean_absolute_error_tst_ctrl_central',
    'mean_absolute_error_tst_ctrl_yakutia',
    'mean_absolute_error_val_tst_ctrl_central',
    'mean_absolute_error_val_tst_ctrl_yakutia',
    'mean_absolute_error_trn_val_tst_ctrl_central',
    'mean_absolute_error_trn_val_tst_ctrl_yakutia',
    'pearson_corr_coef_trn',
    'pearson_corr_coef_val',
    'pearson_corr_coef_tst_ctrl_central',
    'pearson_corr_coef_tst_ctrl_yakutia',
    'pearson_corr_coef_val_tst_ctrl_central',
    'pearson_corr_coef_trn_val_tst_ctrl_central',
    'pearson_corr_coef_val_tst_ctrl_yakutia',
    'pearson_corr_coef_trn_val_tst_ctrl_yakutia',
    'mean_absolute_error_cv_mean_trn',
    'mean_absolute_error_cv_std_trn',
    'mean_absolute_error_cv_mean_val',
    'mean_absolute_error_cv_std_val',
    'pearson_corr_coef_cv_mean_trn',
    'pearson_corr_coef_cv_std_trn',
    'pearson_corr_coef_cv_mean_val',
    'pearson_corr_coef_cv_std_val',
]
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.index = df_res.index.str.replace(path_runs, '', regex=True)
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")