# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from scipy.stats import mannwhitneyu
import matplotlib.lines as mlines
from statsmodels.stats.multitest import multipletests
from sklearn.model_selection import RepeatedStratifiedKFold
import torch
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
import pandas as pd
import os

# Load data

In [None]:
feats = pd.read_excel(f"{os.getcwd()}/data/feats.xlsx", index_col=0).index.values
data = pd.read_excel(f"{os.getcwd()}/data/data.xlsx", index_col=0)

# Region-specific analysis

In [None]:
colors_region = {'Central': 'gold', 'Yakutia': 'lightslategray'}

## Plot distribution

In [None]:
hist_bins = np.linspace(5, 115, 23)
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(2.3, 3))
df_fig = data.loc[:, ['Age', 'Region']].copy()
df_fig.rename(columns={'Region': 'Cohort'}, inplace=True)
histplot = sns.histplot(
    data=df_fig,
    bins=hist_bins,
    edgecolor='k',
    linewidth=1,
    x="Age",
    hue='Cohort',
    palette=colors_region,
    hue_order=['Yakutia', 'Central'],
    ax=ax
)
histplot.set(xlim=(0, 120))
sns.move_legend(histplot, "lower center", bbox_to_anchor=(.5, 1), ncol=2, frameon=False)
plt.savefig(f"{os.getcwd()}/data/histplot.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{os.getcwd()}/data/histplot.pdf", bbox_inches='tight')
plt.close(fig)

## Plot pie chart for Region

In [None]:
region_parts = [data[data['Region'] == 'Central'].shape[0] / data.shape[0] * 100, data[data['Region'] == 'Yakutia'].shape[0]/ data.shape[0] * 100]
explode = [0.05, 0.05]
fig, ax = plt.subplots(figsize=(2.5, 2.5))
plt.pie(
    region_parts,
    labels=['Central', 'Yakutia'],
    colors=[colors_region['Central'], colors_region['Yakutia']],
    explode=explode,
    autopct='%.2f%%',
    wedgeprops={"edgecolor": "black",'linewidth': 1}
)
plt.savefig(f"{os.getcwd()}/data/pie.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{os.getcwd()}/data/pie.pdf", bbox_inches='tight')
plt.close(fig)

## Calculate statistics

In [None]:
df_stat = pd.DataFrame(index=list(feats))
for feat in feats:
    vals = {}
    for group in ['Central', 'Yakutia']:
        vals[group] = data.loc[data['Region'] == group, feat].values
        df_stat.at[feat, f"mean_{group}"] = np.mean(vals[group])
        df_stat.at[feat, f"median_{group}"] = np.median(vals[group])
        df_stat.at[feat, f"q75_{group}"], df_stat.at[feat, f"q25_{group}"] = np.percentile(vals[group], [75 , 25])
        df_stat.at[feat, f"iqr_{group}"] = df_stat.at[feat, f"q75_{group}"] - df_stat.at[feat, f"q25_{group}"]
    _, df_stat.at[feat, "mw_pval"] = mannwhitneyu(vals['Central'], vals['Yakutia'], alternative='two-sided')
_, df_stat.loc[feats, "mw_pval_fdr_bh"], _, _ = multipletests(df_stat.loc[feats, "mw_pval"].values, 0.05, method='fdr_bh')
_, df_stat.loc[feats, "mw_pval_bonferroni"], _, _ = multipletests(df_stat.loc[feats, "mw_pval"].values, 0.05, method='bonferroni')
_, df_stat.loc[feats, "mw_pval_simes-hochberg"], _, _ = multipletests(df_stat.loc[feats, "mw_pval"].values, 0.05, method='simes-hochberg')
df_stat.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_stat.to_excel(f"{os.getcwd()}/data/statistics.xlsx", index_label='Features')

## Plot features p-values

In [None]:
df_fig = df_stat.loc[feats, :]
df_fig.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_fig['mw_pval_fdr_bh_log'] = -np.log10(df_fig['mw_pval_fdr_bh'])
df_fig['color'] = 'pink'
df_fig.loc[df_fig['mw_pval_fdr_bh'] < 0.05, 'color'] = 'red'
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(3, 12))
barplot = sns.barplot(
    data=df_fig,
    y=df_fig.index,
    x='mw_pval_fdr_bh_log',
    edgecolor='black',
    palette=df_fig['color'].values,
    ax=ax,
)
ax.set_xlabel(r"$-\log_{10}(\mathrm{p-value})$", fontsize=18)
ax.set_ylabel('', fontsize=20)
ax.set_xticklabels([f"{int(tick):d}" for tick in ax.get_xticks()], fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
plt.savefig(f"{os.getcwd()}/data/barplot.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{os.getcwd()}/data/barplot.pdf", bbox_inches='tight')
plt.close(fig)

## Plot features distribution (SHAP values order)

In [None]:
df_shap = pd.read_excel(f"{os.getcwd()}/data/classificator/shap_values.xlsx", index_col=0)
df_shap_abs = df_shap.abs()
df_feat_imp = pd.DataFrame(index=df_shap_abs.columns, data=df_shap_abs.mean(), columns=['mean_abs_shap'])
df_feat_imp.sort_values([f"mean_abs_shap"], ascending=[False], inplace=True)
feats_sorted = df_feat_imp.index.values

n_rows = 4
n_cols = 8
fig_width = 15
fig_height = 12

sns.set_theme(style='whitegrid')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={})

for f_id, f in enumerate(feats_sorted):
    row_id, col_id = divmod(f_id, n_cols)
    # Hide outliers
    q01 = data[f].quantile(0.01)
    q99 = data[f].quantile(0.99)
    sns.violinplot(
        data=data.loc[(data[f] > q01) & (data[f] < q99), :],
        x='Region',
        y=f,
        palette=colors_region,
        scale='width',
        order=list(colors_region.keys()),
        saturation=0.75,
        cut=0,
        linewidth=1.0,
        ax=axs[row_id, col_id],
    )
    axs[row_id, col_id].set_ylabel(f)
    axs[row_id, col_id].set_xlabel('')
    axs[row_id, col_id].set(xticklabels=[]) 
    mw_pval = df_stat.at[f, "mw_pval_fdr_bh"]
    axs[row_id, col_id].set_title(f'{mw_pval:.2e}')
    
legend_handles = [
    mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor=colors_region['Central'], markersize=10, label='Central'),
    mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor=colors_region['Yakutia'], markersize=10, label='Yakutia')
]
fig.legend(handles=legend_handles, bbox_to_anchor=(0.5, 1.0), loc="lower center", ncol=2, frameon=False, fontsize='large')
fig.tight_layout()    
plt.savefig(f"{os.getcwd()}/data/feats.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{os.getcwd()}/data/feats.pdf", bbox_inches='tight')
plt.close(fig)

## SImAge analysis

In [None]:
df_fig = data.loc[:, ['Age', 'SImAge', 'SImAge acceleration', 'Region']].copy()
df_fig.rename(columns={'Region': 'Cohort'}, inplace=True)
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(4.5, 4))
scatter = sns.scatterplot(
    data=df_fig,
    x="Age",
    y="SImAge",
    hue="Cohort",
    palette=colors_region,
    linewidth=0.2,
    alpha=0.75,
    edgecolor="k",
    s=20,
    hue_order=list(colors_region.keys()),
    ax=ax
)
bisect = sns.lineplot(
    x=[0, 120],
    y=[0, 120],
    linestyle='--',
    color='black',
    linewidth=1.0,
    ax=ax
)
ax.set_xlabel("Age")
ax.set_ylabel("SImAge")
ax.set_xlim(0, 120)
ax.set_ylim(0, 120)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(f"{os.getcwd()}/data/SImAge_scatter.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{os.getcwd()}/data/SImAge_scatter.pdf", bbox_inches='tight')
plt.close()

sns.set_theme(style='whitegrid')
plt.figure(figsize=(2.5, 4))
violin = sns.violinplot(
    data=df_fig,
    x='Cohort',
    y='SImAge acceleration',
    palette=colors_region,
    #hue='Cohort',
    scale='width',
    order=list(colors_region.keys()),
    saturation=0.75,
    legend=False
)
violin.set_xlabel(f"")
mw_pval = mannwhitneyu(
    df_fig.loc[df_fig['Cohort'] == 'Central', 'SImAge acceleration'].values,
    df_fig.loc[df_fig['Cohort'] == 'Yakutia', 'SImAge acceleration'].values,
    alternative='two-sided').pvalue
violin.set_title(f'{mw_pval:.2e}')
plt.savefig(f"{os.getcwd()}/data/SImAge_acceleration.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{os.getcwd()}/data/SImAge_acceleration.pdf", bbox_inches='tight')
plt.close()

df_metrics = pd.DataFrame(
    index=['Central', 'Yakutia'],
    columns=['mean_absolute_error', 'pearson_corrcoef', 'mean_age_acc']
)
for cohort in ['Central', 'Yakutia']:
    pred = torch.from_numpy(df_fig.loc[df_fig['Cohort'] == cohort, 'SImAge'].values)
    real = torch.from_numpy(df_fig.loc[df_fig['Cohort'] == cohort, 'Age'].values)
    df_metrics.at[cohort, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
    df_metrics.at[cohort, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
    df_metrics.at[cohort, 'mean_age_acc'] = np.mean(df_fig.loc[df_fig['Cohort'] == cohort, 'SImAge acceleration'].values)
df_metrics.to_excel(f"{os.getcwd()}/data/SImAge_metrics.xlsx", index_label="Metrics")

# Create data for ML

## Stratification by regions (target) and age

In [None]:
random_state = 1337
n_splits = 5

stratify_cat_parts = {
    'Central': data.index[data['Region'] == 'Central'].values,
    'Yakutia': data.index[data['Region'] == 'Yakutia'].values,
}

for part, ids in stratify_cat_parts.items():
    print(f"{part}: {len(ids)}")
    con = data.loc[ids, 'Age'].values
    ptp = np.ptp(con)
    num_bins = 5
    bins = np.linspace(np.min(con) - 0.1 * ptp, np.max(con) + 0.1 * ptp, num_bins + 1)
    binned = np.digitize(con, bins) - 1
    unique, counts = np.unique(binned, return_counts=True)
    occ = dict(zip(unique, counts))
    k_fold = RepeatedStratifiedKFold(
        n_splits=n_splits,
        n_repeats=1,
        random_state=random_state
    )
    splits = k_fold.split(X=ids, y=binned, groups=binned)
    
    for split_id, (ids_trn_val, ids_tst) in enumerate(splits):
        data.loc[ids[ids_trn_val], f"Split_{split_id}"] = "trn_val"
        data.loc[ids[ids_tst], f"Split_{split_id}"] = "tst"

## Save data for ML

In [None]:
data.to_excel(f"{os.getcwd()}/data.xlsx", index_label='Index')