# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
import glob
import pandas as pd
from scipy import stats
import seaborn as sns
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
import pathlib
from sklearn.metrics import mean_absolute_error

# Collect data

In [None]:
path_load = "D:/YandexDisk/Work/pydnameth/draft/10_MetaEPIClock/MetaEpiAge/problems"
gses = ['GSE74193']

cols_trgt = [
    'geo_accession',
    'series_id',
    'Age',
    'Sex',
    'Tissue',
]

ages_pc = {
    'PCHorvath1': 'PC-Horvath',
    'PCHorvath2': 'PC-SkinBloodAge',
    'PCHannum': 'PC-Hannum',
    'PCPhenoAge': 'PC-PhenoAge',
    'PCGrimAge': 'PC-GrimAge',
}

pace = 'DunedinPACE'

ages_calc = {
    'DNAmAge': 'Horvath',
    'DNAmAgeHannum': 'Hannum',
    'DNAmPhenoAge': 'PhenoAge',
    'DNAmAgeSkinBloodClock': 'SkinBloodAge',
    'DNAmGrimAgeBasedOnRealAge': 'GrimAge',
    'DNAmGrimAge2BasedOnRealAge': 'GrimAge2',
}

ages = list(ages_pc.values()) + list(ages_calc.values())

dfs = {}
for gse in gses:
    dfs[gse] = {}
    for harm in ['with', 'without']:
        df_gse_pheno = pd.read_csv(f"{path_load}/harm_checking_{gse}/{harm}_harmonization/pheno.csv", index_col=0)
        df_gse_pheno['GSE'] = gse
        df_gse_pheno.rename(columns=ages_pc, inplace=True)
        fn_gse_horvath_files = glob.glob(f"{path_load}/harm_checking_{gse}/{harm}_harmonization/DNAmAgeCalcProject_*_Results.csv")
        fn_gse_horvath = fn_gse_horvath_files[0]
        df_gse_horvath = pd.read_csv(fn_gse_horvath, index_col=0)
        df_gse = df_gse_pheno.loc[:, cols_trgt + list(ages_pc.values()) + [pace, 'GSE']]
        for age_col, age_label in ages_calc.items():
            df_gse.loc[df_gse.index.values, age_label] = df_gse_horvath.loc[df_gse.index.values, age_col]
        
        df_gse.set_index('geo_accession', inplace=True)
        dfs[gse][harm] = df_gse

# Setup age distribution plots

In [None]:
ages_kde = {
    'Hannum': {
        'col_id': 0,
        'row_id': 0,
        'title': 'Hannum',
        'y_label': 'Original clocks',
        'x_label': ''
    },
    'Horvath': {
        'col_id': 1,
        'row_id': 0,
        'title': 'Horvath',
        'y_label': '',
        'x_label': ''
    },
    'SkinBloodAge': {
        'col_id': 2,
        'row_id': 0,
        'title': 'SkinBloodAge',
        'y_label': '',
        'x_label': ''
    },
    'PhenoAge': {
        'col_id': 3,
        'row_id': 0,
        'title': 'PhenoAge',
        'y_label': '',
        'x_label': ''
    },
    'GrimAge': {
        'col_id': 4,
        'row_id': 0,
        'title': 'GrimAge',
        'y_label': '',
        'x_label': ''
    },
    'GrimAge2': {
        'col_id': 5,
        'row_id': 0,
        'title': 'GrimAge2',
        'y_label': '',
        'x_label': ''
    },
    'PC-Hannum': {
        'col_id': 0,
        'row_id': 1,
        'title': '',
        'y_label': 'PC clocks',
        'x_label': ''
    },
    'PC-Horvath': {
        'col_id': 1,
        'row_id': 1,
        'title': '',
        'y_label': '',
        'x_label': ''
    },
    'PC-SkinBloodAge': {
        'col_id': 2,
        'row_id': 1,
        'title': '',
        'y_label': '',
        'x_label': ''
    },
    'PC-PhenoAge': {
        'col_id': 3,
        'row_id': 1,
        'title': '',
        'y_label': '',
        'x_label': ''
    },
    'PC-GrimAge': {
        'col_id': 4,
        'row_id': 1,
        'title': '',
        'y_label': '',
        'x_label': ''
    },
    'PC-GrimAge2': {
        'col_id': 5,
        'row_id': 1,
        'title': '',
        'y_label': '',
        'x_label': ''
    },
}

n_rows = 2
n_cols = 6
fig_width = 15
fig_height = 8
low_percent = 0.005
hgh_percent = 0.995
ptp_shift = 0.05  

# Plot kdes

In [None]:
for gse in gses:
    
    df_gse_w = dfs[gse]['with']
    df_gse_wo = dfs[gse]['without']
    
    x_min = df_gse_w[ages].min().min()
    x_max = df_gse_w[ages].max().max()
    x_ptp = x_max - x_min
    
    y_min = df_gse_wo[ages].min().min()
    y_max = df_gse_wo[ages].max().max()
    y_ptp = y_max - y_min
    
    path_save = f"{path_load}/harm_checking_{gse}"
    pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
    
    sns.set_theme(style='whitegrid')
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), sharey=True, gridspec_kw={'hspace': 0.075})
    
    for epi_est, params in ages_kde.items():
        if epi_est == 'PC-GrimAge2':
            axs[params['row_id'], params['col_id']].set_xlabel(params['x_label'])
            axs[params['row_id'], params['col_id']].axis('off')
        else:
            kdeplot = sns.kdeplot(
                x=df_gse_w[epi_est].values,
                y=df_gse_wo[epi_est].values,
                fill=True,
                cbar=False,
                color='red',
                cut=0,
                legend=False,
                ax=axs[params['row_id'], params['col_id']]
            )
            regplot = sns.regplot(
                x=df_gse_w[epi_est].values,
                y=df_gse_wo[epi_est].values,
                scatter=False,
                color='black',
                truncate=True,
                ax=axs[params['row_id'], params['col_id']]
            )
            points_unity = [min(x_min - ptp_shift * x_ptp, y_min - ptp_shift * y_ptp), max(x_max + ptp_shift * x_ptp, y_max + ptp_shift * y_ptp)]
            axs[params['row_id'], params['col_id']].plot(points_unity, points_unity, color='black', marker=None, linestyle='--', linewidth=1.0)
            
            corr, _ = stats.pearsonr(df_gse_w[epi_est].values, df_gse_wo[epi_est].values)
            mae = mean_absolute_error(df_gse_w[epi_est].values, df_gse_wo[epi_est].values)
            label = r'$\rho$ = ' + f"{corr:0.2f}"
            axs[params['row_id'], params['col_id']].annotate(
                label,
                xy=(0.5, 0.10),
                size=16,
                xycoords=axs[params['row_id'],
                params['col_id']].transAxes,
                ha='center',
                color='black',
                alpha=0.75
            )
            label = f"MAE = {mae:0.2f}"
            axs[params['row_id'], params['col_id']].annotate(
                label,
                xy=(0.5, 0.02),
                size=16,
                xycoords=axs[params['row_id'],
                params['col_id']].transAxes,
                ha='center',
                color='black',
                alpha=0.75
            )
            
            axs[params['row_id'], params['col_id']].set_xlim([x_min - ptp_shift * x_ptp, x_max + ptp_shift * x_ptp])
            axs[params['row_id'], params['col_id']].set_ylim([y_min - ptp_shift * y_ptp, y_max + ptp_shift * y_ptp])
            axs[params['row_id'], params['col_id']].set_title(params['title'], fontsize=18)
            axs[params['row_id'], params['col_id']].set_ylabel(params['y_label'], fontsize=18)
            axs[params['row_id'], params['col_id']].set_xlabel(params['x_label'])
            if params['x_label'] == '':
                axs[params['row_id'], params['col_id']].set_xticklabels([])
    
    fig.tight_layout()    
    plt.savefig(f"{path_save}/kde_ages.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/kde_ages.pdf", bbox_inches='tight')
    plt.close(fig)    