# EDA 02 – GO Ontology Deep Dive (BPO/MFO/CCO)

Focus:
- Split labels and terms by subontology (BPO, MFO, CCO)
- Analyze IA distribution per ontology
- Compare term rarity and IA weighting across namespaces
- Identify most/least specific terms per branch
- Visualize ontology hierarchy samples

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import obonet
from collections import Counter

sns.set_context('talk')
sns.set_style('whitegrid')
np.random.seed(42)

# Paths
TRAIN_TERMS = Path('Train/train_terms.tsv')
GO_OBO = Path('Train/go-basic.obo')
IA_TSV = Path('IA.tsv')

FIG_DIR = Path('notebooks/figures')
FIG_DIR.mkdir(parents=True, exist_ok=True)
PROC_DIR = Path('data/processed')
PROC_DIR.mkdir(parents=True, exist_ok=True)

print('Python:', sys.version)
print('Paths ready:', [p.exists() for p in [TRAIN_TERMS, GO_OBO, IA_TSV]])

## 1. Load GO Graph and Extract Namespace Mapping

In [None]:
# Load GO ontology
go_graph = obonet.read_obo(str(GO_OBO))
print('Nodes:', len(go_graph), '| Edges:', go_graph.number_of_edges())

# Build term -> namespace mapping
term_to_ns = {}
ns_counts = Counter()
for node, data in go_graph.nodes(data=True):
    ns = data.get('namespace')
    if ns:
        term_to_ns[node] = ns
        ns_counts[ns] += 1

print('\nNamespace distribution in GO graph:')
for ns, cnt in ns_counts.most_common():
    print(f'  {ns}: {cnt} terms')

# Map full names
NS_MAP = {
    'biological_process': 'BPO',
    'molecular_function': 'MFO',
    'cellular_component': 'CCO'
}
term_to_ont = {t: NS_MAP.get(ns, ns) for t, ns in term_to_ns.items()}

# Root terms
ROOTS = {'BPO': 'GO:0008150', 'MFO': 'GO:0003674', 'CCO': 'GO:0005575'}
print('\nRoots present:', {k: v in go_graph for k, v in ROOTS.items()})

## 2. Load Training Labels and Join with Ontology

In [None]:
labels_df = pd.read_csv(TRAIN_TERMS, sep='\t')
print('Labels shape:', labels_df.shape)
print('Unique proteins:', labels_df['EntryID'].nunique())
print('Unique terms:', labels_df['term'].nunique())

# Add ontology column
labels_df['ontology'] = labels_df['term'].map(term_to_ont)
missing_ont = labels_df['ontology'].isna().sum()
if missing_ont:
    print(f'Warning: {missing_ont} annotations with unknown ontology')

labels_df = labels_df.dropna(subset=['ontology'])
print('After drop NA ontology:', labels_df.shape)

# Summary by ontology
ont_summary = labels_df.groupby('ontology').agg(
    proteins=('EntryID', 'nunique'),
    terms=('term', 'nunique'),
    annotations=('term', 'count')
).reset_index()
print('\nPer-ontology summary:')
display(ont_summary)

## 3. Visualize Annotation Distribution by Ontology

In [None]:
# Pie chart: annotation shares
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].pie(ont_summary['annotations'], labels=ont_summary['ontology'], autopct='%1.1f%%', startangle=90)
axes[0].set_title('Annotation Share by Ontology')

axes[1].pie(ont_summary['terms'], labels=ont_summary['ontology'], autopct='%1.1f%%', startangle=90)
axes[1].set_title('Unique Terms by Ontology')
plt.tight_layout()
plt.savefig(FIG_DIR / 'ontology_shares.png', dpi=150, bbox_inches='tight')
plt.show()

# Bar plot: terms per protein by ontology
terms_per_prot_ont = labels_df.groupby(['EntryID','ontology']).size().reset_index(name='count')

fig, ax = plt.subplots(figsize=(10, 5))
sns.boxplot(data=terms_per_prot_ont, x='ontology', y='count', ax=ax)
ax.set_yscale('log')
ax.set_ylabel('Terms per protein (log scale)')
ax.set_title('Terms-per-protein distribution by ontology')
plt.tight_layout()
plt.savefig(FIG_DIR / 'terms_per_protein_by_ont.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Load IA Weights and Merge with Labels

In [None]:
ia_df = pd.read_csv(IA_TSV, sep='\t', header=None, names=['term','ia'])
print('IA entries:', len(ia_df))
print('IA stats:')
display(ia_df['ia'].describe())

# Add ontology to IA
ia_df['ontology'] = ia_df['term'].map(term_to_ont)
ia_df = ia_df.dropna(subset=['ontology'])

# Merge term frequency from labels
term_freq = labels_df['term'].value_counts().rename_axis('term').reset_index(name='freq')
term_stats = term_freq.merge(ia_df, on='term', how='left')
term_stats['log_freq'] = np.log10(term_stats['freq'] + 1)

print('\nTerm stats with IA:')
display(term_stats.head(10))

# Save for later use
term_stats.to_csv(PROC_DIR / 'term_stats_with_ia.csv', index=False)
print(f'Saved: {PROC_DIR / "term_stats_with_ia.csv"}')

## 5. IA Distribution per Ontology

In [None]:
# IA histograms per ontology
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)
for i, ont in enumerate(['BPO','MFO','CCO']):
    subset = ia_df[ia_df['ontology'] == ont]
    axes[i].hist(subset['ia'], bins=40, edgecolor='k', alpha=0.7)
    axes[i].set_title(f'{ont} (n={len(subset)})')
    axes[i].set_xlabel('IA')
    if i == 0:
        axes[i].set_ylabel('Count')
fig.suptitle('IA Distribution by Ontology', y=1.02)
plt.tight_layout()
plt.savefig(FIG_DIR / 'ia_by_ontology.png', dpi=150, bbox_inches='tight')
plt.show()

# Summary stats
ia_by_ont = ia_df.groupby('ontology')['ia'].describe()
print('\nIA stats by ontology:')
display(ia_by_ont)

## 6. IA vs Term Rarity (Frequency) per Ontology

In [None]:
# Scatter: log(freq) vs IA, faceted by ontology
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, ont in enumerate(['BPO','MFO','CCO']):
    subset = term_stats[term_stats['ontology'] == ont].dropna(subset=['ia'])
    axes[i].scatter(subset['log_freq'], subset['ia'], alpha=0.5, s=10)
    axes[i].set_xlabel('log10(frequency + 1)')
    axes[i].set_title(f'{ont} (n={len(subset)})')
    if i == 0:
        axes[i].set_ylabel('IA')
    # Compute correlation
    if len(subset) > 1:
        from scipy.stats import spearmanr
        rho, pval = spearmanr(subset['log_freq'], subset['ia'])
        axes[i].text(0.05, 0.95, f'Spearman ρ={rho:.3f}\np={pval:.2e}',
                     transform=axes[i].transAxes, va='top', fontsize=10,
                     bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
fig.suptitle('IA vs Term Frequency by Ontology', y=1.02)
plt.tight_layout()
plt.savefig(FIG_DIR / 'ia_vs_freq_by_ont.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Identify Most/Least Specific Terms per Ontology

(Highest IA = most specific/informative; lowest IA = most general)

In [None]:
for ont in ['BPO','MFO','CCO']:
    subset = term_stats[term_stats['ontology'] == ont].dropna(subset=['ia']).copy()
    subset = subset.sort_values('ia', ascending=False)
    print(f'\n=== {ont} ===')
    print('Top 10 most specific (highest IA):')
    display(subset[['term','freq','ia']].head(10))
    print('\nTop 10 most general (lowest IA):')
    display(subset[['term','freq','ia']].tail(10))

## 8. Compute Term Depth from Root (Sample)

In [None]:
# Compute shortest path length from each term to its root
def compute_depths(graph, root, namespace_filter=None):
    """Returns dict {term: depth} for terms reachable from root"""
    if root not in graph:
        return {}
    lengths = nx.single_source_shortest_path_length(graph.reverse(), root)
    if namespace_filter:
        lengths = {t: d for t, d in lengths.items() 
                   if term_to_ont.get(t) == namespace_filter}
    return lengths

depth_dicts = {}
for ont_short, root in ROOTS.items():
    depth_dicts[ont_short] = compute_depths(go_graph, root, ont_short)
    print(f'{ont_short}: {len(depth_dicts[ont_short])} terms with depth computed')

# Merge depth into term_stats
for ont in ['BPO','MFO','CCO']:
    term_stats.loc[term_stats['ontology'] == ont, 'depth'] = \
        term_stats.loc[term_stats['ontology'] == ont, 'term'].map(depth_dicts[ont])

print('\nDepth distribution by ontology:')
display(term_stats.groupby('ontology')['depth'].describe())

## 9. IA vs Depth Correlation

In [None]:
# Scatter: depth vs IA per ontology
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, ont in enumerate(['BPO','MFO','CCO']):
    subset = term_stats[(term_stats['ontology'] == ont) & term_stats['depth'].notna() & term_stats['ia'].notna()]
    axes[i].scatter(subset['depth'], subset['ia'], alpha=0.5, s=10)
    axes[i].set_xlabel('Depth from root')
    axes[i].set_title(f'{ont} (n={len(subset)})')
    if i == 0:
        axes[i].set_ylabel('IA')
    # Correlation
    if len(subset) > 1:
        from scipy.stats import pearsonr
        r, pval = pearsonr(subset['depth'], subset['ia'])
        axes[i].text(0.05, 0.95, f'Pearson r={r:.3f}\np={pval:.2e}',
                     transform=axes[i].transAxes, va='top', fontsize=10,
                     bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
fig.suptitle('IA vs Depth by Ontology', y=1.02)
plt.tight_layout()
plt.savefig(FIG_DIR / 'ia_vs_depth_by_ont.png', dpi=150, bbox_inches='tight')
plt.show()

## 10. Visualize Sample Subgraph for Each Ontology

In [None]:
# For each ontology, extract a small induced subgraph near root
for ont_short, root in ROOTS.items():
    if root not in go_graph:
        continue
    # Get descendants at depth <= 3
    descendants = [n for n, d in depth_dicts[ont_short].items() if d <= 3]
    if len(descendants) > 50:
        descendants = descendants[:50]  # cap
    subg = go_graph.subgraph(descendants)
    
    plt.figure(figsize=(10, 8))
    pos = nx.spring_layout(subg, seed=42, k=0.5)
    nx.draw_networkx_nodes(subg, pos, node_size=50, node_color='lightblue')
    nx.draw_networkx_edges(subg, pos, alpha=0.3, arrows=True, arrowsize=10)
    # Label root only
    labels = {root: root}
    nx.draw_networkx_labels(subg, pos, labels, font_size=8)
    plt.title(f'{ont_short} Sample Subgraph (depth ≤ 3, up to 50 nodes)')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(FIG_DIR / f'go_subgraph_{ont_short}.png', dpi=150, bbox_inches='tight')
    plt.show()

## 11. Summary & Key Takeaways

In [None]:
print('Key findings:')
print('- BPO, MFO, and CCO have distinct IA and frequency distributions')
print('- IA generally increases with depth (more specific = higher IA)')
print('- IA inversely correlates with term frequency (rare terms = high IA)')
print('- Per-ontology modeling may capture namespace-specific patterns better')
print('\nNext steps:')
print('- Build separate classifiers per ontology or multi-task heads')
print('- Weight losses by IA to prioritize rare/specific terms')
print('- Implement label propagation (child -> ancestors) for hierarchical consistency')