In [2]:
from scipy.stats import hypergeom
import statsmodels.stats.multitest as multi
import pandas as pd
import os
import glob

In [3]:
degfiles = ['../../../her7/10_gene_exp/gene_expression_jen_local/10_gene_exp/burow_her19_v_herm_de_results.txt',
            '../../../her7/10_gene_exp/gene_expression_jen_local/10_gene_exp/burow_male_v_her19_de_results.txt',
            '../../../her7/10_gene_exp/gene_expression_jen_local/10_gene_exp/burow_male_v_herm_de_results.txt']

In [4]:
annfile = '../../../her7/02_functional_annotations/eggnog_annotations.txt'
outfile = '../../../her7/10_gene_exp/functional_enrichment_results_GO.txt'
obofile = '../../../../dbs/gene_ontology/go-basic.obo'

In [5]:
# min/max thresholds for tests
highCount = 10000000
lowCount = 0

de_lfc_thres = 0.58
de_padj_thres = 0.1

In [6]:
###################################
### Initialize set dictionaries ###
###################################
degsets = {}

for degfile in degfiles:
    #print(degfile)
    comparison = degfile.split('burow_')[1].split('_de_results')[0]
    #print(comparison)
    degsets[comparison + '_up'] = set()
    degsets[comparison + '_down'] = set()
    
    df = pd.read_csv(degfile, sep='\t', index_col=0)
    df.index.name = 'gene_id'
    df
    
    # Filter rows based on conditions
    de_up_df = df[(df['log2FoldChange'] > de_lfc_thres) & (df['padj'] < de_padj_thres)]
    de_down_df = df[(df['log2FoldChange'] < de_lfc_thres * -1) & (df['padj'] < de_padj_thres)]

    # Extract gene_id column as Python list
    for gene in de_up_df.index.tolist():
        #print(gene)
        degsets[comparison + '_up'].add(gene)
        
    for gene in de_down_df.index.tolist():
        #print(gene)
        degsets[comparison + '_down'].add(gene)


In [7]:
# initialize remaining data structures
isaDict = {}
annotate_line = []
goDict = {}
codeDict = {}
#codeDict['set'] = {}
codeDict['total'] = {}
lociDict = {}
#lociDict['set'] = set()
lociDict['total'] = set()

print('Number of genes in each DEG set:')
for degset in degsets:
    print(degset, len(degsets[degset]))
    codeDict[degset] = {}
    lociDict[degset] = set()

Number of genes in each DEG set:
her19_v_herm_up 209
her19_v_herm_down 1982
male_v_her19_up 6222
male_v_her19_down 5568
male_v_herm_up 4868
male_v_herm_down 4885


In [8]:
#############################################
### Load GO relationship hash from file ###
#############################################

def loadrelationships(file):

    with open(file, "r") as f:
        content = f.read()
        # remove end line characters
        entries = content.split('\n\n')
        for entry in entries:
            col = entry.split('\n')
            
            currentID = ''
            currentParents = set()
            name = ''
            space = ''
            
            for line in col:
                #print(line[0:6])
                if line[0:7] == 'id: GO:':
                    currentID = 'GO:' + line.split(':')[-1]
                    #print(currentID)
                    
                if line[0:6] == 'name: ':
                    name = line.split(': ')[-1]
                    #print(name)

                if line[0:11] == 'namespace: ':
                    space = line.split(': ')[-1]
                    #print(space)

                if line[0:6] == 'is_a: ':
                    parentID = line.split(' ')[1]
                    currentParents.add(parentID)
                    #print(parentID)
                        
            #if currentID == 'GO:0005623':
                #print(currentID,space,name, currentParents)
            goDict[currentID] = space + '\t' + name
            isaDict[currentID] = currentParents

loadrelationships(obofile)



In [9]:
# remove obsolete go terms:
obsolete = set()
for go in goDict:
    #print(go,goDict[go].split('\t')[1])
    if 'obsolete' in goDict[go].split('\t')[1]:
        #print(go, goDict[go].split('\t')[1])
        obsolete.add(go)

for go in obsolete:
    goDict.pop(go, None)
    
for go in isaDict:
    for parent in isaDict[go]:
        #print(go, parent)
        if parent in obsolete:
            #print(go, parent)
            isaDict[go].remove(parent)

In [11]:
def getparentlist(queryID):
    
    parentSet = set()
    fullSet = set()

    if queryID in isaDict:
        for id in isaDict[queryID]:
            parentSet.add(id)
            fullSet.add(id)
        
    while len(parentSet) > 0:

        parentID = parentSet.pop()

        tempSet = isaDict[parentID]
        for temp in tempSet:
            if temp not in fullSet:
                fullSet.add(temp)
                parentSet.add(temp)

    return fullSet


In [12]:
######################################
### Parse GOs in Annotation File ###
######################################

fi = open(annfile)

for line in fi:
    if line[0] == '#':
        continue
        
    col = line.rstrip().split('\t')
    locus = 'Ceric.' + col[0].split('.')[1]
    gocol = col[9].split(',')
    if gocol == ['-']:
        continue
    goSet = set()
    
    for go in gocol:
        if go not in goDict:
            continue
            
#         gname = goDict[go].split('\t')[1]
#         if 'obsolete' in gname:
#             continue
            
        goSet.add(go)
        returnedparents = getparentlist(go)
        
        for parent in returnedparents:
            #print(parent)
#             pname = goDict[go].split('\t')[1]
#             if 'obsolete' in pname:
#                 continue
            goSet.add(parent)

    if locus == 'Ceric.1Z236300':
        print(locus,goSet)

    for go in goSet:
        if go not in codeDict['total']:
            codeDict['total'][go] = set()

        codeDict['total'][go].add(locus)
        lociDict['total'].add(locus)    

        for degset in degsets:
            #print(degset)
            if locus in degsets[degset]:
                #if locus == 'Ceric.1Z203700':
                    #print(degset,locus,goSet)

                if go not in codeDict[degset]:
                    codeDict[degset][go] = set()

                codeDict[degset][go].add(locus)
                lociDict[degset].add(locus)

fi.close()

Ceric.1Z236300 {'GO:0043228', 'GO:0005575', 'GO:0005634', 'GO:0005694', 'GO:0043229', 'GO:0043231', 'GO:0110165', 'GO:0043232', 'GO:0043227', 'GO:0043226', 'GO:0005622'}


In [14]:
####################################
### Perform hypergeometric tests ###
####################################

df = pd.DataFrame(columns=['set','go','cat','space','desc','x','N','n','M','pval','genelist'])

for degset in degsets:
    for go in codeDict[degset]:

        count = len(codeDict[degset][go])
        if count > highCount or count < lowCount:
            continue

        desc = goDict[go].split('\t')[1]
        space = goDict[go].split('\t')[0]
        cat = obofile.split('/')[-1].split('.')[0]

        # x is still the number of drawn "successes" (ie no. genes in set and in go category)
        x = len(codeDict[degset][go])
        genelist = ', '.join(codeDict[degset][go])

        # N is the sample size (ie no. genes in set)
        N = len(lociDict[degset])

        # n is the number of successes in the population (ie no. genes in go category)
        n = len(codeDict['total'][go])

        # M is the population size (ie no. genes total)
        M = len(lociDict['total'])

        # https://alexlenail.medium.com/understanding-and-implementing-the-hypergeometric-test-in-python-a7db688a7458
        # https://github.com/jdrudolph/goenrich
        pval = hypergeom.sf(x-1, M, n, N)
        df.loc[len(df.index)] = [degset,go,cat,space,desc,x,N,n,M,pval,genelist]  


In [15]:
#########################################
### Adjust pvalues for multiple tests ###
#########################################
if len(df['pval'].tolist()) > 0:
    adjpval = multi.multipletests(df['pval'].tolist(), alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)[1]
    df['adjpval'] = adjpval
    df['seqfreq'] = df['x'] / df['N']
    df['totalfreq'] = df['n'] / df['M']

    df = df[['set','go','adjpval','cat','space','desc','x','N','seqfreq','n','M','totalfreq','pval','genelist']]

    df.to_csv(outfile, sep='\t', index=False)
