# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
import os
from tqdm import tqdm
import glob
import pandas as pd
import numpy as np
from scipy import stats
import seaborn as sns
import plotly.express as px
import statsmodels.formula.api as smf
import plotly.graph_objects as go
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode
from matplotlib import patheffects as pe
init_notebook_mode(connected=False)
from scipy.stats import mannwhitneyu, median_test, kruskal, wilcoxon, friedmanchisquare
import matplotlib.pyplot as plt
from matplotlib import colors
import pathlib
from sklearn.metrics import mean_absolute_error
from statannotations.Annotator import Annotator
import functools
import matplotlib.lines as mlines
import patchworklib as pw
import pickle
from src.routines.plotly_layout import add_layout, color_tick
from d3blocks import D3Blocks

# Collect data

In [None]:
path_load = "D:/YandexDisk/Work/pydnameth/draft/10_MetaEPIClock/MetaEpiAge"
trgt_dirs = ['GPL13534', 'GPL16304', 'GPL21145', 'GPL23976']
paths_gses = []
for trgt_dir in trgt_dirs:
    for path in pathlib.Path(f"{path_load}/{trgt_dir}").rglob("*"):
        if path.is_dir():
            if path.parent in paths_gses:
                paths_gses.remove(path.parent)
            if 'GSE' in str(path):
                paths_gses.append(path)

cols_trgt = [
    'geo_accession',
    'series_id',
    'Age',
    'Sex',
    'Tissue',
    'StateName',
    'Group',
    'Ethnicity',
    'Geography'
]

ages_pc = [
    'PCHorvath1',
    'PCHorvath2',
    'PCHannum',
    'PCPhenoAge',
    'PCGrimAge',
]

pace = 'DunedinPACE'

ages_calc = {
    'DNAmAge': 'Horvath',
    'DNAmAgeHannum': 'Hannum',
    'DNAmPhenoAge': 'PhenoAge',
    'DNAmAgeSkinBloodClock': 'SkinBloodAge',
    'DNAmGrimAge2BasedOnRealAge': 'GrimAge2',
    'DNAmGrimAgeBasedOnRealAge': 'GrimAge'
}

dfs_gses = []

for path_gse in paths_gses:
    try:
        df_gse_pheno = pd.read_csv(f"{str(path_gse)}/pheno.csv", index_col=0)
    except UnicodeDecodeError:
        df_gse_pheno = pd.read_csv(f"{str(path_gse)}/pheno.csv", index_col=0, encoding='latin-1')
        
    df_gse_pheno['GSE'] = path_gse.parts[-1]
    
    if {'StateName', 'Group', 'Ethnicity', 'Geography'}.issubset(df_gse_pheno.columns):
        fn_gse_horvath_files = glob.glob(f"{str(path_gse)}/DNAmAgeCalcProject_*_Results.csv")
        if len(fn_gse_horvath_files) > 0:
            fn_gse_horvath = fn_gse_horvath_files[0]
            df_gse_horvath = pd.read_csv(fn_gse_horvath, index_col=0)
            
            df_gse = df_gse_pheno.loc[:, cols_trgt + ages_pc + [pace, 'GSE']]
            for age_col, age_label in ages_calc.items():
                df_gse.loc[df_gse.index.values, age_label] = df_gse_horvath.loc[df_gse.index.values, age_col]
                
            df_gse.set_index('geo_accession', inplace=True)
            dfs_gses.append(df_gse)
    else:
        print(path_gse.parts[-1])

df = pd.concat(dfs_gses, verify_integrity=True)
df.to_excel(f"{path_load}/table.xlsx", index_label='geo_accession')

# Calculate age acceleration

In [None]:
ref_gse = 'GSE87571'

ages = list(ages_calc.values()) + ages_pc
for age_type in (pbar := tqdm(ages)):
    pbar.set_description(f"Processing {age_type}")
    # formula = f"{age_type} ~ Age"
    # model = smf.ols(formula=formula, data=df.loc[df['GSE'] == ref_gse, :]).fit()
    # df[f"{age_type}_linear_pred"] = model.predict(df)
    # df[f"{age_type}Acc"] = df[age_type] - df[f"{age_type}_linear_pred"]
    df[f"{age_type}Acc"] = df[age_type] - df['Age']
    
df.to_excel(f"{path_load}/table.xlsx", index_label='geo_accession')

# TO DELETE: Checking best GSEUNN harmonization

In [None]:
df_unn = df.loc[df['StateName'] == 'Russia', :]
gses_unn = df_unn['GSE'].unique()

for harm_type in gses_unn:
    path_save = f"{path_load}/figures/unn_harm_check/{harm_type}"
    pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
    df_tmp = df_unn.loc[df_unn['GSE'] == harm_type, :]
    
    df_stat = pd.DataFrame(index=[f"{x}Acc" for x in ages] + [pace])
    for metric in (pbar := tqdm(df_stat.index.values)):
        pbar.set_description(f"Processing {metric}")
        
        vals = {}
        for group in ['Russians', 'Yakuts']:
            vals[group] = df_tmp.loc[df_tmp['Group'] == group, metric].values
            df_stat.at[metric, f"mean_{group}"] = np.mean(vals[group])
            df_stat.at[metric, f"median_{group}"] = np.median(vals[group])
            df_stat.at[metric, f"q75_{group}"], df_stat.at[metric, f"q25_{group}"] = np.percentile(vals[group], [75 , 25])
            df_stat.at[metric, f"iqr_{group}"] = df_stat.at[metric, f"q75_{group}"] - df_stat.at[metric, f"q25_{group}"]
            
        _, pval = mannwhitneyu(*vals.values(), alternative='two-sided')
        df_stat.at[metric, "pval"] = pval
    
    _, df_stat["pval_fdr_bh"], _, _ = multipletests(df_stat["pval"], 0.05, method='fdr_bh')
    df_stat.to_excel(f"{path_save}/stat.xlsx", index=True)
    
    colors_unn = {'Russians': 'gold', 'Yakuts': 'lightslategray'}
    fig = go.Figure()
    dist_num_bins = 32
    age_order = ages[::-1]
    age_labels = {}
    for age_id, age_type in tqdm(enumerate(age_order)):
        vals_0 = df_tmp.loc[df_tmp['Group'] == 'Russians', f"{age_type}Acc"].values
        color_0 = colors_unn['Russians']
        vals_1 = df_tmp.loc[df_tmp['Group'] == 'Yakuts', f"{age_type}Acc"].values
        color_1 = colors_unn['Yakuts']
        pval = df_stat.at[f'{age_type}Acc', 'pval_fdr_bh']
        age_label = f"{age_type}<br>p-value: {pval:0.2e}"
        age_labels[age_type] = age_label
    
        fig.add_trace(
            go.Violin(
                y=[age_id] * len(vals_0),
                x=vals_0,
                name=age_label,
                box_visible=True,
                meanline_visible=True,
                showlegend=False,
                line_color='black',
                fillcolor=color_0,
                marker=dict(color=color_0, line=dict(color='black', width=0.35), opacity=0.8, size=8),
                points='all',
                bandwidth=np.ptp(vals_0) / dist_num_bins,
                opacity=0.8,
                legendgroup=age_label,
                scalegroup=age_label,
                side='negative',
                orientation='h',
                scalemode="width",
                pointpos=-1.5
            )
        )
    
        fig.add_trace(
            go.Violin(
                y=[age_id] * len(vals_1),
                x=vals_1,
                name=age_label,
                box_visible=True,
                meanline_visible=True,
                showlegend=False,
                line_color='black',
                fillcolor=color_1,
                marker=dict(color=color_1, line=dict(color='black', width=0.35), opacity=0.8, size=8),
                points='all',
                bandwidth=np.ptp(vals_1) / dist_num_bins,
                opacity=0.8,
                legendgroup=age_label,
                scalegroup=age_label,
                scalemode="width",
                side='positive',
                orientation='h',
                pointpos=1.5
            )
        )
    add_layout(fig, "Age acceleration", f"", f"{harm_type}")
    fig.update_layout(
        title=dict(xref='paper', x=0.5),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.01,
            xanchor="left",
            x=0.0001,
            itemsizing='constant',
            font_size=22
        ),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(ages))),
            ticktext=[color_tick('black', age_labels[x]) for x in age_order],
            tickfont=dict(size=25)
        ),
        xaxis=dict(
            tickfont=dict(size=26),
            titlefont=dict(size=26)
        )
    )
    fig.update_layout(
        violingap=0.39,
        violingroupgap=0.39,
        height=140 * len(ages),
        width=1000,
        margin=go.layout.Margin(
            l=260,
            r=30,
            b=110,
            t=50,
            pad=0,
        )
    )
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)
    fig.update_yaxes(autorange=False, range=[-0.5, len(ages) - 0.5])
    fig.update_xaxes(autorange=True)
    fig.write_image(f"{path_save}/ages_violins.png", scale=2)
    fig.write_image(f"{path_save}/ages_violins.pdf", format="pdf")
    
    fig = go.Figure()
    vals_0 = df_tmp.loc[df_tmp['Group'] == 'Russians', f"DunedinPACE"].values
    color_0 = colors_unn['Russians']
    vals_1 = df_tmp.loc[df_tmp['Group'] == 'Yakuts', f"DunedinPACE"].values
    color_1 = colors_unn['Yakuts']
    pval = df_stat.at[f'DunedinPACE', 'pval_fdr_bh']
    label = f"DunedinPACE<br>p-value: {pval:0.2e}"
    fig.add_trace(
        go.Violin(
            y=[0] * len(vals_0),
            x=vals_0,
            name=label,
            box_visible=True,
            meanline_visible=True,
            showlegend=False,
            line_color='black',
            fillcolor=color_0,
            marker=dict(color=color_0, line=dict(color='black', width=0.35), opacity=0.8, size=8),
            points='all',
            bandwidth=np.ptp(vals_0) / dist_num_bins,
            opacity=0.8,
            legendgroup=label,
            scalegroup=label,
            side='negative',
            orientation='h',
            scalemode="width",
            pointpos=-1.5
        )
    )
    fig.add_trace(
        go.Violin(
            y=[0] * len(vals_1),
            x=vals_1,
            name=label,
            box_visible=True,
            meanline_visible=True,
            showlegend=False,
            line_color='black',
            fillcolor=color_1,
            marker=dict(color=color_1, line=dict(color='black', width=0.35), opacity=0.8, size=8),
            points='all',
            bandwidth=np.ptp(vals_1) / dist_num_bins,
            opacity=0.8,
            legendgroup=label,
            scalegroup=label,
            scalemode="width",
            side='positive',
            orientation='h',
            pointpos=1.5
        )
    )
    add_layout(fig, "DunedinPACE", f"", f"{harm_type}")
    fig.update_layout(
        title=dict(xref='paper', x=0.5),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.01,
            xanchor="left",
            x=0.0001,
            itemsizing='constant',
            font_size=22
        ),
        yaxis=dict(
            tickmode='array',
            tickvals=[0],
            ticktext=[color_tick('black', label)],
            tickfont=dict(size=25)
        ),
        xaxis=dict(
            tickfont=dict(size=26),
            titlefont=dict(size=26)
        )
    )
    fig.update_layout(
        violingap=0.39,
        violingroupgap=0.39,
        height=300,
        width=1000,
        margin=go.layout.Margin(
            l=260,
            r=30,
            b=110,
            t=50,
            pad=0,
        )
    )
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)
    fig.update_yaxes(autorange=False, range=[-0.5, 0.5])
    fig.update_xaxes(autorange=True)
    fig.write_image(f"{path_save}/DunedinPACE_violins.png", scale=2)
    fig.write_image(f"{path_save}/DunedinPACE_violins.pdf", format="pdf")

# Statistics by state, group and dataset

In [None]:
# Delete GSEUNN copies
df.drop(df[df['GSE'].isin(['GSEUNN_harm_var(Age)_batch(Slide_Array)', 'GSEUNN_harm_var(Region)_batch(Slide_Array)'])].index, inplace=True)

In [None]:
df_states = df['StateName'].value_counts().to_frame()
dfs_states_group = {}
dfs_states_gse = {}
for state in df_states.index.values:
    dfs_states_group[state] = df.loc[df['StateName'] == state, :]['Group'].value_counts().to_frame()
    dfs_states_gse[state] = df.loc[df['StateName'] == state, :]['GSE'].value_counts().to_frame()

In [None]:
d3 = D3Blocks()
df_tmp = d3.import_example('energy')
d3.sankey(df_tmp)
d3.show()

# Plot global figures

In [None]:
path_save = f"{path_load}/figures/epi_est_stat"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
colors_states = {state: px.colors.qualitative.Light24[state_id] for state_id, state in enumerate(df_states.index.values)}
states = df_states.index.values
df_states_aerr_mean = pd.DataFrame(index=states, columns=ages, data=np.zeros(shape=(len(states), len(ages))))
df_states_pace_mean = pd.DataFrame(index=states, columns=[pace], data=np.zeros(shape=(len(states), 1)))
for state in states:
    vals = df.loc[df['StateName'] == state, pace].values
    df_states_pace_mean.at[state, pace] = np.mean(vals)
    for age_type in ages:
        vals = df.loc[df['StateName'] == state, f"{age_type}Acc"].values
        df_states_aerr_mean.at[state, age_type] = np.mean(vals)
df_states_aerr_mean.to_excel(f"{path_save}/states_aerr_mean.xlsx", index_label="StateName")
df_states_pace_mean.to_excel(f"{path_save}/states_pace_mean.xlsx", index_label="StateName")

fig, ax = plt.subplots(figsize=(2.4 + 0.3 * len(ages), 1.8 + 0.15 * len(states)))
sns.set_theme(style='whitegrid')
heatmap = sns.heatmap(
    df_states_aerr_mean,
    annot=True,
    fmt=".1f",
    center=0.0,
    cmap='seismic',
    linewidth=0.1,
    linecolor='black',
    annot_kws={"size": 35 / np.sqrt(max(df_states_aerr_mean.shape))},
    ax=ax
)
ax.set_xlabel('Epigenetic Age', fontsize=16)
ax.set_ylabel('Countries', fontsize=16)
ax.figure.axes[-1].set_ylabel('Acceleration', size=16)
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
for tick_label in ax.get_yticklabels():
    tick_label.set_color(colors_states[tick_label.get_text()])
plt.savefig(f"{path_save}/heatmap_states_aerr_mean.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/heatmap_states_aerr_mean.pdf", bbox_inches='tight')
plt.close(fig)

sns.set_theme(style='whitegrid')
clustermap = sns.clustermap(
    df_states_aerr_mean,
    annot=True,
    col_cluster=True,
    row_cluster=True,
    fmt=".1f",
    center=0.0,
    cmap='seismic',
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    annot_kws={"size": 55 / np.sqrt(max(df_states_aerr_mean.shape))},
    figsize=((0.3 + 0.05 * len(ages)) * 10, (0.35 + 0.035 * len(states)) * 10)
)
clustermap.ax_heatmap.set_xlabel('Epigenetic Age', fontsize=20)
clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), fontsize = 18)
clustermap.ax_heatmap.set_ylabel('Countries', fontsize=20)
clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), fontsize = 18)
clustermap.ax_cbar.set_ylabel('Acceleration', size=20)
clustermap.ax_cbar.tick_params(labelsize=18)
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    tick_label.set_color(colors_states[tick_label.get_text()])
plt.savefig(f"{path_save}/clustermap_states_aerr_mean.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/clustermap_states_aerr_mean.pdf", bbox_inches='tight')
plt.close(clustermap.fig)

fig, ax = plt.subplots(figsize=(1.5, 1.8 + 0.15 * len(states)))
sns.set_theme(style='whitegrid')
heatmap = sns.heatmap(
    df_states_pace_mean,
    annot=True,
    fmt=".3f",
    center=1.0,
    cmap='PiYG_r',
    linewidth=0.1,
    linecolor='black',
    annot_kws={"size": 35 / np.sqrt(max(df_states_aerr_mean.shape))},
    ax=ax
)
ax.set_xticklabels([''])
ax.set_xlabel('DunedinPACE', fontsize=16)
ax.set_ylabel('Countries', fontsize=16)
ax.figure.axes[-1].set_ylabel('Pace of Aging', size=16)
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
for tick_label in ax.get_yticklabels():
    tick_label.set_color(colors_states[tick_label.get_text()])
plt.savefig(f"{path_save}/heatmap_states_pace_mean.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/heatmap_states_pace_mean.pdf", bbox_inches='tight')
plt.close(fig)

sns.set_theme(style='whitegrid')
clustermap = sns.clustermap(
    df_states_pace_mean,
    annot=True,
    col_cluster=False,
    row_cluster=True,
    fmt=".3f",
    center=1.0,
    cmap='PiYG_r',
    linewidth=0.1,
    linecolor='black',
    tree_kws=dict(linewidths=1.5),
    dendrogram_ratio=(0.6, 0.0),
    cbar_pos=(0.15, 1.06, 0.9, 0.04),
    cbar_kws={"orientation": "horizontal"},
    annot_kws={"size": 55 / np.sqrt(max(df_states_aerr_mean.shape))},
    figsize=(4, (0.25 + 0.03 * len(states)) * 10)
)
clustermap.ax_heatmap.set_xlabel('DunedinPACE', fontsize=20)
clustermap.ax_heatmap.set_xticklabels("", fontsize = 18)
clustermap.ax_heatmap.set_ylabel('Countries', fontsize=20)
clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), fontsize = 18)
clustermap.ax_cbar.set_title('Pace of Aging', size=20)
clustermap.ax_cbar.tick_params(labelsize=18)
for spine in clustermap.ax_cbar.spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
for tick_label in clustermap.ax_heatmap.get_yticklabels():
    tick_label.set_color(colors_states[tick_label.get_text()])
plt.savefig(f"{path_save}/clustermap_states_pace_mean.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/clustermap_states_pace_mean.pdf", bbox_inches='tight')
plt.close(clustermap.fig)

path_save = f"{path_load}/figures/epi_est_stat/epi_ests"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
dfs_states_states_stat = {}
for epi_est in ages + [pace]:
    dfs_states_states_stat[epi_est] = pd.DataFrame(index=states, columns=states, data=np.zeros(shape=(len(states), len(states))))
    if epi_est == pace:
        col = epi_est
    else:
        col = f"{epi_est}Acc"
    for state_1_id, state_1 in enumerate(states):
        vals_1 = df.loc[df['StateName'] == state_1,  col].values
        for state_2_id in range(state_1_id, len(states)):
            state_2 = states[state_2_id]
            vals_2 = df.loc[df['StateName'] == state_2, col].values
            if state_1 != state_2:
                _, pval = mannwhitneyu(vals_1, vals_2, alternative='two-sided')
                diff = np.mean(vals_2) - np.mean(vals_1)
                dfs_states_states_stat[epi_est].at[state_1, state_2] = diff
                dfs_states_states_stat[epi_est].at[state_2, state_1] = pval
            else:
                dfs_states_states_stat[epi_est].at[state_1, state_2]  = np.nan
    selection = np.tri(len(states), len(states), -1, dtype=np.bool)
    df_fdr = dfs_states_states_stat[epi_est].where(selection).stack().reset_index()
    df_fdr.columns = ['row', 'col', 'pval']
    _, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
    nzmin = df_fdr['pval_fdr_bh'][df_fdr['pval_fdr_bh'].gt(0)].min(0) * 0.5
    df_fdr['pval_fdr_bh'].replace({0.0: nzmin}, inplace=True)
    df_fdr['pval_fdr_bh_log'] = -np.log10(df_fdr.loc[:, 'pval_fdr_bh'].values)
    for line_id in range(df_fdr.shape[0]):
        dfs_states_states_stat[epi_est].loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = df_fdr.at[line_id, 'pval_fdr_bh_log']
    dfs_states_states_stat[epi_est].to_excel(f"{path_save}/{epi_est}_states_stat.xlsx", index_label="StateName")
    
    fig, ax = plt.subplots(figsize=(4.2 + 0.23 * len(states), 0.8 + 0.2 * len(states)))
    sns.set_theme(style='whitegrid')
    cmap_triu = plt.get_cmap("seismic").copy()
    heatmap_diff = sns.heatmap(
        dfs_states_states_stat[epi_est],
        mask=np.tri(len(states), len(states), -1, dtype=np.bool),
        annot=True,
        fmt=".2f",
        center=0.0,
        cmap=cmap_triu,
        linewidth=0.1,
        linecolor='black',
        annot_kws={"size": 25 / np.sqrt(max(df_states_aerr_mean.shape))},
        ax=ax
    )
    if epi_est == pace:
        ax.figure.axes[-1].set_ylabel('DunedinPACE Difference', size=13)
    else:
        ax.figure.axes[-1].set_ylabel('Age Acceleration Difference', size=13)
    for spine in ax.figure.axes[-1].spines.values():
        spine.set(visible=True, lw=0.25, edgecolor="black")
    cmap_tril = plt.get_cmap("cool").copy()
    cmap_tril.set_under('black')
    heatmap_pval = sns.heatmap(
        dfs_states_states_stat[epi_est],
        mask=np.tri(len(states), len(states), -1, dtype=np.bool).T,
        annot=True,
        fmt=".1f",
        vmin=-np.log10(0.05),
        cmap=cmap_tril,
        linewidth=0.1,
        linecolor='black',
        annot_kws={"size": 25 / np.sqrt(max(df_states_aerr_mean.shape))},
        ax=ax
    )
    ax.figure.axes[-1].set_ylabel(r"$-\log_{10}(\mathrm{p-value})$", size=13)
    for spine in ax.figure.axes[-1].spines.values():
        spine.set(visible=True, lw=0.25, edgecolor="black")
    ax.set_xlabel('', fontsize=16)
    ax.set_ylabel('', fontsize=16)
    ax.set_xticklabels(ax.get_xticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
    for tick_label in ax.get_xticklabels():
        tick_label.set_color(colors_states[tick_label.get_text()])
        ax.set_xticklabels(ax.get_xticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
    for tick_label in ax.get_yticklabels():
        tick_label.set_color(colors_states[tick_label.get_text()])
        ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
    plt.savefig(f"{path_save}/{epi_est}_states_stat.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/{epi_est}_states_stat.pdf", bbox_inches='tight')
    plt.close(fig)

# Plot figures for each state

In [None]:
for state in df_states.index.values:
    
    df_states_group = dfs_states_group[state]
    groups = df_states_group.index.values
    
    if len(groups) > 1:
        
        df_state = df.loc[df['StateName'] == state, :]
        path_save = f"{path_load}/figures/epi_est_stat/group/{state}"
        pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
    
        colors_groups = {group: px.colors.qualitative.Dark24[group_id] for group_id, group in enumerate(groups)}
        df_groups_aerr_mean = pd.DataFrame(index=groups, columns=ages, data=np.zeros(shape=(len(groups), len(ages))))
        df_groups_pace_mean = pd.DataFrame(index=groups, columns=[pace], data=np.zeros(shape=(len(groups), 1)))
        for group in groups:
            vals = df_state.loc[df_state['Group'] == group, pace].values
            df_groups_pace_mean.at[group, pace] = np.mean(vals)
            for age_type in ages:
                vals = df_state.loc[df_state['Group'] == group, f"{age_type}Acc"].values
                df_groups_aerr_mean.at[group, age_type] = np.mean(vals)
        df_groups_aerr_mean.to_excel(f"{path_save}/groups_aerr_mean.xlsx", index_label="Group")
        df_groups_pace_mean.to_excel(f"{path_save}/groups_pace_mean.xlsx", index_label="Group")
        
        fig, ax = plt.subplots(figsize=(2.7 + 0.375 * len(ages), 1.8 + 0.15 * len(groups)))
        sns.set_theme(style='whitegrid')
        heatmap = sns.heatmap(
            df_groups_aerr_mean,
            annot=True,
            fmt=".1f",
            center=0.0,
            cmap='seismic',
            linewidth=0.1,
            linecolor='black',
            annot_kws={"size": 35 / np.sqrt(max(df_groups_aerr_mean.shape))},
            ax=ax,
        )
        ax.set_xlabel('Epigenetic Age', fontsize=16)
        ax.set_ylabel('Groups', fontsize=16)
        ax.figure.axes[-1].set_ylabel('Acceleration', size=16)
        for spine in ax.figure.axes[-1].spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in ax.get_yticklabels():
            tick_label.set_color(colors_groups[tick_label.get_text()])
        plt.savefig(f"{path_save}/heatmap_groups_aerr_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/heatmap_groups_aerr_mean.pdf", bbox_inches='tight')
        plt.close(fig)
        
        if len(groups) > 2:
            figsize_shift_x = 0.35
            row_cluster = True
        else:
            figsize_shift_x = 0.125
            row_cluster = False
        sns.set_theme(style='whitegrid')
        clustermap = sns.clustermap(
            df_groups_aerr_mean,
            annot=True,
            col_cluster=True,
            row_cluster=row_cluster,
            fmt=".1f",
            center=0.0,
            cmap='seismic',
            linewidth=0.1,
            linecolor='black',
            tree_kws=dict(linewidths=1.5),
            annot_kws={"size": 55 / np.sqrt(max(df_groups_aerr_mean.shape))},
            figsize=((figsize_shift_x + 0.065 * len(ages)) * 10, (0.45 + 0.035 * len(groups)) * 10)
        )
        clustermap.ax_heatmap.set_xlabel('Epigenetic Age', fontsize=20)
        clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), fontsize=18)
        clustermap.ax_heatmap.set_ylabel('Groups', fontsize=20)
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), fontsize=18, rotation=0)
        clustermap.ax_cbar.set_ylabel('Acceleration', size=20)
        clustermap.ax_cbar.tick_params(labelsize=18)
        for spine in clustermap.ax_cbar.spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in clustermap.ax_heatmap.get_yticklabels():
            tick_label.set_color(colors_groups[tick_label.get_text()])
        plt.savefig(f"{path_save}/clustermap_groups_aerr_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/clustermap_groups_aerr_mean.pdf", bbox_inches='tight')
        plt.close(clustermap.fig)
        
        fig, ax = plt.subplots(figsize=(1.5, 1.8 + 0.15 * len(groups)))
        sns.set_theme(style='whitegrid')
        heatmap = sns.heatmap(
            df_groups_pace_mean,
            annot=True,
            fmt=".3f",
            center=1.0,
            cmap='PiYG_r',
            linewidth=0.1,
            linecolor='black',
            annot_kws={"size": 35 / np.sqrt(max(df_groups_aerr_mean.shape))},
            ax=ax
        )
        ax.set_xticklabels([''])
        ax.set_xlabel('DunedinPACE', fontsize=16)
        ax.set_ylabel('Groups', fontsize=16)
        ax.figure.axes[-1].set_ylabel('Pace of Aging', size=16)
        for spine in ax.figure.axes[-1].spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in ax.get_yticklabels():
            tick_label.set_color(colors_groups[tick_label.get_text()])
        plt.savefig(f"{path_save}/heatmap_groups_pace_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/heatmap_groups_pace_mean.pdf", bbox_inches='tight')
        plt.close(fig)
        
        if len(groups) > 2:
            figsize_shift_y = 0.35
            row_cluster = True
        else:
            figsize_shift_y = 0.25
            row_cluster = False
        sns.set_theme(style='whitegrid')
        clustermap = sns.clustermap(
            df_groups_pace_mean,
            annot=True,
            col_cluster=False,
            row_cluster=row_cluster,
            fmt=".3f",
            center=1.0,
            cmap='PiYG_r',
            linewidth=0.1,
            linecolor='black',
            tree_kws=dict(linewidths=1.5),
            dendrogram_ratio=(0.6, 0.0),
            cbar_pos=(0.15, 1.06, 0.9, 0.04),
            cbar_kws={"orientation": "horizontal"},
            annot_kws={"size": 55 / np.sqrt(max(df_groups_aerr_mean.shape))},
            figsize=(4, (figsize_shift_y + 0.03 * len(groups)) * 10)
        )
        clustermap.ax_heatmap.set_xlabel('DunedinPACE', fontsize=20)
        clustermap.ax_heatmap.set_xticklabels("", fontsize = 18)
        clustermap.ax_heatmap.set_ylabel('Groups', fontsize=20)
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), fontsize=18, rotation=0)
        clustermap.ax_cbar.set_title('Pace of Aging', size=20)
        clustermap.ax_cbar.tick_params(labelsize=18)
        for spine in clustermap.ax_cbar.spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in clustermap.ax_heatmap.get_yticklabels():
            tick_label.set_color(colors_groups[tick_label.get_text()])
        plt.savefig(f"{path_save}/clustermap_groups_pace_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/clustermap_groups_pace_mean.pdf", bbox_inches='tight')
        plt.close(clustermap.fig)
        
    df_states_gse = dfs_states_gse[state]
    gses = df_states_gse.index.values
    
    if len(gses) > 1:
        
        df_state = df.loc[df['StateName'] == state, :]
        path_save = f"{path_load}/figures/epi_est_stat/gse/{state}"
        pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)
    
        colors_gses = {gse: px.colors.qualitative.Dark24_r[gse_id] for gse_id, gse in enumerate(gses)}
        df_gses_aerr_mean = pd.DataFrame(index=gses, columns=ages, data=np.zeros(shape=(len(gses), len(ages))))
        df_gses_pace_mean = pd.DataFrame(index=gses, columns=[pace], data=np.zeros(shape=(len(gses), 1)))
        for gse in gses:
            vals = df_state.loc[df_state['GSE'] == gse, pace].values
            df_gses_pace_mean.at[gse, pace] = np.mean(vals)
            for age_type in ages:
                vals = df_state.loc[df_state['GSE'] == gse, f"{age_type}Acc"].values
                df_gses_aerr_mean.at[gse, age_type] = np.mean(vals)
        df_gses_aerr_mean.to_excel(f"{path_save}/gses_aerr_mean.xlsx", index_label="GSE")
        df_gses_pace_mean.to_excel(f"{path_save}/gses_pace_mean.xlsx", index_label="GSE")
        
        fig, ax = plt.subplots(figsize=(2.7 + 0.375 * len(ages), 1.8 + 0.15 * len(gses)))
        sns.set_theme(style='whitegrid')
        heatmap = sns.heatmap(
            df_gses_aerr_mean,
            annot=True,
            fmt=".1f",
            center=0.0,
            cmap='seismic',
            linewidth=0.1,
            linecolor='black',
            annot_kws={"size": 35 / np.sqrt(max(df_gses_aerr_mean.shape))},
            ax=ax,
        )
        ax.set_xlabel('Epigenetic Age', fontsize=16)
        ax.set_ylabel('GSE', fontsize=16)
        ax.figure.axes[-1].set_ylabel('Acceleration', size=16)
        for spine in ax.figure.axes[-1].spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in ax.get_yticklabels():
            tick_label.set_color(colors_gses[tick_label.get_text()])
        plt.savefig(f"{path_save}/heatmap_gses_aerr_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/heatmap_gses_aerr_mean.pdf", bbox_inches='tight')
        plt.close(fig)
        
        if len(gses) > 2:
            figsize_shift_x = 0.35
            row_cluster = True
        else:
            figsize_shift_x = 0.125
            row_cluster = False
        sns.set_theme(style='whitegrid')
        clustermap = sns.clustermap(
            df_gses_aerr_mean,
            annot=True,
            col_cluster=True,
            row_cluster=row_cluster,
            fmt=".1f",
            center=0.0,
            cmap='seismic',
            linewidth=0.1,
            linecolor='black',
            tree_kws=dict(linewidths=1.5),
            annot_kws={"size": 55 / np.sqrt(max(df_gses_aerr_mean.shape))},
            figsize=((figsize_shift_x + 0.065 * len(ages)) * 10, (0.45 + 0.035 * len(gses)) * 10)
        )
        clustermap.ax_heatmap.set_xlabel('Epigenetic Age', fontsize=20)
        clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), fontsize=18)
        clustermap.ax_heatmap.set_ylabel('GSE', fontsize=20)
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), fontsize=18, rotation=0)
        clustermap.ax_cbar.set_ylabel('Acceleration', size=20)
        clustermap.ax_cbar.tick_params(labelsize=18)
        for spine in clustermap.ax_cbar.spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in clustermap.ax_heatmap.get_yticklabels():
            tick_label.set_color(colors_gses[tick_label.get_text()])
        plt.savefig(f"{path_save}/clustermap_gses_aerr_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/clustermap_gses_aerr_mean.pdf", bbox_inches='tight')
        plt.close(clustermap.fig)
        
        fig, ax = plt.subplots(figsize=(1.5, 1.8 + 0.15 * len(gses)))
        sns.set_theme(style='whitegrid')
        heatmap = sns.heatmap(
            df_gses_pace_mean,
            annot=True,
            fmt=".3f",
            center=1.0,
            cmap='PiYG_r',
            linewidth=0.1,
            linecolor='black',
            annot_kws={"size": 35 / np.sqrt(max(df_gses_aerr_mean.shape))},
            ax=ax
        )
        ax.set_xticklabels([''])
        ax.set_xlabel('DunedinPACE', fontsize=16)
        ax.set_ylabel('GSE', fontsize=16)
        ax.figure.axes[-1].set_ylabel('Pace of Aging', size=16)
        for spine in ax.figure.axes[-1].spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        ax.set_yticklabels(ax.get_yticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in ax.get_yticklabels():
            tick_label.set_color(colors_gses[tick_label.get_text()])
        plt.savefig(f"{path_save}/heatmap_gses_pace_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/heatmap_gses_pace_mean.pdf", bbox_inches='tight')
        plt.close(fig)
        
        if len(gses) > 2:
            figsize_shift_y = 0.35
            row_cluster = True
        else:
            figsize_shift_y = 0.25
            row_cluster = False
        sns.set_theme(style='whitegrid')
        clustermap = sns.clustermap(
            df_gses_pace_mean,
            annot=True,
            col_cluster=False,
            row_cluster=row_cluster,
            fmt=".3f",
            center=1.0,
            cmap='PiYG_r',
            linewidth=0.1,
            linecolor='black',
            tree_kws=dict(linewidths=1.5),
            dendrogram_ratio=(0.6, 0.0),
            cbar_pos=(0.15, 1.06, 0.9, 0.04),
            cbar_kws={"orientation": "horizontal"},
            annot_kws={"size": 55 / np.sqrt(max(df_gses_aerr_mean.shape))},
            figsize=(4, (figsize_shift_y + 0.03 * len(gses)) * 10)
        )
        clustermap.ax_heatmap.set_xlabel('DunedinPACE', fontsize=20)
        clustermap.ax_heatmap.set_xticklabels("", fontsize = 18)
        clustermap.ax_heatmap.set_ylabel('GSE', fontsize=20)
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), fontsize=18, rotation=0)
        clustermap.ax_cbar.set_title('Pace of Aging', size=20)
        clustermap.ax_cbar.tick_params(labelsize=18)
        for spine in clustermap.ax_cbar.spines.values():
            spine.set(visible=True, lw=0.25, edgecolor="black")
        clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_ymajorticklabels(), path_effects=[pe.withStroke(linewidth=0.2, foreground="black")])
        for tick_label in clustermap.ax_heatmap.get_yticklabels():
            tick_label.set_color(colors_gses[tick_label.get_text()])
        plt.savefig(f"{path_save}/clustermap_gses_pace_mean.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_save}/clustermap_gses_pace_mean.pdf", bbox_inches='tight')
        plt.close(clustermap.fig)
