In [None]:
import pandas as pd
import numpy as np
import pathlib as pl

from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from statannotations.Annotator import Annotator

from scipy.stats import mannwhitneyu, fisher_exact, pearsonr, kruskal

In [None]:
import sys
sys.path.append("../../FinalCode/")
import download.download as dwnl
import utils.plotting as plting
import adVMP.adVMP_discovery as discov
import adVMP.adVMP_plots as advmpplt
import adVMP.adVMP_crossval as advmpcross
import adVMP.comparison_random as rdn

In [None]:
# For figures
colors = sns.color_palette("muted")
fig_dir = pl.Path("/add/path/here")

In [None]:
base_dir = pl.Path("/add/path/here")
base_dir4 = pl.Path("/add/path/here")

data_dir = pl.Path("/add/path/here")

bad_probes = pd.read_csv(data_dir / "auxiliary" / "sketchy_probe_list_epic.csv",index_col=0).values.ravel()
sample_origin_path = pl.Path(data_dir / "clinical" / "sample_origin_wbatch.csv")

clinical_path = pl.Path(data_dir / "clinical" / "cleaned_clinical_reduced_diet.csv")
target_path = pl.Path(data_dir / "clinical" / "targets.csv")

In [None]:
EPIC2_b, EPIC2_clin, EPIC2_samples, EPIC2_phenotypes, EPIC3_b, EPIC3_clin, EPIC3_samples, EPIC3_phenotypes = dwnl.download_EPIC(sample_origin_path=sample_origin_path, 
                     base_dir=base_dir, clinical_path=clinical_path, target_path=target_path,
                  bad_probes=bad_probes, EPIC4=False) 

In [None]:
EPIC4_b, EPIC4_clin, EPIC4_samples, EPIC4_phenotypes = dwnl.download_EPIC(sample_origin_path=sample_origin_path, 
                     base_dir=base_dir4, clinical_path=clinical_path, target_path=target_path, 
                  bad_probes=bad_probes, EPIC4=True) 

# Find adVMP

In [None]:
advmpcross.get_stratified_hyper_DMC(y=EPIC2_phenotypes.astype(int), 
                             EPIC_m=EPIC2_b, 
                             result_dir=data_dir / "adVMP_crossvalidation_4fold" / "SWEPIC1",
                             n_splits=4, 
                             rs=1)

In [None]:
advmpcross.get_stratified_hyper_DMC(y=EPIC3_phenotypes.astype(int), 
                             EPIC_m=EPIC3_b, 
                             result_dir=data_dir / "adVMP_crossvalidation_4fold" / "SWEPIC2",
                            n_splits=4, 
                             rs=10)

In [None]:
advmpcross.get_stratified_hyper_DMC(y=EPIC4_phenotypes.astype(int), 
                             EPIC_m=EPIC4_b, 
                             result_dir=data_dir / "adVMP_crossvalidation_4fold" / "SWEPIC3",
                             n_splits=4, 
                             rs=1)

# Ensembling probes

In [None]:
test_results = {}
for i in ['1','2','3']:
    direc = data_dir / "adVMP_crossvalidation_4fold" / f"SWEPIC{i}"
    test_results[i] = {}
    for fold in direc.iterdir():
        if fold.stem==".DS_Store":
            continue
        test_results[i][fold.stem]  = pd.read_csv(fold / "adVMP_right.csv", index_col=0)

In [None]:
common_sign_probes = {}
for ds in test_results:
    common_sign_probes[ds] = []
    for i,fold in enumerate(test_results[ds]):
        df = test_results[ds][fold]
        sign = df[(df["q"]<0.05) & (df["ttest_p"]<0.05) & (df.diffV>0)]
        if i==0:
            common_sign_probes[ds] = sign.index
        else:
            common_sign_probes[ds] = np.intersect1d(common_sign_probes[ds], sign.index)

In [None]:
len(common_sign_probes["1"]),len(common_sign_probes["2"]),len(common_sign_probes["3"])

In [None]:
# retrieve the results of the variability/diff meth analysis performed on the entire SWEPIC cohorts 
fullset_test_results = {}
for i in ['1','2','3']:
    fullset_test_results[i]  = pd.read_csv(data_dir / "adVMP" / f"adVMP_SWEPIC{i}_right.csv", index_col=0)

In [None]:
# keep only the significant probes for each cohort
fullset_sign_probes = {}
for i in fullset_test_results:
    fullset_sign_probes[i] = fullset_test_results[i][(fullset_test_results[i]["q"]<0.05) & (fullset_test_results[i]["ttest_p"]<0.05) & (fullset_test_results[i]["diffV"]>0)]

In [None]:
# get the intersection of significant probes for each pair of cohorts. 
# this will be useful to get the aDVMCs associated with each fold
# e.g., for the folds on SWEPIC1, we will use the intersection of significant probes of SWEPIC2 and SWEPIC3
ext_intersection = {}
list_cohorts = ['1','2','3']
for i in list_cohorts:
    intersect = []
    for j in list_cohorts:
        if i==j:
            continue
        else:
            intersect.append(fullset_sign_probes[j].index.to_numpy())
    ext_intersection[i] = np.unique(np.intersect1d(*intersect))

In [None]:
from typing import Dict
def get_fold_specific_ensembling_cpgs(test_results: Dict, fullset_sign_probes: Dict, q_lim: float=0.05) -> Dict:
    
    sel_probes = {}
    for i in test_results:
        sel_probes[i] = {}
        for fold in test_results[i]:
            # for a fold, get the probes that are diff variable, diff methylated, and more variable in adenoma tissue
                signprobes = test_results[i][fold][(test_results[i][fold]["q"]<0.05) & (test_results[i][fold]["ttest_p"]<0.05) & (test_results[i][fold]["diffV"]>0)].index.to_numpy()
                intersect = []
                for j in fullset_sign_probes:
                    if i==j:
                        continue
                    else: 
                        # get the intersectin of these probes with the intersection of probes from the two other cohorts, computed previously 
                        # e.g., if the fold comes from SWEPIC1, get the intersection between the fold probes and those of SWEPIC2 and SWEPIC3.
                        intersect.append(np.intersect1d(signprobes,fullset_sign_probes[j].index.to_numpy()))
                intersect = np.unique(np.append(*intersect))
                sel_probes[i][fold] = np.unique(np.append(intersect, ext_intersection[i]))
    return sel_probes

In [None]:
union_cpgs_fold = get_fold_specific_ensembling_cpgs(test_results=test_results, 
                                                    fullset_sign_probes=fullset_sign_probes, 
                                                    q_lim=0.05)

# Visualize probe performance

In [None]:
ds_dir = data_dir / "adVMP_crossvalidation_4fold" / "SWEPIC1"
all_stats1, crossval_hit_fraction1 = advmpcross.get_crossval_performance(ds_dir=ds_dir, EPIC_b=EPIC2_b, 
                             union_cpgs_fold_spec=union_cpgs_fold["1"], EPIC_phenotypes=EPIC2_phenotypes, 
                             estimate_copa=True, order="Mixed Order")

In [None]:
all_stats1

In [None]:
fig, ax = plt.subplots(1,1,figsize=(15,5))
sns.scatterplot(data=crossval_hit_fraction1, x="Mixed Order", y="Hit fraction", hue="Ad_plot", 
                palette={"No": colors[0], "Yes": colors[3]})
plting.transform_plot_ax(ax, legend_title="Adenoma")
fig.savefig(fig_dir / "SWEPIC1" / "crossval_worm_plot.svg", bbox_inches="tight")

In [None]:
ds_dir = data_dir / "adVMP_crossvalidation_4fold" / "SWEPIC2"
all_stats2, crossval_hit_fraction2 = advmpcross.get_crossval_performance(ds_dir=ds_dir, EPIC_b=EPIC3_b, 
                             union_cpgs_fold_spec=union_cpgs_fold["2"], EPIC_phenotypes=EPIC3_phenotypes,
                             estimate_copa=True, order="Mixed Order")

In [None]:
all_stats2

In [None]:
fig, ax = plt.subplots(1,1,figsize=(15,5))
sns.scatterplot(data=crossval_hit_fraction2, x="Mixed Order", y="Hit fraction", hue="Ad_plot", 
                palette={"No": colors[0], "Yes": colors[3]})
plting.transform_plot_ax(ax, legend_title="Adenoma")
fig.savefig(fig_dir / "SWEPIC2" / "crossval_worm_plot.svg", bbox_inches="tight")

In [None]:
ds_dir = data_dir / "adVMP_crossvalidation_4fold" / "SWEPIC3"
all_stats3, crossval_hit_fraction3 = advmpcross.get_crossval_performance(ds_dir=ds_dir, EPIC_b=EPIC4_b, 
                             union_cpgs_fold_spec=union_cpgs_fold["3"], EPIC_phenotypes=EPIC4_phenotypes, 
                             estimate_copa=True, order="Mixed Order")

In [None]:
all_stats3

In [None]:
fig, ax = plt.subplots(1,1,figsize=(15,5))
sns.scatterplot(data=crossval_hit_fraction3, x="Mixed Order", y="Hit fraction", hue="Ad_plot", 
                palette={"No": colors[0], "Yes": colors[3]})
plting.transform_plot_ax(ax, legend_title="Adenoma")
fig.savefig(fig_dir / "SWEPIC3" / "crossval_worm_plot.svg", bbox_inches="tight")

In [None]:
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5))
RocCurveDisplay.from_predictions(
    crossval_hit_fraction1["Ad"].astype(int).ravel(),
    crossval_hit_fraction1["Hit fraction"].ravel(),
    ax=ax,
    c=colors[6],
    name='SWEPIC1',
)
RocCurveDisplay.from_predictions(
    crossval_hit_fraction2["Ad"].astype(int).ravel(),
    crossval_hit_fraction2["Hit fraction"].ravel(),
    ax=ax,
    c=colors[7],
    name='SWEPIC2',
)
RocCurveDisplay.from_predictions(
    crossval_hit_fraction3["Ad"].astype(int).ravel(),
    crossval_hit_fraction3["Hit fraction"].ravel(),
    ax=ax, 
    c=colors[9],
    name='SWEPIC3',
)
plt.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100), c=colors[3])
plting.transform_plot_ax(ax, legend_title="", ftsize=17, leg_ftsize=17, linew=3)
fig.savefig(fig_dir / "ROC_AUC_curve_crossval.svg", bbox_inches="tight")

# Hit fraction crossval

In [None]:
dfs = []
df1 = crossval_hit_fraction1[["Hit fraction","Ad_plot","Mixed Order"]]
df1 = pd.concat([df1,pd.DataFrame(["SWEPIC1"]*df1.shape[0],
                                      index=df1.index,columns=["Batch"])],axis=1)
vc1 = df1.Ad_plot.value_counts()
dfs.append(df1)
df2 = crossval_hit_fraction2[["Hit fraction","Ad_plot","Mixed Order"]]
df2 = pd.concat([df2,pd.DataFrame(["SWEPIC2"]*df2.shape[0],
                                      index=df2.index,columns=["Batch"])],axis=1)
vc2 = df2.Ad_plot.value_counts()
dfs.append(df2)
df3 = crossval_hit_fraction3[["Hit fraction","Ad_plot","Mixed Order"]]
df3 = pd.concat([df3,pd.DataFrame(["SWEPIC3"]*df3.shape[0],
                                      index=df3.index,columns=["Batch"])],axis=1)
vc3 = df3.Ad_plot.value_counts()
dfs.append(df3)
df = pd.concat(dfs)

In [None]:
fig, ax= plt.subplots(1,1, figsize=(8,4))
sns.boxplot(data=df, x="Batch", y="Hit fraction", hue="Ad_plot", palette={"No": colors[0], "Yes": colors[3]})
annot = Annotator(
        ax,
        pairs=[(("SWEPIC1", "No"),("SWEPIC1", "Yes")),
               (("SWEPIC2", "No"),("SWEPIC2", "Yes")),
               (("SWEPIC3", "No"),("SWEPIC3", "Yes"))],
        data=df, x="Batch", y="Hit fraction", hue="Ad_plot",
    )
annot.configure(
        test="Mann-Whitney",
        loc="inside",
    text_format="simple",
        show_test_name=False,
        verbose=2,
        comparisons_correction=None,
        correction_format="replace",
    )
annot.apply_test()
ax, _ = annot.annotate()
plting.transform_plot_ax(ax, legend_title="Adenoma (right)",linew=2.5)
ax.set_xticklabels(["SWEPIC1\n$N_{No}$="+f"{vc1.loc['No']}\n"+"$N_{Yes}$="+f"{vc1.loc['Yes']}",
                    "SWEPIC2\n$N_{No}$="+f"{vc2.loc['No']}\n"+"$N_{Yes}$="+f"{vc2.loc['Yes']}",
                    "SWEPIC3\n$N_{No}$="+f"{vc3.loc['No']}\n"+"$N_{Yes}$="+f"{vc3.loc['Yes']}"], size=12)
ax.set_xlabel("")
fig.savefig(fig_dir / "crossval_hit_fraction_dist.svg", bbox_inches="tight")

# Hit fraction crossval per age category

In [None]:
from typing import List
def get_plot_by_age_group(EPIC_clin: pd.DataFrame, 
                          age_bins: List, age_cat_labels: List, 
                          heatmap_df: pd.DataFrame, title: str) -> plt.Axes:
    age_cat = pd.cut(EPIC_clin["Age at visit"],
       bins=age_bins, labels=age_cat_labels)

    df = pd.concat([heatmap_df[["Hit fraction","Ad_plot"]],age_cat],axis=1)
    
    print(df.groupby(["Age at visit","Ad_plot"]).median())
    
    vc = df.value_counts(["Age at visit","Ad_plot"])
    xticklabs = [f"{cat}\n"+"$N_{No}$="+f"{vc.loc[cat,'No']}\n"+"$N_{Yes}$="+f"{vc.loc[cat,'Yes']}" for cat in age_cat_labels]
    
    pairs = [((cat,"No"),(cat,"Yes")) for cat in age_cat_labels]
    fig, ax = plt.subplots(1,1)
    sns.boxplot(data=df, x="Age at visit",y="Hit fraction",hue="Ad_plot",
                palette={"No": colors[0], "Yes": colors[3]},
                ax=ax)

    annot = Annotator(
            ax,
            pairs=pairs,
            data=df, x="Age at visit", y="Hit fraction", hue="Ad_plot",
        )
    annot.configure(
            test="Mann-Whitney",
            loc="inside",
        text_format="simple",
            show_test_name=False,
            verbose=2,
            comparisons_correction=None,
            correction_format="replace",
        )
    annot.apply_test()
    ax, _ = annot.annotate()
    
    ax.set_ylim([0,0.65])
    plting.transform_plot_ax(ax, legend_title="Adenoma")
    ax.set_xticklabels(xticklabs)
    ax.set_xlabel("")
    ax.set_title(title)
    
    
    return ax

In [None]:
age_bins = [0,55,65,120]
age_cat_labels = ["<55","55-65",">=65"]

In [None]:
ax = get_plot_by_age_group(EPIC_clin=EPIC2_clin, 
                          age_bins=age_bins, age_cat_labels=age_cat_labels, 
                          heatmap_df=crossval_hit_fraction1, title="SWEPIC1")
ax.figure.savefig(fig_dir / "crossval_SWEPIC1_age_cat_hit_fraction_dist.svg", bbox_inches="tight")

In [None]:
ax = get_plot_by_age_group(EPIC_clin=EPIC3_clin, 
                          age_bins=age_bins, age_cat_labels=age_cat_labels, 
                          heatmap_df=crossval_hit_fraction2, title="SWEPIC2")
ax.figure.savefig(fig_dir / "crossval_SWEPIC2_age_cat_hit_fraction_dist.svg", bbox_inches="tight")

In [None]:
ax = get_plot_by_age_group(EPIC_clin=EPIC4_clin, 
                          age_bins=age_bins, age_cat_labels=age_cat_labels, 
                          heatmap_df=crossval_hit_fraction3, title="SWEPIC3")
ax.figure.savefig(fig_dir / "crossval_SWEPIC3_age_cat_hit_fraction_dist.svg", bbox_inches="tight")

# Compare to Horvath 

In [None]:
horvath_age = pd.read_csv("../../FinalData/auxiliary/horvath_age.csv",index_col=0)

horvath_age = horvath_age[~horvath_age.index.duplicated()]
horvath_age.index = horvath_age.index.astype(str)
horvath_age.columns = ["Horvath age"]

In [None]:
age_bins = [0,55,65,120]
age_cat_labels = ["<55","55-65",">=65"]


from typing import List
def get_plots_per_agecat(EPIC_clin: pd.DataFrame, 
                         df: pd.DataFrame, 
                         age_cat_labels: List, 
                         age_bins: List, fig_dir: pl.Path,
                         name: str) -> None:

    age_cat = pd.cut(EPIC_clin["Age at visit"],
           bins=age_bins, labels=age_cat_labels)
    dfage = pd.concat([df,age_cat],axis=1)

    fig, ax = plt.subplots(1,3,figsize=(15,3))
    for i,cat in enumerate(age_cat_labels):
        sub_df = dfage[dfage["Age at visit"]==cat]
        corr = sub_df[["Hit fraction","Horvath age"]].corr().iloc[0,1]
        sns.scatterplot(data=sub_df,x="Hit fraction",y="Horvath age",ax=ax[i],s=10)
        ax[i].set_ylim([35,85])
        ax[i].set_title(f"Age {cat}",fontsize=15)
        ax[i].text(0.7*sub_df["Hit fraction"].max(),40,f'r={corr:.2f}',fontsize=15)
        plting.transform_plot_ax(ax[i], legend_title="")
    fig.savefig(fig_dir / f"horvath_age_hit_fraction_age_cat_{name}.svg",bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,3,figsize=(15,3))
for i,cv in enumerate([crossval_hit_fraction1,crossval_hit_fraction2,crossval_hit_fraction3]):
    df = pd.concat([cv,horvath_age],axis=1,join="inner")
    print(f"SWEPIC{i+1}")
    print(df[["Hit fraction","Horvath age","Ad"]].corr())
    corr = df[["Hit fraction","Horvath age"]].corr().iloc[0,1]
    sns.scatterplot(data=df,x="Hit fraction",y="Horvath age",ax=ax[i],s=10)
    ax[i].set_ylim([35,85])
    ax[i].set_title(f"SWEPIC{i+1}",fontsize=15)
    ax[i].text(0.7*df["Hit fraction"].max(),40,f'r={corr:.2f}',fontsize=15)
    plting.transform_plot_ax(ax[i], legend_title="")
    
fig.savefig(fig_dir / "horvath_age_hit_fraction_link.svg",bbox_inches="tight")

In [None]:
list_clin = [EPIC2_clin, EPIC3_clin, EPIC4_clin]
for i,cv in enumerate([crossval_hit_fraction1,crossval_hit_fraction2,crossval_hit_fraction3]):
    df = pd.concat([cv,horvath_age],axis=1,join="inner")
    name=f"SWEPIC{i+1}"
    
    get_plots_per_agecat(EPIC_clin=list_clin[i], 
                         df=df, 
                         age_cat_labels=age_cat_labels, 
                         age_bins=age_bins, fig_dir=fig_dir,
                         name=name)

# Compare to random probes

In [None]:
swepic1_ref = 0.82

background_cpgs = pd.read_csv("../../FinalData/variable_probes/union_cpgs_5_pct_most_variable_onlyhealthy.csv",index_col=0).values.ravel()
ds_dir = pl.Path("/Users/josephineyates/Documents/CRC_polyp/FinalData/adVMP_crossvalidation_4fold/SWEPIC1/")

advmpcross.get_comparison_rdn(background_cpgs=background_cpgs, 
                       figdir=fig_dir / "SWEPIC1",
                       ref=swepic1_ref,
                       ds_dir=ds_dir, 
                       phenotypes=EPIC2_phenotypes,
                       union_cpgs_fold_spec=union_cpgs_fold["1"], 
                       data=EPIC2_b, 
                       clin=EPIC2_clin,
                       n_iter=200, order="Mixed Order")

In [None]:
swepic2_ref = 0.63

background_cpgs = pd.read_csv("../../FinalData/variable_probes/union_cpgs_5_pct_most_variable_onlyhealthy.csv",index_col=0).values.ravel()
ds_dir = pl.Path("/Users/josephineyates/Documents/CRC_polyp/FinalData/adVMP_crossvalidation_4fold/SWEPIC2/")

advmpcross.get_comparison_rdn(background_cpgs=background_cpgs, 
                       figdir=fig_dir / "SWEPIC2",
                       ref=swepic2_ref,
                       ds_dir=ds_dir, 
                       phenotypes=EPIC3_phenotypes,
                       union_cpgs_fold_spec=union_cpgs_fold["2"], 
                       data=EPIC3_b, 
                       clin=EPIC3_clin,
                       n_iter=200, order="Mixed Order")

In [None]:
swepic3_ref = 0.66

background_cpgs = pd.read_csv("../../FinalData/variable_probes/union_cpgs_5_pct_most_variable_onlyhealthy.csv",index_col=0).values.ravel()
ds_dir = pl.Path("/Users/josephineyates/Documents/CRC_polyp/FinalData/adVMP_crossvalidation_4fold/SWEPIC3/")

advmpcross.get_comparison_rdn(background_cpgs=background_cpgs, 
                       figdir=fig_dir / "SWEPIC3",
                       ref=swepic3_ref,
                       ds_dir=ds_dir, 
                       phenotypes=EPIC4_phenotypes,
                       union_cpgs_fold_spec=union_cpgs_fold["3"], 
                       data=EPIC4_b, 
                       clin=EPIC4_clin,
                       n_iter=200, order="Mixed Order")