Let's cluster the samples in the SwissProt GO dataset by their GO terms so we can use the cluster assignment to later do a stratified train/test split. We'll start by creating some embeddings for the terms.

In [None]:
from functools import partial

from datasets import load_dataset

import numpy as np

from sklearn.decomposition import TruncatedSVD

all_dataset_path = "./dataset/all-expanded.jsonl"
mf_dataset_path = "./dataset/mf-expanded.jsonl"
bp_dataset_path = "./dataset/bp-expanded.jsonl"
cc_dataset_path = "./dataset/cc-expanded.jsonl"

random_seed = 42

min_explained_variance = 0.5

all_terms_embeddings = {}
mf_terms_embeddings = {}
bp_terms_embeddings = {}
cc_terms_embeddings = {}

new_svd = partial(TruncatedSVD, n_iter=7, random_state=random_seed)

for dataset_path, terms_embeddings in [
    (all_dataset_path, all_terms_embeddings),
    (mf_dataset_path, mf_terms_embeddings),
    (bp_dataset_path, bp_terms_embeddings),
    (cc_dataset_path, cc_terms_embeddings),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    unique_terms = set()

    for record in dataset:
        for term in record["terms"]:
            unique_terms.add(term)

    term_index_mapping = {term: index for index, term in enumerate(unique_terms)}

    for record in dataset:
        id = record["id"]

        embedding = np.zeros(len(term_index_mapping), dtype=np.int8)

        for term in record["terms"]:
            if term in term_index_mapping:
                index = term_index_mapping[term]

                embedding[index] = 1

        terms_embeddings[id] = embedding

    best_dimensionality = 0
    best_explained_variance = 0
    best_model = None

    for dimensionality in (32, 48, 64):
        svd = new_svd(n_components=dimensionality)

        x = np.stack(list(terms_embeddings.values()))
        
        svd.fit(x)
        
        explained_variance = np.sum(svd.explained_variance_ratio_)

        if explained_variance > best_explained_variance:
            best_dimensionality = dimensionality
            best_explained_variance = explained_variance
            best_model = svd

        if explained_variance >= min_explained_variance:
            break

    z = best_model.transform(x)

    for sequence_id, embedding in zip(terms_embeddings.keys(), z):
        terms_embeddings[sequence_id] = embedding

    print(f"Best dimensionality: {best_dimensionality}")
    print(f"Explained variance ratio: {best_explained_variance:.2f}")

    print("\n")

With the fresh embeddings, we'll cluster the sequences into strata of similar GO subgraphs.

In [None]:
from sklearn.cluster import KMeans

from collections import Counter

from functools import partial

import matplotlib.pyplot as plt

num_strata = 100

all_stratum_ids = {}
mf_stratum_ids = {}
bp_stratum_ids = {}
cc_stratum_ids = {}

new_kmeans = partial(KMeans, random_state=random_seed)

for name, terms_embeddings, stratum_ids in [
    ("All", all_terms_embeddings, all_stratum_ids),
    ("Molecular Function", mf_terms_embeddings, mf_stratum_ids),
    ("Biological Process", bp_terms_embeddings, bp_stratum_ids),
    ("Cellular Component", cc_terms_embeddings, cc_stratum_ids),
]:
    kmeans = new_kmeans(n_clusters=num_strata)

    x = np.stack(list(terms_embeddings.values()))

    kmeans.fit(x)

    strata_ids = kmeans.predict(x)

    counter = Counter()

    for sequence_id, stratum_id in zip(terms_embeddings.keys(), strata_ids):
        stratum_ids[sequence_id] = stratum_id

        counter[stratum_id] += 1

    plt.figure(figsize=(12, 5)) 

    plt.bar(counter.keys(), counter.values())

    plt.title(f"{name} Stratum Frequencies")
    plt.xticks(rotation=90)
    plt.xlabel("Stratum ID")
    plt.ylabel("Frequency")

    plt.show()

    print(f"K-means steps: {kmeans.n_iter_}")
    print(f"Inertia Loss: {kmeans.inertia_}")

    print("\n")

Lastly, add the term embeddings and stratum IDs to the dataset and write to disk.

In [None]:
import json

all_stratified_path = "./dataset/all-stratified.jsonl"
mf_stratified_path = "./dataset/mf-stratified.jsonl"
bp_stratified_path = "./dataset/bp-stratified.jsonl"
cc_stratified_path = "./dataset/cc-stratified.jsonl"

for dataset_path, stratified_path, terms_embeddings, stratum_ids in [
    (all_dataset_path, all_stratified_path, all_terms_embeddings, all_stratum_ids),
    (mf_dataset_path, mf_stratified_path, mf_terms_embeddings, mf_stratum_ids),
    (bp_dataset_path, bp_stratified_path, bp_terms_embeddings, bp_stratum_ids),
    (cc_dataset_path, cc_stratified_path, cc_terms_embeddings, cc_stratum_ids),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    for record in dataset:
        id = record["id"]

        record["terms_embedding"] = terms_embeddings[id].tolist()
        record["stratum_id"] = stratum_ids[id]

    with open(stratified_path, "w") as file:
        for record in dataset:
            file.write(json.dumps(record) + "\n")

    print(f"Dataset saved to {stratified_path}")

print("Done!")