In [None]:
import pandas as pd
import numpy as np
import pathlib as pl
from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from scipy.stats import ks_2samp
from tqdm.notebook import tqdm
from adjustText import adjust_text

In [None]:
import sys
sys.path.append("../../FinalCode/")
import download.download_gex as dwnl
import utils.plotting as plting

In [None]:
colors = sns.color_palette("muted")

In [None]:
fig_dir = pl.Path("/add/path/here/")

# Load mapping

In [None]:
data_dir = pl.Path("/add/path/here/")

In [None]:
mapping_roadmap = pd.read_csv(data_dir / "NIH_Epigenomics_Roadmap" / "EPIC_to_state_mapping.csv",index_col=0)

In [None]:
union_cpgs = pd.read_csv(data_dir / "adVMP" / "union_cpgs.csv",index_col=0).values.ravel()

In [None]:
epic_manifest = pd.read_csv(data_dir / "illumina_manifests" / "GPL21145_MethylationEPIC_15073387_v-1-0.csv.gz",skiprows=7,index_col=0)

In [None]:
epic_manifest = epic_manifest.loc[union_cpgs]
epic_manifest = pd.concat([epic_manifest,mapping_roadmap],axis=1,join="inner")

In [None]:
red_manifest = epic_manifest[['CHR', 'MAPINFO',
       'UCSC_RefGene_Name', 
       'UCSC_RefGene_Group', 'UCSC_CpG_Islands_Name',
       'Relation_to_UCSC_CpG_Island', '450k_Enhancer',"State"]]

In [None]:
tssA_manifest = red_manifest[red_manifest.State.isin(["1_TssA","2_TssAFlnk","3_TxFlnk"])]
tss_unique_genes = np.unique(np.concatenate(tssA_manifest["UCSC_RefGene_Name"].dropna().str.split(";").values))

poised_manifest = red_manifest[red_manifest.State.isin(["10_TssBiv","11_BivFlnk"])]
poised_unique_genes = np.unique(np.concatenate(poised_manifest["UCSC_RefGene_Name"].dropna().str.split(";").values))

# Download data

In [None]:
gex_path = pl.Path("/add/path/here")
path_right = gex_path  / "GSE76987_RightColonProcessed.csv"
path_right_cr = gex_path / "/add/path/here/GSE76987_ColonCancerProcessed.csv"

In [None]:
right_data = dwnl.download_gex_data(path_right=path_right, path_right_cr=path_right_cr)

In [None]:
test_genes = right_data.columns.intersection(tss_unique_genes)
poised_genes = right_data.columns.intersection(poised_unique_genes)

In [None]:
ordered_type = pd.Categorical(right_data.type, 
                      categories=["Healthy","NAC","Adenoma","SSL","Cancer"],
                      ordered=True)

right_data["Ordered type"] = ordered_type

ordered_idx = right_data.sort_values(by="Ordered type").index

In [None]:
pca = PCA(n_components=2)
X_PCA = pca.fit_transform(right_data.loc[:,test_genes])
X_PCA = pd.DataFrame(X_PCA, index=right_data.index, columns=["PCA1","PCA2"])
X_PCA = pd.concat([X_PCA, right_data["Ordered type"]],axis=1)
print(pca.explained_variance_ratio_)

In [None]:
ax = sns.scatterplot(data=X_PCA, x="PCA1", y="PCA2", hue="Ordered type")
plting.transform_plot_ax(ax, legend_title="", remove_ticks=True)
ax.set_xlabel(f"PCA1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PCA2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.figure.savefig(fig_dir / "PCA_gex_tssgenes.svg", bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(1,2,figsize=(10,3))
sns.kdeplot(data=X_PCA,x="PCA1",hue="Ordered type", ax=ax[0], legend=False, common_norm=False)
sns.kdeplot(data=X_PCA,x="PCA2",hue="Ordered type", ax=ax[1], common_norm=False)
ax[0].spines[['right', 'top']].set_visible(False)
ax[0].spines[["bottom", "left"]].set_linewidth(4)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].spines[['right', 'top']].set_visible(False)
ax[1].spines[["bottom", "left"]].set_linewidth(4)
ax[1].set_xticks([])
ax[1].set_yticks([])
fig.savefig(fig_dir / "gex_density_pca_advmp_genes.svg", bbox_inches="tight")

In [None]:
ks_2samp(X_PCA[X_PCA["Ordered type"]=="Healthy"]["PCA1"],X_PCA[X_PCA["Ordered type"]!="Healthy"]["PCA1"]),ks_2samp(X_PCA[X_PCA["Ordered type"]=="Healthy"]["PCA2"],X_PCA[X_PCA["Ordered type"]!="Healthy"]["PCA2"])

In [None]:
pca = PCA(n_components=2)
X_PCA = pca.fit_transform(right_data.drop(["type","Ordered type"],axis=1))
X_PCA = pd.DataFrame(X_PCA, index=right_data.index, columns=["PCA1","PCA2"])
X_PCA = pd.concat([X_PCA, right_data["Ordered type"]],axis=1)
print(pca.explained_variance_ratio_)

In [None]:
ax = sns.scatterplot(data=X_PCA, x="PCA1", y="PCA2", hue="Ordered type")
plting.transform_plot_ax(ax, legend_title="", remove_ticks=True)
ax.set_xlabel(f"PCA1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PCA2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.figure.savefig(fig_dir / "PCA_gex.svg", bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(1,2,figsize=(10,3))
sns.kdeplot(data=X_PCA,x="PCA1",hue="Ordered type", ax=ax[0], legend=False, common_norm=False)
sns.kdeplot(data=X_PCA,x="PCA2",hue="Ordered type", ax=ax[1], common_norm=False)
ax[0].spines[['right', 'top']].set_visible(False)
ax[0].spines[["bottom", "left"]].set_linewidth(4)
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].spines[['right', 'top']].set_visible(False)
ax[1].spines[["bottom", "left"]].set_linewidth(4)
ax[1].set_xticks([])
ax[1].set_yticks([])
fig.savefig(fig_dir / "gex_density_pca.svg", bbox_inches="tight")

In [None]:
ks_2samp(X_PCA[X_PCA["Ordered type"]=="Healthy"]["PCA1"],X_PCA[X_PCA["Ordered type"]!="Healthy"]["PCA1"]),ks_2samp(X_PCA[X_PCA["Ordered type"]=="Healthy"]["PCA2"],X_PCA[X_PCA["Ordered type"]!="Healthy"]["PCA2"])

# Get differential expression

In [None]:
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests
def get_p_values_cond(data: pd.DataFrame, test_genes: np.array, 
                      pheno1: pd.Series, pheno2: pd.Series) -> np.ndarray:
    p_values = []
    for g in data.columns.intersection(test_genes):
        df = data[g]
        p = mannwhitneyu(df[pheno1].values.ravel(),
                         df[pheno2].values.ravel())[1]
        p_values.append(p)
    return np.array(p_values)

In [None]:
pheno1 = (right_data["type"]=="Healthy")
pheno2 = (right_data["type"]=="NAC")
NAC_p = get_p_values_cond(data=right_data, 
                               test_genes=right_data.columns[:-2], 
                               pheno1=pheno1, pheno2=pheno2)

In [None]:
from scipy.stats import fisher_exact

In [None]:
full_dgex = pd.DataFrame(NAC_p,columns=["p"],index=right_data.columns[:-2])

full_dgex = full_dgex<.05

nunion_genes = np.setdiff1d(right_data.columns[:-2], test_genes)

a = full_dgex.loc[test_genes,"p"].sum()
b = len(test_genes) - a
c = full_dgex.loc[nunion_genes,"p"].sum()
d = len(nunion_genes) - c

conting = np.array([[a,b],[c,d]])
fisher_exact(conting)

In [None]:
conting

In [None]:
a/(a+b), c/(c+d)

# Volcano plots

In [None]:
goi = test_genes

def get_gex(goi: np.ndarray) -> pd.DataFrame:
    pheno1 = (right_data["type"]=="Healthy")
    pheno2 = (right_data["type"]=="NAC")
    NAC_p = get_p_values_cond(data=right_data, 
                               test_genes=goi, 
                               pheno1=pheno1, pheno2=pheno2)
    pheno2 = (right_data["type"]=="Adenoma")
    ad_p = get_p_values_cond(data=right_data, 
                               test_genes=goi, 
                               pheno1=pheno1, pheno2=pheno2)
    pheno2 = (right_data["type"]=="SSL")
    ssl_p = get_p_values_cond(data=right_data, 
                               test_genes=goi, 
                               pheno1=pheno1, pheno2=pheno2)
    pheno2 = (right_data["type"]=="Cancer")
    cancer_p = get_p_values_cond(data=right_data, 
                               test_genes=goi, 
                               pheno1=pheno1, pheno2=pheno2)

    diff_expr = pd.DataFrame(np.array([NAC_p,ad_p,ssl_p,cancer_p]),
                             index=["NAC_p","Ad_p","SSL_p","Cancer_p"],columns=goi).T
    for i,col in enumerate(diff_expr.columns):
        q_value = multipletests(diff_expr[col],method="fdr_bh")[1]
        diff_expr[col[:-1]+"q"] = q_value
    return diff_expr

In [None]:
def get_volcano_plot(right_data: pd.DataFrame, 
                     pheno1_name: str, 
                     pheno2_name: str, 
                     goi: np.ndarray, lim_fc: float=1.5) -> plt.Figure:

    pheno1 = (right_data["type"]==pheno1_name)
    pheno2 = (right_data["type"]==pheno2_name)

    log2FC = np.log2(right_data.loc[pheno2,goi].mean()/right_data.loc[pheno1,goi].mean())

    volcano_df = pd.concat([log2FC,
                            diff_expr[f"{pheno2_name}_p"].apply(lambda x: -np.log10(x)),
                            diff_expr[f"{pheno2_name}_q"].apply(lambda x: -np.log10(x))],axis=1)
    volcano_df.columns = ["log2(FC)","-log10(p)","-log10(q)"]
    volcano_df["Significant"] = volcano_df["-log10(q)"]>=1

    genes_to_annotate = volcano_df[(volcano_df["log2(FC)"].abs()>np.log2(lim_fc)) & (volcano_df["Significant"])]

    pmin = volcano_df[volcano_df["Significant"]]["-log10(p)"].min()
    fig, ax = plt.subplots(1,1)
    sns.scatterplot(data=volcano_df, 
                    y="-log10(p)", 
                    x="log2(FC)", 
                    hue="Significant", 
                    palette = {True: "red", False: "blue"},
                    ax=ax)
    ax.hlines(xmin=ax.get_xlim()[0],xmax=ax.get_xlim()[1].max(),y=pmin,color="r",linestyle='--')
    plting.transform_plot_ax(ax, legend_title="Sign. diff. expressed")

    texts = []
    for g in genes_to_annotate.index:
        x = genes_to_annotate.loc[g,"log2(FC)"]
        y = genes_to_annotate.loc[g,"-log10(p)"]
        texts.append(ax.text(x,y,g,fontsize=10))

    adjust_text(texts, only_move={'points':'y', 'texts':'y'}, arrowprops=dict(arrowstyle="-", color='r', lw=0.5))
    return fig

In [None]:
diff_expr = get_gex(goi=goi)

In [None]:
(diff_expr.loc[:,diff_expr.columns.str.endswith("_q")]<0.1).sum(axis=1).value_counts()

In [None]:
sign_diff = diff_expr["NAC_q"].apply(lambda x: -np.log10(x)).sort_values(ascending=False).to_frame()

In [None]:
fig = get_volcano_plot(right_data=right_data, pheno1_name="Healthy", pheno2_name="NAC", goi=test_genes)
fig.savefig(fig_dir / "volcano_plot_nac_vs_healthy.svg", bbox_inches="tight")

In [None]:
fig = get_volcano_plot(right_data=right_data, pheno1_name="Healthy", pheno2_name="SSL", goi=test_genes)
fig.savefig(fig_dir / "volcano_plot_nac_vs_SSL.svg", bbox_inches="tight")

In [None]:
fig = get_volcano_plot(right_data=right_data, pheno1_name="Healthy", pheno2_name="Cancer", goi=test_genes, lim_fc=2.5)
fig.savefig(fig_dir / "volcano_plot_nac_vs_cancer.svg", bbox_inches="tight")

In [None]:
fully_dysregulated = diff_expr[(diff_expr["NAC_q"]<0.1) & (diff_expr["SSL_q"]<0.1) & (diff_expr["Cancer_q"]<0.1)].sort_values(by="NAC_p").index.to_numpy()

In [None]:
def get_gex_boxplot(right_data: pd.DataFrame, gene: str) -> plt.Figure:
    fig, ax = plt.subplots(1,1,figsize=(3,2))
    sns.boxplot(data=right_data, 
                x="type", y=gene, ax=ax,
                order=["Healthy","NAC","Adenoma","SSL","Cancer"], 
                     palette=palette)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment="right")
    ax.set_xlabel("")
    ax.spines[['right', 'top']].set_visible(False)
    ax.spines[["bottom", "left"]].set_linewidth(3)
    return fig

In [None]:
genes_to_plot = fully_dysregulated
for gene in genes_to_plot:
    figure = get_gex_boxplot(right_data=right_data, gene=gene)
    figure.savefig(fig_dir / "diffgexboxplots" / f"{gene}.svg", bbox_inches="tight")

# Compute enrichment

In [None]:
def get_full_dgex(right_data: pd.DataFrame, test_genes: np.ndarray) -> pd.DataFrame:
    pvalues = []
    for gene in tqdm(right_data.columns[:-2]):
        p = mannwhitneyu(right_data.loc[right_data["type"]=="Healthy",gene],
                 right_data.loc[right_data["type"]=="NAC",gene])[1]
        pvalues.append(p)

    full_dgex = pd.DataFrame(pvalues,index=right_data.columns[:-2], columns=["p"])
    indicator_col = [1 if gene in test_genes else 0 for gene in full_dgex.index]
    full_dgex["Indicator"] = indicator_col
    full_dgex = full_dgex.sort_values(by="p")
    full_dgex["Order"] = np.arange(1,full_dgex.shape[0]+1)[::-1]
    return full_dgex

In [None]:
from typing import List
def compute_ks_random_stat_l(full_dgex: pd.DataFrame, l: int, 
                           posconst: float, 
                           negconst: float) -> float:
    df = full_dgex.iloc[:l]
    dfpos = df[df.Indicator==1]
    if dfpos.shape[0]==0:
        posnum = 0
    else:
        posnum = dfpos.Order.sum()
    dfneg = df[df.Indicator==0]
    if dfneg.shape[0]==0:
        negnum = 0
    else:
        negnum = dfneg.shape[0]
    kstat = posnum/posconst - negnum/negconst
    return kstat

def get_kstat_list(full_dgex: pd.DataFrame) -> List:
    posconst = full_dgex[full_dgex["Indicator"]==1]["Order"].sum()
    negconst = full_dgex.shape[0] - full_dgex.Indicator.sum()
    kstat_list = []
    for l in tqdm(range(1,full_dgex.shape[0]+1)):
        kstat_list.append(compute_ks_random_stat_l(full_dgex=full_dgex, l=l, 
                                                   posconst=posconst, negconst=negconst))
    return kstat_list

In [None]:
red_right_data = right_data[right_data["type"].isin(["Healthy","NAC"])]

In [None]:
full_dgex = get_full_dgex(right_data=red_right_data, test_genes=test_genes)

In [None]:
kstat_list = get_kstat_list(full_dgex=full_dgex)

ksplot_df = pd.DataFrame(np.array([np.arange(1,len(kstat_list)+1),kstat_list])).T
ksplot_df.columns = ["x","KS"]

In [None]:
empirical_p = 0.1
fig, ax = plt.subplots(1,1,figsize=(3,2))
sns.lineplot(data=ksplot_df,x="x",y="KS",ax=ax)
plting.transform_plot_ax(ax, legend_title="", linew=3)
ax.set_xticklabels([])
ax.set_xticks([])
ax.set_xlabel("Order")
ES = np.max(np.abs(kstat_list))
ymin, ymax = ax.get_ylim()
ax.vlines(x= np.argmax(np.abs(kstat_list)),ymin=ymin, ymax=ymax, color="r")
ax.text(np.argmax(np.abs(kstat_list))*1.1, ES, f"ES={ES:.2f}\np={empirical_p}", color="r", fontsize=12)
fig.savefig(fig_dir / "ES_tss_genes_dgex.svg", bbox_inches="tight")

In [None]:
all_ES = []
n_permut = 100
for i in tqdm(range(n_permut)):
    random_right = red_right_data.copy()
    random_right["type"] = np.random.permutation(random_right["type"])

    full_dgex = get_full_dgex(right_data=random_right, test_genes=test_genes)
    kstat_list = get_kstat_list(full_dgex=full_dgex)
    all_ES.append(np.max(np.abs(kstat_list)))

In [None]:
ES = 0.32
empirical_p = len(np.where(np.array(all_ES)>ES)[0])/n_permut

In [None]:
resdir = pl.Path("/add/path/here")

In [None]:
pd.Series(np.array(all_ES)).to_csv(resdir / "empirical_p_ES.csv")