In [None]:
import pandas as pd
import numpy as np
from scipy import stats
import seaborn as sns
import pickle
import plotly.express as px
import statsmodels.formula.api as smf
import plotly.graph_objects as go
from scripts.python.routines.manifest import get_manifest
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=False)
from scipy.stats import mannwhitneyu, median_test
import matplotlib.pyplot as plt
import pathlib
from tqdm import tqdm
from src.utils.plot.bioinfokit import mhat, volcano
import gseapy as gp
import mygene
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, IncrementalPCA, KernelPCA, TruncatedSVD
from sklearn.decomposition import MiniBatchDictionaryLearning, FastICA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap, TSNE, LocallyLinearEmbedding
import upsetplot as upset
import missingno as msno
from pyod.models.lunar import LUNAR
from plotly.subplots import make_subplots
from matplotlib_venn import venn2, venn2_circles
from glob import glob
from hydra import compose, initialize
from omegaconf import OmegaConf
import omegaconf
import os
import ast
from scripts.python.pheno.datasets.filter import filter_pheno, get_passed_fields
from scripts.python.pheno.datasets.features import get_column_name, get_status_dict, get_status_dict_default, get_sex_dict
from scripts.python.routines.betas import betas_drop_na
from scipy import stats
from src.tasks.metrics import get_reg_metrics
import torch

# 0. Prepare data

In [None]:
path = f"D:/YandexDisk/Work/pydnameth/datasets"
path_dataset = f"{path}/GPL21145/GSEUNN"
path_save = f"{path_dataset}/special/042_agena"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

dataset = "GSEUNN"
datasets_info = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/datasets/datasets.xlsx", index_col='dataset')
platform = datasets_info.loc[dataset, 'platform']
manifest = get_manifest(platform, path=path)
manifest['CHR'] = manifest['chr'].str[3::]

dnam_suffix = "_harm"

pheno = pd.read_excel(f"{path}/{platform}/{dataset}/pheno.xlsx", index_col="index")
pheno.index.name = "index"
pheno.drop(["I64_old", "I1_duplicate"], inplace=True)
betas = pd.read_pickle(f"{path}/{platform}/{dataset}/betas{dnam_suffix}.pkl")
feats_dnam = betas.columns.values
df_dnam_all = pd.merge(pheno, betas, left_index=True, right_index=True)
df_dnam = df_dnam_all.loc[(df_dnam_all["Region"] == "Central") & (df_dnam_all["Status"] == "Control") & (df_dnam_all["Sample_Chronology"] == 0), :]

path_agena = f"{path_dataset}/data/agena"
df_agena = pd.read_excel(f"{path_agena}/source(данные_для_обработки)_date(140123).xlsx", index_col="index")
feats_agena = pd.read_excel(f"{path_agena}/feats.xlsx")['features'].values
feats_common = list(set(feats_dnam).intersection(set(feats_agena)))
df_agena = df_agena.loc[:, feats_common] * 0.01
df_agena['Status'] = "Control"
df_agena.loc[df_agena.index.str.startswith(('H')), 'Status'] = "ESRD"
df_agena_esrd = df_agena.loc[df_agena['Status'] == 'ESRD', :]
df_agena = df_agena.loc[df_agena['Status'] == 'Control', :]

index_common = sorted(list(set(df_agena.index.values).intersection(set(df_dnam.index.values))))
index_agena_only = set(df_agena.index.values) - set(df_dnam.index.values)
df_agena.drop(index_agena_only, inplace=True)

# 1. Samples and relative difference

In [None]:
pathlib.Path(f"{path_save}/samples").mkdir(parents=True, exist_ok=True)

rel_diff_df = pd.DataFrame(index=index_common)

for sample in index_common:
    agena_i = df_agena.loc[sample, feats_common]
    agena_i.dropna(how='all')
    cpgs_i = sorted(list(set(agena_i.index.values).intersection(set(betas.columns.values))))
    df_i = df_dnam.loc[[sample], cpgs_i]

    fig = go.Figure()
    for cpg_id, cpg in enumerate(cpgs_i):
        distrib_i = df_dnam.loc[:, cpg].values
        fig.add_trace(
            go.Violin(
                x=[cpg] * len(distrib_i),
                y=distrib_i,
                box_visible=True,
                meanline_visible=True,
                line_color='grey',
                showlegend=False,
                opacity=1.0
            )
        )

        showlegend = False
        if cpg_id == 0:
            showlegend = True

        meth_epic = df_i.at[sample, cpg]
        meth_agena = agena_i.at[cpg]
        tmp = (meth_agena - meth_epic) / meth_epic * 100.0
        rel_diff_df.at[sample, cpg] = tmp

        fig.add_trace(
            go.Scatter(
                x=[cpg],
                y=[meth_epic],
                showlegend=showlegend,
                name="850K",
                mode="markers",
                marker=dict(
                    size=15,
                    opacity=0.7,
                    line=dict(
                        width=1
                    ),
                    color='red'
                ),
            )
        )

        fig.add_trace(
            go.Scatter(
                x=[cpg],
                y=[meth_agena],
                showlegend=showlegend,
                name="Agena",
                mode="markers",
                marker=dict(
                    size=12,
                    opacity=0.7,
                    line=dict(
                        width=1
                    ),
                    color='blue'
                ),
            )
        )

    add_layout(fig, f"", 'Methylation level', f"{sample}")
    fig.update_xaxes(tickangle=270)
    fig.update_xaxes(tickfont_size=15)
    fig.update_layout(margin=go.layout.Margin(
        l=120,
        r=20,
        b=120,
        t=90,
        pad=0
    ))
    save_figure(fig, f"{path_save}/samples/{sample}")

rel_diff_df.to_excel(f"{path_save}/rel_diff.xlsx", index=True)

fig = go.Figure()
for cpg_id, cpg in enumerate(feats_common):
    series_i = rel_diff_df.loc[index_common, cpg].dropna()
    series_i = series_i.astype('float64')
    distrib_i = series_i.values

    showlegend = False
    if cpg_id == 0:
        showlegend = True

    fig.add_trace(
        go.Violin(
            x=[cpg] * len(distrib_i),
            y=distrib_i,
            showlegend=False,
            box_visible=True,
            meanline_visible=True,
            line_color='black',
            line=dict(width=0.35),
            fillcolor='grey',
            marker=dict(color='grey', line=dict(color='black', width=0.3), opacity=0.8),
            points=False,
            bandwidth=np.ptp(distrib_i) / 25,
            opacity=0.8
        )
    )
add_layout(fig, "", "Relative difference, %", f"")
fig.update_xaxes(tickangle=270)
fig.update_xaxes(tickfont_size=15)
fig.update_layout(margin=go.layout.Margin(
    l=120,
    r=20,
    b=120,
    t=50,
    pad=0
))
fig.update_layout(title_xref='paper')
fig.update_layout(legend= {'itemsizing': 'constant'})
fig.update_layout(legend_font_size=20)
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="center",
        x=0.5
    )
)
save_figure(fig, f"{path_save}/rel_diff")

# 2. Features

In [None]:
pathlib.Path(f"{path_save}/feats").mkdir(parents=True, exist_ok=True)

pvals = []
values_dict = {'ID': index_common}
for cpg_id, cpg in enumerate(feats_common):
    values_dict[f"{cpg}_850K"] = df_dnam.loc[index_common, cpg].values
    values_dict[f"{cpg}_agena"] = df_agena.loc[index_common, cpg].values
    epic_mw_data = df_dnam.loc[index_common, cpg].dropna(how='all').values
    agena_mw_data = df_agena.loc[index_common, cpg].dropna(how='all').values
    stat, pval = mannwhitneyu(epic_mw_data, agena_mw_data, alternative='two-sided')
    pvals.append(pval)

values_df = pd.DataFrame(values_dict)
values_df.set_index("ID", inplace=True)
values_df.to_excel(f"{path_save}/values.xlsx", index=True)
_, pvals_corr, _, _ = multipletests(pvals, 0.05, method='fdr_bh')
pvals_df = pd.DataFrame(index=feats_common)
pvals_df['pvals'] = pvals
pvals_df['pvals_fdr_bh'] = pvals_corr
pvals_df.to_excel(f"{path_save}/pvals.xlsx", index=True)

for cpg_id, cpg in enumerate(feats_common):

    pval = pvals_df.at[cpg, 'pvals_fdr_bh']
    epic_data = df_dnam.loc[index_common, cpg].dropna(how='all').values
    agena_data = df_agena.loc[index_common, cpg].dropna(how='all').values

    fig = go.Figure()
    fig.add_trace(
        go.Violin(
            y=epic_data,
            name=f"850K",
            box_visible=True,
            meanline_visible=True,
            showlegend=False,
            line_color='black',
            fillcolor='blue',
            marker=dict(color='blue', line=dict(color='black', width=0.3), opacity=0.8),
            points='all',
            bandwidth=np.ptp(epic_data) / 25,
            opacity=0.8
        )
    )
    fig.add_trace(
        go.Violin(
            y=agena_data,
            name=f"Agena",
            box_visible=True,
            meanline_visible=True,
            showlegend=False,
            line_color='black',
            fillcolor='red',
            marker=dict(color='red', line=dict(color='black', width=0.3), opacity=0.8),
            points='all',
            bandwidth=np.ptp(agena_data) / 25,
            opacity=0.8
        )
    )
    gene = manifest.at[cpg, 'Gene']
    add_layout(fig, "", "Beta value", f"{cpg} ({gene})<br>p-value: {pval:0.2e}")
    fig.update_layout(title_xref='paper')
    fig.update_layout(legend_font_size=20)
    fig.update_xaxes(tickfont_size=15)
    fig.update_layout(
        margin=go.layout.Margin(
            l=110,
            r=20,
            b=50,
            t=80,
            pad=0
        )
    )
    fig.update_layout(
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="center",
            x=0.5
        )
    )
    save_figure(fig, f"{path_save}/feats/{cpg_id:3d}_{cpg}")

## Nans analysis

In [None]:
pathlib.Path(f"{path_save}/ml_data/nans").mkdir(parents=True, exist_ok=True)

series_n_nan = df_agena.loc[:, feats_common].isna().sum()
df_nan = pd.DataFrame({'n_nans': series_n_nan.values}, index=series_n_nan.index)
df_nan.sort_values([f"n_nans"], ascending=[False], inplace=True)

fig = plt.figure(figsize=(12, 0.4 * df_nan.shape[0]))
sns.set_theme(style='whitegrid', font_scale=1)
bar = sns.barplot(
    data=df_nan,
    y=df_nan.index,
    x='n_nans',
    edgecolor='black',
    orient='h',
    palette=px.colors.qualitative.Alphabet,
    dodge=True
)
bar.set_xlabel("Number of NaNs")
bar.set_ylabel("")
bar.set_title(f"Agena missing values")
plt.savefig(f"{path_save}/ml_data/nans/bar.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_save}/ml_data/nans/bar.pdf", bbox_inches='tight')
plt.close()

feats_remain = df_nan.index[df_nan['n_nans'] < 0.2 * df_nan.shape[0]].values

df_ml_feats = pd.DataFrame(index=feats_remain)
df_ml_feats.to_excel(f"{path_save}/ml_data/feats_con_{len(feats_remain)}.xlsx", index_label="features")

# 3.  Generate data for agena clocks

In [None]:
n_feats_remain = 8
df_ml_feats = pd.read_excel(f"{path_save}/ml_data/feats_con_{n_feats_remain}.xlsx", index_col="features")
feats_remain = df_ml_feats.index.values

In [None]:
df_ml_agena = df_agena.loc[index_common, list(feats_remain) + ["Status"]]
df_ml_agena.loc[index_common, "Age"] = df_dnam.loc[index_common, "Age"]
df_ml_agena.loc[index_common, "Region"] = df_dnam.loc[index_common, "Region"]
df_ml_agena.loc[index_common, "Sex"] = df_dnam.loc[index_common, "Sex"]
df_ml_agena.dropna(inplace=True)
df_ml_agena["Split"] = 'trn_val'
df_ml_agena["Part"] = 'UNN EpiTYPER'

df_ml_agena_esrd = df_agena_esrd.loc[:, list(feats_remain) + ["Status"]]
df_ml_agena_esrd.loc[df_ml_agena_esrd.index, "Age"] = df_dnam_all.loc[df_ml_agena_esrd.index, "Age"]
df_ml_agena_esrd.loc[df_ml_agena_esrd.index, "Region"] = df_dnam_all.loc[df_ml_agena_esrd.index, "Region"]
df_ml_agena_esrd.loc[df_ml_agena_esrd.index, "Sex"] = df_dnam_all.loc[df_ml_agena_esrd.index, "Sex"]
df_ml_agena_esrd.dropna(inplace=True)
df_ml_agena_esrd["Split"] = 'tst'
df_ml_agena_esrd["Part"] = 'UNN EpiTYPER ESRD'

cells = ["CD8T", "CD4T", "NK", "Bcell", "Mono", "Gran"]
dnam_ages = ['DNAmAgeHannum', 'DNAmAge', 'DNAmPhenoAge', 'DNAmGrimAge']
pc_ages = ["PCHorvath1", "PCHorvath2", "PCHannum", "PCPhenoAge", "PCGrimAge"]

df_ml_850k = df_dnam.loc[:, list(feats_remain) + ["Age", "Status", "Region", "Sex"] + [f"{x}{dnam_suffix}" for x in cells] + [f"{x}_harm" for x in dnam_ages] + pc_ages]
df_ml_850k.rename(columns={f"{x}{dnam_suffix}": x for x in cells}, inplace=True)
df_ml_850k.rename(columns={f"{x}{dnam_suffix}": x for x in dnam_ages}, inplace=True)
df_ml_850k["Split"] = 'tst'
df_ml_850k["Part"] = 'UNN EPIC'

index_common = sorted(list(set(df_ml_agena.index.values).intersection(set(df_ml_850k.index.values))))
index_epic_only = [f"{x}_850k" for x in sorted(list(set(df_ml_850k.index.values) - set(df_ml_agena.index.values)))]

df_ml_850k['index_new'] = df_ml_850k.index + "_850k"
df_ml_850k.set_index("index_new", inplace=True)

datasets = ['GSE152026', 'GSE55763', 'GSE40279', 'GSE87571']
df_datasets = []
for d_id, d in enumerate(datasets):
    print(d)

    platform = datasets_info.loc[d, 'platform']
    manifest = get_manifest(platform, path=path)

    status_col = get_column_name(d, 'Status')
    age_col = get_column_name(d, 'Age')
    sex_col = get_column_name(d, 'Sex')

    status_dict = get_status_dict_default(d)
    status_passed_fields = get_passed_fields(status_dict, ['Control'])
    sex_dict = get_sex_dict(d)

    categorical_vars = {status_col: [x.column for x in status_passed_fields]}
    categorical_vars.update({sex_col: list(sex_dict.values())})
    continuous_vars = {'Age': age_col}
    continuous_vars.update({x: x for x in cells})
    continuous_vars.update({x: x for x in dnam_ages})
    continuous_vars.update({x: x for x in pc_ages})

    pheno_d = pd.read_excel(f"{path}/{platform}/{d}/pheno.xlsx", index_col=0)
    pheno_d = filter_pheno(d, pheno_d, continuous_vars, categorical_vars)

    dict_rename_columns = {
        status_col: 'Status',
        age_col: 'Age',
        sex_col: 'Sex',
    }
    dict_rename_columns.update({x: x for x in cells})
    dict_rename_columns.update({x: x for x in dnam_ages})
    dict_rename_columns.update({x: x for x in pc_ages})
    pheno_d = pheno_d.loc[:, list(dict_rename_columns.keys())]
    pheno_d.rename(columns=dict_rename_columns, inplace=True)

    betas_d = pd.read_pickle(f"{path}/{platform}/{d}/betas.pkl")
    print(f"betas_d: {betas_d.shape}")
    missed_feats_in_d = list(set(feats_remain) - set(betas_d.columns.values))
    print(f"missed_feats_in_d: {missed_feats_in_d}")
    betas_d = betas_d.loc[:, feats_remain]
    betas_d = betas_drop_na(betas_d)

    print(f"pheno shape: {pheno_d.shape}")
    print(f"betas shape: {betas_d.shape}")
    df_d = pd.merge(pheno_d, betas_d, left_index=True, right_index=True)
    print(f"df shape: {df_d.shape}")

    df_d['Split'] = 'tst'
    df_d['Part'] = f'{d}'

    df_datasets.append(df_d)

df_ml = pd.concat([df_ml_agena, df_ml_agena_esrd, df_ml_850k] + df_datasets)

df_ml.to_excel(f"{path_save}/ml_data/data.xlsx", index=True, index_label="index")

# 4. Collect multirun data

In [None]:
dataset = "GSEUNN"
path = f"D:/YandexDisk/Work/pydnameth/datasets"
datasets_info = pd.read_excel(f"{path}/datasets.xlsx", index_col='dataset')
platform = datasets_info.loc[dataset, 'platform']

model = 'catboost_trn_val_tst'

path_runs = f"{path}/{platform}/{dataset}/special/042_agena/ml_data/models/{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics_val_best_*.xlsx")

test_datasets = ['UNN_EPIC', 'GSE87571', 'GSE40279', 'GSE55763', 'GSE152026', 'all']

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)
for file in files:
    # Validation
    df_val = pd.read_excel(file, index_col="metric")
    for metric in df_val.index.values:
        df_res.at[file, metric + "_val"] = df_val.at[metric, "val"]

    # Train
    head, tail = os.path.split(file)
    tail = tail.replace('val', 'trn')
    df_trn = pd.read_excel(f"{head}/{tail}", index_col="metric")
    for metric in df_trn.index.values:
        df_res.at[file, metric + "_trn"] = df_trn.at[metric, "trn"]

    # Test
    for test_dataset in test_datasets:
        head, tail = os.path.split(file)
        tail = tail.replace('val', f'tst_{test_dataset}')
        df_tst = pd.read_excel(f"{head}/{tail}", index_col="metric")
        for metric in df_trn.index.values:
            df_res.at[file, metric + f"_tst_{test_dataset}"] = df_tst.at[metric, f'tst_{test_dataset}']

    # Params
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val

first_columns = [
    'mean_absolute_error_trn',
    'mean_absolute_error_cv_mean_trn',
    'mean_absolute_error_val',
    'mean_absolute_error_cv_mean_val'
]
for test_dataset in test_datasets:
    first_columns.append(f"mean_absolute_error_tst_{test_dataset}")
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")

# 5. Figures

In [None]:
dataset = "GSEUNN"
datasets_info = pd.read_excel(f"D:/YandexDisk/Work/pydnameth/datasets/datasets.xlsx", index_col='dataset')
platform = datasets_info.loc[dataset, 'platform']
manifest = get_manifest(platform, path=path)

path_load = f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/042_agena/ml_data"
path_save = f"{path_load}/figures"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path_load}/data.xlsx", index_col=0) # Need to replace for inference dataset
feats = pd.read_excel(f"{path_load}/feats_con_8.xlsx", index_col=0).index.values

index_agena = df.index[df['Part'] == 'UNN EpiTYPER'].values
index_agena_esrd = df.index[df['Part'] == 'UNN EpiTYPER'].values
index_epic = df.index[df['Part'] == 'UNN EPIC'].values
index_gse87571 = df.index[df['Part'] == 'GSE87571'].values
index_gse40279 = df.index[df['Part'] == 'GSE40279'].values
index_gse55763 = df.index[df['Part'] == 'GSE55763'].values
index_common = sorted(list(set(index_agena).intersection(set([x[:-5] for x in index_epic]))))
index_common_suffix = [f"{x}_850k" for x in index_common]
index_epic_only = sorted(list(set(index_epic) - set(index_common_suffix)))

colors = {
    f'UNN EpiTYPER': px.colors.qualitative.G10[0],
    f'UNN EpiTYPER ESRD': px.colors.qualitative.D3[5],
    f'UNN EpiTYPER ({len(index_agena)})': px.colors.qualitative.G10[0],
    f'UNN EpiTYPER ESRD ({len(index_agena_esrd)})': px.colors.qualitative.D3[5],
    f'UNN EpiTYPER and EPIC': px.colors.qualitative.G10[0],
    f'UNN EpiTYPER and EPIC ({len(index_common)})': px.colors.qualitative.G10[0],
    f'UNN EPIC': px.colors.qualitative.G10[1],
    f'UNN EPIC ({len(index_epic)})': px.colors.qualitative.G10[1],
    f'GSE87571': px.colors.qualitative.G10[2],
    f'GSE40279': px.colors.qualitative.G10[3],
    f'GSE55763': px.colors.qualitative.G10[4],
    f'GSE152026': px.colors.qualitative.G10[5],
}

## 1. Participants

In [None]:
path_local = "01_participants"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

hist_bins = np.linspace(5, 115, 23)

df_fig = df.loc[df['Part'].isin(['UNN EpiTYPER', 'UNN EPIC'])]
df_fig.loc[index_common, "Dataset"] = f'UNN EpiTYPER and EPIC ({len(index_common)})'
df_fig.loc[index_epic, "Dataset"] = f'UNN EPIC ({len(index_epic)})'
fig = plt.figure()
sns.set_theme(style='whitegrid')
hist = sns.histplot(
    data=df_fig,
    bins=hist_bins,
    discrete=False,
    edgecolor='k',
    linewidth=1,
    hue_order=[f'UNN EpiTYPER and EPIC ({len(index_common)})', f'UNN EPIC ({len(index_epic)})'],
    x="Age",
    hue="Dataset",
    palette=colors
)
hist.set(xlim=(0, 120))
sns.move_legend(hist, "lower center", bbox_to_anchor=(.5, 1), ncol=2, title=None, frameon=False)
plt.savefig(f"{path_save}/{path_local}/UNN_hist.png", bbox_inches='tight', dpi=800)
plt.savefig(f"{path_save}/{path_local}/UNN_hist.pdf", bbox_inches='tight')
plt.close(fig)

datasets = ['GSE152026', 'GSE87571', 'GSE40279', 'GSE55763']
for dataset in datasets:
    df_fig = df.loc[df['Part'].isin([dataset])]
    fig = plt.figure()
    sns.set_theme(style='whitegrid')
    hist = sns.histplot(
        data=df_fig,
        bins=hist_bins,
        discrete=False,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color=colors[dataset]
    )
    hist.set(xlim=(0, 120))
    hist.set_title(f"{dataset} ({df_fig.shape[0]})")
    plt.savefig(f"{path_save}/{path_local}/{dataset}_hist.png", bbox_inches='tight', dpi=800)
    plt.savefig(f"{path_save}/{path_local}/{dataset}_hist.pdf", bbox_inches='tight')
    plt.close(fig)

fig, ax = plt.subplots()
venn = venn2(
    subsets=(set([f"{x}_850k" for x in index_agena]), set(index_epic)),
    set_labels = (f'EpiTYPER\n and\n EPIC', f'EPIC'),
    set_colors=(colors[f'UNN EpiTYPER and EPIC'], colors[f'UNN EPIC']),
    alpha = 0.5
)
venn2_circles(subsets=(set([f"{x}_850k" for x in index_agena]), set(index_epic)))
for text in venn.set_labels:
    text.set_fontsize(16)
for text in venn.subset_labels:
    text.set_fontsize(25)
plt.savefig(f"{path_save}/{path_local}/UNN_venn.png", bbox_inches='tight', dpi=800)
plt.savefig(f"{path_save}/{path_local}/UNN_venn.pdf", bbox_inches='tight')
plt.close(fig)

## 2. Ages

In [None]:
path_local = "02_ages"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

datasets = ['UNN EPIC', 'GSE152026', 'GSE87571', 'GSE40279', 'GSE55763']
dnam_ages = ['Age', 'DNAmAgeHannum', 'DNAmAge', 'DNAmPhenoAge', 'DNAmGrimAge', "PCHorvath1", "PCHorvath2", "PCHannum", "PCPhenoAge", "PCGrimAge"]

hist_bins = np.linspace(5, 115, 23)

def plot_unity(x, y, **kwargs):
    points = np.linspace(0, 120, 121)
    plt.gca().plot(points, points, color='k', marker=None, linestyle='--', linewidth=1.0)

def plot_regression(x, y, **kwargs):
    df = pd.DataFrame({"x": x, "y": y})
    formula = "y ~ x"
    model = smf.ols(formula=formula, data=df).fit()
    df_line = pd.DataFrame({"x": [0, 120]})
    df_line["y"] = model.predict(df_line)
    plt.gca().plot(df_line['x'].values, df_line['y'].values, color='dimgrey', marker=None, linestyle='-', linewidth=4.0)
    plt.gca().plot(df_line['x'].values, df_line['y'].values, color=kwargs['color'], marker=None, linestyle='-', linewidth=2.0)

def corr(x, y, **kwargs):
    metrics = get_reg_metrics()
    metrics_res = {}
    for m in metrics:
        x_torch = torch.from_numpy(x.values)
        y_torch = torch.from_numpy(y.values)
        m_val = float(metrics[m][0](y_torch, x_torch).numpy())
        metrics[m][0].reset()
        metrics_res[m] = m_val
    ax = plt.gca()
    label = f"MAE = {metrics_res['mean_absolute_error']:0.2f}"
    ax.annotate(label, xy = (0.19, 0.65), size = 16, xycoords = ax.transAxes)
    label = f"RMSE = {np.sqrt(metrics_res['mean_squared_error']):0.2f}"
    ax.annotate(label, xy = (0.14, 0.45), size = 16, xycoords = ax.transAxes)
    label = r'$\rho$ = ' + f"{metrics_res['pearson_corr_coef']:0.2f}"
    ax.annotate(label, xy = (0.27, 0.25), size = 16, xycoords = ax.transAxes)


for dataset in datasets:
    df_fig = df.loc[df['Part'].isin([dataset]), dnam_ages]

    sns.set_theme(style="whitegrid", font_scale=1.5)
    pair_grid = sns.PairGrid(df_fig, vars=dnam_ages)
    pair_grid.map_upper(sns.scatterplot, color=colors[dataset], s=25, alpha=0.5, edgecolor='k', linewidth=0.2)
    pair_grid.map_diag(sns.histplot, bins=hist_bins, color=colors[dataset], edgecolor='k')
    pair_grid.map_upper(plot_regression, color=colors[dataset])
    pair_grid.map_upper(plot_unity)
    pair_grid.map_lower(corr)
    for x_axis_id in range(pair_grid.axes.shape[0]):
        for y_axis_id in range(pair_grid.axes.shape[1]):
            pair_grid.axes[x_axis_id, y_axis_id].spines[['right', 'top']].set_visible(True)
            if x_axis_id != y_axis_id:
                pair_grid.axes[x_axis_id, y_axis_id].set_xlim((0, 120))
                pair_grid.axes[x_axis_id, y_axis_id].set_ylim((0, 120))
            if x_axis_id > y_axis_id:
                pair_grid.axes[x_axis_id, y_axis_id].grid(False)
    plt.savefig(f"{path_save}/{path_local}/{dataset}_scatter_mtx.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/{path_local}/{dataset}_scatter_mtx.pdf", bbox_inches='tight')
    plt.clf()

## 3. CpGs versus

In [None]:
path_local = "03_cpgs_versus"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

n_cols = 4
n_rows = int(np.ceil(len(feats) / n_cols))

df_feats = pd.DataFrame(index=feats)
for feat in feats:
    df_feats.at[feat, "Gene"] = manifest.at[feat, 'Gene']
df_feats.sort_values(['Gene'], ascending=[True], inplace=True)
feats = df_feats.index.values

fig = make_subplots(
    rows=n_rows,
    cols=n_cols,
    shared_yaxes=False,
    shared_xaxes=False,
    horizontal_spacing=0.075,
    vertical_spacing=0.21,
    subplot_titles=feats
)

titles = {}
for r_id in range(n_rows):
    for c_id in range(n_cols):
        rc_id = r_id * n_cols + c_id
        if rc_id < len(feats):
            feat = feats[rc_id]


            df_feats.at[feat, 'Color'] = px.colors.qualitative.Antique[rc_id]

            xs = df.loc[index_common, feat].values
            ys = df.loc[index_common_suffix, feat].values

            df_reg = pd.DataFrame({"x": xs, "y": ys})
            formula = "y ~ x"
            model = smf.ols(formula=formula, data=df_reg).fit()

            min_val = min(min(xs), min(ys))
            max_val = max(max(xs), max(ys))
            shift_val = max_val - min_val
            min_val -= 0.05 * shift_val
            max_val += 0.05 * shift_val

            df_line = pd.DataFrame({"x": [min_val, max_val]})
            df_line["y"] = model.predict(df_line)

            fig.add_trace(
                go.Scatter(
                    x=[min_val, max_val],
                    y=[min_val, max_val],
                    showlegend=False,
                    name="",
                    mode="lines",
                    marker_color="black",
                    line_dash='dash',
                    marker=dict(
                        size=8,
                        opacity=0.75,
                        line=dict(color='black', width=0.5)
                    )
                ),
                row=r_id + 1,
                col=c_id + 1
            )

            fig.add_trace(
                go.Scatter(
                    x=xs,
                    y=ys,
                    showlegend=False,
                    name="",
                    mode='markers',
                    marker=dict(
                        size=8,
                        opacity=0.75,
                        color=df_feats.at[feat, 'Color'],
                        line=dict(
                            color='black',
                            width=0.5
                        )
                    ),
                ),
                row=r_id + 1,
                col=c_id + 1
            )

            fig.add_trace(
                go.Scatter(
                    x=[min_val, max_val],
                    y=df_line["y"].values,
                    showlegend=False,
                    name="",
                    mode="lines",
                    marker_color=df_feats.at[feat, 'Color'],
                    line_dash='solid',
                    line_width=4,
                    marker=dict(
                        size=8,
                        opacity=0.75,
                        line=dict(color=df_feats.at[feat, 'Color'], width=2)
                    )
                ),
                row=r_id + 1,
                col=c_id + 1
            )

            fig.update_xaxes(
                row=r_id + 1,
                col=c_id + 1,
                title_text="EPIC",
                autorange=False,
                range=[min_val, max_val],
                showgrid=True,
                zeroline=False,
                linecolor='black',
                showline=True,
                gridcolor='gainsboro',
                gridwidth=0.05,
                mirror=True,
                ticks='outside',
                titlefont=dict(
                    color='black',
                    size=20
                ),
                showticklabels=True,
                tickangle=0,
                tickfont=dict(
                    color='black',
                    size=20
                ),
                exponentformat='e',
                showexponent='all'
            )

            fig.update_yaxes(
                row=r_id + 1,
                col=c_id + 1,
                title_text="EpiTYPER",
                autorange=False,
                range=[min_val, max_val],
                showgrid=True,
                zeroline=False,
                linecolor='black',
                showline=True,
                gridcolor='gainsboro',
                gridwidth=0.05,
                mirror=True,
                ticks='outside',
                titlefont=dict(
                    color='black',
                    size=20
                ),
                showticklabels=True,
                tickangle=0,
                tickfont=dict(
                    color='black',
                    size=20
                ),
                exponentformat='e',
                showexponent='all'
            )

            perason_r = stats.pearsonr(xs, ys).correlation
            titles[feat] = f"{feat} ({manifest.at[feat, 'Gene']})<br>" + u"\u03C1" + f" = {perason_r:0.2f}"

fig.for_each_annotation(lambda a: a.update(text = titles[a.text]))
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.01,
        xanchor="center",
        x=0.5,
        itemsizing='constant',
        font_size=50
    ),
    title=dict(
        text="",
        font=dict(size=25)
    ),
    template="none",
    autosize=False,
    width=2000,
    height=1000,
    margin=go.layout.Margin(
        l=100,
        r=100,
        b=100,
        t=100,
        pad=0
    ),
)
fig.update_annotations(font_size=25)
save_figure(fig, f"{path_save}/{path_local}/scatters")

## 4. CpGs distributions

In [None]:
path_local = "04_cpgs_distributions"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats = pd.DataFrame(index=feats)
for feat in feats:
    df_feats.at[feat, "Gene"] = manifest.at[feat, 'Gene']
df_feats.sort_values(['Gene'], ascending=[True], inplace=True)
feats = df_feats.index.values

datasets = ['UNN EPIC', 'GSE152026', 'GSE87571', 'GSE40279', 'GSE55763']
for dataset in datasets:
    dist_num_bins = 15
    fig = go.Figure()

    for feat in feats:
        feat_plot = f"{feat}<br>{df_feats.at[feat, 'Gene']}"
        vals_neg = df.loc[df["Part"] == "UNN EpiTYPER", feat].values
        vals_pos = df.loc[df["Part"] == dataset, feat].values
        fig.add_trace(
            go.Violin(
                x=[f"{feat_plot}"] * len(vals_neg),
                y=vals_neg,
                name=feat_plot,
                box_visible=True,
                meanline_visible=True,
                showlegend=False,
                line_color='black',
                fillcolor=colors["UNN EpiTYPER"],
                marker=dict(color=colors["UNN EpiTYPER"], line=dict(color='black', width=0.3), opacity=0.8),
                points='all',
                bandwidth=np.ptp(vals_neg) / dist_num_bins,
                opacity=0.8,
                legendgroup=feat_plot,
                scalegroup=feat_plot,
                side='negative',
                scalemode="width",
                pointpos=-1.5
            )
        )
        fig.add_trace(
            go.Violin(
                x=[f"{feat_plot}"] * len(vals_pos),
                y=vals_pos,
                name=feat_plot,
                box_visible=True,
                meanline_visible=True,
                showlegend=False,
                line_color='black',
                fillcolor=colors[dataset],
                marker=dict(color=colors[dataset], line=dict(color='black',width=0.3), opacity=0.8),
                points='all',
                bandwidth=np.ptp(vals_pos) / dist_num_bins,
                opacity=0.8,
                legendgroup=feat_plot,
                scalegroup=feat_plot,
                scalemode="width",
                side='positive',
                pointpos=1.5
            )
        )
    add_layout(fig, "", f"Methylation level", f"{dataset}")
    fig.update_layout(title_xref='paper')
    fig.update_layout(
        violingap=0.39,
        violingroupgap=0.39,
        width=1600,
        height=700,
        margin=go.layout.Margin(
            l=100,
            r=50,
            b=180,
            t=50,
            pad=0,
        )
    )
    fig.update_layout(xaxis=dict(tickfont=dict(size=22)))
    fig.update_yaxes(autorange=False, range=[-0.1, 1.1])
    fig.update_xaxes(autorange=False, range=[-0.5, len(feats) - 0.5])
    fig.update_xaxes(tickangle=270)
    save_figure(fig, f"{path_save}/{path_local}/{dataset}")

## 5. Age prediction

In [None]:
path_local = "05_age_prediction"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

path_result = f"{path_dataset}/special/042_agena/ml_data/models/widedeep_tab_net_inference/runs/2023-02-13_22-47-47"
df_result = pd.read_excel(f"{path_result}/df.xlsx", index_col=0)
df_result["Age difference"] = df_result["Estimation"] - df_result['Age']
df_fig = df_result.loc[:, ["Age difference", "Estimation", 'Age']]
df_fig.to_excel(f"{path_save}/{path_local}/data.xlsx")

# datasets = ['UNN EpiTYPER', 'UNN EPIC', 'GSE87571', 'GSE40279', 'GSE55763', 'GSE152026']
datasets = ['UNN EpiTYPER', 'UNN EpiTYPER ESRD', 'UNN EPIC', 'GSE87571', 'GSE40279', 'GSE55763', 'GSE152026']

fig = go.Figure()
for dataset in datasets:
    vals = df_result.loc[df["Part"] == dataset, "Age difference"].values
    fig.add_trace(
        go.Violin(
            x=[dataset] * len(vals),
            y=vals,
            name=dataset,
            box_visible=True,
            meanline_visible=True,
            showlegend=False,
            line_color='black',
            fillcolor=colors[dataset],
            marker=dict(color=colors[dataset], line=dict(color='black', width=0.1), opacity=0.8),
            points='all',
            opacity=0.8,
        )
    )
add_layout(fig, "", f"Age difference", f"")
fig.update_layout(title_xref='paper')
fig.update_layout(
    violingap=0.35,
    violingroupgap=0.35,
    width=2000,
    height=800,
    margin=go.layout.Margin(
        l=100,
        r=50,
        b=60,
        t=50,
        pad=0,
    )
)
fig.update_layout(xaxis=dict(tickfont=dict(size=22)))
fig.update_xaxes(autorange=False, range=[-0.5, len(datasets) - 0.5])
save_figure(fig, f"{path_save}/{path_local}/age_difference")

for dataset in datasets:
    xs = df_result.loc[df["Part"] == dataset, "Age"].values
    ys = df_result.loc[df["Part"] == dataset, "Estimation"].values

    min_val = min(min(xs), min(ys))
    max_val = max(max(xs), max(ys))
    shift_val = max_val - min_val
    min_val -= 0.05 * shift_val
    max_val += 0.05 * shift_val

    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=xs,
            y=ys,
            showlegend=False,
            name="",
            mode='markers',
            marker=dict(
                size=8,
                opacity=0.75,
                color=colors[dataset],
                line=dict(
                    color='black',
                    width=0.5
                )
            ),
        ),
    )
    fig.add_trace(
        go.Scatter(
            x=[min_val, max_val],
            y=[min_val, max_val],
            showlegend=False,
            name="",
            mode="lines",
            marker_color="black",
            line_dash='dash',
            marker=dict(
                size=8,
                opacity=0.75,
                line=dict(color='black', width=0.5)
            )
        )
    )

    add_layout(fig, f"Age", f"Prediction", f"{dataset}")
    fig.update_layout(legend_font_size=20)
    fig.update_layout(legend= {'itemsizing': 'constant'})
    fig.update_xaxes(autorange=False)
    fig.update_yaxes(autorange=False)
    fig.update_layout(title_xref='paper')
    fig.update_layout(xaxis_range=[min_val, max_val])
    fig.update_layout(yaxis_range=[min_val, max_val])
    fig.update_layout(
        width=650,
        height=600,
        margin=go.layout.Margin(
            l=100,
            r=50,
            b=100,
            t=50,
            pad=0,
        )
    )
    save_figure(fig, f"{path_save}/{path_local}/{dataset}")

## 6. XAI

In [None]:
path_local = "06_xai"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_feats = pd.DataFrame(index=feats)
for feat in feats:
    df_feats.at[feat, "Gene"] = manifest.at[feat, 'Gene']
    df_feats.at[feat, "name"] = f"{feat}<br>{manifest.at[feat, 'Gene']}"

datasets = ['UNN EpiTYPER', 'UNN EPIC', 'GSE87571', 'GSE40279', 'GSE55763', 'GSE152026']

for dataset in datasets:

    df_shap = pd.read_excel(f"{path_dataset}/special/042_agena/ml_data/models/widedeep_tab_net_inference/runs/2023-02-06_19-36-19/shap/{dataset}/shap.xlsx", index_col="index")

    shap_mean_abs = []
    for feat in feats:
        shap_mean_abs.append(np.mean(np.abs(df_shap.loc[:, feat].values)))

    order = np.argsort(shap_mean_abs)
    feats_sorted = feats[order]
    shap_mean_abs = np.array(shap_mean_abs)[order]
    feats_names = [df_feats.at[x, "name"] for x in feats_sorted]

    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            x=shap_mean_abs,
            y=list(range(len(shap_mean_abs))),
            orientation='h',
            marker=dict(color=colors[dataset], opacity=1.0)
        )
    )
    add_layout(fig, "Mean(|SHAP values|)", "", f"{dataset}")
    fig.update_layout(legend_font_size=20)
    fig.update_layout(showlegend=False)
    fig.update_layout(
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(feats_names))),
            ticktext=feats_names
        )
    )
    fig.update_yaxes(autorange=False)
    fig.update_layout(yaxis_range=[-0.5, len(feats_names) - 0.5])
    fig.update_yaxes(tickfont_size=18)
    fig.update_xaxes(tickfont_size=18)
    fig.update_xaxes(title_font_size=18)
    fig.update_layout(title_font_size=20)
    fig.update_xaxes(nticks=6)
    fig.update_layout(
        autosize=False,
        width=300,
        height=600,
        margin=go.layout.Margin(
            l=130,
            r=50,
            b=70,
            t=50,
            pad=0
        )
    )
    save_figure(fig, f"{path_save}/{path_local}/{dataset}")
    df_shap.to_excel(f"{path_save}/{path_local}/{dataset}.xlsx")

## 7. Ages table

In [None]:
path_local = "07_ages_table"
pathlib.Path(f"{path_save}/{path_local}").mkdir(parents=True, exist_ok=True)

df_result = pd.read_excel(f"{path_dataset}/special/042_agena/ml_data/models/widedeep_tab_net_inference/runs/2023-02-06_19-36-19/df.xlsx", index_col=0)
df_result["EpiTYPER Age"] = df_result["Estimation"]

datasets = ['UNN EPIC', 'GSE87571', 'GSE40279', 'GSE55763', 'GSE152026']
ages = ['EpiTYPER Age', 'DNAmAgeHannum', 'DNAmAge', 'DNAmPhenoAge', 'DNAmGrimAge', "PCHorvath1", "PCHorvath2", "PCHannum", "PCPhenoAge", "PCGrimAge"]

df_mae = pd.DataFrame(index=ages, columns=datasets)
for dataset in datasets:
    for age in ages:
        real = df_result.loc[df["Part"] == dataset, "Age"].values
        pred = df_result.loc[df["Part"] == dataset, age].values

        metrics = get_reg_metrics()
        metrics_res = {}
        for m in metrics:
            x_torch = torch.from_numpy(real)
            y_torch = torch.from_numpy(pred)
            m_val = float(metrics[m][0](y_torch, x_torch).numpy())
            metrics[m][0].reset()
            metrics_res[m] = m_val

        mae = metrics_res['mean_absolute_error']
        df_mae.at[age, dataset] = mae

df_mae.to_excel(f"{path_save}/{path_local}/mae.xlsx")