# Enrichment in GWAS, TWAS, SMR, and DE 

In [None]:
import functools
import numpy as np
import pandas as pd
import collections as cx
from pybiomart import Dataset
from gtfparse import read_gtf
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests

# GO analysis
from goatools.base import download_go_basic_obo
from goatools.base import download_ncbi_associations
from goatools.obo_parser import GODag
from goatools.anno.genetogo_reader import Gene2GoReader
from goatools.goea.go_enrichment_ns import GOEnrichmentStudyNS

## Functions

In [None]:
@functools.lru_cache()
def get_gtf_genes_df():
    gtf_df = read_gtf("/ceph/genome/human/gencode25/gtf.CHR/_m/gencode.v25.annotation.gtf")
    return gtf_df[gtf_df["feature"] == "gene"][['gene_id', 'gene_name']]


@functools.lru_cache()
def get_wgcna_modules():
    return pd.read_csv("../../_m/modules.csv", index_col=0)


@functools.lru_cache()
def get_database():
    dataset = Dataset(name="hsapiens_gene_ensembl", 
                      host="http://www.ensembl.org",
                      use_cache=True)
    db = dataset.query(attributes=["ensembl_gene_id", 
                                   "external_gene_name", 
                                   "entrezgene_id"], 
                       use_attr_names=True).dropna(subset=['entrezgene_id'])
    return db

In [None]:
def fet(a, b, u):
    # a, b, u are sets
    # u is the universe
    yes_a = u.intersection(a)
    yes_b = u.intersection(b)
    no_a = u - a
    no_b = u - b
    m = [[len(yes_a.intersection(yes_b)), len(no_a.intersection(yes_b)) ], 
         [len(yes_a.intersection(no_b)), len(no_a.intersection(no_b))]]
    return fisher_exact(m)


def enrichment_rows():
    mod = get_wgcna_modules().module.unique()
    u = set(get_wgcna_modules().index) 
    for ii in range(len(mod)): # for each module
        a = set(get_wgcna_modules()[(get_wgcna_modules().module) == mod[ii]].index)
        yield (mod[ii],
               len(a),
               *fet(a, gwas_genes, u),
               *fet(a, twas_genes, u),
               *fet(a, smr_genes, u),
               *fet(a, de_genes, u),
               )
        

def enrichment_rows_nomhc():
    mod = get_wgcna_modules().module.unique()
    u = set(get_wgcna_modules().index) - mhc_genes
    for ii in range(len(mod)): # for each module
        a = set(get_wgcna_modules()[(get_wgcna_modules().module) == mod[ii]].index) - mhc_genes
        yield (mod[ii],
               len(a),
               *fet(a, gwas_genes - mhc_genes, u),
               *fet(a, twas_genes - mhc_genes, u),
               *fet(a, smr_genes - mhc_genes, u),
               *fet(a, de_genes - mhc_genes, u),
              )
        

def convert2entrez(mod):
    df = get_wgcna_modules()[(get_wgcna_modules().module) == mod].copy()
    df["ensemblID"] = df.index.str.replace("\\..*", "", regex=True)
    return df.merge(get_database(), left_on='ensemblID', 
                    right_on='ensembl_gene_id')


def obo_annotation(alpha=0.05):
    # database annotation
    fn_obo = download_go_basic_obo()
    fn_gene2go = download_ncbi_associations() # must be gunzip to work
    obodag = GODag(fn_obo) # downloads most up-to-date
    anno_hs = Gene2GoReader(fn_gene2go, taxids=[9606])
    # get associations
    ns2assoc = anno_hs.get_ns2assc()
    for nspc, id2gos in ns2assoc.items():
        print("{NS} {N:,} annotated human genes".format(NS=nspc, N=len(id2gos)))
    goeaobj = GOEnrichmentStudyNS(
        get_database()['entrezgene_id'], # List of human genes with entrez IDs
        ns2assoc, # geneid/GO associations
        obodag, # Ontologies
        propagate_counts = False,
        alpha = alpha, # default significance cut-off
        methods = ['fdr_bh'])
    return goeaobj


def run_goea(mod):
    df = convert2entrez(mod)
    geneids_study = {z[0]:z[1] for z in zip(df['entrezgene_id'], df['external_gene_name'])}
    goeaobj = obo_annotation()
    goea_results_all = goeaobj.run_study(geneids_study)
    goea_results_sig = [r for r in goea_results_all if r.p_fdr_bh < 0.05]
    ctr = cx.Counter([r.NS for r in goea_results_sig])
    print('Significant results[{TOTAL}] = {BP} BP + {MF} MF + {CC} CC'.format(
        TOTAL=len(goea_results_sig),
        BP=ctr['BP'],  # biological_process
        MF=ctr['MF'],  # molecular_function
        CC=ctr['CC'])) # cellular_component
    goeaobj.wr_xlsx("GO_analysis_module_%s.xlsx" % mod, goea_results_sig)
    goeaobj.wr_txt("GO_analysis_module_%s.txt" % mod, goea_results_sig)

## Gene annotation

In [None]:
gtf = get_gtf_genes_df()
gtf.head(2)

## GWAS, TWAS and DE enrichment

### Load DE, TWAS, and GWAS genes

In [None]:
de_genes = set(pd.read_csv('../../../differential_expression/_m/genes/diffExpr_szVctl_FDR05.txt',
                           sep='\t', usecols=[0], index_col=0).index)
len(de_genes)

In [None]:
mhc_genes = set(pd.read_csv('/ceph/projects/v4_phase3_paper/inputs/counts/mhc_region_genes/_m/mhc_genes.csv')['gene_id'])
len(mhc_genes)

In [None]:
annot = pd.read_csv("/ceph/projects/v4_phase3_paper/inputs/counts/text_files_counts/_m/caudate/gene.bed", 
                    sep='\t', index_col=0)
annot["Feature"] = annot.gene_id.str.replace("\\..*", "", regex=True)
twas = pd.read_csv('/ceph/projects/v4_phase3_paper/analysis/twas_ea/'+\
                    'gene_weights/fusion/summary_stats/_m/fusion_associations.txt', sep='\t')
twas = twas[(twas["FDR"] < 0.05)].merge(annot, left_on="FILE", right_on="Feature")
twas_genes = set(twas['gene_id'])
len(twas_genes)

In [None]:
## Extract prioritized genes from PGC3 (FINEMAP or SMR evidence)
gwas_fn = '/ceph/users/jbenja13/resources/gwas/pgc3/_m/'+\
        'nature_submission_11.08.2021/Supplementary Tables/'+\
        'Supplementary Table 12 - Prioritized Genes UPDATED.xlsx'
gwas_df = pd.read_excel(gwas_fn, sheet_name="Prioritised")\
    .merge(annot, left_on="Ensembl.ID", right_on="Feature")
gwas_genes = set(gwas_df['gene_id'])
len(gwas_genes)

In [None]:
smr_fn = "../../../smr/_m/eqtl_genes.eqtl_p1e-04.gwas_p5e-08.csv"
smr_df = pd.read_csv(smr_fn)
smr_genes = set(smr_df[(smr_df["FDR"] < 0.05) & 
                       (smr_df["p_HEIDI"] > 0.01)]["probeID"])
len(smr_genes)

### Load WGCNA module

In [None]:
wgcna_df = get_wgcna_modules().merge(gtf, left_index=True, right_on="gene_id", how="left")
wgcna_df.head(2)

In [None]:
wgcna_df[(wgcna_df.gene_id.str.startswith("chr"))]

In [None]:
wgcna_df[(wgcna_df.gene_name == 'DRD2')]

In [None]:
wgcna_df[(wgcna_df.gene_name == 'SETD1A')]

### Enrichment

In [None]:
edf1 = pd.DataFrame.from_records(enrichment_rows(), 
                                 columns=['module_id', 'n_genes', 'gwas_or', 'gwas_p', 'twas_or', 
                                          'twas_p', 'smr_or', 'smr_p', 'de_or', 'de_p'],
                                 index='module_id')
edf1['twas_fdr_bh'] = multipletests(edf1['twas_p'], method='fdr_bh')[1]
edf1['gwas_fdr_bh'] = multipletests(edf1['gwas_p'], method='fdr_bh')[1]
edf1['smr_fdr_bh'] = multipletests(edf1['smr_p'], method='fdr_bh')[1]
edf1['de_fdr_bh'] = multipletests(edf1['de_p'], method='fdr_bh')[1]
edf1[['n_genes', 'gwas_or', 'gwas_p', 'gwas_fdr_bh', 'twas_or', 'twas_p', 
      'twas_fdr_bh', 'smr_or', 'smr_p', 'smr_fdr_bh', 'de_or', 'de_p', 'de_fdr_bh']]\
    .to_csv('wgcna_module_enrichment.csv')
edf1[['n_genes', 'gwas_or', 'gwas_p', 'gwas_fdr_bh', 'twas_or', 'twas_p', 
      'twas_fdr_bh', 'smr_or', 'smr_p', 'smr_fdr_bh', 'de_or', 'de_p', 'de_fdr_bh']]

### No MHC region

In [None]:
edf2 = pd.DataFrame.from_records(enrichment_rows_nomhc(), 
                                 columns=['module_id', 'n_genes', 'gwas_or', 'gwas_p', 'twas_or', 
                                          'twas_p', 'smr_or', 'smr_p', 'de_or', 'de_p'],
                                 index='module_id')
edf2['twas_fdr_bh'] = multipletests(edf2['twas_p'], method='fdr_bh')[1]
edf2['gwas_fdr_bh'] = multipletests(edf2['gwas_p'], method='fdr_bh')[1]
edf2['smr_fdr_bh'] = multipletests(edf2['smr_p'], method='fdr_bh')[1]
edf2['de_fdr_bh'] = multipletests(edf2['de_p'], method='fdr_bh')[1]
edf2[['n_genes', 'gwas_or', 'gwas_p', 'gwas_fdr_bh', 'twas_or', 'twas_p', 
      'twas_fdr_bh', 'smr_or', 'smr_p', 'smr_fdr_bh', 'de_or', 'de_p', 'de_fdr_bh']]\
    .to_csv('wgcna_module_enrichment_excluding_mhc_region.csv')
edf2[['n_genes', 'gwas_or', 'gwas_p', 'gwas_fdr_bh', 'twas_or', 'twas_p', 
      'twas_fdr_bh', 'smr_or', 'smr_p', 'smr_fdr_bh', 'de_or', 'de_p', 'de_fdr_bh']]

## GO enrichment for each cluster

In [None]:
for mod in get_wgcna_modules().module.unique():
    run_goea(mod)