In this notebook, we'll filter the samples of the SwissProt dataset to only include those that have annotations that were experimentally verified or inferred by a human curator.

In [None]:
from datasets import load_dataset

from collections import Counter

import matplotlib.pyplot as plt

import json

all_dataset_path = "./dataset/all.jsonl"
mf_dataset_path = "./dataset/mf.jsonl"
bp_dataset_path = "./dataset/bp.jsonl"
cc_dataset_path = "./dataset/cc.jsonl"

all_filtered_path = "./dataset/all-filtered.jsonl"
mf_filtered_path = "./dataset/mf-filtered.jsonl"
bp_filtered_path = "./dataset/bp-filtered.jsonl"
cc_filtered_path = "./dataset/cc-filtered.jsonl"

min_terms = 1

CAFA5_EVIDENCE_CODES = {"EXP", "IDA", "IPI", "IMP", "IGI", "IEP", "HTP", "HDA", "HMP", "HGI", "HEP", "TAS", "IC"}

# Refer to https://geneontology.org/docs/guide-go-evidence-codes/.
admissible_evidence_codes = CAFA5_EVIDENCE_CODES

for subset_name, dataset_path, filtered_path in [
    ("All", all_dataset_path, all_filtered_path),
    ("Molecular Function", mf_dataset_path, mf_filtered_path),
    ("Biological Process", bp_dataset_path, bp_filtered_path),
    ("Cellular Component", cc_dataset_path, cc_filtered_path),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    counter = Counter()

    with open(filtered_path, "w") as file:
        for record in dataset:
            filtered_terms = []

            for go_term in record["terms"]:
                evidence_code = go_term["evidence_code"]

                if evidence_code in admissible_evidence_codes:
                    filtered_terms.append(go_term)

                    counter[evidence_code] += 1

            if len(filtered_terms) >= min_terms:
                record["terms"] = filtered_terms

                file.write(json.dumps(record) + "\n")

    plt.figure(figsize=(12, 5)) 

    plt.bar(counter.keys(), counter.values())

    plt.title(f"{subset_name} Evidence Code Frequencies")
    plt.xticks(rotation=90)
    plt.xlabel("Evidence Code")
    plt.ylabel("Frequency")

    plt.show()

    print(f"Dataset saved to {filtered_path}")
