In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from statsmodels.stats.multitest import multipletests

from scipy.stats import pearsonr, fisher_exact, kruskal, mannwhitneyu
import scipy.cluster.hierarchy as hierarchy
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from gtfparse import read_gtf
from typing import Dict
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats
from pydeseq2.utils import load_example_data

from adjustText import adjust_text

import pathlib as pl

import sys
sys.path.append("../../FinalCode/")
import download.download as dwnl
import utils.plotting as plting
import adVMP.adVMP_discovery as discov

In [None]:
fig_dir = pl.Path("/add/path/here/")

In [None]:
base_dir = pl.Path("/add/path/here/")

data_dir = pl.Path("/add/path/here/")

bad_probes = pd.read_csv(data_dir / "auxiliary" / "sketchy_probe_list_epic.csv",index_col=0).values.ravel()
sample_origin_path = pl.Path(data_dir / "clinical" / "sample_origin_wbatch.csv")

clinical_path = pl.Path(data_dir / "clinical" / "cleaned_clinical_reduced_diet.csv")
target_path = pl.Path(data_dir / "clinical" / "targets.csv")

In [None]:
sample_origin = pd.read_csv(sample_origin_path)
sample_origin = sample_origin.astype(str)

In [None]:
EPIC2_b, EPIC2_clin, EPIC2_samples, EPIC2_phenotypes, EPIC3_b, EPIC3_clin, EPIC3_samples, EPIC3_phenotypes = dwnl.download_EPIC(sample_origin_path=sample_origin_path, 
                     base_dir=base_dir, clinical_path=clinical_path, target_path=target_path,
                  bad_probes=bad_probes, EPIC4=False) 

In [None]:
union_cpgs = pd.read_csv(data_dir / "adVMP" / "union_cpgs.csv",index_col=0).values.ravel()

# Map the genes to official gene name using Gencode

In [None]:
def get_gex_df(path: pl.Path, sample_origin: pd.DataFrame, mapping: Dict) -> pd.DataFrame:
    gex_results = pd.read_csv(path,index_col=0)
    gex_results = gex_results.loc[gex_results.index.str.startswith("ENSG")]
    gex_results.index = gex_results.index.str.split(".").str[0]

    gex_results = gex_results.rename(index=mapping)

    gex_results.columns = gex_results.columns.str.split("_").str[0]

    df = sample_origin[["specimen_number","patient_id"]]
    specimen_mapping = df[df.specimen_number.isin(EPIC2_b.index)].set_index("patient_id").to_dict()["specimen_number"]

    gex_results = gex_results.rename(columns=specimen_mapping)

    gex_results = gex_results.reset_index().groupby(by="index").mean().T
    return gex_results

In [None]:
gencode_path = pl.Path("/add/path/here/")

processed_rna_path = pl.Path("/add/path/here")

tpm_path = pl.Path(processed_rna_path / "txiScaledTPM.csv")
count_path = pl.Path(processed_rna_path / "txiCounts.csv")
length_path = pl.Path(processed_rna_path / "txiLength.csv")

In [None]:
gencode = read_gtf(gencode_path)

gencode = gencode.to_pandas()
gencode = gencode[gencode["feature"]=="gene"]
gene_id = gencode["gene_id"]

df = gencode[["gene_id","gene_name"]].set_index("gene_id")
df.index = df.index.str.split(".").str[0]
mapping = df.to_dict()["gene_name"]

In [None]:
gex_results = get_gex_df(path=tpm_path, sample_origin=sample_origin, mapping=mapping)

gex_counts = get_gex_df(path=count_path, sample_origin=sample_origin, mapping=mapping)

gex_length = get_gex_df(path=length_path, sample_origin=sample_origin, mapping=mapping)

In [None]:
batch_layout = pd.read_csv(data_dr / "RNAseq"/ "batch_layout.csv").set_index("specimen_number")
batch_layout.index = batch_layout.index.astype(str)

In [None]:
mapping = batch_layout.reset_index().set_index("patient_id")["specimen_number"].to_dict()

In [None]:
institution = pd.read_csv(data_dir / "clinical" / "institution_information.csv")
institution = institution.set_index("Patient ID").loc[batch_layout.patient_id.ravel()]
institution = institution.rename(index=mapping)

### Reads tagged low quality by FastQC

In [None]:
low_quality = [1149, 1179, 1340, 1520, 1666, 268, 362,
               498, 548, 576, 766, 771, 772, ]
low_quality = pd.Series(low_quality).replace(mapping).ravel()

# filter out genes that are 0 in more than 75% of the data
to_keep = gex_results.loc[:,((gex_counts.round(0)==0).sum()<=(gex_results.shape[0]*0.75))].columns
red_gex = gex_results.loc[:,to_keep]

adenoma = pd.Series(EPIC2_phenotypes,index=EPIC2_clin.index).loc[red_gex.index]

# General plotting

In [None]:
std_red_gex = (red_gex - red_gex.mean())/red_gex.std()

In [None]:
pca = PCA(n_components=50)
X_PCA = pca.fit_transform(std_red_gex)

In [None]:
pca.explained_variance_ratio_

In [None]:
X_PCA = pd.DataFrame(X_PCA, index=std_red_gex.index, columns=[f"PCA{i+1}" for i in range(X_PCA.shape[1])])

X_PCA = pd.concat([X_PCA,EPIC2_clin.loc[X_PCA.index],batch_layout["Given Plate Name"],institution],axis=1)

X_PCA["Adenoma"] = adenoma

X_PCA["Low quality"] = False

X_PCA.loc[low_quality,"Low quality"] = True

In [None]:
clin_params = ["Age at visit","BMI","Ever smoked cigarettes",
               "Metabolic syndrome","Analgesic >=2 years (overall)",
               "Pack years","inflammatory_n","anti-inflammatory_n",
               "western_n","prudent_n","Adenoma","Given Plate Name","Institution","Low quality"]
bin_params = ["Adenoma","Ever smoked cigarettes","Metabolic syndrome",
              "Analgesic >=2 years (overall)","Given Plate Name","Institution","Low quality"]
cont_params = ["Age at visit","BMI","Pack years","inflammatory_n",
               "anti-inflammatory_n",
               "western_n","prudent_n"]

In [None]:
df = X_PCA[[f"PCA{i+1}" for i in range(10)] + clin_params]

df["Institution"] = df["Institution"].replace({"e2": 0, "55": 1})

df["Given Plate Name"] = df["Given Plate Name"].replace({"Plate 1": 0, "Plate 2": 1})

In [None]:
associations = {}
for pc in [f"PCA{i+1}" for i in range(10)]:
    associations[pc] = {}
    for col in bin_params:
        neg = df[col]==0
        pos = df[col]==1
        pval = kruskal(df[neg][pc],df[pos][pc])[1]
        associations[pc][col] = pval
    for col in cont_params:
        dfred = df[[pc,col]].dropna()
        pval = pearsonr(dfred[pc],dfred[col])[1]
        associations[pc][col] = pval
associations = pd.DataFrame.from_dict(associations)
associations = associations.applymap(lambda x: -np.log10(x))

In [None]:
association_q = []
for col in associations.columns:
    association_q.append(pd.DataFrame(multipletests(associations.loc[:,col].ravel(), method="fdr_bh")[1],
                  index=associations.index,
                  columns=[col]))

In [None]:
association_q = pd.concat(association_q,axis=1)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(12,4))
alpha_bonf = 0.05/associations.shape[0]
alpha_bonf = -np.log10(alpha_bonf)
sns.heatmap(associations, mask=associations<alpha_bonf, cmap="vlag", vmax=5,
            center=0.9*alpha_bonf, ax=ax, cbar_kws={"label": "-log10(p)"})
fig.savefig(fig_dir / "PCA_heatmap_clin_associations.png", dpi=250, bbox_inches="tight")

In [None]:
ax = sns.scatterplot(data=X_PCA,x="PCA1",y="PCA2",hue="Adenoma")
plting.transform_plot_ax(ax, legend_title="Adenoma", remove_ticks=True)
ax.set_xlabel(f"PCA1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PCA2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.figure.savefig(fig_dir / "PCA_gex_rnaseq.svg", bbox_inches='tight')

In [None]:
ax = sns.scatterplot(data=X_PCA[X_PCA["Given Plate Name"]=="Plate 1"],x="PCA1",y="PCA2",hue="Adenoma")
plting.transform_plot_ax(ax, legend_title="Adenoma", remove_ticks=True)
ax.set_xlabel(f"PCA1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PCA2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.figure.savefig(fig_dir / "PCA_gex_rnaseq.svg", bbox_inches='tight')

In [None]:
ax = sns.scatterplot(data=X_PCA,x="PCA1",y="PCA2",hue="Given Plate Name")
plting.transform_plot_ax(ax, legend_title="Given Plate Name", remove_ticks=True)
ax.set_xlabel(f"PCA1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PCA2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.figure.savefig(fig_dir / "PCA_rnaseq_plate.svg", bbox_inches='tight')

# Get differential expression using DESeq2

In [None]:
def get_volcano_plot(summary: pd.DataFrame, 
                     lim_fc: float=1, 
                     lim_sign: float=-np.log10(0.05), 
                     lim_fc_annot: float=1.5, 
                     lim_sign_annot: float=5) -> plt.Figure:
    
    summary["-log10(p)"] = summary.pvalue.apply(lambda x: -np.log10(x))
    summary["-log10(q)"] = summary.padj.apply(lambda x: -np.log10(x))
    summary["Significant"] = summary["-log10(q)"]>=lim_sign

    genes_to_annotate = summary[(summary["log2FoldChange"].abs()>np.log2(lim_fc_annot)) & (summary["-log10(q)"]>=lim_sign_annot)]

    pmin = summary[summary["Significant"]]["-log10(q)"].min()
    fig, ax = plt.subplots(1,1)
    sns.scatterplot(data=summary, 
                    y="-log10(q)", 
                    x="log2FoldChange", 
                    hue="Significant", 
                    palette = {True: "red", False: "blue"},
                    ax=ax)
    ax.hlines(xmin=ax.get_xlim()[0],xmax=ax.get_xlim()[1].max(),y=pmin,color="r",linestyle='--')
    plting.transform_plot_ax(ax, legend_title="Sign. diff. expressed")
    if genes_to_annotate.shape[0]==0:
        print("No significant gene!")
        return fig
        
    texts = []
    for g in genes_to_annotate.index:
        x = genes_to_annotate.loc[g,"log2FoldChange"]
        y = genes_to_annotate.loc[g,"-log10(q)"]
        texts.append(ax.text(x,y,g,fontsize=10))

    adjust_text(texts, only_move={'points':'y', 'texts':'y'}, arrowprops=dict(arrowstyle="-", color='r', lw=0.5))
    return fig

In [None]:
metadata = pd.concat([batch_layout["Given Plate Name"].loc[gex_counts.index],adenoma.to_frame().astype(int).astype(str)],axis=1)
metadata.columns = ["Plate","Adenoma"]

In [None]:
dds = DeseqDataSet(
    counts=gex_counts.loc[:,to_keep].round(0),
    metadata=metadata.loc[gex_counts.index],
    design_factors=["Adenoma","Plate"],
    refit_cooks=True,
    n_cpus=8,
)

dds.deseq2()

In [None]:
stat_res = DeseqStats(dds, contrast=["Adenoma", "1", "0"], n_cpus=8)

In [None]:
stat_res.summary()

summary = stat_res.results_df.copy()

In [None]:
stat_res.lfc_shrink(coeff="Adenoma_1_vs_0")

shrink_summary = stat_res.results_df.copy()

In [None]:
df = shrink_summary.dropna()
df.log2FoldChange = df.log2FoldChange.clip(-2,2)

In [None]:
df.padj = df.padj.clip(10**(-5),1)

In [None]:
fig = get_volcano_plot(summary=df,
                       lim_fc=1, lim_sign=-np.log10(0.05), lim_fc_annot=2, lim_sign_annot=2)
fig.savefig(fig_dir / "volcano_plot_naa_vs_healthy.svg", bbox_inches="tight")

# Get the unsupervised grouping of patients according to expression of aDVMC-related genes

In [None]:
def get_clustermap(df: pd.DataFrame, 
                   adenoma: pd.Series, 
                   filename: str = "clustermap_advmc_related_genes.svg"):
    std_df = (df-df.mean())/df.std()

    pred = hierarchy.fclusterdata(std_df, 2, criterion='maxclust', method='ward', metric='euclidean')
    pred = pd.DataFrame(pred, index=std_df.index, columns=["Cluster"])

    pred = pred.sort_values(by="Cluster")
    conting = pd.crosstab(pred["Cluster"],adenoma)
    print("Fisher adenoma",fisher_exact(conting))
    print(conting)

    cg = sns.clustermap(data=std_df.loc[pred.index], row_colors=[pred.replace({1: "orange", 2: "grey"}).values.ravel(),
                                            adenoma.loc[pred.index].replace({0: "blue", 1: "red"}).ravel(),],
                   row_cluster=True, method="ward", cmap="vlag", center=0, vmin=-2, vmax=3)
    cg.ax_col_dendrogram.set_visible(False)
    cg.ax_heatmap.axis("off")
    cg.figure.savefig(fig_dir / filename, bbox_inches="tight")

In [None]:
epic_manifest = pd.read_csv(data_dir / "illumina_manifests" / "GPL21145_MethylationEPIC_15073387_v-1-0.csv.gz",skiprows=7,index_col=0)

In [None]:
mapping_roadmap = pd.read_csv(data_dir / "NIH_Epigenomics_Roadmap" / "EPIC_to_state_mapping.csv",index_col=0)

In [None]:
epic_manifest = epic_manifest.loc[union_cpgs]
epic_manifest = pd.concat([epic_manifest,mapping_roadmap],axis=1,join="inner")

In [None]:
red_manifest = epic_manifest[['CHR', 'MAPINFO',
       'UCSC_RefGene_Name', 
       'UCSC_RefGene_Group', 'UCSC_CpG_Islands_Name',
       'Relation_to_UCSC_CpG_Island', '450k_Enhancer',"State"]]

In [None]:
tssA_manifest = red_manifest[red_manifest.State.isin(["1_TssA","2_TssAFlnk","3_TxFlnk"])]
tss_unique_genes = np.unique(np.concatenate(tssA_manifest["UCSC_RefGene_Name"].dropna().str.split(";").values))

to_drop = ["AGPAT9", "TUBA3FP", 'LOC100996325', 'LOC101593348', 'LOC101929234', 'LOC101929512', 'LOC339874', 'LOC375196']
to_replace = {"HIST1H2BB": "H2BC3", "C1orf101": "CATSPERE", "HIST1H3C": "H3C3", "MB21D1" : "CGAS"}
tss_unique_genes = np.setdiff1d(tss_unique_genes, to_drop)

tss_unique_genes = pd.Series(tss_unique_genes).replace(to_replace).ravel()
tss_unique_genes = summary.index.intersection(tss_unique_genes).to_numpy()
print(len(tss_unique_genes))

In [None]:
heatmap_df1, hit_fraction1 = discov.get_heatmap_df(selcpgs=union_cpgs, EPIC_m=EPIC2_b, phenotypes=EPIC2_phenotypes, bal=True)

In [None]:
df = red_gex.loc[:,red_gex.columns.intersection(tss_unique_genes)]
get_clustermap(df, adenoma)

In [None]:
selpat = batch_layout[batch_layout["Given Plate Name"]=="Plate 1"].index
df = red_gex.loc[selpat,red_gex.columns.intersection(tss_unique_genes)]
get_clustermap(df, adenoma, filename="clustermap_advmc_related_genes_plate1.svg")

# Get the paired relation between methylation and GEX

In [None]:
outlier_matrix = (heatmap_df1.iloc[:,:-6].abs()>4).astype(int)
baseline_matrix = (~(heatmap_df1.iloc[:,:-6].abs()<1)).astype(int)

In [None]:
invert_replace = {v: k for k,v in to_replace.items()}
gene_ps = {}
for gene in red_gex.columns.intersection(tss_unique_genes):
    if gene in list(to_replace.values()):
        gene2 = invert_replace[gene]
    else:
        gene2 = gene
    cgs = tssA_manifest[tssA_manifest['UCSC_RefGene_Name'].str.contains(gene2).fillna(False)].index
    
    outlier_patient = outlier_matrix[outlier_matrix.loc[:,cgs].sum(axis=1)>0].index
    baseline_patient = baseline_matrix[baseline_matrix.loc[:,cgs].sum(axis=1)==0].index
    hit_patient = pd.DataFrame(np.zeros((outlier_matrix.shape[0],1)), index=outlier_matrix.index, columns=["Outlier"])
    hit_patient.loc[outlier_patient]="Outlier"
    hit_patient.loc[baseline_patient]="Baseline"
    
    if len(outlier_patient)==0:
        continue
    plot_df = pd.concat([red_gex.loc[:,gene],hit_patient],axis=1).dropna()
    plot_df.columns = ["gene","outlier"]
    _, p = mannwhitneyu(plot_df[plot_df.outlier=="Outlier"].gene.ravel(),plot_df[plot_df.outlier=="Baseline"].gene.ravel())
    gene_ps[gene] = [p]

In [None]:
gene_ps = pd.DataFrame(gene_ps).T
gene_ps.columns = ["p"]

qs = multipletests(gene_ps.p.ravel())[1]

gene_ps["q"] = qs

In [None]:
fig, ax = plt.subplots(2,2, figsize=(8,6))
flatax = ax.flatten()

for i,gene in enumerate(gene_ps[gene_ps.q<0.1].index):

    cgs = tssA_manifest[tssA_manifest['UCSC_RefGene_Name'].str.contains(gene).fillna(False)].index
    outlier_patient = outlier_matrix[outlier_matrix.loc[:,cgs].sum(axis=1)>0].index
    baseline_patient = baseline_matrix[baseline_matrix.loc[:,cgs].sum(axis=1)==0].index
    hit_patient = pd.DataFrame(np.zeros((outlier_matrix.shape[0],1)), index=outlier_matrix.index, columns=["Outlier"])
    hit_patient.loc[outlier_patient]="Outlier"
    hit_patient.loc[baseline_patient]="Baseline"
    
    if len(outlier_patient)==0:
        continue
    plot_df = pd.concat([red_gex.loc[:,gene],hit_patient],axis=1).dropna()
    plot_df.columns = ["gene","outlier"]
    _, p = mannwhitneyu(plot_df[plot_df.outlier=="Outlier"].gene.ravel(),plot_df[plot_df.outlier=="Baseline"].gene.ravel())

    sns.swarmplot(data=plot_df,x="outlier",y="gene", 
                  order=["Baseline","Outlier"], 
                  palette=["tab:blue","indianred"],ax=flatax[i])
    sns.boxplot(data=plot_df,x="outlier",y="gene", order=["Baseline","Outlier"],
                 showcaps=False,width=0.3,boxprops={'facecolor':'None','linewidth':1},
                 showfliers=False,whiskerprops={'linewidth':1}, ax=flatax[i])
    flatax[i].set_title(f"{gene}, p={p:.2e}", fontsize=15, style="italic") 
    flatax[i].spines[["top","right"]].set_visible(False)
    flatax[i].spines[["bottom","left"]].set_linewidth(4)
    flatax[i].set_xticklabels(flatax[i].get_xticklabels(), fontsize=15)
    flatax[i].set_xlabel("")
    flatax[i].set_yticklabels(flatax[i].get_yticklabels(), fontsize=15)
    flatax[i].set_ylabel("scaledTPM", fontsize=15)
    
fig.tight_layout()
fig.savefig(fig_dir / f"signgenes_swarmplot.svg", bbox_inches="tight")

In [None]:
redsum = summary.loc[summary.index.intersection(tss_unique_genes)].copy()

In [None]:
fig = get_volcano_plot(summary=redsum,
                       lim_fc=1, lim_sign=-np.log10(0.1), lim_fc_annot=1, lim_sign_annot=-np.log10(0.2))
fig.savefig(fig_dir / "volcano_plot_naa_vs_healthy_advmc_genes.svg")