# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
import torch
import pickle
import numpy as np
from pytorch_tabular import TabularModel
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
import pandas as pd
import warnings
import pathlib
import os
from tqdm import tqdm
from sklearn.impute import KNNImputer
import pyaging as pya
import matplotlib.pyplot as plt
import seaborn as sns
import distinctipy
import matplotlib.colors as mcolors
import matplotlib.patheffects as pe
from plottable import ColumnDefinition, Table
from plottable.cmap import normed_cmap


import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*exists and is not empty.*")
warnings.filterwarnings("ignore", ".*is smaller than the logging interval Trainer.*")


# Setup variables and paths

In [None]:
feats_imm = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/059_imm_data_selection/feats_selected.xlsx", index_col=0).index.values

epi_data_type = 'no_harm'
imm_data_type = 'imp_source(imm)_method(knn)_params(5)' # 'origin' 'imp_source(imm)_method(knn)_params(5)' 'imp_source(imm)_method(miceforest)_params(2)'

selection_method = 'mrmr' # 'f_regression' 'spearman' 'mrmr'
n_feats = 100

path_imm = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}"
path_save = f"{path_imm}/EpImAge"
pathlib.Path(path_save).mkdir(parents=True, exist_ok=True)

df_models = pd.read_excel(f"{path_imm}/best_models_v4.xlsx", index_col=0)

path_epi = "D:/YandexDisk/Work/bbd/immunology/003_EpImAge/epi"
feats_pheno = ['Age', 'Sex', 'Status', 'Tissue']
path_clocks = "D:/YandexDisk/Work/pydnameth/datasets/pyaging"
clocks = [
    "altumage",
    "dunedinpace",
    "han",
    "knight",
    "leecontrol",
    "leerefinedrobust",
    "leerobust",
    "dnamfitage",
    "dnamphenoage",
    "dnamtl",
    "encen100",
    "encen40",
    "grimage",
    "grimage2",
    "hannum",
    "horvath2013",
    "hrsinchphenoage",
    "lin",
    "pcdnamtl",
    "pcgrimage",
    "pchannum",
    "pchorvath2013",
    "pcphenoage",
    "pcskinandblood",
    "pedbe",
    "replitali",
    "skinandblood",
    "stemtoc",
    "stoch",
    "stocp",
    "stocz",
    "yingadaptage",
    "yingcausage",
    "yingdamage",
    "zhangblup",
    "zhangen",
    "zhangmortality",
]

# Create data

## Load immunomarkers models

In [None]:
imm_epi_feats = {}              
imm_models = {}
for imm in (pbar := tqdm(feats_imm)):
    pbar.set_description(f"Processing {imm}")
    imm_epi_feats[imm] = pd.read_excel(f"{path_imm}/{imm}/feats_con.xlsx", index_col=0).index.values.tolist()
    imm_path_model = f"{path_imm}/{imm}/pytorch_tabular/candidates/{df_models.at[imm, 'model']}/{df_models.at[imm, 'directory']}/model.ckpt"
    head, tail = os.path.split(imm_path_model)
    imm_models[imm] = TabularModel.load_model(f"{head}")

feats_epi_cmn = list(set.union(*[set(x) for x in imm_epi_feats.values()]))
print(f"Number of CpGs: {len(feats_epi_cmn)}")

## Load epigenetics data and calculate clocks

In [None]:
gpls = [f.name for f in os.scandir(path_epi) if f.is_dir()]
gse_missed_cpgs = {}
dfs_gses = []
for gpl in gpls:
    print(gpl)
    gses = [f.name for f in os.scandir(f"{path_epi}/{gpl}") if f.is_dir()]
    for gse in (pbar := tqdm(gses)):
        pbar.set_description(f"Processing {gse}")
        if gse == 'GSEUNN':
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/{epi_data_type}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/{epi_data_type}/pheno.csv", index_col='index')
        elif gse == 'GSE53740':
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/pheno.csv", index_col=0)
            df_pheno.drop(df_pheno.index[df_pheno['Status'] == 'Unknown'], inplace=True)
        elif gse == 'GSE87648':
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/pheno.csv", index_col=0)
            df_pheno.drop(df_pheno.index[df_pheno['Status'] == 'HS'], inplace=True)
        else:
            df_betas = pd.read_pickle(f"{path_epi}/{gpl}/{gse}/betas.pkl")
            df_pheno = pd.read_csv(f"{path_epi}/{gpl}/{gse}/pheno.csv", index_col='gsm')
        df_for_ages = pd.merge(df_pheno.loc[:, feats_pheno], df_betas, left_index=True, right_index=True)
        if df_for_ages['Sex'].value_counts().size > 2:
            raise ValueError(f"More than 2 sexes")
        elif df_for_ages['Sex'].value_counts().size == 1:
            print(f"{gse} contains only one sex")
        else:
            print(f"{gse} contains 2 sexes")
        df_for_ages['Female'] = (df_for_ages['Sex'] == 'F').astype(int)
        df_for_ages = pya.pp.epicv2_probe_aggregation(df_for_ages, verbose=False)
        adata = pya.pp.df_to_adata(df_for_ages, metadata_cols=['Sex', 'Status', 'Tissue'], imputer_strategy='knn', verbose=False)
        pya.pred.predict_age(adata=adata, dir=path_clocks, clock_names=clocks, verbose=False)
        df_pheno = pd.merge(df_pheno.loc[:, feats_pheno], adata.obs[clocks], left_index=True, right_index=True)
        gse_missed_cpgs[gse] = len(set(feats_epi_cmn) - set(df_betas.columns))
        exist_cpgs = list(set.intersection(set(df_betas.columns), set(feats_epi_cmn)))
        df_gse = pd.merge(df_pheno, df_betas.loc[:, exist_cpgs], left_index=True, right_index=True)
        if df_gse.shape[0] == 0:
            raise ValueError(f"{gse} indexes problem!")
        df_gse.insert(0, 'GPL', gpl)
        df_gse.insert(0, 'GSE', gse)
        dfs_gses.append(df_gse)
        
df_gse_missed_cpgs = pd.DataFrame.from_dict(gse_missed_cpgs, orient='index', columns=['Missed CpGs'])
df_gse_missed_cpgs.to_excel(f"{path_save}/gse_missed_cpgs.xlsx", index=True, index_label='GSE')

df = pd.concat(dfs_gses, verify_integrity=True)

## Impute missing values

In [None]:
n_neighbors = 5
X = df.loc[:, feats_epi_cmn + ['Age']].values
print(f'Missing before imputation: {np.isnan(X).sum()}')
imputer = KNNImputer(n_neighbors=n_neighbors)
X_imptd = imputer.fit_transform(X)
print(f'Missing after imputation: {np.isnan(X_imptd).sum()}')

In [None]:
df.loc[:, feats_epi_cmn + ['Age']] = X_imptd

## Calculate immunomarkers

In [None]:
for imm in (pbar := tqdm(feats_imm)):
    pbar.set_description(f"Processing {imm}")
    df[f"{imm}_log"] = imm_models[imm].predict(df.loc[:, imm_epi_feats[imm]])
    df[imm] = np.exp(df[f"{imm}_log"])

## Save data

In [None]:
df[['GPL', 'GSE'] + feats_pheno + clocks + list(feats_imm) + [f"{imm}_log" for imm in feats_imm]].to_excel(f"{path_save}/data.xlsx")
df.to_pickle(f"{path_save}/data_full.pkl")

## Check models on GSEUNN

In [None]:
tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337
tst_split_id = 5

val_n_splits = 4
val_n_repeats = 2
val_random_state = 1337
val_fold_id = 5

fn_samples = f"samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats})"
with open(f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{fn_samples}.pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }

        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

split_dict = samples[tst_split_id]

df_models_check = pd.DataFrame(index=feats_imm)
for imm in (pbar := tqdm(feats_imm)):
    pbar.set_description(f"Processing {imm}")
    data_imm = pd.read_excel(f"{path_imm}/{imm}/data.xlsx", index_col=0)
    
    y_train_real = torch.from_numpy(data_imm.loc[split_dict['trains'][val_fold_id], f"{imm}_log"].values)
    y_validation_real = torch.from_numpy(data_imm.loc[split_dict['validations'][val_fold_id], f"{imm}_log"].values)
    y_test_real = torch.from_numpy(data_imm.loc[split_dict['test'], f"{imm}_log"].values)
    
    y_train_pred = torch.from_numpy(df.loc[split_dict['trains'][val_fold_id], f"{imm}_log"].values)
    y_validation_pred = torch.from_numpy(df.loc[split_dict['validations'][val_fold_id], f"{imm}_log"].values)
    y_test_pred = torch.from_numpy(df.loc[split_dict['test'], f"{imm}_log"].values)
    
    df_models_check.at[imm, 'train_mae_before'] = df_models.at[imm, 'train_mean_absolute_error']
    df_models_check.at[imm, 'validation_mae_before'] = df_models.at[imm, 'validation_mean_absolute_error']
    df_models_check.at[imm, 'test_mae_before'] = df_models.at[imm, 'test_mean_absolute_error']
    df_models_check.at[imm, 'train_mae_after'] = mean_absolute_error(y_train_pred, y_train_real).numpy()
    df_models_check.at[imm, 'validation_mae_after'] = mean_absolute_error(y_validation_pred, y_validation_real).numpy()
    df_models_check.at[imm, 'test_mae_after'] = mean_absolute_error(y_test_pred, y_test_real).numpy()
    
    df_models_check.at[imm, 'train_rho_before'] = df_models.at[imm, 'train_pearson_corrcoef']
    df_models_check.at[imm, 'validation_rho_before'] = df_models.at[imm, 'validation_pearson_corrcoef']
    df_models_check.at[imm, 'test_rho_before'] = df_models.at[imm, 'test_pearson_corrcoef']
    df_models_check.at[imm, 'train_rho_after'] = pearson_corrcoef(y_train_pred, y_train_real).numpy()
    df_models_check.at[imm, 'validation_rho_after'] = pearson_corrcoef(y_validation_pred, y_validation_real).numpy()
    df_models_check.at[imm, 'test_rho_after'] = pearson_corrcoef(y_test_pred, y_test_real).numpy()

df_models_check['train_mae_diff'] = df_models_check['train_mae_after'] - df_models_check['train_mae_before']
df_models_check['validation_mae_diff'] = df_models_check['validation_mae_after'] - df_models_check['validation_mae_before']
df_models_check['test_mae_diff'] = df_models_check['test_mae_after'] - df_models_check['test_mae_before']

df_models_check['train_rho_diff'] = df_models_check['train_rho_after'] - df_models_check['train_rho_before']
df_models_check['validation_rho_diff'] = df_models_check['validation_rho_after'] - df_models_check['validation_rho_before']
df_models_check['test_rho_diff'] = df_models_check['test_rho_after'] - df_models_check['test_rho_before']

df_models_check.to_excel(f"{path_save}/models_check.xlsx")

# Plot immunomarkers results

## Load data

In [None]:
data = pd.read_excel(f"{path_save}/data.xlsx", index_col=0)
df_models.sort_values(['test_pearson_corrcoef'], ascending=[False], inplace=True)
       
imm_results = {}
for imm, row in (pbar := tqdm(df_models.iterrows())):
    pbar.set_description(f"Processing {imm}")
    imm_result = pd.read_excel(f"{path_imm}/{imm}/pytorch_tabular/candidates/{row['model']}/{row['directory']}/df.xlsx", index_col=0)
    imm_result.rename(columns={f"{imm}_log": imm}, inplace=True)
    imm_results[imm] = imm_result

In [None]:
n_rows = 4 * 3
n_cols = 8
fig_height = 20
fig_width = 35

imm_colors = distinctipy.get_colors(n_colors=df_models.shape[0], exclude_colors=[mcolors.hex2color(mcolors.CSS4_COLORS['gray'])], rng=42)

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.8, 0.2]*4, gridspec_kw={'wspace':0.35, 'hspace': 0.05}, sharey=False, sharex=False)

for imm_id, imm in tqdm(enumerate(df_models.index.values)):
    imm_color = imm_colors[imm_id]
    imm_result = imm_results[imm]
    row_id, col_id = divmod(imm_id, n_cols)
    row_id_table = row_id * 3
    row_id_scatter = row_id * 3 + 1
    row_id_empty = row_id * 3 + 2

    q01 = imm_result[imm].quantile(0.01)
    q99 = imm_result[imm].quantile(0.99)

    df_metrics = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$"], columns=['Train', 'Validation', 'Test'])
    df_metrics.at['MAE', 'Train'] = f"{df_models.at[imm, 'train_mean_absolute_error']:0.3f}"
    df_metrics.at['MAE', 'Validation'] = f"{df_models.at[imm, 'validation_mean_absolute_error']:0.3f}"
    df_metrics.at['MAE', 'Test'] = f"{df_models.at[imm, 'test_mean_absolute_error']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Train'] = f"{df_models.at[imm, 'train_pearson_corrcoef']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Validation'] = f"{df_models.at[imm, 'validation_pearson_corrcoef']:0.3f}"
    df_metrics.at[fr"Pearson $\mathbf{{\rho}}$", 'Test'] = f"{df_models.at[imm, 'test_pearson_corrcoef']:0.3f}"
    
    col_defs = [
        ColumnDefinition(
            name="index",
            title=imm,
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            # border="both"
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left"
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
            # border="both"
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
            # border="both"
        )
    ]

    table = Table(
        df_metrics,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table, col_id],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    kdeplot = sns.kdeplot(
        data=imm_result.loc[imm_result['Group'] != 'Test', :],
        x=imm,
        y='Prediction',
        fill=True,
        cbar=False,
        color='gray',
        cut=0,
        legend=False,
        ax=axs[row_id_scatter, col_id]
    )
    scatter = sns.scatterplot(
        data=imm_result.loc[imm_result['Group'] == 'Test', :],
        x=imm,
        y="Prediction",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=35,
        color=imm_color,
        ax=axs[row_id_scatter, col_id],
    )
    axs[row_id_scatter, col_id].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter, col_id].set_xlim(q01, q99)
    axs[row_id_scatter, col_id].set_ylim(q01, q99)
    axs[row_id_scatter, col_id].set_xlabel(imm, color=imm_color, path_effects=[pe.withStroke(linewidth=1.0, foreground="black")])
    
    axs[row_id_empty, col_id].axis('off')

fig.tight_layout()    
fig.savefig(f"{path_save}/immuno.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path_save}/immuno.pdf", bbox_inches='tight')
plt.close(fig)

# Generate data for additional statistics

## GSE statistics

In [None]:
data = pd.read_excel(f"{path_save}/data.xlsx", index_col=0)
gse_count = data['GSE'].value_counts()
gse_count.to_excel(f"{path_save}/gse_preproc.xlsx", index_label='GSE')