To ensure that every samples has GO annotations that represent a complete subgraph of the Gene Ontology, we'll leverage the hierarchical structure of the GO DAG to include terms of any missing descendants.

In [None]:
from datasets import load_dataset

import obonet

import networkx as nx

from collections import Counter

from itertools import islice

import matplotlib.pyplot as plt

import json

all_dataset_path = "./dataset/all-filtered.jsonl"
mf_dataset_path = "./dataset/mf-filtered.jsonl"
bp_dataset_path = "./dataset/bp-filtered.jsonl"
cc_dataset_path = "./dataset/cc-filtered.jsonl"

all_expanded_path = "./dataset/all-expanded.jsonl"
mf_expanded_path = "./dataset/mf-expanded.jsonl"
bp_expanded_path = "./dataset/bp-expanded.jsonl"
cc_expanded_path = "./dataset/cc-expanded.jsonl"

min_go_terms = 2

top_k = 30

go_obo_path = "./dataset/go-basic.obo"

graph = obonet.read_obo(go_obo_path)

if not nx.is_directed_acyclic_graph(graph):
    raise ValueError("Invalid gene ontology graph, must not contain cycles.")

for name, dataset_path, expanded_path in [
    ("All", all_dataset_path, all_expanded_path),
    ("Molecular Function", mf_dataset_path, mf_expanded_path),
    ("Biological Process", bp_dataset_path, bp_expanded_path),
    ("Cellular Component", cc_dataset_path, cc_expanded_path),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    counter = Counter()

    with open(expanded_path, "w") as file:
        for record in dataset:
            subgraph = graph.subgraph(record["go_terms"])

            expanded_terms = set()

            for node in subgraph.nodes:
                expanded_terms.add(node)

                for descendant in nx.descendants(subgraph, node):
                    expanded_terms.add(descendant)

            expanded_terms = list(expanded_terms)

            if len(expanded_terms) < min_go_terms:
                continue

            for term in expanded_terms:
                counter[term] += 1
            
            record["go_terms"] = expanded_terms

            file.write(json.dumps(record) + "\n")

    counter = dict(sorted(counter.items(), key=lambda item: item[1], reverse=True))

    first_k = dict(islice(counter.items(), top_k))

    plt.figure(figsize=(12, 5)) 

    plt.bar(first_k.keys(), first_k.values())

    plt.title(f"Top {top_k} {name} Term Frequencies")
    plt.xticks(rotation=90)
    plt.xlabel("GO Term ID")
    plt.ylabel("Frequency")

    plt.show()

    print(f"Dataset saved to {expanded_path}")
                