In [40]:
import os
import gzip
import numpy as np
import pandas as pd
from goatools.obo_parser import GODag
from goatools.anno.gaf_reader import GafReader
from goatools.goea.go_enrichment_ns import GOEnrichmentStudyNS
import seaborn as sns
import matplotlib.pyplot as plt
from adjustText import adjust_text
import textwrap
import requests
import matplotlib as mpl

In [None]:
gene_list = pd.read_csv('gene_list.txt', header=None, names=['gene'])
gene_list = gene_list['gene'].tolist()

In [None]:
if not os.path.exists("go-basic.obo"):
    from goatools.base import download_go_basic_obo
    download_go_basic_obo()
    # !wget http://purl.obolibrary.org/obo/go/go-basic.obo -O go-basic.obo
if not os.path.exists("goa_uniprot_all.gaf.gz"):
    !wget https://ftp.ebi.ac.uk/pub/databases/GO/goa/UNIPROT/goa_uniprot_all.gaf.gz
    
def filter_gaf_by_taxid(input_file, output_file, taxid):
    with gzip.open(input_file, 'rt') as f_in, open(output_file, 'w') as f_out:
        for line in f_in:
            if line.startswith('!'):
                f_out.write(line)
                continue
            fields = line.strip().split('\t')
            if len(fields) >= 13 and f"taxon:{taxid}" in fields[12]:
                f_out.write(line)

if not os.path.exists("goa_pst_filtered.gaf"):
    filter_gaf_by_taxid("goa_uniprot_all.gaf.gz", "goa_pst_filtered.gaf", 27350)

In [5]:
obodag = GODag("go-basic.obo")
gaf_reader = GafReader("goa_wheat_filtered.gaf", godag=obodag)
ns2assoc = gaf_reader.get_ns2assc()

background_uniprot = set(gaf_reader.get_id2gos().keys())

goaobj = GOEnrichmentStudyNS(
    background_uniprot,  # List of all UniProt IDs
    ns2assoc,            # UniProt/GO associations
    obodag,              # Ontologies
    propagate_counts=False,
    alpha=0.05,          # default significance cut-off
    methods=['fdr_bh']   # default multipletest correction method
)

go-basic.obo: fmt(1.2) rel(2025-02-06) 43,597 Terms
HMS:0:00:29.438952 1,103,693 annotations READ: goa_wheat_filtered.gaf 
212941 IDs in loaded association branch, biological_process

Load BP Ontology Enrichment Analysis ...
 62% 132,353 of 212,941 population items found in association

Load CC Ontology Enrichment Analysis ...
 59% 126,081 of 212,941 population items found in association

Load MF Ontology Enrichment Analysis ...
 79% 167,759 of 212,941 population items found in association


In [6]:
def run_go_enrichment(uniprot_set, name, goaobj):
    uniprot_ids = list(uniprot_set['Entry'])
    
    goa_results_all = goaobj.run_study(uniprot_ids)
    
    goa_results_sig = [r for r in goa_results_all if r.p_fdr_bh < 0.05]
    
    results_df = pd.DataFrame([
        {
            'protein': r.study_items,
            'GO_term': r.name, 
            'p_value': r.p_uncorrected,
            'p_fdr_bh': r.p_fdr_bh,
            'protein_count': r.study_count,
            'gene_set': name,
            'namespace': r.NS
        } 
        for r in goa_results_sig
    ])
    
    return results_df

In [None]:
results = run_go_enrichment(gene_list, 'gene_list', goaobj)


Runing BP Ontology Analysis: current study set of 100 IDs.
 78%     78 of    100 study items found in association
100%    100 of    100 study items found in population(212941)
Calculating 2,571 uncorrected p-values using fisher_scipy_stats
   2,571 terms are associated with 132,353 of 212,941 population items
      77 terms are associated with     78 of    100 study items
  METHOD fdr_bh:
      10 GO terms found significant (< 0.05=alpha) ( 10 enriched +   0 purified): statsmodels fdr_bh
      44 study items associated with significant GO IDs (enriched)
       0 study items associated with significant GO IDs (purified)

Runing CC Ontology Analysis: current study set of 100 IDs.
 97%     97 of    100 study items found in association
100%    100 of    100 study items found in population(212941)
Calculating 736 uncorrected p-values using fisher_scipy_stats
     736 terms are associated with 126,081 of 212,941 population items
      43 terms are associated with     97 of    100 study item

In [None]:
def faceted_bubble_plot(d, title):
    d = d[d['adjusted_p_value'] < 1]
    colors = {"BP": "#4059AD", "CC": "#F4B942", "MF": "#88CD4A"}
    
    d['GO'] = d['GO'].str.replace('.', '')
    
    yintercepts = pd.DataFrame({
        'source': ['BP', 'CC', 'MF'],
        'yintercept': [5, 15, 5]
    })
    
    d = pd.merge(d, yintercepts, on='source', how='left')
    
    sources = d['source'].unique()
    
    fig, axes = plt.subplots(1, len(sources), figsize=(10, 5), sharey=True)
    if len(sources) == 1:
        axes = [axes]
    
    for i, source in enumerate(sources):
        source_data = d[d['source'] == source]
        
        axes[i].scatter(
            x=range(len(source_data)),
            y=-np.log10(source_data['adjusted_p_value']),
            s=source_data['intersection_size'] * 20,
            c=colors[source],
            alpha=0.7,
            edgecolors=colors[source],
            linewidths=1
        )
        
        yintercept = yintercepts[yintercepts['source'] == source]['yintercept'].values[0]
        axes[i].axhline(y=yintercept, linestyle='-', color='gray', linewidth=0.2)
        
        texts = []
        for j, row in source_data.iterrows():
            if -np.log10(row['adjusted_p_value']) > yintercept:
                wrapped_term = '\n'.join(textwrap.wrap(row['term'], 15))
                texts.append(axes[i].text(j, -np.log10(row['adjusted_p_value']), 
                                         wrapped_term, fontsize=8, ha='center', va='center'))
        
        adjust_text(texts, ax=axes[i], arrowprops=dict(arrowstyle='-', color='black', lw=0.5))
        
        axes[i].set_title(source, fontsize=12, pad=10, 
                         backgroundcolor='#97D8C4', color='black')
        axes[i].set_xticks([])
        axes[i].set_frame_on(True)
        
        for spine in axes[i].spines.values():
            spine.set_visible(False)
        axes[i].spines['bottom'].set_visible(True)
        axes[i].spines['left'].set_visible(True)
        axes[i].spines['bottom'].set_linewidth(0.5)
        axes[i].spines['left'].set_linewidth(0.5)
    
    fig.text(0.04, 0.5, '-log10 (Padj)', va='center', rotation='vertical', fontsize=12)
    
    for ax in axes:
        ax.set_ylim(0, 60)
        ax.set_yticks(np.arange(0, 65, 5))
    
    fig.suptitle(title, fontsize=14, y=0.98)
    
    plt.tight_layout(rect=[0.05, 0, 1, 0.95])
    return fig

In [None]:
def go_category_bubble_plot(d, go_category, title):
    d = d[d['adjusted_p_value'] < 1]
    d = d[d['source'] == go_category]
    
    gene_sets = d['gene_set'].unique()
    d['gene_set'] = pd.Categorical(d['gene_set'], ordered=True)
    d = d.sort_values('gene_set')
    
    colors = plt.cm.tab20(np.linspace(0, 1, len(gene_sets)))
    markers = ['o', 's', 'D', 'h',]
    
    fig, ax = plt.subplots(figsize=(12, 6))
    all_texts = []
    
    for i, (gene_set, group) in enumerate(d.groupby('gene_set')):
        # Generate x positions first so we can use them for both points and labels
        x_positions = np.random.normal(i, 0.2, len(group))
        y_positions = -np.log10(group['adjusted_p_value'])
        
        # Plot all points for this gene set
        ax.scatter(
            x=x_positions,
            y=y_positions,
            s=group['intersection_size'] * 20,
            c=[colors[i]] * len(group),
            marker=markers[i % len(markers)],
            alpha=0.7,
            edgecolors=colors[i],
            linewidths=1,
            label=gene_set
        )
        
        if len(group) > 0:
            top_terms = group.nsmallest(min(3, len(group)), 'adjusted_p_value')
            
            highest_point_y = max(y_positions) if len(y_positions) > 0 else 0
            label_y = highest_point_y + 5
            
            for _, term_row in top_terms.iterrows():
                term_name = term_row['term']
                
                label_text = f"{term_name}"
                
                term_idx = group.index.get_loc(term_row.name)
                term_x = x_positions[term_idx]
                
                text = ax.text(
                    term_x, label_y, 
                    label_text, 
                    fontsize=8, 
                    ha='center', 
                    color=colors[i],
                    bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
                )
                all_texts.append(text)
                label_y += 2 
    
    # if len(all_texts) > 0:
    #     adjust_text(all_texts, ax=ax, arrowprops=dict(arrowstyle='->', color='gray', lw=0.5))
    
    ax.set_title(go_category, fontsize=14)
    ax.set_xlabel('')
    ax.set_ylabel('-log10 (Padj)', fontsize=12)
    ax.set_ylim(0, 60)
    ax.set_yticks(np.arange(0, 65, 5))
    ax.legend(title='Gene Sets', bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Set x-axis limits and ticks
    ax.set_xlim(-0.5, len(gene_sets) - 0.5)
    ax.set_xticks(range(len(gene_sets)))
    ax.set_xticklabels([])  # Hide x-tick labels
    
    plt.tight_layout()
    return fig


In [None]:
def format_go_results_for_bubble_plot(results_df):
    namespace_map = {
        'biological_process': 'BP',
        'cellular_component': 'CC',
        'molecular_function': 'MF'
    }
    
    formatted_df = pd.DataFrame({
        'term': results_df['GO_term'],
        'adjusted_p_value': results_df['p_fdr_bh'],
        'intersection_size': results_df['protein_count'],
        'gene_set': results_df['gene_set'],
        'source': results_df['namespace'],
        'GO': results_df['GO_id'] if 'GO_id' in results_df.columns else 'GO:0000000'
    })
    
    return formatted_df

formatted_results = format_go_results_for_bubble_plot(gene_list)

for go_category in ['BP', 'CC', 'MF']:
    subset = formatted_results[formatted_results['source'] == go_category]
    fig = go_category_bubble_plot(subset, go_category, f"{go_category} GO Enrichment")
    new_rc_params = {'text.usetex': False,"svg.fonttype": 'none'}
    mpl.rcParams.update(new_rc_params)
    plt.savefig(f"{go_category}_go_enrichment.svg")
    plt.close()

# for gene_set in ['gene_list']:
#     subset = formatted_results[formatted_results['gene_set'] == gene_set]
#     fig = faceted_bubble_plot(subset, f"{gene_set} GO Enrichment Analysis")
#     plt.savefig(f"{gene_set}_go_enrichment.png", dpi=300, bbox_inches='tight')
#     plt.close()

# fig = faceted_bubble_plot(formatted_results, "GO Enrichment Analysis by Gene Set")
# plt.savefig("all_go_enrichment.png", dpi=300, bbox_inches='tight')
# plt.show()





In [None]:
def get_top_terms(df, n=20):
    return df.sort_values('p_fdr_bh').head(n)

top_results = pd.concat([
    get_top_terms(gene_list)
])

# g = sns.FacetGrid(
#     top_results, 
#     col='gene_set', 
#     col_wrap=2, 
#     height=4, 
#     sharey=False
# )

# g = g.map(sns.barplot, 'p_fdr_bh', 'GO_term')

# g.set_axis_labels('FDR-corrected p-value', 'GO Term')
# g.set_titles('{col_name}')
# g.figure.tight_layout()
# plt.subplots_adjust(top=0.9)

# plt.show()

In [None]:
import plotly.express as px

top_results = pd.concat([
    get_top_terms(gene_list)
])

fig = px.bar(
    top_results,
    x='p_fdr_bh',
    y='GO_term',
    color='namespace',
    facet_col='gene_set',
    facet_col_wrap=2,
    height=2000,
    labels={'p_fdr_bh': 'FDR-corrected p-value', 'GO_term': 'GO Term', 'namespace': 'GO Namespace'},
    hover_data=['protein_count', 'p_value']
)

fig.update_yaxes(categoryorder='total ascending')
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

fig.show()
