# Debugging autoreload

In [1]:
%load_ext autoreload
%autoreload 2

# Load packages

In [35]:
import pickle
import torch
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from sklearn.model_selection import RepeatedStratifiedKFold
import numpy as np
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
import pandas as pd
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
from statannotations.Annotator import Annotator
import optuna
import pathlib
from tqdm import tqdm
import os
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import distinctipy
import matplotlib.colors as mcolors
import matplotlib.cm
import matplotlib.patheffects as pe
from statsmodels.stats.multitest import multipletests
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import mannwhitneyu
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap
import yaml
import ast
import copy

# Immunomarkers plots

## Load data

In [2]:
path_root = pathlib.Path(os.getcwd())
path_plots = f"{path_root}/plots"
path_data = f"{path_root}/data/immuno-regression"
df_feats = pd.read_excel(f"{path_data}/features.xlsx")
imms = df_feats.columns.to_list()
df = pd.read_excel(f"{path_data}/data.xlsx")

## Correlation heatmap

In [3]:
df_epi_imm_corr = pd.DataFrame(index=imms, columns=list(range(1, 101)))
for imm in imms:
    corrs = []
    for cpg in df_feats[imm]:
        res = stats.pearsonr(df[imm], df[cpg], alternative='two-sided')
        corrs.append(abs(res.statistic))
    df_epi_imm_corr.loc[imm, :] = sorted(corrs, reverse=True)
    
df_fig = df_epi_imm_corr.astype(float)
sns.set_theme(style='ticks', font_scale=1.0)
fig, ax = plt.subplots(figsize=(30, 10))
heatmap = sns.heatmap(
    df_fig,
    annot=False,
    cmap='hot',
    linewidth=0.1,
    linecolor='black',
    cbar_kws={
        'orientation': 'horizontal',
        'location': 'top',
        'fraction': 0.05,
        'pad': 0.025,
        'aspect': 60
    },
    annot_kws={"size": 12},
    ax=ax
)
ax.set_xlabel('Top CpGs')
ax.set_ylabel('')
heatmap_pos = heatmap.get_position()
ax.figure.axes[-1].set_title(fr"|Pearson $\rho$|", fontsize='large')
ax.figure.axes[-1].tick_params(labelsize='large')
for spine in ax.figure.axes[-1].spines.values():
    spine.set_linewidth(1)
plt.savefig(f"{path_plots}/imms_cpgs_correlation.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_plots}/imms_cpgs_correlation.pdf", bbox_inches='tight')
plt.close(fig)

## Models results

In [None]:
n_rows = 4 * 3
n_cols = 8
fig_height = 20
fig_width = 35

imm_colors = distinctipy.get_colors(n_colors=len(imms), exclude_colors=[mcolors.hex2color(mcolors.CSS4_COLORS['gray'])], rng=42)

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.8, 0.2]*4, gridspec_kw={'wspace':0.35, 'hspace': 0.05}, sharey=False, sharex=False)

for imm_id, imm in enumerate(imms):
    imm_color = imm_colors[imm_id]
    imm_metrics = pd.read_excel(f"{path_root}/models/Immunomarkers/{imm}/metrics.xlsx", index_col=0)
    with open(f"{path_root}/models/Immunomarkers/{imm}/config.yml") as f:
        imm_config = yaml.safe_load(f)
    imm_df = pd.read_excel(f"{path_root}/models/Immunomarkers/{imm}/df.xlsx", index_col=0)
    imm_df.rename(columns={f"{imm}_log": imm}, inplace=True)
    
    row_id, col_id = divmod(imm_id, n_cols)
    row_id_table = row_id * 3
    row_id_scatter = row_id * 3 + 1
    row_id_empty = row_id * 3 + 2

    q01 = df[imm].quantile(0.01)
    q99 = df[imm].quantile(0.99)

    df_metrics = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$"], columns=['Train', 'Validation', 'Test'])
    df_metrics.at['MAE', 'Train'] = f"{imm_metrics.at['Train', 'mean_absolute_error']:0.3f}"
    df_metrics.at['MAE', 'Validation'] = f"{imm_metrics.at['Validation', 'mean_absolute_error']:0.3f}"
    df_metrics.at['MAE', 'Test'] = f"{imm_metrics.at['Test', 'mean_absolute_error']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Train'] = f"{imm_metrics.at['Train', 'pearson_corrcoef']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Validation'] = f"{imm_metrics.at['Validation', 'pearson_corrcoef']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Test'] = f"{imm_metrics.at['Test', 'pearson_corrcoef']:0.3f}"
    
    col_defs = [
        ColumnDefinition(
            name="index",
            title=imm_config['_model_name'].replace('Model', ''),
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            group=fr"$\mathbf{{{imm}}}$",
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
            group=fr"$\mathbf{{{imm}}}$",
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
            group=fr"$\mathbf{{{imm}}}$",
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
            group=fr"$\mathbf{{{imm}}}$",
        )
    ]

    table = Table(
        df_metrics,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table, col_id],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    kdeplot = sns.kdeplot(
        data=imm_df.loc[imm_df['Group'] != 'Test', :],
        x=imm,
        y='Prediction',
        fill=True,
        cbar=False,
        color=imm_color,
        cut=0,
        legend=False,
        ax=axs[row_id_scatter, col_id]
    )
    scatter = sns.scatterplot(
        data=imm_df.loc[imm_df['Group'] == 'Test', :],
        x=imm,
        y="Prediction",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=35,
        color=imm_color,
        ax=axs[row_id_scatter, col_id],
    )
    axs[row_id_scatter, col_id].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter, col_id].set_xlim(q01, q99)
    axs[row_id_scatter, col_id].set_ylim(q01, q99)
    axs[row_id_scatter, col_id].set_xlabel(imm, color=imm_color, path_effects=[pe.withStroke(linewidth=1.0, foreground="black")])
    
    axs[row_id_empty, col_id].axis('off')

fig.tight_layout()    
fig.savefig(f"{path_plots}/imms_models.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/imms_models.pdf", bbox_inches='tight')
plt.close(fig)

# EpImAge and ICD-11 plots 

## Load data

In [2]:
path_root = pathlib.Path(os.getcwd())
path_plots = f"{path_root}/plots"
path_data = f"{path_root}/data/age-regression"
df_feats = pd.read_excel(f"{path_data}/features.xlsx", index_col=0)
imms = df_feats.index.to_list()
df = pd.read_excel(f"{path_data}/data.xlsx")

gse_controls_count = df.loc[df['Status'] == 'Control', 'GSE'].value_counts()
gses_controls = gse_controls_count.index.values
gse_controls_ids = {gse: df.index[(df['Status'] == 'Control') & (df['GSE'] == gse)].values for gse in gses_controls}

statuses_rename = {x: x.replace(' ', '\n').replace('-', '\n') for x in df['Status'].value_counts().index.values}
df['Status Origin'] = df['Status']
df['Status'] = df['Status'].replace(statuses_rename) 
status_count = df['Status'].value_counts()
statuses = status_count.index.values

df_groups = pd.read_excel(f"{path_data}/groups.xlsx", index_col=0)
icd_chpts = np.sort(df_groups['ICD-11 chapter'].unique())
icd_codes = np.sort(df_groups['ICD-11 code'].unique())
icd_cols = []
icd_code_chpt = {}
icd_codes_for_chpts = {}
for icd_chpt in icd_chpts:
    icd_cols.append(f'Passed\nICD-11\nChapter {icd_chpt}')
    icd_codes_for_chpts[icd_chpt] = np.sort(df_groups.loc[df_groups['ICD-11 chapter'] == icd_chpt, 'ICD-11 code'].unique())
    for icd_code in [f"Passed\nICD-11\nCode {x}" for x in icd_codes_for_chpts[icd_chpt]]:
        icd_code_chpt[icd_code] = f'Passed\nICD-11\nChapter {icd_chpt}'
    icd_cols += [f"Passed\nICD-11\nCode {x}" for x in icd_codes_for_chpts[icd_chpt]]
icd_cols_max = [f"Max\n{x}" for x in icd_cols]

## Prepare colors

In [3]:
def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

def form_bar(base):
    def formatter(x):
        return f'{str(int(round(x * base)))}/{base}'
    return formatter

# Colors for ICD-11 chapters
colors = distinctipy.get_colors(len(icd_chpts), [mcolors.hex2color(mcolors.CSS4_COLORS['black']), mcolors.hex2color(mcolors.CSS4_COLORS['white'])], rng=1337, pastel_factor=0.5)
colors_icd_chpts = {icd_chpt: colors[icd_chpt_id] for icd_chpt_id, icd_chpt in enumerate(icd_chpts)}
colormaps_icd_chpts = {
    icd_chpt: LinearSegmentedColormap.from_list(
        name=f"ICD-11 Chapter {icd_chpt} cmap",
        colors=[make_rgb_transparent(colors_icd_chpts[icd_chpt], (1, 1, 1), 0.2), colors_icd_chpts[icd_chpt]], N=256
    )
    for icd_chpt in icd_chpts
}
colormap_total = LinearSegmentedColormap.from_list(
    name=f"ICD-11 Total cmap",
    colors=[
        mcolors.hex2color(mcolors.CSS4_COLORS['lavender']),
        mcolors.hex2color(mcolors.CSS4_COLORS['dimgray'])],
    N=256,
)

# Colors for GSEs
colors = distinctipy.get_colors(len(gses_controls), [mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_gse = {gses_controls[gse_id]: colors[gse_id] for gse_id in range(len(gses_controls))}

# Colors for statuses
colors = distinctipy.get_colors(len(statuses) - 1, [mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue']), mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_status = {'Control': mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue'])}
for status_id, status in enumerate(statuses):
    if status != 'Control':
        colors_status[status] = colors[status_id - 1]

## Perform statistical tests

In [None]:
df_clocks = pd.read_excel(f"{path_data}/clocks_meta.xlsx", index_col=0)
new_cols = ['Total Rho', 'Total MAE', 'Passed\nICD-11\nTotal'] + icd_cols + ['Max\nPassed\nICD-11\nTotal'] + icd_cols_max
for col in new_cols: 
    df_clocks[col] = None
    
for clock_name in (pbar := tqdm(df_clocks.index)):
    pbar.set_description(f"Processing {clock_name}")
    clock_type = df_clocks.at[clock_name, 'Type']

    if clock_type == 'Age':
        df['Error'] = df[clock_name] - df['Age']
    else:
        if clock_name in ['dnamtl', 'pcdnamtl']:
            df['Error'] = -df[clock_name]
        else:
            df['Error'] = df[clock_name]
        
    for section_id, section_row in df_groups.iterrows():
        section_statuses = ast.literal_eval(section_row['Statuses'])
        section_groups = ast.literal_eval(section_row['Groups'])
        section_directions = ast.literal_eval(section_row['Directions'])
        df_section = df.loc[(df['GSE'] == section_row['GSE']) & (df['Status Origin'].isin(section_statuses)), ['Status Origin', 'Error']]
        
        for section_group_id, section_group in enumerate(section_groups):
            _, pval = mannwhitneyu(
                df_section.loc[df_section["Status Origin"] == section_group[0], "Error"].values,
                df_section.loc[df_section["Status Origin"] == section_group[1], "Error"].values,
                alternative="two-sided",
            )
            bias_0 = np.mean(df_section.loc[df_section['Status Origin'] == section_group[0], 'Error'])
            bias_1 = np.mean(df_section.loc[df_section['Status Origin'] == section_group[1], 'Error'])
            
            df_clocks.at[clock_name, f"pval\n{section_id}\n{section_group}"] = pval
            df_clocks.at[clock_name, f"bias_0\n{section_id}\n{section_group}"] = bias_0
            df_clocks.at[clock_name, f"bias_1\n{section_id}\n{section_group}"] = bias_1

# FDR correction
pvals_cols = [col for col in df_clocks.columns if 'pval' in col]
for clock_name in (pbar := tqdm(df_clocks.index)):
    _, df_clocks.loc[clock_name, pvals_cols], _, _ = multipletests(df_clocks.loc[clock_name, pvals_cols], 0.05, method='fdr_bh')

for clock_name in (pbar := tqdm(df_clocks.index)):
    pbar.set_description(f"Processing {clock_name}")
    clock_type = df_clocks.at[clock_name, 'Type']
    
    if clock_type == 'Age':
        real_all = torch.from_numpy(df.loc[df['Status Origin'] == 'Control', 'Age'].values)
        pred_all = torch.from_numpy(df.loc[df['Status Origin'] == 'Control', clock_name].values)
        df_clocks.at[clock_name, 'Total Rho'] = pearson_corrcoef(pred_all, real_all).numpy().item()
        df_clocks.at[clock_name, 'Total MAE'] = mean_absolute_error(pred_all, real_all).numpy().item()

    passed_icd_chpt = {icd_chpt: 0 for icd_chpt in icd_chpts}
    passed_icd_chpt_max = {icd_chpt: 0 for icd_chpt in icd_chpts}
    for icd_chpt in icd_chpts:
        df_chpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
        for section_id, section_row in df_chpt.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])
            
            for section_group_id, section_group in enumerate(section_groups):
                passed_icd_chpt_max[icd_chpt] += 1
                
                pval = df_clocks.at[clock_name, f"pval\n{section_id}\n{section_group}"]
                bias_0 = df_clocks.at[clock_name, f"bias_0\n{section_id}\n{section_group}"]
                bias_1 = df_clocks.at[clock_name, f"bias_1\n{section_id}\n{section_group}"]
                
                group_direction = section_directions[section_group_id]
                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd_chpt[icd_chpt] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd_chpt[icd_chpt] += 1
        df_clocks.at[clock_name, f'Passed\nICD-11\nChapter {icd_chpt}'] = passed_icd_chpt[icd_chpt]
        df_clocks.at[clock_name, f'Max\nPassed\nICD-11\nChapter {icd_chpt}'] = passed_icd_chpt_max[icd_chpt]              
    df_clocks.at[clock_name, f'Passed\nICD-11\nTotal'] = sum(passed_icd_chpt.values())
    df_clocks.at[clock_name, f'Max\nPassed\nICD-11\nTotal'] = sum(passed_icd_chpt_max.values())
    
    passed_icd_code = {icd_code: 0 for icd_code in icd_codes}
    passed_icd_code_max = {icd_code: 0 for icd_code in icd_codes}
    for icd_code in icd_codes:
        df_code = df_groups[df_groups['ICD-11 code'] == icd_code]
        for section_id, section_row in df_code.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])
            
            for section_group_id, section_group in enumerate(section_groups):
                passed_icd_code_max[icd_code] += 1
                
                pval = df_clocks.at[clock_name, f"pval\n{section_id}\n{section_group}"]
                bias_0 = df_clocks.at[clock_name, f"bias_0\n{section_id}\n{section_group}"]
                bias_1 = df_clocks.at[clock_name, f"bias_1\n{section_id}\n{section_group}"]
                
                group_direction = section_directions[section_group_id]
                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd_code[icd_code] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd_code[icd_code] += 1
        df_clocks.at[clock_name, f'Passed\nICD-11\nCode {icd_code}'] = passed_icd_code[icd_code]
        df_clocks.at[clock_name, f'Max\nPassed\nICD-11\nCode {icd_code}'] = passed_icd_code_max[icd_code]      
    
df_clocks.to_excel(f"{path_plots}/clocks_tests.xlsx", index_label='Clock Name')

## Plot clock results

### Plot tables

In [19]:
df_clocks = pd.read_excel(f"{path_plots}/clocks_tests.xlsx", index_col='Clock Name')
df_clocks.insert(0, 'Clock Name', df_clocks.index.values)
df_clocks[f"Passed\nICD-11\nTotal"] /= df_clocks.at['Hannum', f'Max\nPassed\nICD-11\nTotal']
for col in icd_cols:
    df_clocks[col] /= df_clocks.at['Hannum', f'Max\n{col}']

col_names_common = ["Year", "Total Rho", "Total MAE", f"Passed\nICD-11\nTotal"]
col_defs_common = [
    ColumnDefinition(
        name="Clock Name",
        title="Clocks",
        textprops={"ha": "right", "weight": "bold"},
        width=2.25,
    ),
    ColumnDefinition(
        name="Year",
        title="Year",
        textprops={"ha": "center"},
        width=1.0,
        border="left"
    ),
    ColumnDefinition(
        name="Total Rho",
        title="Total\n" + r"Pearson $\rho$",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=normed_cmap(df_clocks["Total Rho"].dropna(), cmap=matplotlib.cm.Greens, num_stds=2.5),
        width=1.0,
        border="left"
    ),
    ColumnDefinition(
        name="Total MAE",
        title="Total\nMAE",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=normed_cmap(df_clocks["Total MAE"].dropna(), cmap=matplotlib.cm.Reds, num_stds=2.5),
        width=1.0,
    ),
    ColumnDefinition(
        name=f"Passed\nICD-11\nTotal",
        title="Passed\nICD-11",
        width=1.5,
        border="left",
        textprops={"ha": "center"},
        plot_fn=bar,
        plot_kw={
            "cmap": colormap_total,
            "plot_bg_bar": True,
            "annotate": True,
            "height": 0.95,
            "linewidth": 0.5,
            "formatter": form_bar(df_clocks.at['Hannum', f'Max\nPassed\nICD-11\nTotal']),
        },
    ),
]

icd_chpt_col_defs = copy.deepcopy(col_defs_common)
icd_chpt_col_names = copy.deepcopy(col_names_common)
for icd_chpt in icd_chpts:
    if icd_chpt == 1:
        border = 'left'
    else:
        border = None
    max_passed = df_clocks.at['Hannum', f'Max\nPassed\nICD-11\nChapter {icd_chpt}']
    icd_chpt_col_names.append(f'Passed\nICD-11\nChapter {icd_chpt}')
    col_def = ColumnDefinition(
        name=f'Passed\nICD-11\nChapter {icd_chpt}',
        title=f'Chapter {icd_chpt}',
        width=1.0,
        plot_fn=bar,
        border=border,
        textprops={"ha": "center"},
        plot_kw={
            "cmap": colormaps_icd_chpts[icd_chpt],
            "plot_bg_bar": True,
            "annotate": True,
            "height": 0.95,
            "lw": 0.5,
            "formatter": form_bar(max_passed),
        },
    )
    icd_chpt_col_defs.append(col_def)
    
icd_code_col_defs = copy.deepcopy(col_defs_common)
icd_code_col_names = copy.deepcopy(col_names_common)
for icd_chpt in icd_chpts:
    if icd_chpt in [11, 12]:
        group=fr"$\mathbf{{{icd_chpt}}}$"
    else:
        group=fr"$\mathbf{{Chapter\,{icd_chpt}}}$"
        
    for code_in_chpt in icd_codes_for_chpts[icd_chpt]:
        max_passed = df_clocks.at['Hannum', f'Max\nPassed\nICD-11\nCode {code_in_chpt}']
        icd_code_col_names.append(f'Passed\nICD-11\nCode {code_in_chpt}')
        col_def = ColumnDefinition(
            name=f'Passed\nICD-11\nCode {code_in_chpt}',
            title=f'{code_in_chpt}',
            width=1.0,
            plot_fn=bar,
            border=None,
            textprops={"ha": "center"},
            plot_kw={
                "cmap": colormaps_icd_chpts[icd_chpt],
                "plot_bg_bar": True,
                "annotate": True,
                "height": 0.95,
                "lw": 0.5,
                "formatter": form_bar(max_passed),
            },
            group=group,
        )
        icd_code_col_defs.append(col_def) 
    
fig, ax = plt.subplots(figsize=(25, 17))
table = Table(
    df_clocks[icd_chpt_col_names],
    column_definitions=icd_chpt_col_defs,
    row_dividers=True,
    footer_divider=False,
    odd_row_color="#ffffff", 
    even_row_color="#f0f0f0",
    ax=ax,
    row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
    col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
    column_border_kw={"linewidth": 1, "linestyle": "-"},
).autoset_fontcolors(colnames=icd_chpt_col_names) 
fig.savefig(f"{path_plots}/clocks_tests_chpts.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/clocks_tests_chpts.pdf", bbox_inches='tight')
plt.close(fig)

fig, ax = plt.subplots(figsize=(45, 22))
table = Table(
    df_clocks[icd_code_col_names],
    column_definitions=icd_code_col_defs,
    row_dividers=True,
    footer_divider=False,
    odd_row_color="#ffffff", 
    even_row_color="#f0f0f0",
    ax=ax,
    row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
    col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
    column_border_kw={"linewidth": 1, "linestyle": "-"},
).autoset_fontcolors(colnames=icd_code_col_names) 
fig.savefig(f"{path_plots}/clocks_tests_codes.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/clocks_tests_codes.pdf", bbox_inches='tight')
plt.close(fig)

# Plot EpImAge for Controls

## All Controls together

In [None]:
n_rows = 4
n_cols = 1
fig_height = 6
fig_width = 4
sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.2, 0.8, 0.15], gridspec_kw={'wspace':0.25, 'hspace': 0.05}, sharey=False, sharex=False)

df_fig = df.loc[df['Status'] == 'Control', ['Age', 'Status', 'Split', 'EpImAge']]
df_fig['Error'] = df_fig['EpImAge'] - df_fig['Age']
df_metrics = pd.read_excel(f"{path_root}/models/EpImAge/metrics.xlsx", index_col=0)

row_id_table = 0
row_id_hist = 1
row_id_scatter = 2
row_id_empty = 3

df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test', 'Total'])
for part in ['Train', 'Validation', 'Test', 'Total']:
    df_table.at['MAE', part] = f"{df_metrics.at[part, 'MAE']:0.3f}"
    df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{df_metrics.at[part, 'Rho']:0.3f}"
    df_table.at["Bias", part] = f"{df_metrics.at[part, 'Bias']:0.3f}"

col_defs = [
    ColumnDefinition(
        name="index",
        title='',
        textprops={"ha": "center", "weight": "bold"},
        width=2.5,
    ),
    ColumnDefinition(
        name="Train",
        textprops={"ha": "left"},
        width=1.5,
        border="left"
    ),
    ColumnDefinition(
        name="Validation",
        textprops={"ha": "left"},
        width=2.2,
    ),
    ColumnDefinition(
        name="Test",
        textprops={"ha": "left"},
        width=1.5,
    ),
    ColumnDefinition(
        name="Total",
        textprops={"ha": "left"},
        width=1.5,
    )
]

table = Table(
    df_table,
    column_definitions=col_defs,
    row_dividers=True,
    footer_divider=False,
    ax=axs[row_id_table],
    textprops={"fontsize": 8},
    row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
    col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
    column_border_kw={"linewidth": 1, "linestyle": "-"},
).autoset_fontcolors(colnames=['Train', 'Validation', 'Test', 'Total'])

hist_bins = np.linspace(0, 120, 25)
histplot = sns.histplot(
    data=df_fig,
    bins=hist_bins,
    edgecolor='k',
    linewidth=1,
    x="Age",
    color='lightslategray',
    ax=axs[row_id_hist]
)
axs[row_id_hist].set_xticks([])
axs[row_id_hist].set_xlim(0, 105)
axs[row_id_hist].set_ylabel("Count")

kdeplot = sns.kdeplot(
    data=df_fig.loc[df_fig['Split'].isin(['Train', 'Validation']), :],
    x='Age',
    y='EpImAge',
    fill=True,
    cbar=False,
    thresh=0.0001,
    color=make_rgb_transparent(mcolors.hex2color(mcolors.CSS4_COLORS['lightslategray']), (1, 1, 1), 0.25),
    cut=0,
    legend=False,
    ax=axs[row_id_scatter]
)
scatter = sns.scatterplot(
    data=df_fig.loc[df_fig['Split'] == 'Test', :],
    x='Age',
    y="EpImAge",
    linewidth=0.01,
    alpha=0.8,
    edgecolor="k",
    s=2,
    color='lightslategray',
    ax=axs[row_id_scatter],
)
axs[row_id_scatter].axline((0, 0), slope=1, color="black", linestyle=":")
axs[row_id_scatter].set_xlim(0, 105)
axs[row_id_scatter].set_ylim(0, 105)
axs[row_id_scatter].set_ylabel("Prediction")
axs[row_id_scatter].set_xlabel("Age")
axs[row_id_empty].axis('off')

fig.tight_layout()
fig.savefig(f"{path_plots}/EpImAge_Controls_All.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/EpImAge_Controls_All.pdf", bbox_inches='tight')
plt.close(fig)

## Controls from each GSE separately

In [None]:
n_rows = 6 * 4
n_cols = 12
fig_height = 32
fig_width = 46
sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.2, 0.8, 0.15] * 6, gridspec_kw={'wspace':0.25, 'hspace': 0.05}, sharey=False, sharex=False)

for gse_id, (gse, gse_samples) in tqdm(enumerate(gse_controls_ids.items())):
    color_gse = colors_gse[gse]
    data_gse = df.loc[gse_samples, ['GSE', 'Age', 'Status', 'Split', 'EpImAge']]
    data_gse['Error'] = data_gse['EpImAge'] - data_gse['Age']
    row_id, col_id = divmod(gse_id, n_cols)
    row_id_table = row_id * 4
    row_id_hist = row_id * 4 + 1
    row_id_scatter = row_id * 4 + 2
    row_id_empty = row_id * 4 + 3

    df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test', 'Total'])
    for part in ['Train', 'Validation', 'Test', 'Total']:
        
        if part != 'Total':
            pred = data_gse.loc[data_gse['Split'] == part, 'EpImAge'].values
            real = data_gse.loc[data_gse['Split'] == part, 'Age'].values
            errs = data_gse.loc[data_gse['Split'] == part, 'Error'].values
        else:
            pred = data_gse['EpImAge'].values
            real = data_gse['Age'].values
            errs = data_gse['Error'].values
            
        df_table.at['MAE', part] = f"{mean_absolute_error(torch.from_numpy(pred), torch.from_numpy(real)).numpy().item():0.3f}"
        df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{pearson_corrcoef(torch.from_numpy(pred), torch.from_numpy(real)).numpy().item():0.3f}"
        df_table.at["Bias", part] = f"{np.mean(errs):0.3f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title=gse if gse != 'GSEUNN' else 'This work',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left"
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=2.2,
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
        ),
        ColumnDefinition(
            name="Total",
            textprops={"ha": "left"},
            width=1.5,
        )
    ]

    table = Table(
        df_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table, col_id],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test', 'Total'])

    hist_bins = np.linspace(0, 120, 13)
    histplot = sns.histplot(
        data=data_gse,
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color=color_gse,
        ax=axs[row_id_hist, col_id]
    )
    axs[row_id_hist, col_id].set_xticks([])
    axs[row_id_hist, col_id].set_xlim(0, 115)
    if col_id == 0:
        axs[row_id_hist, col_id].set_ylabel("Count")
    else:
        axs[row_id_hist, col_id].set_ylabel("")

    kdeplot = sns.kdeplot(
        data=data_gse.loc[data_gse['Split'].isin(['Train', 'Validation']), :],
        x='Age',
        y='EpImAge',
        fill=True,
        cbar=False,
        color=make_rgb_transparent(color_gse, (1, 1, 1), 0.25),
        cut=0,
        legend=False,
        ax=axs[row_id_scatter, col_id]
    )
    scatter = sns.scatterplot(
        data=data_gse.loc[data_gse['Split'] == 'Test', :],
        x='Age',
        y="EpImAge",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=35,
        color=color_gse,
        ax=axs[row_id_scatter, col_id],
    )
    axs[row_id_scatter, col_id].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter, col_id].set_xlim(0, 115)
    axs[row_id_scatter, col_id].set_ylim(0, 115)
    if col_id == 0:
        axs[row_id_scatter, col_id].set_ylabel("Prediction")
    else:
        axs[row_id_scatter, col_id].set_ylabel("")
    if row_id_empty == n_rows - 1:
        axs[row_id_scatter, col_id].set_xlabel("Age")
    else:
        axs[row_id_scatter, col_id].set_xlabel("")
    axs[row_id_empty, col_id].axis('off')

fig.tight_layout()
fig.savefig(f"{path_plots}/EpImAge_Controls_GSEs.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/EpImAge_Controls_GSEs.pdf", bbox_inches='tight')
plt.close(fig)

# Violins for diffferent ICD-11 statuses

In [None]:
df_clocks = pd.read_excel(f"{path_plots}/clocks_tests.xlsx", index_col='Clock Name')

mosaic_violins = []
mosaic_rows = np.sort(df_groups['Row on violin plot'].unique())
for mosaic_row in mosaic_rows:
    df_mosaic_row = df_groups[df_groups['Row on violin plot'] == mosaic_row].sort_values(by=['ICD-11 chapter', 'ICD-11 code', 'GSE'], ascending=[True, True, True])
    mosaic_row_labels = []
    for plot_id, plot_row in df_mosaic_row.iterrows():
        n_violins = len(ast.literal_eval(plot_row['Statuses']))
        mosaic_row_labels += [plot_id]*n_violins
    mosaic_violins.append(mosaic_row_labels)

max_mosaic_row = max([len(x) for x in mosaic_violins])
violons_empty_panels = set()
for row_id, row in enumerate(mosaic_violins):
    for added_spaces in range(len(row), max_mosaic_row):
        violons_empty_panels.add(f'Empty row {row_id}')
        mosaic_violins[row_id].append(f'Empty row {row_id}')

sns.set_theme(style='ticks')
fig_height = 5 * len(mosaic_violins)
fig_width = 1.4 * max_mosaic_row
fig, axs = plt.subplot_mosaic(mosaic=mosaic_violins, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False)

for plot_id, plot_row in df_groups.iterrows():

    plot_statuses = [statuses_rename[x] for x in ast.literal_eval(plot_row['Statuses'])]
    plot_groups_raw = ast.literal_eval(plot_row['Groups'])
    plot_groups = [(statuses_rename[x[0]], statuses_rename[x[1]]) for x in ast.literal_eval(plot_row['Groups'])]

    df_plot = df.loc[(df['GSE'] == plot_row['GSE']) & (df['Status'].isin(plot_statuses)), ['Status', 'Error']]
    plot_status_count = df_plot['Status'].value_counts()
    plot_statuses_rename = {}
    for x in plot_statuses:
        plot_statuses_rename[x] = x + f"\nCount: {plot_status_count[x]}\nBias: {np.mean(df_plot.loc[df_plot['Status'] == x, 'Error']):0.1f}"
    plot_statuses = [plot_statuses_rename[x] for x in plot_statuses]
    plot_groups = [(plot_statuses_rename[x[0]], plot_statuses_rename[x[1]]) for x in plot_groups]
    df_plot['Status'] = df_plot['Status'].replace(plot_statuses_rename)
    colors_plot_status = {plot_statuses_rename[x]: colors_status[x] for x in plot_statuses_rename}

    pval_formatted = []
    for plot_group_id, plot_group in enumerate(plot_groups):
        pval = df_clocks.at['EpImAge', f"pval\n{plot_id}\n{plot_groups_raw[plot_group_id]}"]
        pval_formatted.append(f"{pval:.1e}")

    violinplot = sns.violinplot(
        data=df_plot,
        x='Status',
        y='Error',
        palette=colors_plot_status,
        scale='width',
        order=plot_statuses,
        saturation=0.75,
        legend=False,
        ax=axs[plot_id]
    )
    annotator = Annotator(
        ax=axs[plot_id],
        pairs=plot_groups,
        data=df_plot,
        x="Status",
        y="Error",
        order=plot_statuses,
    )
    annotator.set_custom_annotations(pval_formatted)
    annotator.configure(loc='inside', verbose=0)
    annotator.annotate()

    axs[plot_id].set_xlabel('')
    axs[plot_id].set_ylabel('Age acceleration')
    if plot_row['GSE'] == 'GSEUNN':
        axs[plot_id].set_title(f"This work ({df_plot.shape[0]})")
    else:
        axs[plot_id].set_title(f"{plot_row['GSE']} ({df_plot.shape[0]})")
    axs[plot_id].set_facecolor(make_rgb_transparent(colors_icd_chpts[plot_row['ICD-11 chapter']], (1, 1, 1), 0.33))

for empty_panel in violons_empty_panels:
    axs[empty_panel].axis('off')
fig.tight_layout()
fig.savefig(f"{path_plots}/EpImAge_Violins.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_plots}/EpImAge_Violins.pdf", bbox_inches='tight')
plt.close(fig)

# XAI age acceleration difference

In [None]:
explain_method = "GradientShap"
explain_baselines = "b|100000"

model = TabularModel.load_model(f"{path_root}/models/EpImAge")

imms_log = [f"{x}_log" for x in imms]
feats_dict_log = {x: f"{x}_log" for x in imms}

df_imm_shap_raw = pd.DataFrame(index=imms_log)
for section_id, section_row in df_groups.iterrows():
    section_statuses = ast.literal_eval(section_row['Statuses'])
    section_groups = ast.literal_eval(section_row['Groups'])
    section_directions = ast.literal_eval(section_row['Directions'])
    df_section = df.loc[(df['GSE'] == section_row['GSE']) & (df['Status Origin'].isin(section_statuses)), :]
    df_section.rename(columns=feats_dict_log, inplace=True)

    for section_group_id, section_group in enumerate(section_groups):
        group_direction = section_directions[section_group_id]

        df_section_group_0 = df_section.loc[df_section["Status Origin"] == section_group[0], :]
        section_group_expl_0 = model.explain(df_section_group_0, method=explain_method, baselines=explain_baselines)
        shap_values_mean_0 = section_group_expl_0.mean()

        df_section_group_1 = df_section.loc[df_section["Status Origin"] == section_group[1], :]
        section_group_expl_1 = model.explain(df_section_group_1, method=explain_method, baselines=explain_baselines)
        shap_values_mean_1 = section_group_expl_1.mean()

        if group_direction == 'Increasing':
            df_imm_shap_raw.loc[imms_log, f"shap_mean_diff\n{section_id}\n{section_group}"] = shap_values_mean_1[imms_log] - shap_values_mean_0[imms_log]
        else:
            df_imm_shap_raw.loc[imms_log, f"shap_mean_diff\n{section_id}\n{section_group}"] = shap_values_mean_0[imms_log] - shap_values_mean_1[imms_log]

df_imm_shap = pd.DataFrame(index=imms_log)
for icd_chpt in icd_chpts:
    df_imm_shap.loc[imms_log, f'Chapter {icd_chpt}'] = 0.0
    df_chpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
    for section_id, section_row in df_chpt.iterrows():
        section_groups = ast.literal_eval(section_row['Groups'])
        for section_group_id, section_group in enumerate(section_groups):
            df_imm_shap.loc[imms_log, f'Chapter {icd_chpt}'] += df_imm_shap_raw.loc[imms_log, f"shap_mean_diff\n{section_id}\n{section_group}"]   
for icd_code in icd_codes:
    df_imm_shap.loc[imms_log, f'Code {icd_code}'] = 0.0
    df_code = df_groups[df_groups['ICD-11 code'] == icd_code]
    for section_id, section_row in df_code.iterrows():
        section_groups = ast.literal_eval(section_row['Groups'])
        for section_group_id, section_group in enumerate(section_groups):
            df_imm_shap.loc[imms_log, f'Code {icd_code}'] += df_imm_shap_raw.loc[imms_log, f"shap_mean_diff\n{section_id}\n{section_group}"]

df_imm_shap_table = df_imm_shap[[f'Chapter {icd_chpt}' for icd_chpt in icd_chpts]]
df_imm_shap_table.index = df_imm_shap_table.index.str.replace('_log', '')

df_fig = df_imm_shap_table.astype(float)
x_ticks_colors = {f'Chapter {icd_chpt}': colors_icd_chpts[icd_chpt] for icd_chpt in icd_chpts}
sns.set_theme(style='ticks')
clustermap = sns.clustermap(
    df_fig,
    annot=True,
    col_cluster=False,
    row_cluster=True,
    fmt=".2f",
    center=0.0,
    cmap='seismic',
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    figsize=(16, 12),
    cbar_kws={'orientation': 'horizontal'}
)
clustermap.ax_heatmap.set_xlabel('')
clustermap.ax_heatmap.set_ylabel('')
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), rotation=0, path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
for tick_label in clustermap.ax_heatmap.get_xticklabels():
    tick_label.set_color(x_ticks_colors[tick_label.get_text()])
clustermap_pos = clustermap.ax_heatmap.get_position()
clustermap.ax_cbar.set_position([clustermap_pos.x0, clustermap_pos.y1 + 0.05, clustermap_pos.width, 0.03])
clustermap.ax_cbar.set_title("XAI age acceleration difference", fontsize='large')
clustermap.ax_cbar.tick_params(labelsize='large')
for spine in clustermap.ax_cbar.spines:
    clustermap.ax_cbar.spines[spine].set_linewidth(1)
plt.savefig(f"{path_plots}/XAI_age_acceleration_difference.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_plots}/XAI_age_acceleration_difference.pdf", bbox_inches='tight')
plt.close(clustermap.figure)