# Debugging autoreload

In [ ]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
import pandas as pd
import numpy as np
from scipy import stats
import seaborn as sns
import plotly.express as px
import statsmodels.formula.api as smf
import plotly.graph_objects as go
from scripts.python.routines.manifest import get_manifest
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=False)
from scipy.stats import mannwhitneyu, median_test, kruskal, wilcoxon, friedmanchisquare
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects as path_effects
import random
import pathlib
from tqdm import tqdm
from src.utils.plot.bioinfokit import mhat, volcano
import gseapy as gp
import mygene
from sklearn.decomposition import PCA, IncrementalPCA, KernelPCA, TruncatedSVD
from sklearn.decomposition import MiniBatchDictionaryLearning, FastICA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap, TSNE, LocallyLinearEmbedding
import upsetplot
from matplotlib_venn import venn2, venn2_circles
from itertools import chain
from sklearn.metrics import mean_absolute_error
from scripts.python.routines.plot.colorscales import get_continuous_color
from impyute.imputation.cs import fast_knn
import plotly
from scripts.python.routines.plot.p_value import add_p_value_annotation
from scripts.python.routines.sections import get_sections
from statannotations.Annotator import Annotator
import functools
import matplotlib.lines as mlines
import patchworklib as pw


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

# Load data

In [None]:
path = f"D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN"
df_imm = pd.read_excel(f"{path}/data/immuno/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
feats_imm = pd.read_excel(f"{path}/data/immuno/feats_con.xlsx", index_col=0).index.values
df_ld_imm = df_imm['Subject ID'].value_counts().to_frame()
df_imm['Is longitudinal?'] = False
df_imm.loc[df_imm['Subject ID'].isin(df_ld_imm.index[df_ld_imm['Subject ID'] > 1].values), 'Is longitudinal?'] = True
df_imm = df_imm.loc[(df_imm['Status'] == 'Control'), :]
df_imm.rename(columns={'Sample_Chronology': 'Time'}, inplace=True)
df_imm['Time'].replace({0: 'T0', 1: 'T1', 2: 'T2', 3: 'T3'}, inplace=True)
df_imm_ppr = pd.read_excel(f"{path}/data/immuno/models/SImAge/data.xlsx", index_col="sample_id")
ids_imm_ppr = df_imm_ppr.index[df_imm_ppr['Status'] == 'Control'].values
df_imm.loc[ids_imm_ppr, 'ids_fimmu'] = df_imm_ppr.loc[ids_imm_ppr, 'index']
feats_imm_ppr = pd.read_excel(f"{path}/data/immuno/models/SImAge/feats_con_top10.xlsx", index_col=0).index.values

epi_suffix = "_harm"
df_epi = pd.read_excel(f"{path}/pheno.xlsx", index_col="index")
df_epi.index.name = "index"
df_epi.drop(["I64_old", "I1_duplicate"], inplace=True)
df_epi.rename(columns={'Subject_ID': 'Subject ID'}, inplace=True)
df_ld_epi = df_epi['Subject ID'].value_counts().to_frame()
df_epi['Is longitudinal?'] = False
df_epi.loc[df_epi['Subject ID'].isin(df_ld_epi.index[df_ld_epi['Subject ID'] > 1].values), 'Is longitudinal?'] = True
df_epi = df_epi.loc[(df_epi['Status'] == 'Control'), :]
df_epi.rename(columns={'Sample_Chronology': 'Time'}, inplace=True)
df_epi['Time'].replace({0: 'T0', 1: 'T1', 2: 'T2', 3: 'T3'}, inplace=True)
ids_epi_ppr = pd.read_excel(f"{path}/data/GSE234461/samples.xlsx", index_col=0).index.values

path_save = f"{path}/special/059_imm_data_selection"
pathlib.Path(f"{path_save}").mkdir(parents=True, exist_ok=True)

df_imm_w_nans = pd.read_excel(f"{path_save}/df_imm_w_nans.xlsx", index_col="Index")

# Select samples

## Filter Yakutian samples with small SImAge MAE value

In [None]:
thld_mae = 8.05

df_imm_yak = df_imm.loc[df_imm['Region'] == 'Yakutia']
df_imm_yak.sort_values(["|SImAge acceleration|"], ascending=[True], inplace=True)
df_imm_yak['|SImAge acceleration| cumsum'] = df_imm_yak['|SImAge acceleration|'].expanding().mean()
ids_imm_yak = df_imm_yak.index[df_imm_yak['|SImAge acceleration| cumsum'] < thld_mae].values

df_imm_cnt = df_imm.loc[(df_imm['Region'] == 'Central') & ~df_imm.index.isin(ids_imm_ppr), :]
df_imm_cnt.sort_values(["|SImAge acceleration|"], ascending=[True], inplace=True)
df_imm_cnt['|SImAge acceleration| cumsum'] = df_imm_cnt['|SImAge acceleration|'].expanding().mean()
ids_imm_cnt = df_imm_cnt.index[df_imm_cnt['|SImAge acceleration| cumsum'] < thld_mae].values

mae_imm_ppr = mean_absolute_error(df_imm.loc[ids_imm_ppr, 'Age'].values, df_imm.loc[ids_imm_ppr, 'SImAge'].values)
mae_imm_yak = mean_absolute_error(df_imm.loc[ids_imm_yak, 'Age'].values, df_imm.loc[ids_imm_yak, 'SImAge'].values)
mae_imm_cnt = mean_absolute_error(df_imm.loc[ids_imm_cnt, 'Age'].values, df_imm.loc[ids_imm_cnt, 'SImAge'].values)

## Define samples for immunology and epigenetics

In [None]:
ids_imm = list(set.union(set(ids_imm_ppr), set(ids_imm_yak)))
# ids_imm = list(set.union(set(ids_imm_ppr), set(ids_imm_yak), set(ids_imm_cnt)))
ids_epi_ppr = ids_epi_ppr
ids_epi_full = df_epi.index.values

epi_types = {
    'paper_only': ids_epi_ppr,
    'full': ids_epi_full
}

## Compare immunology and epigenetics samples

In [None]:
for epi_type, ids_epi in epi_types.items():

    pathlib.Path(f"{path_save}/imm_vs_epi/{epi_type}").mkdir(parents=True, exist_ok=True)
    
    sections = get_sections([set(ids_imm), set(ids_epi)])
    for sec in sections:
        df_sec = pd.DataFrame(index=list(sections[sec]))
        df_sec.to_excel(f"{path_save}/imm_vs_epi/{epi_type}/{sec}.xlsx", index_label='index')
    
    fig, ax = plt.subplots()
    venn = venn2(
        subsets=(set(ids_imm), set(ids_epi)),
        set_labels = ('Imm', 'Epi'),
        set_colors=('r', 'g'),
        alpha = 0.5
    )
    venn2_circles(subsets=(set(ids_imm), set(ids_epi)))
    for text in venn.set_labels:
        text.set_fontsize(16)
    for text in venn.subset_labels:
        text.set_fontsize(25)
    plt.savefig(f"{path_save}/imm_vs_epi/{epi_type}/venn.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_save}/imm_vs_epi/{epi_type}/venn.pdf", bbox_inches='tight')
    plt.clf()

# Processing immunology samples

In [None]:
df_imm = df_imm.loc[ids_imm, :]
df_imm_w_nans = df_imm_w_nans.loc[df_imm.index.values, :]
df_imm.sort_values([f"ids_fimmu"], ascending=[True], inplace=True)
df_imm.to_excel(f"{path_save}/df_imm.xlsx", index_label="index")

## Statistics of missed values

In [None]:
df_nan_feats = df_imm_w_nans.loc[:, feats_imm].isna().sum(axis=0).to_frame(name="Number of NaNs")
df_nan_feats["% of NaNs"] = df_nan_feats["Number of NaNs"] / df_imm.shape[0] * 100
df_nan_feats["Number of not-NaNs"] = df_imm_w_nans.loc[:, feats_imm].notna().sum(axis=0)
df_nan_feats.sort_values(["% of NaNs"], ascending=[False], inplace=True)
df_nan_feats.to_excel(f"{path_save}/df_nan_feats.xlsx", index_label="Features")

plt.figure(figsize=(14, 4))
plt.xticks(rotation=90)
sns.set_theme(style='whitegrid')
barplot = sns.barplot(
    data=df_nan_feats,
    x=df_nan_feats.index,
    y=f"% of NaNs",
    edgecolor='black',
    dodge=False
)
plt.savefig(f"{path_save}/df_nan_feats.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/df_nan_feats.pdf", bbox_inches='tight')
plt.close()

## Selection of good features

In [None]:
thld_nan_in_feat = 25.3
feats_imm_good = set(df_nan_feats.index[df_nan_feats['% of NaNs'] <= thld_nan_in_feat].values).union(set(feats_imm_ppr))
print(f"Number of filtered features: {len(feats_imm_good)}")
df_nan_feats.loc[feats_imm_good, :].to_excel(f"{path_save}/feats_imm_good.xlsx", index_label="Features")

# Region-specific analysis

In [None]:
colors_region = {'Central': 'gold', 'Yakutia': 'lightslategray'}
pathlib.Path(f"{path_save}/region_specific").mkdir(parents=True, exist_ok=True)

## Plot distribution

In [None]:
hist_bins = np.linspace(5, 115, 23)

fig, ax = plt.subplots(figsize=(4, 3))
sns.set_theme(style='whitegrid')
histplot = sns.histplot(
    data=df_imm,
    bins=hist_bins,
    edgecolor='k',
    linewidth=1,
    x="Age",
    hue='Region',
    palette=colors_region,
    hue_order=['Yakutia', 'Central'],
    ax=ax
)
histplot.set(xlim=(0, 120))
plt.savefig(f"{path_save}/region_specific/histplot.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/histplot.pdf", bbox_inches='tight')
plt.close(fig)

## Calculate statistics

In [None]:
df_stat = pd.DataFrame(index=list(feats_imm_good))
for feat in list(feats_imm_good):
    vals = {}
    for group in ['Central', 'Yakutia']:
        vals[group] = df_imm.loc[df_imm['Region'] == group, feat].values
        df_stat.at[feat, f"mean_{group}"] = np.mean(vals[group])
        df_stat.at[feat, f"median_{group}"] = np.median(vals[group])
        df_stat.at[feat, f"q75_{group}"], df_stat.at[feat, f"q25_{group}"] = np.percentile(vals[group], [75 , 25])
        df_stat.at[feat, f"iqr_{group}"] = df_stat.at[feat, f"q75_{group}"] - df_stat.at[feat, f"q25_{group}"]
    _, df_stat.at[feat, "mw_pval"] = mannwhitneyu(vals['Central'], vals['Yakutia'], alternative='two-sided')

_, df_stat.loc[feats_imm_good, "mw_pval_fdr_bh"], _, _ = multipletests(df_stat.loc[feats_imm_good, "mw_pval"], 0.05, method='fdr_bh')
df_stat.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_stat.to_excel(f"{path_save}/region_specific/stat.xlsx", index_label='Features')

## Plot features p-values

In [None]:
df_fig = df_stat.loc[feats_imm_good, :]
df_fig.sort_values([f"mw_pval_fdr_bh"], ascending=[True], inplace=True)
df_fig['mw_pval_fdr_bh_log'] = -np.log10(df_fig['mw_pval_fdr_bh'])
df_fig['color'] = 'pink'
df_fig.loc[df_fig['mw_pval_fdr_bh'] < 0.05, 'color'] = 'red'

fig, ax = plt.subplots(figsize=(3, 12))
sns.set_theme(style='whitegrid')
barplot = sns.barplot(
    data=df_fig,
    y=df_fig.index.values,
    x='mw_pval_fdr_bh_log',
    edgecolor='black',
    palette=df_fig['color'].values,
    dodge=True,
    ax=ax
)
ax.set_xlabel(r"$-\log_{10}(\mathrm{p-value})$", fontsize=18)
ax.set_ylabel('', fontsize=20)
ax.set_xticklabels([f"{int(tick):d}" for tick in ax.get_xticks()], fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize = 16)
plt.savefig(f"{path_save}/region_specific/barplot.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/barplot.pdf", bbox_inches='tight')
plt.close(fig)

## Plot features distributions

In [None]:
n_rows = 4
n_cols = 8
fig_width = 15
fig_height = 12

fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={})
sns.set_theme(style='whitegrid')

feats_sorted = df_stat.index[df_stat.index.isin(feats_imm_good)].values

for f_id, f in enumerate(feats_sorted):
    row_id, col_id = divmod(f_id, n_cols)
    
    q01 = df_imm[f].quantile(0.01)
    q99 = df_imm[f].quantile(0.99)
    
    sns.violinplot(
        data=df_imm.loc[(df_imm[f] > q01) & (df_imm[f] < q99), :],
        x='Region',
        y=f,
        palette=colors_region,
        scale='width',
        order=list(colors_region.keys()),
        saturation=0.75,
        cut=0,
        linewidth=1.0,
        ax=axs[row_id, col_id],
        legend=False,
    )
    axs[row_id, col_id].set_ylabel(f)
    axs[row_id, col_id].set_xlabel('')
    axs[row_id, col_id].set(xticklabels=[]) 
    mw_pval = df_stat.at[f, "mw_pval_fdr_bh"]
    pval_formatted = [f'{mw_pval:.2e}']
    annotator = Annotator(
        axs[row_id, col_id],
        pairs=[('Central', 'Yakutia')],
        data=df_imm,
        x='Region',
        y=f,
        order=list(colors_region.keys()),
    )
    annotator.set_custom_annotations(pval_formatted)
    annotator.configure(loc='outside')
    annotator.annotate()

legend_handles = [
    mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor=colors_region['Central'], markersize=10, label='Central'),
    mlines.Line2D([], [], marker='o', linestyle='None', markeredgecolor='k', markerfacecolor=colors_region['Yakutia'], markersize=10, label='Yakutia')
]
fig.legend(handles=legend_handles, bbox_to_anchor=(0.5, 1.0), loc="lower center", ncol=2, frameon=False, fontsize='large')
fig.tight_layout()    
plt.savefig(f"{path_save}/region_specific/feats.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/feats.pdf", bbox_inches='tight')
plt.close(fig)
    

## SImAge analysis

In [None]:
formula = f"SImAge ~ Age"
model = smf.ols(formula=formula, data=df_imm.loc[df_imm['Region'] == 'Central', :]).fit()
df_imm[f"SImAge_Central_linreg"] = model.predict(df_imm)
df_imm[f"SImAge residuals"] = df_imm['SImAge'] - df_imm["SImAge_Central_linreg"]

fig, ax = plt.subplots(figsize=(4.5, 4))
sns.set_theme(style='whitegrid')
scatter = sns.scatterplot(
    data=df_imm,
    x="Age",
    y="SImAge",
    hue="Region",
    palette=colors_region,
    linewidth=0.2,
    alpha=0.75,
    edgecolor="k",
    s=20,
    hue_order=list(colors_region.keys()),
    ax=ax
)
bisect = sns.lineplot(
    x=[0, 120],
    y=[0, 120],
    linestyle='--',
    color='black',
    linewidth=1.0,
    ax=ax
)
df_line = pd.DataFrame({'Age': [-100, 200]})
df_line[f"SImAge_Central_linreg"] = model.predict(df_line)
central_linreg_back = sns.lineplot(
    x=df_line['Age'].values,
    y=df_line['SImAge_Central_linreg'].values,
    color='black',
    linewidth=3.0,
    ax=ax
)
central_linreg_front = sns.lineplot(
    x=df_line['Age'].values,
    y=df_line['SImAge_Central_linreg'].values,
    color=colors_region['Central'],
    linewidth=2.0,
    ax=ax
)
ax.set_xlabel("Age")
ax.set_ylabel("SImAge")
ax.set_xlim(0, 120)
ax.set_ylim(0, 120)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(f"{path_save}/region_specific/SImAge/scatter.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/SImAge/scatter.pdf", bbox_inches='tight')
plt.close()

plt.figure(figsize=(4, 4))
sns.set_theme(style='whitegrid')
violin = sns.violinplot(
    data=df_imm,
    x='Region',
    y='SImAge acceleration',
    palette=colors_region,
    scale='width',
    order=list(colors_region.keys()),
    saturation=0.75,
)
violin.set_xlabel(f"")
mw_pval = mannwhitneyu(
    df_imm.loc[df_imm['Region'] == 'Central', 'SImAge acceleration'].values,
    df_imm.loc[df_imm['Region'] == 'Yakutia', 'SImAge acceleration'].values,
    alternative='two-sided').pvalue
pval_formatted = [f'{mw_pval:.2e}']
annotator = Annotator(
    violin,
    pairs=[('Central', 'Yakutia')],
    data=df_imm,
    x='Region',
    y='SImAge acceleration',
    order=list(colors_region.keys())
)
annotator.set_custom_annotations(pval_formatted)
annotator.configure(loc='outside')
annotator.annotate()
plt.savefig(f"{path_save}/region_specific/SImAge/acceleration.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/SImAge/acceleration.pdf", bbox_inches='tight')
plt.close()

plt.figure(figsize=(4, 4))
sns.set_theme(style='whitegrid')
violin = sns.violinplot(
    data=df_imm,
    x='Region',
    y='SImAge residuals',
    palette=colors_region,
    scale='width',
    order=list(colors_region.keys()),
    saturation=0.75,
)
violin.set_xlabel(f"")
mw_pval = mannwhitneyu(
    df_imm.loc[df_imm['Region'] == 'Central', 'SImAge residuals'].values,
    df_imm.loc[df_imm['Region'] == 'Yakutia', 'SImAge residuals'].values,
    alternative='two-sided').pvalue
pval_formatted = [f'{mw_pval:.2e}']
annotator = Annotator(
    violin,
    pairs=[('Central', 'Yakutia')],
    data=df_imm,
    x='Region',
    y='SImAge residuals',
    order=list(colors_region.keys())
)
annotator.set_custom_annotations(pval_formatted)
annotator.configure(loc='outside')
annotator.annotate()
plt.savefig(f"{path_save}/region_specific/SImAge/residuals.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path_save}/region_specific/SImAge/residuals.pdf", bbox_inches='tight')
plt.close()
