In this notebook we'll load the SwissProt GO dataset and visualize the distribution of the GO terms. Then we'll assign stratum IDs to each sample based on their similarity to other samples.

In [None]:
from datasets import load_dataset

from itertools import islice

from collections import Counter

import matplotlib.pyplot as plt

all_dataset_path = "./dataset/all-filtered.jsonl"
mf_dataset_path = "./dataset/mf-filtered.jsonl"
bp_dataset_path = "./dataset/bp-filtered.jsonl"
cc_dataset_path = "./dataset/cc-filtered.jsonl"

random_seed = 42

top_k = 30

all_term_counter = Counter()
mf_term_counter = Counter()
bp_term_counter = Counter()
cc_term_counter = Counter()

taxon_id_counter = Counter()

for subset_name, dataset_path, term_counter in [
    ("All", all_dataset_path, all_term_counter),
    ("Molecular Function", mf_dataset_path, mf_term_counter),
    ("Biological Process", bp_dataset_path, bp_term_counter),
    ("Cellular Component", cc_dataset_path, cc_term_counter),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    for record in dataset:
        for term in record["terms"]:
            term_counter[term] += 1

        taxon_id = record["taxon_id"]

        taxon_id_counter[taxon_id] += 1

    term_counter = dict(sorted(term_counter.items(), key=lambda item: item[1], reverse=True))

    first_k = dict(islice(term_counter.items(), top_k))

    plt.figure(figsize=(12, 5)) 

    plt.bar(first_k.keys(), first_k.values())

    plt.title(f"Top {top_k} {subset_name} Term Frequencies")
    plt.xticks(rotation=90)
    plt.xlabel("GO Term ID")
    plt.ylabel("Frequency")

    plt.show()

Let's cluster the samples in the SwissProt GO dataset by their GO terms so we can use the cluster assignment to later do a stratified train/test split. We'll start by creating some embeddings for the terms.

In [None]:
import numpy as np

from sklearn.decomposition import TruncatedSVD

from functools import partial

min_explained_variance = 0.5

all_terms_embeddings = {}
mf_terms_embeddings = {}
bp_terms_embeddings = {}
cc_terms_embeddings = {}

new_svd = partial(TruncatedSVD, n_iter=7, random_state=random_seed)

for dataset_path, terms_embeddings, term_counter in [
    (all_dataset_path, all_terms_embeddings, all_term_counter),
    (mf_dataset_path, mf_terms_embeddings, mf_term_counter),
    (bp_dataset_path, bp_terms_embeddings, bp_term_counter),
    (cc_dataset_path, cc_terms_embeddings, cc_term_counter),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    term_index_mapping = {term: index for index, term in enumerate(term_counter.keys())}

    for record in dataset:
        id = record["id"]

        template = np.zeros(len(term_index_mapping), dtype=np.int8)

        for term in record["terms"]:
            if term in term_index_mapping:
                index = term_index_mapping[term]

                template[index] = 1

        terms_embeddings[id] = template

    best_dimensionality = 0
    best_explained_variance = 0
    best_model = None

    for dimensionality in (32, 48, 64):
        svd = new_svd(n_components=dimensionality)

        x = np.stack(list(terms_embeddings.values()))
        
        svd.fit(x)
        
        explained_variance = np.sum(svd.explained_variance_ratio_)

        if explained_variance > best_explained_variance:
            best_dimensionality = dimensionality
            best_explained_variance = explained_variance
            best_model = svd

        if explained_variance >= min_explained_variance:
            break

    print(f"Best dimensionality: {best_dimensionality}")
    print(f"Explained variance ratio: {best_explained_variance:.2f}")

    print("\n")

    z = best_model.transform(x)

    for sequence_id, embedding in zip(terms_embeddings.keys(), z):
        terms_embeddings[sequence_id] = embedding

With the fresh embeddings, we'll cluster the sequences into strata of similar GO subgraphs.

In [None]:
from sklearn.cluster import KMeans

from functools import partial

num_strata = 100

all_stratum_ids = {}
mf_stratum_ids = {}
bp_stratum_ids = {}
cc_stratum_ids = {}

bp_stratum_counter = Counter()
cc_stratum_counter = Counter()
mf_stratum_counter = Counter()
all_stratum_counter = Counter()

new_kmeans = partial(KMeans, random_state=random_seed)

for terms_embeddings, stratum_ids, stratum_counter in [
    (all_terms_embeddings, all_stratum_ids, all_stratum_counter),
    (mf_terms_embeddings, mf_stratum_ids, mf_stratum_counter),
    (bp_terms_embeddings, bp_stratum_ids, bp_stratum_counter),
    (cc_terms_embeddings, cc_stratum_ids, cc_stratum_counter),
]:
    kmeans = new_kmeans(n_clusters=num_strata)

    x = np.stack(list(terms_embeddings.values()))

    kmeans.fit(x)

    print(f"K-means steps: {kmeans.n_iter_}")
    print(f"Inertia Loss: {kmeans.inertia_}")

    strata_ids = kmeans.predict(x)

    for sequence_id, stratum_id in zip(terms_embeddings.keys(), strata_ids):
        stratum_ids[sequence_id] = stratum_id

        stratum_counter[stratum_id] += 1

    print(f"Number of unique strata: {len(stratum_counter):,}")

    print("\n")

Let's plot the stratum counts to visualize how they are distributed.

In [None]:
for name, counter in [
    ("All", all_stratum_counter),
    ("Biological Process", bp_stratum_counter),
    ("Cellular Component", cc_stratum_counter),
    ("Molecular Function", mf_stratum_counter),
]:
    plt.figure(figsize=(12, 5)) 

    plt.bar(counter.keys(), counter.values())

    plt.title(f"{name} Stratum Frequencies")
    plt.xticks(rotation=90)
    plt.xlabel("Stratum ID")
    plt.ylabel("Frequency")

    plt.show()

Lastly, add the term embeddings and stratum IDs to the dataset and write to disk.

In [None]:
import json

all_stratified_path = "./dataset/all-stratified.jsonl"
mf_stratified_path = "./dataset/mf-stratified.jsonl"
bp_stratified_path = "./dataset/bp-stratified.jsonl"
cc_stratified_path = "./dataset/cc-stratified.jsonl"

for dataset_path, stratified_path, terms_embeddings, stratum_ids in [
    (all_dataset_path, all_stratified_path, all_terms_embeddings, all_stratum_ids),
    (mf_dataset_path, mf_stratified_path, mf_terms_embeddings, mf_stratum_ids),
    (bp_dataset_path, bp_stratified_path, bp_terms_embeddings, bp_stratum_ids),
    (cc_dataset_path, cc_stratified_path, cc_terms_embeddings, cc_stratum_ids),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    for record in dataset:
        id = record["id"]

        record["terms_embedding"] = terms_embeddings[id].tolist()
        record["stratum_id"] = stratum_ids[id]

    with open(stratified_path, "w") as file:
        for record in dataset:
            file.write(json.dumps(record) + "\n")

    print(f"Dataset saved to {stratified_path}")

print("Done!")