# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [1]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from scipy.stats import mannwhitneyu
import matplotlib.lines as mlines
from statsmodels.stats.multitest import multipletests
from sklearn.model_selection import RepeatedStratifiedKFold
import statsmodels.formula.api as smf
import torch
import pathlib
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
import json
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash


# Data preparation

## Load full immunology data

In [None]:
path = f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
path_save = "D:/YandexDisk/Work/bbd/immunology/002_central_vs_yakutia"

feats = pd.read_excel(f"{path}/data/immuno/feats_con.xlsx", index_col=0).index.values
feats_fimmu = pd.read_excel(f"{path}/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values
feats_slctd = pd.read_excel(f"{path}/special/059_imm_data_selection/feats_selected.xlsx", index_col=0).index.values

df = pd.read_excel(f"{path}/data/immuno/data.xlsx", index_col=0)
df_w_nans = pd.read_excel(f"{path}/data/immuno/data_with_nans.xlsx", index_col=0)

## Create data with NaNs

In [None]:
files = [
    "Aging L, Q, H, I",
    "Aging-Covid_05.01.2022",
    "Aging-Covid-05.05.22",
    "Covid_results_02_2021",
    "Covid-25.11.20",
    "MULTIPLEX_20_11_2020_ AGING",
    "Yakutiya + TR",
    "Мультиплекс_Agind&Covid",
    "10-March-2024/48-plex-human-_xPONENT_2024", 
]
df_imm_genes = pd.read_excel(f"{path}/data/immuno/immuno_markers_genes.xlsx")
dict_imm_genes = dict(zip(df_imm_genes['immuno_marker'], df_imm_genes['gene']))

dfs_files = []
nans_by_features = {}
for file in files:
    if file in ["10-March-2024/48-plex-human-_xPONENT_2024", "10-March-2024/plate_1_analyst_2024", "10-March-2024/plate_2_analyst_2024", "10-March-2024/plate_3_analyst_2024"]:
        df_file = pd.read_excel(f"{path}/data/immuno/files/processed/{file}.xlsx", index_col=0)
    else:
        df_file = pd.read_excel(f"{path}/data/immuno/files/processed/{file}.xlsx", index_col="Sample")
    df_file.rename(columns=dict_imm_genes, inplace=True)
    df_file = df_file.loc[:, feats]

    # duplicates processing
    if file == "MULTIPLEX_20_11_2020_ AGING":
        df_file_doubled_unique = df_file.loc[~df_file.index.duplicated(keep=False), :]
        df_file_doubled_1 = df_file.loc[df_file.index.duplicated(keep='first'), :]
        df_file_doubled_2 = df_file.loc[df_file.index.duplicated(keep='last'), :]
        df_file_duplicates_final = pd.concat([df_file_doubled_2, df_file_doubled_unique], axis=0)
        df_file = df_file_duplicates_final
    elif file == "10-March-2024/48-plex-human-_xPONENT_2024":
        df_file = df_file.loc[df_file.index.str.startswith('M', na=False), :]
    elif file in ["10-March-2024/plate_1_analyst_2024", "10-March-2024/plate_2_analyst_2024", "10-March-2024/plate_3_analyst_2024"]:
        df_file = df_file.loc[df_file.index.str.startswith('M', na=False), :]
        df_file.index += '_nlst'
    df_file_duplicates = df_file.loc[df_file.index.duplicated(keep=False), :]
    if df_file_duplicates.shape[0] > 0:
        print(df_file_duplicates.index)
        
    for feat in df_file:
        nan_vals = set(df_file.loc[df_file[feat].astype(str).str.contains(r'^([<>].*)$', regex=True), feat].values)
        if len(nan_vals) > 0:
            for nv in nan_vals:
                if feat in nans_by_features:
                    nans_by_features[feat].add(nv)
                else:
                    nans_by_features[feat] = {nv}
    
    dfs_files.append(df_file)

df_w_nans = pd.concat(dfs_files, verify_integrity=False)
df_w_nans.index = df_w_nans.index.map(str)
df_w_nans = df_w_nans.loc[df.index.values, :]
df_w_nans.replace(r'^([\<].*)$', 'NaN', inplace=True, regex=True)
for feat in feats:
    ids_imputed_above = df_w_nans.index[df_w_nans[feat].astype(str).str.contains('>')]
    df_w_nans.loc[ids_imputed_above, feat] = df.loc[ids_imputed_above, feat]
df_w_nans = df_w_nans.apply(pd.to_numeric, errors='coerce')

### Save data with NaNs

In [None]:
df_nans_by_features = pd.DataFrame(index=list(nans_by_features.keys()))
df_nans_by_features['NaN values'] = list(nans_by_features.values())
df_nans_by_features.to_excel(f"{path}/data/immuno/nans_by_features.xlsx", index=True, index_label='Feature')
df_w_nans.to_excel(f"{path}/data/immuno/data_with_nans.xlsx", index=True, index_label='Index')

### Select filtered samples

In [None]:
ids_target = pd.read_excel(f"{path}/special/059_imm_data_selection/df_imm.xlsx", index_col=0).index.values
df = df.loc[ids_target, :]
df_w_nans = df_w_nans.loc[ids_target, :]
df.to_excel(f"{path_save}/data.xlsx", index=True, index_label='Index')
df_w_nans.to_excel(f"{path_save}/data_with_nans.xlsx", index=True, index_label='Index')

# Load prepared data

In [2]:
path = f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
path_save = "D:/YandexDisk/Work/bbd/immunology/002_central_vs_yakutia"

feats = pd.read_excel(f"{path}/data/immuno/feats_con.xlsx", index_col=0).index.values
feats_fimmu = pd.read_excel(f"{path}/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values
feats_slctd = pd.read_excel(f"{path}/special/059_imm_data_selection/feats_selected.xlsx", index_col=0).index.values

df = pd.read_excel(f"{path_save}/data.xlsx", index_col=0)
df_w_nans = pd.read_excel(f"{path_save}/data_with_nans.xlsx", index_col=0)

# Statistics of NaNs

In [3]:
df_nan_feats = df_w_nans.loc[:, feats].isna().sum(axis=0).to_frame(name="Number of NaNs")
df_nan_feats["% of NaNs"] = df_nan_feats["Number of NaNs"] / df_w_nans.shape[0] * 100
df_nan_feats["Number of not-NaNs"] = df_w_nans.loc[:, feats].notna().sum(axis=0)
df_nan_feats.sort_values(["% of NaNs"], ascending=[False], inplace=True)
df_nan_feats.to_excel(f"{path_save}/nan_feats.xlsx", index_label="Features")

sns.set_theme(style='whitegrid')
fig = plt.figure(figsize=(14, 4))
plt.xticks(rotation=90)
barplot = sns.barplot(
    data=df_nan_feats,
    x=df_nan_feats.index,
    y=f"% of NaNs",
    edgecolor='black',
    dodge=False,
)
barplot.set_xlabel("")
plt.savefig(f"{path_save}/nan_feats.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/nan_feats.pdf", bbox_inches='tight')
plt.close(fig)

thld_nan_in_feat = 25.3
feats_imm_good = set(df_nan_feats.index[df_nan_feats['% of NaNs'] <= thld_nan_in_feat].values).union(set(feats_fimmu))
print(f"Number of filtered features: {len(feats_imm_good)}")
print(f"Intersection with previous: {len(set.intersection(set(feats_imm_good), set(feats_slctd)))}")

Number of filtered features: 32
Intersection with previous: 32


# Basic region-specific analysis

In [4]:
colors_region = {'Central': 'gold', 'Yakutia': 'lightslategray'}
pathlib.Path(f"{path_save}/region_specific").mkdir(parents=True, exist_ok=True)

## Plot distribution

In [6]:
hist_bins = np.linspace(5, 115, 23)
sns.set_theme(style='whitegrid')
fig, ax = plt.subplots(figsize=(4, 3))
df_fig = df.loc[:, ['Age', 'Region']].copy()
df_fig.rename(columns={'Region': 'Cohort'}, inplace=True)
histplot = sns.histplot(
    data=df_fig,
    bins=hist_bins,
    edgecolor='k',
    linewidth=1,
    x="Age",
    hue='Cohort',
    palette=colors_region,
    hue_order=['Yakutia', 'Central'],
    ax=ax
)
histplot.set(xlim=(0, 120))
plt.savefig(f"{path_save}/region_specific/histplot.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/histplot.pdf", bbox_inches='tight')
plt.close(fig)

## Calculate statistics

In [7]:
df_stat = pd.DataFrame(index=list(feats_slctd))
for feat in list(feats_slctd):
    vals = {}
    for group in ['Central', 'Yakutia']:
        vals[group] = df.loc[df['Region'] == group, feat].values
        df_stat.at[feat, f"mean_{group}"] = np.mean(vals[group])
        df_stat.at[feat, f"median_{group}"] = np.median(vals[group])
        df_stat.at[feat, f"q75_{group}"], df_stat.at[feat, f"q25_{group}"] = np.percentile(vals[group], [75 , 25])
        df_stat.at[feat, f"iqr_{group}"] = df_stat.at[feat, f"q75_{group}"] - df_stat.at[feat, f"q25_{group}"]
    _, df_stat.at[feat, "mw_pval"] = mannwhitneyu(vals['Central'], vals['Yakutia'], alternative='two-sided')
_, df_stat.loc[feats_slctd, "mw_pval_fdr_bh"], _, _ = multipletests(df_stat.loc[feats_slctd, "mw_pval"].values, 0.05, method='fdr_bh')
_, df_stat.loc[feats_slctd, "mw_pval_bonferroni"], _, _ = multipletests(df_stat.loc[feats_slctd, "mw_pval"].values, 0.05, method='bonferroni')
_, df_stat.loc[feats_slctd, "mw_pval_simes-hochberg"], _, _ = multipletests(df_stat.loc[feats_slctd, "mw_pval"].values, 0.05, method='simes-hochberg')
df_stat.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_stat.to_excel(f"{path_save}/region_specific/stat.xlsx", index_label='Features')

## Plot features p-values

In [8]:
df_fig = df_stat.loc[feats_slctd, :]
df_fig.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_fig['mw_pval_fdr_bh_log'] = -np.log10(df_fig['mw_pval_fdr_bh'])
df_fig['color'] = 'pink'
df_fig.loc[df_fig['mw_pval_fdr_bh'] < 0.05, 'color'] = 'red'

fig, ax = plt.subplots(figsize=(3, 12))
sns.set_theme(style='whitegrid')
barplot = sns.barplot(
    data=df_fig,
    y=df_fig.index.values,
    x='mw_pval_fdr_bh_log',
    edgecolor='black',
    palette=df_fig['color'].values,
    hue=df_fig.index.values,
    ax=ax,
    legend=False
)
ax.set_xlabel(r"$-\log_{10}(\mathrm{p-value})$", fontsize=18)
ax.set_ylabel('', fontsize=20)
ax.set_xticklabels([f"{int(tick):d}" for tick in ax.get_xticks()], fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize = 16)
plt.savefig(f"{path_save}/region_specific/barplot.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/barplot.pdf", bbox_inches='tight')
plt.close(fig)

  barplot = sns.barplot(
  ax.set_xticklabels([f"{int(tick):d}" for tick in ax.get_xticks()], fontsize=16)
  ax.set_yticklabels(ax.get_yticklabels(), fontsize = 16)


In [9]:
n_rows = 4
n_cols = 8
fig_width = 15
fig_height = 12

fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={})
sns.set_theme(style='whitegrid')

feats_sorted = df_stat.index[df_stat.index.isin(feats_slctd)].values

for f_id, f in enumerate(feats_sorted):
    row_id, col_id = divmod(f_id, n_cols)
    
    q01 = df[f].quantile(0.01)
    q99 = df[f].quantile(0.99)
    
    sns.violinplot(
        data=df.loc[(df[f] > q01) & (df[f] < q99), :],
        x='Region',
        y=f,
        hue='Region',
        palette=colors_region,
        density_norm='width',
        order=list(colors_region.keys()),
        saturation=0.75,
        cut=0,
        linewidth=1.0,
        ax=axs[row_id, col_id],
        legend=False,
    )
    axs[row_id, col_id].set_ylabel(f)
    axs[row_id, col_id].set_xlabel('')
    axs[row_id, col_id].set(xticklabels=[]) 
    mw_pval = df_stat.at[f, "mw_pval_fdr_bh"]
    axs[row_id, col_id].set_title(f'{mw_pval:.2e}')
    
legend_handles = [
    mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor=colors_region['Central'], markersize=10, label='Central'),
    mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor=colors_region['Yakutia'], markersize=10, label='Yakutia')
]
fig.legend(handles=legend_handles, bbox_to_anchor=(0.5, 1.0), loc="lower center", ncol=2, frameon=False, fontsize='large')
fig.tight_layout()    
plt.savefig(f"{path_save}/region_specific/feats.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/feats.pdf", bbox_inches='tight')
plt.close(fig)

## SImAge analysis

In [12]:
df_fig = df.loc[:, ['Age', 'SImAge', 'SImAge acceleration', 'Region']].copy()
df_fig.rename(columns={'Region': 'Cohort'}, inplace=True)
fig, ax = plt.subplots(figsize=(4.5, 4))
sns.set_theme(style='whitegrid')
scatter = sns.scatterplot(
    data=df_fig,
    x="Age",
    y="SImAge",
    hue="Cohort",
    palette=colors_region,
    linewidth=0.2,
    alpha=0.75,
    edgecolor="k",
    s=20,
    hue_order=list(colors_region.keys()),
    ax=ax
)
bisect = sns.lineplot(
    x=[0, 120],
    y=[0, 120],
    linestyle='--',
    color='black',
    linewidth=1.0,
    ax=ax
)
ax.set_xlabel("Age")
ax.set_ylabel("SImAge")
ax.set_xlim(0, 120)
ax.set_ylim(0, 120)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(f"{path_save}/region_specific/SImAge_scatter.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/SImAge_scatter.pdf", bbox_inches='tight')
plt.close()

plt.figure(figsize=(2.5, 4))
sns.set_theme(style='whitegrid')
violin = sns.violinplot(
    data=df_fig,
    x='Cohort',
    y='SImAge acceleration',
    palette=colors_region,
    hue='Cohort',
    density_norm='width',
    order=list(colors_region.keys()),
    saturation=0.75,
    legend=False
)
violin.set_xlabel(f"")
mw_pval = mannwhitneyu(
    df_fig.loc[df_fig['Cohort'] == 'Central', 'SImAge acceleration'].values,
    df_fig.loc[df_fig['Cohort'] == 'Yakutia', 'SImAge acceleration'].values,
    alternative='two-sided').pvalue
violin.set_title(f'{mw_pval:.2e}')
plt.savefig(f"{path_save}/region_specific/SImAge_acceleration.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/SImAge_acceleration.pdf", bbox_inches='tight')
plt.close()

df_metrics = pd.DataFrame(
    index=['Central', 'Yakutia'],
    columns=['mean_absolute_error', 'pearson_corrcoef', 'mean_age_acc']
)
for cohort in ['Central', 'Yakutia']:
    pred = torch.from_numpy(df_fig.loc[df_fig['Cohort'] == cohort, 'SImAge'].values)
    real = torch.from_numpy(df_fig.loc[df_fig['Cohort'] == cohort, 'Age'].values)
    df_metrics.at[cohort, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
    df_metrics.at[cohort, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
    df_metrics.at[cohort, 'mean_age_acc'] = np.mean(df_fig.loc[df_fig['Cohort'] == cohort, 'SImAge acceleration'].values)
df_metrics.to_excel(f"{path_save}/region_specific/SImAge_metrics.xlsx", index_label="Metrics")

# Create data for ML

In [None]:
df_stat.loc[:, ['mw_pval']].to_excel(f"{path_save}/classification/feats.xlsx", index_label='Features')
cols = ['Subject ID', 'Time', 'Sex', 'Age', 'Region', 'SImAge', 'SImAge acceleration', '|SImAge acceleration|', 'PMC10485620 ID', 'PMC9135940 ID', 'PMC10699032 ID']
df_ml = df.loc[:, cols + df_stat.index.values.tolist()]

## Stratification by regions (target) and age

In [None]:
random_state = 1337
n_splits = 5

stratify_cat_parts = {
    'Central': df_ml.index[df_ml['Region'] == 'Central'].values,
    'Yakutia': df_ml.index[df_ml['Region'] == 'Yakutia'].values,
}

for part, ids in stratify_cat_parts.items():
    print(f"{part}: {len(ids)}")
    con = df_ml.loc[ids, 'Age'].values
    ptp = np.ptp(con)
    num_bins = 5
    bins = np.linspace(np.min(con) - 0.1 * ptp, np.max(con) + 0.1 * ptp, num_bins + 1)
    binned = np.digitize(con, bins) - 1
    unique, counts = np.unique(binned, return_counts=True)
    occ = dict(zip(unique, counts))
    k_fold = RepeatedStratifiedKFold(
        n_splits=n_splits,
        n_repeats=1,
        random_state=random_state
    )
    splits = k_fold.split(X=ids, y=binned, groups=binned)
    
    for split_id, (ids_trn_val, ids_tst) in enumerate(splits):
        df_ml.loc[ids[ids_trn_val], f"Split_{split_id}"] = "trn_val"
        df_ml.loc[ids[ids_tst], f"Split_{split_id}"] = "tst"

## Save data for ML

In [None]:
df_ml.to_excel(f"{path_save}/classification/data.xlsx", index_label='Index')