# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
from statsmodels.stats.multitest import multipletests
import torch
from scipy.stats import mannwhitneyu
import pickle
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
#from statannotations.Annotator import Annotator
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from matplotlib import patheffects as pe
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
import pathlib
from src.models.simage.tabular.widedeep.ft_transformer import WDFTTransformerModel
import os
from tqdm import tqdm
import optuna
from functools import reduce 
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# Load data

## Load SImAge

In [None]:
path_simage = f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/data/immuno/models/SImAge/best_fold_0002.ckpt"
feats_simage = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values
model_simage = WDFTTransformerModel.load_from_checkpoint(checkpoint_path=path_simage)
model_simage.eval()
model_simage.freeze()
model_simage.to('cpu')

## Load immuno models

In [None]:
%%capture

epi_data_type = 'no_harm'
imm_data_type = 'imp_source(imm)_method(miceforest)_params(2)'

selection_method = 'f_regression' # 'f_regression' 'spearman' 'mrmr'
n_feats = 100
path_imm_models = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}"

models_type = "trials_models(GANDALF)_tst(5)_val(5)_CosineAnnealingWarmRestarts_yeo-johnson"

models = {}
models_paths = {}

n_combos_total = 1
for f in feats_simage:
    models[f] = {}
    models_paths[f] = []
    paths = glob(f"{path_imm_models}/{f}/pytorch_tabular/candidates/{models_type}/*/model.ckpt")
    for p in paths:
        head, tail = os.path.split(p)
        head = head.replace('\\', '/')
        models_paths[f].append(head)
        models[f][head] = TabularModel.load_model(f"{head}")
    print(len(models_paths[f]))
    n_combos_total *= len(models_paths[f])

In [None]:
print(n_combos_total)

## Load data

In [None]:
path_data = f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/060_EpiSImAge"

ages_pc = {
    'PCHorvath1': 'PC-Horvath',
    'PCHorvath2': 'PC-SkinBloodAge',
    'PCHannum': 'PC-Hannum',
    'PCPhenoAge': 'PC-PhenoAge',
    'PCGrimAge': 'PC-GrimAge',
}
pace = 'DunedinPACE'
ages_calc = {
    'DNAmAge': 'Horvath',
    'DNAmAgeHannum': 'Hannum',
    'DNAmPhenoAge': 'PhenoAge',
    'DNAmAgeSkinBloodClock': 'SkinBloodAge',
    'DNAmGrimAgeBasedOnRealAge': 'GrimAge',
    'DNAmGrimAge2BasedOnRealAge': 'GrimAge2',
}
ages_all = list(ages_pc.values()) + list(ages_calc.values()) + [pace]

feats_epi_dict = {}
feats_epi_all = set()
for f in feats_simage:
    feats_epi_dict[f] = pd.read_excel(f"{path_imm_models}/{f}/feats_con.xlsx", index_col=0).index.values
    feats_epi_all.update(feats_epi_dict[f])
feats_epi_all = list(feats_epi_all)
print(len(feats_epi_all))

data_parts = ['GSEUNN', 'GSE87571', 'GSE40279', 'GSE179325', 'GSE217633', 'GSE118144', 'GSE42861', 'GSE106648', 'GSE67530', 'GSE77696']
dfs = []
for data_part in tqdm(data_parts):
    if data_part == 'GSEUNN':
        df_betas = pd.read_pickle(f"{path_data}/{data_part}/{epi_data_type}/betas.pkl")
        df_betas = df_betas.loc[:, feats_epi_all]
        df_pheno = pd.read_csv(f"{path_data}/{data_part}/{epi_data_type}/pheno.csv", index_col=0)
        df_pheno.rename(columns=ages_pc, inplace=True)
        df_horvath = pd.read_csv(glob(f"{path_data}/{data_part}/{epi_data_type}/DNAmAgeCalcProject_*_Results.csv")[0], index_col=0)
        for age_col, age_label in ages_calc.items():
            df_pheno.loc[df_pheno.index.values, age_label] = df_horvath.loc[df_pheno.index.values, age_col]
        df_pheno.set_index("index", inplace=True)
        cols_pheno = ['Age', 'Sex', 'Status', 'Region'] + ages_all
        df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
    else:
        df_betas = pd.read_pickle(f"{path_data}/{data_part}/betas.pkl")
        df_betas = df_betas.loc[:, feats_epi_all]
        df_pheno = pd.read_csv(f"{path_data}/{data_part}/pheno.csv", index_col=0)
        df_pheno.rename(columns=ages_pc, inplace=True)
        df_horvath = pd.read_csv(glob(f"{path_data}/{data_part}/DNAmAgeCalcProject_*_Results.csv")[0], index_col=0)
        for age_col, age_label in ages_calc.items():
            df_pheno.loc[df_pheno.index.values, age_label] = df_horvath.loc[df_pheno.index.values, age_col]
        if data_part == 'GSE40279':
            df_pheno.set_index("gsm", inplace=True)
            df_pheno['Status'] = 'Control'
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE87571':
            df_pheno['Status'] = 'Control'
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE179325':
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE217633':
            df_pheno.rename(columns={'Years.with.HIV': 'Years with HIV'}, inplace=True)
            cols_pheno = ['Age', 'Sex', 'Status', 'Years with HIV'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE118144':
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE42861':
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE106648':
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
            df.drop('GSM2844233', inplace=True)
        elif data_part == 'GSE67530':
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE71955':
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
        elif data_part == 'GSE77696':
            df_pheno.set_index("Index", inplace=True)
            cols_pheno = ['Age', 'Sex', 'Status'] + ages_all
            df = pd.merge(df_pheno.loc[:, cols_pheno], df_betas, left_index=True, right_index=True)
    df['GSE'] = data_part
    dfs.append(df)

df = pd.concat(dfs, verify_integrity=True)

for age_type in list(ages_pc.values()) + list(ages_calc.values()):
    df[f'{age_type} Acceleration'] = df[age_type] - df['Age']
    df[f'|{age_type} Acceleration|'] = df[f'{age_type} Acceleration'].abs()

ids_groups = {
    'GSEUNN\nControls': df.index[(df['GSE'] == 'GSEUNN') & (df['Status'] == 'Control')].values,
    'GSEUNN\nESRD': df.index[(df['GSE'] == 'GSEUNN') & (df['Status'] == 'ESRD')].values,
    'GSE87571': df.index[(df['GSE'] == 'GSE87571')].values,
    'GSE40279': df.index[(df['GSE'] == 'GSE40279')].values,
    'GSE179325\nControls': df.index[(df['GSE'] == 'GSE179325') & (df['Status'] == 'Control')].values,
    'GSE179325\nCOVID-19\nSevere': df.index[(df['GSE'] == 'GSE179325') & (df['Status'] == 'COVID-19 Severe')].values,
    'GSE179325\nCOVID-19\nMild': df.index[(df['GSE'] == 'GSE179325') & (df['Status'] == 'COVID-19 Mild')].values,
    'GSE217633\nControls': df.index[(df['GSE'] == 'GSE217633') & (df['Status'] == 'Control')].values,
    'GSE217633\nPre ART': df.index[(df['GSE'] == 'GSE217633') & (df['Status'] == 'Pre ART')].values,
    'GSE217633\nPost ART': df.index[(df['GSE'] == 'GSE217633') & (df['Status'] == 'Post ART')].values,
    'GSE118144\nControls': df.index[(df['GSE'] == 'GSE118144') & (df['Status'] == 'Control')].values,
    'GSE118144\nSLE': df.index[(df['GSE'] == 'GSE118144') & (df['Status'] == 'SLE')].values,
    'GSE42861\nControls': df.index[(df['GSE'] == 'GSE42861') & (df['Status'] == 'Control')].values,
    'GSE42861\nRheumatoid\nArthritis': df.index[(df['GSE'] == 'GSE42861') & (df['Status'] == 'Rheumatoid Arthritis')].values,
    'GSE106648\nControls': df.index[(df['GSE'] == 'GSE106648') & (df['Status'] == 'Control')].values,
    'GSE106648\nMultiple\nSclerosis': df.index[(df['GSE'] == 'GSE106648') & (df['Status'] == 'Multiple Sclerosis')].values,
    'GSE67530\nControls': df.index[(df['GSE'] == 'GSE67530') & (df['Status'] == 'Control')].values,
    'GSE67530\nICU': df.index[(df['GSE'] == 'GSE67530') & (df['Status'] == 'ICU')].values,
    'GSE67530\nICU ARDS': df.index[(df['GSE'] == 'GSE67530') & (df['Status'] == 'ICU ARDS')].values,
    'GSE77696\nControls': df.index[(df['GSE'] == 'GSE77696') & (df['Status'] == 'Control')].values,
    'GSE77696\nHIV': df.index[(df['GSE'] == 'GSE77696') & (df['Status'] == 'HIV')].values,
}
for group, ids in ids_groups.items():
    df.loc[ids, 'Group'] = group
    
ids_test_controls = df.index[(~df['GSE'].isin(['GSEUNN'])) & ((df['GSE'].isin(['GSE87571', 'GSE40279'])) | (df['Status'] == 'Control'))].values
df['Test\nControls'] = False
df.loc[ids_test_controls, 'Test\nControls'] = True

mw_pairs = [
    ('GSEUNN\nControls', 'GSEUNN\nESRD'),
    ('GSEUNN\nControls', 'GSE87571'),
    ('GSEUNN\nControls', 'GSE40279'),
    ('GSE87571', 'GSE40279'),
    ('GSEUNN\nControls', 'GSE179325\nControls'),
    ('GSE179325\nControls', 'GSE179325\nCOVID-19\nMild'),
    ('GSE179325\nCOVID-19\nMild', 'GSE179325\nCOVID-19\nSevere'),
    ('GSE179325\nControls', 'GSE179325\nCOVID-19\nSevere'),
    ('GSEUNN\nControls', 'GSE217633\nControls'),
    ('GSE217633\nControls', 'GSE217633\nPre ART'),
    ('GSE217633\nPre ART', 'GSE217633\nPost ART'),
    ('GSE217633\nControls', 'GSE217633\nPost ART'),
    ('GSEUNN\nControls', 'GSE118144\nControls'),
    ('GSE118144\nControls', 'GSE118144\nSLE'),
    ('GSEUNN\nControls', 'GSE42861\nControls'),
    ('GSE42861\nControls', 'GSE42861\nRheumatoid\nArthritis'),
    ('GSEUNN\nControls', 'GSE106648\nControls'),
    ('GSE106648\nControls', 'GSE106648\nMultiple\nSclerosis'),
    ('GSEUNN\nControls', 'GSE67530\nControls'),
    ('GSE67530\nControls', 'GSE67530\nICU'),
    ('GSE67530\nControls', 'GSE67530\nICU ARDS'),
    ('GSE67530\nICU', 'GSE67530\nICU ARDS'),
    ('GSEUNN\nControls', 'GSE77696\nControls'),
    ('GSE77696\nControls', 'GSE77696\nHIV'),
]

first_columns = [
    "MAE\nTest\nControls",
    "<AA>\nTest\nControls",
    "MAE\nGSEUNN\nControls",
    "<AA>\nGSEUNN\nControls",
    "MAE\nGSE87571",
    "<AA>\nGSE87571",
    "MAE\nGSE40279",
    "<AA>\nGSE40279",
    "MAE\nGSE179325\nControls",
    "<AA>\nGSE179325\nControls",
    "MAE\nGSE217633\nControls",
    "<AA>\nGSE217633\nControls",
    "MAE\nGSE118144\nControls",
    "<AA>\nGSE118144\nControls",
    "MAE\nGSE42861\nControls",
    "<AA>\nGSE42861\nControls",
    "MAE\nGSE106648\nControls",
    "<AA>\nGSE106648\nControls",
    "MAE\nGSE67530\nControls",
    "<AA>\nGSE67530\nControls",
    "MAE\nGSE77696\nControls",
    "<AA>\nGSE77696\nControls",
] + [f"Mann-Whitney\n{pair[0]}\nVS\n{pair[1]}" for pair in mw_pairs]


# Check models' combinations

## Define objective function for optuna

In [None]:
results_combos = []

def objective(trial: optuna.Trial):
    params = {}
    res_dict = {}
    for f in feats_simage:
        params[f] = trial.suggest_categorical(f, models_paths[f])
        res_dict[f] = params[f]
    
    for f in feats_simage:
        df[f"{f}_log"] = models[f][params[f]].predict(df.loc[:, feats_epi_dict[f]])
        df[f] = np.exp(df[f"{f}_log"])
    
    df['EpiSImAge'] = model_simage(torch.from_numpy(df.loc[:, feats_simage].values)).cpu().detach().numpy().ravel()
    df['EpiSImAge Acceleration'] = df['EpiSImAge'] - df['Age']
    df['|EpiSImAge Acceleration|'] = df['EpiSImAge Acceleration'].abs()
    
    pred = torch.from_numpy(np.float32(df.loc[ids_test_controls, 'EpiSImAge'].values))
    real = torch.from_numpy(np.float32(df.loc[ids_test_controls, 'Age'].values))
    res_dict[f"MAE\nTest\nControls"] = mean_absolute_error(pred, real).item()
    res_dict[f"Pearson\nTest\nControls"] = pearson_corrcoef(pred, real).item()
    res_dict[f"<AA>\nTest\nControls"] = np.mean(df.loc[ids_test_controls, 'EpiSImAge Acceleration'].values)
    
    for group, ids in ids_groups.items():
        pred = torch.from_numpy(np.float32(df.loc[ids, 'EpiSImAge'].values))
        real = torch.from_numpy(np.float32(df.loc[ids, 'Age'].values))
        res_dict[f"MAE\n{group}"] = mean_absolute_error(pred, real).item()
        res_dict[f"Pearson\n{group}"] = pearson_corrcoef(pred, real).item()
        res_dict[f"<AA>\n{group}"] = np.mean(df.loc[ids, 'EpiSImAge Acceleration'].values)
    
    for pair in mw_pairs:
        stat, pval = mannwhitneyu(
            df.loc[ids_groups[pair[0]], 'EpiSImAge Acceleration'].values,
            df.loc[ids_groups[pair[1]], 'EpiSImAge Acceleration'].values,
            alternative='two-sided'
        )
        res_dict[f"Mann-Whitney\n{pair[0]}\nVS\n{pair[1]}"] = pval
    
    results_combos.append(res_dict)
    
    min_0 = res_dict["MAE\nTest\nControls"]
    min_1 = res_dict["MAE\nGSE87571"] + res_dict["MAE\nGSE40279"]
    mw_pair = ('GSEUNN\nControls', 'GSE87571')
    min_2 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    mw_pair = ('GSEUNN\nControls', 'GSE40279')
    min_3 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    
    max_0 = res_dict["Pearson\nTest\nControls"]
    mw_pair = ('GSE179325\nControls', 'GSE179325\nCOVID-19\nSevere')
    max_1 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    mw_pair = ('GSE217633\nControls', 'GSE217633\nPre ART')
    max_2 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    mw_pair = ('GSE118144\nControls', 'GSE118144\nSLE')
    max_3 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    mw_pair = ('GSE42861\nControls', 'GSE42861\nRheumatoid\nArthritis')
    max_4 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    mw_pair = ('GSE106648\nControls', 'GSE106648\nMultiple\nSclerosis')
    max_5 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    mw_pair = ('GSE67530\nControls', 'GSE67530\nICU ARDS')
    max_6 = -np.log10(res_dict[f"Mann-Whitney\n{mw_pair[0]}\nVS\n{mw_pair[1]}"])
    
    return min_0, min_2, min_3, max_0, max_1, max_2, max_3, max_4, max_5, max_6

## Run optimization

In [None]:
results_combos = []

n_trials = 1000
opt_seed = 1337
n_startup_trials = 10
n_ei_candidates = 24

study = optuna.create_study(
    study_name='optimization',
    directions=["minimize"]*3 + ["maximize"]*7
)
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

In [None]:
df_combos = pd.DataFrame(results_combos)
df_combos = df_combos[first_columns + [col for col in df_combos.columns if col not in first_columns]]
df_combos.to_excel(f"{path_imm_models}/EpiSImAge_combos__.xlsx")

# Best combination plots

## Setup best combination

In [None]:
combo_file_id = 1
df_combos = pd.read_excel(f"{path_imm_models}/EpiSImAge_combos_{combo_file_id}.xlsx", index_col=0)

best_combo_id = 4874

path_save = f"{path_imm_models}/results_{combo_file_id}_{best_combo_id}"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

best_models_path = {}
best_models = {}
for f in feats_simage:
    best_models_path[f] = df_combos.at[best_combo_id, f]
    best_models[f] = models[f][best_models_path[f]]

print(best_models_path)

for f in feats_simage:
    df[f"{f}_log"] = best_models[f].predict(df.loc[:, feats_epi_dict[f]])
    df[f] = np.exp(df[f"{f}_log"])

df['EpiSImAge'] = model_simage(torch.from_numpy(df.loc[:, feats_simage].values)).cpu().detach().numpy().ravel()
df['EpiSImAge Acceleration'] = df['EpiSImAge'] - df['Age']
df['|EpiSImAge Acceleration|'] = df['EpiSImAge Acceleration'].abs()

df.to_excel(f"{path_save}/df.xlsx")

df_test_controls = pd.DataFrame(index=["MAE", "Pearson", "<AA>"], columns=["Value"])

pred = torch.from_numpy(np.float32(df.loc[ids_test_controls, 'EpiSImAge'].values))
real = torch.from_numpy(np.float32(df.loc[ids_test_controls, 'Age'].values))
df_test_controls.at["MAE", "Value"] = mean_absolute_error(pred, real).item()
df_test_controls.at["Pearson", "Value"] = pearson_corrcoef(pred, real).item()
df_test_controls.at["<AA>", "Value"] = np.mean(df.loc[ids_test_controls, 'EpiSImAge Acceleration'].values)
df_test_controls.to_excel(f"{path_save}/test_controls.xlsx", index_label="Metrics")

## Setup colors

In [None]:
colors_groups = {
    'GSEUNN\nControls': 'crimson',
    'GSEUNN\nESRD': 'dodgerblue',
    'GSE87571': 'limegreen',
    'GSE40279': 'gold',
    'GSE179325\nControls': 'olive',
    'GSE179325\nCOVID-19\nMild': 'burlywood',
    'GSE179325\nCOVID-19\nSevere': 'darkorchid',
    'GSE217633\nControls': 'slateblue',
    'GSE217633\nPre ART': 'yellow',
    'GSE217633\nPost ART': 'orange',
    'GSE118144\nControls': 'aqua',
    'GSE118144\nSLE': 'darkkhaki',
    'GSE42861\nControls': 'coral',
    'GSE42861\nRheumatoid\nArthritis': 'brown',
    'GSE106648\nControls': 'palegreen',
    'GSE106648\nMultiple\nSclerosis': 'lightsteelblue',
    'GSE67530\nControls': 'maroon',
    'GSE67530\nICU': 'lime',
    'GSE67530\nICU ARDS': 'orchid',
    'GSE77696\nControls': 'teal',
    'GSE77696\nHIV': 'orangered'
}

colors_gses_bckg = {
    'GSEUNN': 'aliceblue',
    'GSE87571': 'whitesmoke',
    'GSE40279': 'cornsilk',
    'GSE179325': 'ghostwhite',
    'GSE217633': 'lavenderblush',
    'GSE118144': 'linen',
    'GSE42861': 'mintcream',
    'GSE106648': 'oldlace',
    'GSE67530': 'snow',
    'GSE77696': 'ivory'
}

colors_groups_bckg = {
    'GSEUNN\nControls': colors_gses_bckg['GSEUNN'],
    'GSEUNN\nESRD': colors_gses_bckg['GSEUNN'],
    'GSE87571': colors_gses_bckg['GSE87571'],
    'GSE40279': colors_gses_bckg['GSE40279'],
    'GSE179325\nControls': colors_gses_bckg['GSE179325'],
    'GSE179325\nCOVID-19\nMild': colors_gses_bckg['GSE179325'],
    'GSE179325\nCOVID-19\nSevere': colors_gses_bckg['GSE179325'],
    'GSE217633\nControls': colors_gses_bckg['GSE217633'],
    'GSE217633\nPre ART': colors_gses_bckg['GSE217633'],
    'GSE217633\nPost ART': colors_gses_bckg['GSE217633'],
    'GSE118144\nControls': colors_gses_bckg['GSE118144'],
    'GSE118144\nSLE': colors_gses_bckg['GSE118144'],
    'GSE42861\nControls': colors_gses_bckg['GSE42861'],
    'GSE42861\nRheumatoid\nArthritis': colors_gses_bckg['GSE42861'],
    'GSE106648\nControls': colors_gses_bckg['GSE106648'],
    'GSE106648\nMultiple\nSclerosis': colors_gses_bckg['GSE106648'],
    'GSE67530\nControls': colors_gses_bckg['GSE67530'],
    'GSE67530\nICU': colors_gses_bckg['GSE67530'],
    'GSE67530\nICU ARDS': colors_gses_bckg['GSE67530'],
    'GSE77696\nControls': colors_gses_bckg['GSE77696'],
    'GSE77696\nHIV': colors_gses_bckg['GSE77696'],
}

## Hists of age distributions in different datasets

In [None]:
n_rows = 5
n_cols = 2
fig_height = 18
fig_width = 10
hist_bins = np.linspace(5, 115, 23)
sns.set_theme(style='whitegrid')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False)
for gse_id, gse in tqdm(enumerate(colors_gses_bckg.keys())):
    row_id, col_id = divmod(gse_id, n_cols)
    histplot = sns.histplot(
        data=df.loc[df['GSE'] == gse, :],
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        hue='Group',
        palette=colors_groups,
        ax=axs[row_id, col_id]
    )
    axs[row_id, col_id].set(xlim=(0, 120))
    axs[row_id, col_id].set_facecolor(colors_gses_bckg[gse])
fig.tight_layout()    
fig.savefig(f"{path_save}/hist_age.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/hist_age.pdf", bbox_inches='tight')
plt.close(fig)

## EpiSImAge

### Scatters

In [None]:
n_rows = 3
n_cols = 7
fig_height = 15
fig_width = 25

sns.set_theme(style='whitegrid')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=True, sharex=False)

for group_id, group in tqdm(enumerate(colors_groups.keys())):
    row_id, col_id = divmod(group_id, n_cols)
    scatter = sns.scatterplot(
        data=df.loc[ids_groups[group]],
        x="Age",
        y="EpiSImAge",
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=40,
        color=colors_groups[group],
        ax=axs[row_id, col_id],
    )
    mae = mean_absolute_error(
        torch.from_numpy(np.float32(df.loc[ids_groups[group], "Age"].values)),
        torch.from_numpy(np.float32(df.loc[ids_groups[group], "EpiSImAge"].values)),
    ).item()
    rho = pearson_corrcoef(
        torch.from_numpy(np.float32(df.loc[ids_groups[group], "Age"].values)),
        torch.from_numpy(np.float32(df.loc[ids_groups[group], "EpiSImAge"].values)),
    ).item()
    maa = np.mean(df.loc[ids_groups[group], f"EpiSImAge Acceleration"].values)
    title = (f"{group}" + "\n" + 
             fr"MAE: {mae:0.2f}" + "\n" + 
             fr"Pearson $\rho$: {rho:0.2f}" + "\n" +
             fr"$\langle$Acceleration$\rangle$: {maa:0.2f}") 
    axs[row_id, col_id].set_title(title)
    bisect = sns.lineplot(
        x=[0, 120],
        y=[0, 120],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=axs[row_id, col_id]
    )
    axs[row_id, col_id].set_xlim(0, 120)
    axs[row_id, col_id].set_ylim(0, 120)
    axs[row_id, col_id].set_facecolor(colors_groups_bckg[group])
    if row_id != n_rows - 1:
        axs[row_id, col_id].set_xticklabels([])
        axs[row_id, col_id].set_xlabel('')
# axs[n_rows - 1, n_cols - 1].axis('off')
fig.tight_layout()    
fig.savefig(f"{path_save}/scatters_EpiSImAge.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/scatters_EpiSImAge.pdf", bbox_inches='tight')
plt.close(fig)

### Violins

In [None]:
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(35, 6))

for group_id, group in enumerate(colors_groups):
    ax.axvspan(group_id - 0.5, group_id + 0.5, facecolor=colors_groups_bckg[group], alpha=1.0, zorder=0, lw=0)

q01 = df['EpiSImAge Acceleration'].quantile(0.01)
q99 = df['EpiSImAge Acceleration'].quantile(0.99)

sns.violinplot(
    data=df.loc[(df['EpiSImAge Acceleration'] > q01) & (df['EpiSImAge Acceleration'] < q99), :],
    x='Group',
    y='EpiSImAge Acceleration',
    palette=colors_groups,
    scale='width',
    order=list(colors_groups.keys()),
    saturation=0.75,
    ax=ax
)
ax.set_xlabel(f"")
pvals = []
for pair in mw_pairs:
    stat, pval = mannwhitneyu(
        df.loc[ids_groups[pair[0]], "EpiSImAge Acceleration"].values,
        df.loc[ids_groups[pair[1]], "EpiSImAge Acceleration"].values,
        alternative='two-sided'
    )
    pvals.append(pval)
pvals_formatted = [f'{pval:.1e}' for pval in pvals]
# annotator = Annotator(
#     ax,
#     pairs=mw_pairs,
#     data=df,
#     x='Group',
#     y='EpiSImAge Acceleration',
#     order=list(colors_groups),
# )
# annotator.set_custom_annotations(pvals_formatted)
# annotator.configure(loc='outside')
# annotator.annotate()
fig.savefig(f"{path_save}/violins_EpiSImAge.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/violins_EpiSImAge.pdf", bbox_inches='tight')
plt.close(fig)

## All epigenetic ages

In [None]:
ages = [
    'EpiSImAge',
    'Hannum',
    'PC-Hannum',
    'Horvath',
    'PC-Horvath',
    'SkinBloodAge',
    'PC-SkinBloodAge',
    'PhenoAge',
    'PC-PhenoAge',
    'GrimAge',
    'PC-GrimAge',
    'GrimAge2',
]

### KDEs

In [None]:
n_rows = len(ages)
n_cols = len(ids_groups)
fig_height = 32
fig_width = 42

sns.set_theme(style='whitegrid')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=True, sharex=True)

for group_id, group in tqdm(enumerate(colors_groups.keys())):
    for age_id, age_type in enumerate(ages):
        kdeplot = sns.kdeplot(
            data=df.loc[ids_groups[group]],
            x='Age',
            y=age_type,
            fill=True,
            cbar=False,
            color=colors_groups[group],
            cut=0,
            legend=False,
            ax=axs[age_id, group_id]
        )
        regplot = sns.regplot(
            data=df.loc[ids_groups[group]],
            x='Age',
            y=age_type,
            scatter=False,
            color='black',
            truncate=True,
            ax=axs[age_id, group_id]
        )
        bisect = sns.lineplot(
            x=[0, 120],
            y=[0, 120],
            linestyle='--',
            color='black',
            linewidth=1.0,
            ax=axs[age_id, group_id]
        )
        axs[age_id, group_id].set_xlim(0, 120)
        axs[age_id, group_id].set_ylim(0, 120)
        axs[age_id, group_id].set_facecolor(colors_groups_bckg[group])

        mae = mean_absolute_error(
            torch.from_numpy(np.float32(df.loc[ids_groups[group], "Age"].values)),
            torch.from_numpy(np.float32(df.loc[ids_groups[group], age_type].values)),
        ).item()
        rho = pearson_corrcoef(
            torch.from_numpy(np.float32(df.loc[ids_groups[group], "Age"].values)),
            torch.from_numpy(np.float32(df.loc[ids_groups[group], age_type].values)),
        ).item()
        maa = np.mean(df.loc[ids_groups[group], f"{age_type} Acceleration"].values)
        title = (fr"MAE: {mae:0.2f}" + "\n" +
                 fr"Pearson $\rho$: {rho:0.2f}" + "\n" +
                 fr"$\langle$Acceleration$\rangle$: {maa:0.2f}")    
        if age_id == 0:
            title = f"{group}" + "\n" + title
        axs[age_id, group_id].set_title(title)
fig.tight_layout()    
fig.savefig(f"{path_save}/kdes_ages.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/kdes_ages.pdf", bbox_inches='tight')
plt.close(fig)

### Controls

In [None]:
controls_groups = {
    'GSEUNN\nControls': 'Train',
    'GSE87571': 'GSE87571',
    'GSE40279': 'GSE40279',
    'GSE179325\nControls': 'GSE179325',
    'GSE217633\nControls': 'GSE217633',
    'GSE118144\nControls': 'GSE118144',
    'GSE42861\nControls': 'GSE42861',
    'GSE106648\nControls': 'GSE106648',
    'GSE67530\nControls': 'GSE67530',
    'GSE77696\nControls': 'GSE77696'
}
controls_groups_inv = {v: k for k, v in controls_groups.items()}
pathlib.Path(f"{path_save}/controls").mkdir(parents=True, exist_ok=True)

#### Metrics



In [None]:
df_mae = pd.DataFrame(index=ages, columns=list(controls_groups.values()))
df_rho = pd.DataFrame(index=ages, columns=list(controls_groups.values()))
df_maa = pd.DataFrame(index=ages, columns=list(controls_groups.values()))
for age_type in ages:
    for group_old, group_new in controls_groups.items():
        df_mae.at[age_type, group_new] = mean_absolute_error(
            torch.from_numpy(np.float32(df.loc[ids_groups[group_old], "Age"].values)),
            torch.from_numpy(np.float32(df.loc[ids_groups[group_old], age_type].values)),
        ).item()
        df_rho.at[age_type, group_new] = pearson_corrcoef(
            torch.from_numpy(np.float32(df.loc[ids_groups[group_old], "Age"].values)),
            torch.from_numpy(np.float32(df.loc[ids_groups[group_old], age_type].values)),
        ).item()
        df_maa.at[age_type, group_new] = np.mean(df.loc[ids_groups[group_old], f"{age_type} Acceleration"].values)
df_mae.to_excel(f"{path_save}/controls/mae.xlsx")
df_rho.to_excel(f"{path_save}/controls/rho.xlsx")
df_maa.to_excel(f"{path_save}/controls/maa.xlsx")

In [None]:
df_fig = df_mae[df_mae.columns].astype(float)
sns.set_theme(style='whitegrid')
clustermap = sns.clustermap(
    df_fig,
    annot=True,
    col_cluster=True,
    row_cluster=True,
    fmt=".1f",
    cmap='flare',
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    figsize=(8, 8),
    cbar_kws={'orientation': 'horizontal'}
)
clustermap.ax_heatmap.set_xlabel('')
clustermap.ax_heatmap.set_ylabel('')
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
for tick_label in clustermap.ax_heatmap.get_xticklabels():
    tick_label.set_color(colors_groups[controls_groups_inv[tick_label.get_text()]])
    if tick_label.get_text() == 'Train':
        tick_label.set_fontweight('extra bold')
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    if tick_label.get_text() == 'EpiSImAge':
        tick_label.set_fontweight('extra bold')
x0, _y0, _w, _h = clustermap.cbar_pos
clustermap_pos = clustermap.ax_heatmap.get_position()
col_dendrogram_pos = clustermap.ax_col_dendrogram.get_position()
clustermap.ax_cbar.set_position([clustermap_pos.x0, col_dendrogram_pos.y1 + 0.05, clustermap_pos.width, 0.03])
clustermap.ax_cbar.set_title("MAE")
clustermap.ax_cbar.tick_params()
for spine in clustermap.ax_cbar.spines:
    clustermap.ax_cbar.spines[spine].set_linewidth(1)
plt.savefig(f"{path_save}/controls/clustermap_mae.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/controls/clustermap_mae.pdf", bbox_inches='tight')
plt.close(clustermap.fig)

In [None]:
df_fig = df_rho[df_rho.columns].astype(float)
sns.set_theme(style='whitegrid')
clustermap = sns.clustermap(
    df_fig,
    annot=True,
    col_cluster=True,
    row_cluster=True,
    fmt=".2f",
    cmap='hot',
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    figsize=(8, 8),
    cbar_kws={'orientation': 'horizontal'}
)
clustermap.ax_heatmap.set_xlabel('')
clustermap.ax_heatmap.set_ylabel('')
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
for tick_label in clustermap.ax_heatmap.get_xticklabels():
    tick_label.set_color(colors_groups[controls_groups_inv[tick_label.get_text()]])
    if tick_label.get_text() == 'Train':
        tick_label.set_fontweight('extra bold')
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    if tick_label.get_text() == 'EpiSImAge':
        tick_label.set_fontweight('extra bold')
x0, _y0, _w, _h = clustermap.cbar_pos
clustermap_pos = clustermap.ax_heatmap.get_position()
col_dendrogram_pos = clustermap.ax_col_dendrogram.get_position()
clustermap.ax_cbar.set_position([clustermap_pos.x0, col_dendrogram_pos.y1 + 0.05, clustermap_pos.width, 0.03])
clustermap.ax_cbar.set_title(r"Pearson $\rho$")
clustermap.ax_cbar.tick_params()
for spine in clustermap.ax_cbar.spines:
    clustermap.ax_cbar.spines[spine].set_linewidth(1)
plt.savefig(f"{path_save}/controls/clustermap_rho.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/controls/clustermap_rho.pdf", bbox_inches='tight')
plt.close(clustermap.fig)

In [None]:
df_fig = df_maa[df_maa.columns].astype(float)
sns.set_theme(style='whitegrid')
clustermap = sns.clustermap(
    df_fig,
    annot=True,
    col_cluster=True,
    row_cluster=True,
    fmt=".1f",
    center=0.0,
    cmap='seismic',
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    figsize=(8, 8),
    cbar_kws={'orientation': 'horizontal'}
)
clustermap.ax_heatmap.set_xlabel('')
clustermap.ax_heatmap.set_ylabel('')
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
for tick_label in clustermap.ax_heatmap.get_xticklabels():
    tick_label.set_color(colors_groups[controls_groups_inv[tick_label.get_text()]])
    if tick_label.get_text() == 'Train':
        tick_label.set_fontweight('extra bold')
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    if tick_label.get_text() == 'EpiSImAge':
        tick_label.set_fontweight('extra bold')
x0, _y0, _w, _h = clustermap.cbar_pos
clustermap_pos = clustermap.ax_heatmap.get_position()
col_dendrogram_pos = clustermap.ax_col_dendrogram.get_position()
clustermap.ax_cbar.set_position([clustermap_pos.x0, col_dendrogram_pos.y1 + 0.05, clustermap_pos.width, 0.03])
clustermap.ax_cbar.set_title(r"$\langle$ Acceleration $\rangle$")
clustermap.ax_cbar.tick_params()
for spine in clustermap.ax_cbar.spines:
    clustermap.ax_cbar.spines[spine].set_linewidth(1)
plt.savefig(f"{path_save}/controls/clustermap_maa.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/controls/clustermap_maa.pdf", bbox_inches='tight')
plt.close(clustermap.fig)

#### Mann-Whitney and Mean Age Acceleration

In [None]:
for age_type in ages:
    df_controls_stat = pd.DataFrame(
        index=list(controls_groups.values()), 
        columns=list(controls_groups.values()), 
        data=np.zeros(shape=(len(controls_groups), len(controls_groups))))
    for group_id_1, (group_1_old, group_1_new) in enumerate(controls_groups.items()):
        vals_1 = df.loc[ids_groups[group_1_old], f"{age_type} Acceleration"].values
        for group_id_2 in range(group_id_1, len(controls_groups)):
            group_2_old = list(controls_groups.keys())[group_id_2]
            group_2_new = controls_groups[group_2_old]
            vals_2 = df.loc[ids_groups[group_2_old], f"{age_type} Acceleration"].values
            if group_id_1 != group_id_2:
                _, pval = mannwhitneyu(vals_1, vals_2, alternative='two-sided')
                diff = np.mean(vals_2) - np.mean(vals_1)
                df_controls_stat.at[group_1_new, group_2_new] = diff
                df_controls_stat.at[group_2_new, group_1_new] = pval
            else:
                df_controls_stat.at[group_1_new, group_2_new]  = np.nan
    selection = np.tri(len(controls_groups), len(controls_groups), -1, dtype=bool)
    df_fdr = df_controls_stat.where(selection).stack().reset_index()
    df_fdr.columns = ['row', 'col', 'pval']
    _, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
    nzmin = df_fdr['pval_fdr_bh'][df_fdr['pval_fdr_bh'].gt(0)].min(0) * 0.5
    df_fdr['pval_fdr_bh'].replace({0.0: nzmin}, inplace=True)
    df_fdr['pval_fdr_bh_log'] = -np.log10(df_fdr.loc[:, 'pval_fdr_bh'].values)
    for line_id in range(df_fdr.shape[0]):
        df_controls_stat.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = df_fdr.at[line_id, 'pval_fdr_bh_log']
    df_controls_stat.to_excel(f"{path_save}/controls/{age_type}.xlsx", index_label="Tissue")
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(8, 5))
    cmap_triu = plt.get_cmap("seismic").copy()
    heatmap_diff = sns.heatmap(
        df_controls_stat,
        mask=np.tri(len(controls_groups), len(controls_groups), -1, dtype=bool),
        annot=True,
        fmt=".2f",
        center=0.0,
        cmap=cmap_triu,
        linewidth=0.1,
        linecolor='black',
        annot_kws={"size": 10},
        ax=ax
    )
    ax.figure.axes[-1].set_ylabel(r"$\langle$ Acceleration $\rangle$ Difference", size=13)
    for spine in ax.figure.axes[-1].spines.values():
        spine.set(visible=True, lw=0.25, edgecolor="black")
    cmap_tril = plt.get_cmap("cool").copy()
    cmap_tril.set_under('black')
    heatmap_pval = sns.heatmap(
        df_controls_stat,
        mask=np.tri(len(controls_groups), len(controls_groups), -1, dtype=bool).T,
        annot=True,
        fmt=".1f",
        vmin=-np.log10(0.05),
        cmap=cmap_tril,
        linewidth=0.1,
        linecolor='black',
        annot_kws={"size": 10},
        ax=ax
    )
    ax.figure.axes[-1].set_ylabel(r"$-\log_{10}(\mathrm{p-value})$", size=13)
    for spine in ax.figure.axes[-1].spines.values():
        spine.set(visible=True, lw=0.25, edgecolor="black")
    ax.set_xlabel('', fontsize=16)
    ax.set_ylabel('', fontsize=16)
    ax.set_title(age_type, fontsize=16)
    ax.set_xticklabels(ax.get_xticklabels(), path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
    for tick_label in ax.get_xticklabels():
        tick_label.set_color(colors_groups[controls_groups_inv[tick_label.get_text()]])
        if tick_label.get_text() == 'Train':
            tick_label.set_fontweight('extra bold')
    ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
    for tick_label in ax.get_yticklabels():
        tick_label.set_color(colors_groups[controls_groups_inv[tick_label.get_text()]])
        if tick_label.get_text() == 'Train':
            tick_label.set_fontweight('extra bold')
    plt.savefig(f"{path_save}/controls/{age_type}.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/controls/{age_type}.pdf", bbox_inches='tight')
    plt.close(fig)
    

### Cases

In [None]:
cases_pairs = {
    'ESRD (Train)': ('GSEUNN\nControls', 'GSEUNN\nESRD'),
    'COVID-19 Severe (GSE179325)': ('GSE179325\nControls', 'GSE179325\nCOVID-19\nSevere'),
    'HIV, no ART (GSE217633)': ('GSE217633\nControls', 'GSE217633\nPre ART'),
    'SLE (GSE118144)': ('GSE118144\nControls', 'GSE118144\nSLE'),
    'Rheumatoid Arthritis (GSE42861)': ('GSE42861\nControls', 'GSE42861\nRheumatoid\nArthritis'),
    'Multiple Sclerosis (GSE106648)': ('GSE106648\nControls', 'GSE106648\nMultiple\nSclerosis'),
    'ICU Respiratory (GSE67530)': ('GSE67530\nControls', 'GSE67530\nICU', 'GSE67530\nICU ARDS'),
    'HIV (GSE77696)': ('GSE77696\nControls', 'GSE77696\nHIV'),
}
pathlib.Path(f"{path_save}/controls").mkdir(parents=True, exist_ok=True)

#### Mann-Whitney and Mean Age Acceleration Difference

In [None]:
df_mw_cases = pd.DataFrame(index=ages, columns=list(cases_pairs.keys()))
df_maa_diff_cases = pd.DataFrame(index=ages, columns=list(cases_pairs.keys()))
for age_type in ages:
    for case, pair in cases_pairs.items():
        if case == 'ICU Respiratory (GSE67530)':
            stat, pval = mannwhitneyu(
                df.loc[ids_groups[pair[0]], f"{age_type} Acceleration"].values,
                df.loc[list(set.union(set(ids_groups[pair[1]]), set(ids_groups[pair[2]]))), f"{age_type} Acceleration"].values,
                alternative='two-sided'
            )
            maa_ctrl = np.mean(df.loc[ids_groups[pair[0]], f"{age_type} Acceleration"].values)
            maa_case = np.mean(df.loc[list(set.union(set(ids_groups[pair[1]]), set(ids_groups[pair[2]]))), f"{age_type} Acceleration"].values)
        else:
            stat, pval = mannwhitneyu(
                df.loc[ids_groups[pair[0]], f"{age_type} Acceleration"].values,
                df.loc[ids_groups[pair[1]], f"{age_type} Acceleration"].values,
                alternative='two-sided'
            )
            maa_ctrl = np.mean(df.loc[ids_groups[pair[0]], f"{age_type} Acceleration"].values)
            maa_case = np.mean(df.loc[ids_groups[pair[1]], f"{age_type} Acceleration"].values)
           
        df_mw_cases.at[age_type, case] = pval
        df_maa_diff_cases.at[age_type, case] = maa_case - maa_ctrl
df_mw_cases.to_excel(f"{path_save}/cases_mw.xlsx")
df_maa_diff_cases.to_excel(f"{path_save}/cases_maa_diff.xlsx")

In [None]:
df_fig = df_mw_cases[df_mw_cases.columns].astype(float)
for col in df_fig.columns:
    df_fig[col] = -np.log10(df_fig[col])
sns.set_theme(style='whitegrid')
cmap = plt.get_cmap("cool").copy()
cmap.set_under('black')
clustermap = sns.clustermap(
    df_fig,
    annot=True,
    col_cluster=True,
    row_cluster=True,
    fmt=".1f",
    cmap=cmap,
    vmin=-np.log10(0.05),
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    figsize=(8, 8),
    cbar_kws={'orientation': 'horizontal'}
)
clustermap.ax_heatmap.set_xlabel('')
clustermap.ax_heatmap.set_ylabel('')
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
for tick_label in clustermap.ax_heatmap.get_xticklabels():
    if tick_label.get_text() == 'ESRD (Train)':
        tick_label.set_fontweight('extra bold')
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    if tick_label.get_text() == 'EpiSImAge':
        tick_label.set_fontweight('extra bold')
x0, _y0, _w, _h = clustermap.cbar_pos
clustermap_pos = clustermap.ax_heatmap.get_position()
col_dendrogram_pos = clustermap.ax_col_dendrogram.get_position()
clustermap.ax_cbar.set_position([clustermap_pos.x0, col_dendrogram_pos.y1 + 0.05, clustermap_pos.width, 0.03])
clustermap.ax_cbar.set_title(r"$-\log_{10}$(Mann-Whitney U test p-value)")
clustermap.ax_cbar.tick_params()
for spine in clustermap.ax_cbar.spines:
    clustermap.ax_cbar.spines[spine].set_linewidth(1)
plt.savefig(f"{path_save}/cases_clustermap_mw.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/cases_clustermap_mw.pdf", bbox_inches='tight')
plt.close(clustermap.fig)

In [None]:
df_fig = df_maa_diff_cases[df_maa_diff_cases.columns].astype(float)
sns.set_theme(style='whitegrid')
clustermap = sns.clustermap(
    df_fig,
    annot=True,
    col_cluster=True,
    row_cluster=True,
    fmt=".1f",
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    figsize=(8, 8),
    cbar_kws={'orientation': 'horizontal'}
)
clustermap.ax_heatmap.set_xlabel('')
clustermap.ax_heatmap.set_ylabel('')
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
for tick_label in clustermap.ax_heatmap.get_xticklabels():
    if tick_label.get_text() == 'ESRD (Train)':
        tick_label.set_fontweight('extra bold')
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    if tick_label.get_text() == 'EpiSImAge':
        tick_label.set_fontweight('extra bold')
x0, _y0, _w, _h = clustermap.cbar_pos
clustermap_pos = clustermap.ax_heatmap.get_position()
col_dendrogram_pos = clustermap.ax_col_dendrogram.get_position()
clustermap.ax_cbar.set_position([clustermap_pos.x0, col_dendrogram_pos.y1 + 0.05, clustermap_pos.width, 0.03])
clustermap.ax_cbar.set_title(r"$\langle$ Acceleration $\rangle$ Difference")
clustermap.ax_cbar.tick_params()
for spine in clustermap.ax_cbar.spines:
    clustermap.ax_cbar.spines[spine].set_linewidth(1)
plt.savefig(f"{path_save}/cases_clustermap_maa_diff.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/cases_clustermap_maa_diff.pdf", bbox_inches='tight')
plt.close(clustermap.fig)