In [1]:
from scipy.stats import hypergeom
import statsmodels.stats.multitest as multi
import pandas as pd
import os
import glob

In [2]:
degfiles = ['../../../her7/10_gene_exp/gene_expression_jen_local/10_gene_exp/burow_her19_v_herm_de_results.txt',
            '../../../her7/10_gene_exp/gene_expression_jen_local/10_gene_exp/burow_male_v_her19_de_results.txt',
            '../../../her7/10_gene_exp/gene_expression_jen_local/10_gene_exp/burow_male_v_herm_de_results.txt']

In [3]:
annfile = '../../../her7/02_functional_annotations/Ceratopteris_KEGG_annotations_by_phylogeny.txt'
outfile = '../../../her7/10_gene_exp/functional_enrichment_results_KEGG.txt'
keggfile = '../../../../dbs/kegg/KEGG_KO_ALL_PARENTS_Jun-30-2024.txt'

In [4]:
# min/max thresholds for tests
highCount = 10000000
lowCount = 0

de_lfc_thres = 0.58
de_padj_thres = 0.1

In [5]:
###################################
### Initialize set dictionaries ###
###################################
degsets = {}

for degfile in degfiles:
    #print(degfile)
    comparison = degfile.split('burow_')[1].split('_de_results')[0]
    #print(comparison)
    degsets[comparison + '_up'] = set()
    degsets[comparison + '_down'] = set()
    
    df = pd.read_csv(degfile, sep='\t', index_col=0)
    df.index.name = 'gene_id'
    df
    
    # Filter rows based on conditions
    de_up_df = df[(df['log2FoldChange'] > de_lfc_thres) & (df['padj'] < de_padj_thres)]
    de_down_df = df[(df['log2FoldChange'] < de_lfc_thres * -1) & (df['padj'] < de_padj_thres)]

    # Extract gene_id column as Python list
    for gene in de_up_df.index.tolist():
        #print(gene)
        degsets[comparison + '_up'].add(gene)
        
    for gene in de_down_df.index.tolist():
        #print(gene)
        degsets[comparison + '_down'].add(gene)


In [6]:
# initialize remaining data structures
isaDict = {}
annotate_line = []
keggDict = {}
codeDict = {}
#codeDict['set'] = {}
codeDict['total'] = {}
lociDict = {}
#lociDict['set'] = set()
lociDict['total'] = set()

print('Number of genes in each DEG set:')
for degset in degsets:
    print(degset, len(degsets[degset]))
    codeDict[degset] = {}
    lociDict[degset] = set()

Number of genes in each DEG set:
her19_v_herm_up 209
her19_v_herm_down 1982
male_v_her19_up 6222
male_v_her19_down 5568
male_v_herm_up 4868
male_v_herm_down 4885


In [7]:
#############################################
### Load KEGG relationship hash from file ###
#############################################

fi = open(keggfile)

for line in fi:
    currentParents = set()
    line = line.rstrip().split('\t')
    currentID = line.pop(0).split(':')[0]
    
    for parent in line:
        currentParents.add(parent)
        
    #print(currentID, currentParents)
    isaDict[currentID] = currentParents

fi.close()

descfiles = os.path.realpath(keggfile)
descfiles = os.path.split(descfiles)[0] + '/*DESC_Jun-30-2024.txt'

for infile in glob.glob(descfiles):
    #print(infile)
    fi = open(infile)
    
    for line in fi:
        term, desc = line.rstrip().split('\t')
        #print(term, desc)
        keggDict[term] = desc
    
    fi.close()


In [8]:
######################################
### Parse KEGGs in Annotation File ###
######################################

fi = open(annfile)

for line in fi:
    locus, kegg = line.rstrip().split('\t')
    keggSet = set()
        
    keggSet.add(kegg)
    if kegg in isaDict:
        returnedparents = isaDict[kegg]

        for parent in returnedparents:
            #skip if parent is a Module and not a full pathway
            #regexp = re.compile(r'^M\d+$')
            #if regexp.search(parent):
            #    continue
            keggSet.add(parent)

        #if locus == 'Ceric.1Z203700':
            #print(locus,keggSet)
    
    for kegg in keggSet:
        if kegg not in codeDict['total']:
            codeDict['total'][kegg] = set()

        codeDict['total'][kegg].add(locus)
        lociDict['total'].add(locus)    

        for degset in degsets:
            #print(degset)
            if locus in degsets[degset]:
                #if locus == 'Ceric.1Z203700':
                    #print(degset,locus,keggSet)
    
                if kegg not in codeDict[degset]:
                    codeDict[degset][kegg] = set()

                codeDict[degset][kegg].add(locus)
                lociDict[degset].add(locus)

fi.close()

In [9]:
####################################
### Perform hypergeometric tests ###
####################################

df = pd.DataFrame(columns=['set','kegg','desc','x','N','n','M','pval','genelist'])

for degset in degsets:
    for kegg in codeDict[degset]:

        count = len(codeDict[degset][kegg])
        if count > highCount or count < lowCount:
            continue

        desc = ''
        if kegg in keggDict:
            desc = keggDict[kegg]

        # x is still the number of drawn "successes" (ie no. genes in set and in go category)
        x = len(codeDict[degset][kegg])
        genelist = ', '.join(codeDict[degset][kegg])

        # N is the sample size (ie no. genes in set)
        N = len(lociDict[degset])

        # n is the number of successes in the population (ie no. genes in go category)
        n = len(codeDict['total'][kegg])

        # M is the population size (ie no. genes total)
        M = len(lociDict['total'])

        # https://alexlenail.medium.com/understanding-and-implementing-the-hypergeometric-test-in-python-a7db688a7458
        # https://github.com/jdrudolph/goenrich
        pval = hypergeom.sf(x-1, M, n, N)
        df.loc[len(df.index)] = [degset,kegg,desc,x,N,n,M,pval,genelist]  


In [10]:
#########################################
### Adjust pvalues for multiple tests ###
#########################################
if len(df['pval'].tolist()) > 0:
    adjpval = multi.multipletests(df['pval'].tolist(), alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)[1]
    df['adjpval'] = adjpval
    df['cat'] = 'kegg'
    df['seqfreq'] = df['x'] / df['N']
    df['totalfreq'] = df['n'] / df['M']

    df = df[['set','kegg','cat','desc','x','N','seqfreq','n','M','totalfreq','pval','adjpval','genelist']]

    df.to_csv(outfile, sep='\t', index=False)
