Let's cluster the samples in the SwissProt GO dataset by their GO terms so we can use the cluster assignment to later do a stratified train/test split. We'll start by creating some embeddings for the terms.

In [None]:
from functools import partial

from datasets import load_dataset

import numpy as np

from sklearn.decomposition import TruncatedSVD
from sklearn.mixture import GaussianMixture

from collections import Counter

import matplotlib.pyplot as plt

all_dataset_path = "./dataset/all-expanded.jsonl"
mf_dataset_path = "./dataset/mf-expanded.jsonl"
bp_dataset_path = "./dataset/bp-expanded.jsonl"
cc_dataset_path = "./dataset/cc-expanded.jsonl"

num_strata = 100

min_explained_variance = 0.5

random_seed = 42

all_stratum_ids = {}
mf_stratum_ids = {}
bp_stratum_ids = {}
cc_stratum_ids = {}

new_svd = partial(TruncatedSVD, random_state=random_seed)

new_gmm = partial(
    GaussianMixture,
    n_components=100,
    covariance_type="spherical",
    random_state=random_seed,
    max_iter=200,
    verbose=2,
)

all_counter = Counter()
mf_counter = Counter()
bp_counter = Counter()
cc_counter = Counter()

for name, dataset_path, stratum_ids, try_dimensions, counter in [
    ("All", all_dataset_path, all_stratum_ids, (1024, 2048, 4096), all_counter),
    ("Molecular Function", mf_dataset_path, mf_stratum_ids, (256, 512, 1024), mf_counter),
    ("Biological Process", bp_dataset_path, bp_stratum_ids, (1024, 2048, 4096), bp_counter),
    ("Cellular Component", cc_dataset_path, cc_stratum_ids, (64, 128, 256), cc_counter),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    unique_terms = set()

    for record in dataset:
        for term in record["go_terms"]:
            unique_terms.add(term)

    term_index_mapping = {term: index for index, term in enumerate(unique_terms)}

    term_embeddings = {}

    for record in dataset:
        id = record["id"]

        embedding = np.zeros(len(term_index_mapping), dtype=np.int8)

        for term in record["go_terms"]:
            if term in term_index_mapping:
                index = term_index_mapping[term]

                embedding[index] = 1

        term_embeddings[id] = embedding

    print(f"Finding dimensionality for {name}")

    best_explained_variance = 0

    for dimensionality in try_dimensions:
        svd = new_svd(n_components=dimensionality)

        x = np.stack(list(term_embeddings.values()))
        
        svd.fit(x)
        
        explained_variance = np.sum(svd.explained_variance_ratio_)

        print(f"Dimensionality: {dimensionality}, Explained Variance: {explained_variance:.4f}")

        if explained_variance >= min_explained_variance:
            break

    z = svd.transform(x)

    for sequence_id, embedding in zip(term_embeddings.keys(), z):
        term_embeddings[sequence_id] = embedding

    print("\n")

    print(f"Clustering {name}")

    gmm = new_gmm()

    x = np.stack(list(term_embeddings.values()))

    predictions = gmm.fit_predict(x)

    for sequence_id, stratum_id in zip(term_embeddings.keys(), predictions):
        stratum_ids[sequence_id] = stratum_id

        counter[stratum_id] += 1

    print("\n")

    plt.figure(figsize=(12, 5)) 

    plt.bar(counter.keys(), counter.values())

    plt.title(f"{name} Stratum Frequencies")
    plt.xticks(rotation=90)
    plt.xlabel("Stratum ID")
    plt.ylabel("Frequency")

    plt.show()

    print("\n")

Lastly, add the term embeddings and stratum IDs to the dataset and write to disk.

In [None]:
import json

all_stratified_path = "./dataset/all-stratified.jsonl"
mf_stratified_path = "./dataset/mf-stratified.jsonl"
bp_stratified_path = "./dataset/bp-stratified.jsonl"
cc_stratified_path = "./dataset/cc-stratified.jsonl"

for dataset_path, stratified_path, stratum_ids, counter in [
    (all_dataset_path, all_stratified_path, all_stratum_ids, all_counter),
    (mf_dataset_path, mf_stratified_path, mf_stratum_ids, mf_counter),
    (bp_dataset_path, bp_stratified_path, bp_stratum_ids, bp_counter),
    (cc_dataset_path, cc_stratified_path, cc_stratum_ids, cc_counter),
]:
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    dropped = 0

    with open(stratified_path, "w") as file:
        for record in dataset:
            id = record["id"]
            stratum_id = stratum_ids[id]

            if counter[stratum_id] < 2:
                dropped += 1

                continue

            record["stratum_id"] = str(stratum_id)

            file.write(json.dumps(record) + "\n")

    print(f"Dropped {dropped:,} noise samples from dataset.")

    print(f"Dataset saved to {stratified_path}")

print("Done!")