In [None]:
import pathlib
import os
import Bio.SeqIO
import numpy as np
from tqdm import tqdm
import datasets
import matplotlib.pyplot as plt
from rag_esm.utils.hamming import *
import torch

path = input("Path to the 'OpenProteinSet_uniclust30-filtered' dataset (downloaded from aws):")

In [None]:
import torch.multiprocessing as mp
    
def gen(lst):
    for el in lst:
        try:
            # read all sequences from a3m file
            with open(path + f"/{el}/a3m/uniclust30.a3m", "r") as f:    
                data = [(record.description, str(record.seq)) for record in Bio.SeqIO.parse(f, "fasta")]
            sequences = [x[1] for x in data]
            seq_ids = [x[0] for x in data]
            # from each seq id get the PE score (protein existence)
            pes = [x.split("PE=")[1].split(" ")[0] if "PE=" in x else None for x in seq_ids]
            pes = [int(x) if (x !="") and x is not None else None for x in pes]
            # val, cnt = np.unique(pes, return_counts=True)
            # add msa to list of clusters
            assert len(sequences) > 0
            assert len(sequences) == len(pes)
        except:
            # print(f"Failed to parse {el}")
            # go to next cluster
            continue
        # compute hamming distance between sequences
        # check how much memory is available on gpu
        device="cuda:1"
        free_memory, total_memory = torch.cuda.mem_get_info(device)
        if free_memory < 10*1024**3:
            device = "cpu"
        
        hamming_distances = hamming_gpu(sequences, return_matrix=True, normalize=False, batch=100, device=device)
        # get closest sequences
        num = len(sequences)//10
        val, idx = torch.topk(hamming_distances, num, dim=1, largest=False, sorted=True)
        # yield sample
        yield {"cluster_id": el,
                "msa": np.array(sequences),
                "PE_scores": np.array(pes),
                "closest_sequences": idx[:,1:].numpy()}

def make_dataset_and_save_it_to_disk(path, debug=True, save_path=f"../../../data/example_dataset", list_test_ids=None):
    """
    Create a dataset from the given path and save it to disk.
    """
    def get_folders(path):
        # get all folders in given path
        p = pathlib.Path(path)
        return [os.path.basename(x) for x in p.iterdir() if x.is_dir()]
    
    # get all clusters folders
    lst = get_folders(path)
    if list_test_ids is not None:
        # get the test clusters as the overlap between the given list and the clusters
        test_clusters = set(x for x in list_test_ids if x in lst)
        train_clusters = set(x for x in lst if x not in test_clusters)
        assert len(test_clusters.intersection(train_clusters)) == 0, "Test clusters are in train clusters!"
        # split train_clusters in train and validation clusters
        test_size = len(test_clusters)
        train_clusters = list(train_clusters)
        np.random.shuffle(train_clusters)
        train_clusters, val_clusters = set(train_clusters[:-test_size]), set(train_clusters[-test_size:])
        assert len(train_clusters.intersection(val_clusters)) == 0, "Train and validation clusters overlap!"
        assert set(lst) == test_clusters.union(train_clusters.union(val_clusters)), "Some clusters are missing!"
        print("All tests passed!")
        print(f"Train clusters: {len(train_clusters)}, Validation clusters: {len(val_clusters)}, Test clusters: {len(test_clusters)}")
    else:
        test_size = int(input("Enter the test size as the number of clusters to use as test set: "))
        train_clusters = set(lst[:32]) if debug else set(lst[:-2*test_size])
        val_clusters = set(lst[32:32+test_size]) if debug else set(lst[-2*test_size:-test_size])
        test_clusters = set(lst[32+test_size:32+2*test_size]) if debug else set(lst[-test_size:])
        lst = list(train_clusters) + list(test_clusters) + list(val_clusters)
    train_clusters, val_clusters, test_clusters = list(train_clusters), list(val_clusters), list(test_clusters)
        
    signature = datasets.Features({"cluster_id": datasets.Value("string"),
                                   "msa": datasets.Sequence(datasets.Value("string")),
                                   "PE_scores": datasets.Sequence(datasets.Value("int64")),
                                   "closest_sequences": datasets.Sequence(datasets.Sequence(datasets.Value("int64")))})
    counter = []
    print("Creating dataset from generator...")
    ds_train = datasets.Dataset.from_generator(gen,
                                               features=signature,
                                               gen_kwargs={"lst": train_clusters},
                                               num_proc=80
                                               )
    ds_val = datasets.Dataset.from_generator(gen,
                                               features=signature,
                                               gen_kwargs={"lst": val_clusters},
                                               num_proc=min(80, len(val_clusters))
                                               )
    ds_test = datasets.Dataset.from_generator(gen,
                                               features=signature,
                                               gen_kwargs={"lst": test_clusters},
                                               num_proc=min(80, len(test_clusters))
                                               )
    print("Length of train dataset: ", len(ds_train), "Length of validation dataset: ", len(ds_val), "Length of test dataset: ", len(ds_test))
    print("Done!")
    
    # create a dataset from the list of clusters
    # ds_train = datasets.Dataset.from_dict(all_clusters["train"])
    # test_set = datasets.Dataset.from_dict(all_clusters["test"])
    # Split the dataset into training, test, and validation sets
    # train_val_split = ds_train.train_test_split(test_size=test_size)
    # train_set = train_val_split['train']
    # val_set = train_val_split['test']
    
    # make a DatasetDict
    dataset_dict = datasets.DatasetDict({"train": ds_train, "val": ds_val, "test": ds_test})
    # Save it to disk
    dataset_dict.save_to_disk(save_path, max_shard_size="1GB")
    del ds_train, ds_val, ds_test, dataset_dict

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)  
    # make_dataset_and_save_it_to_disk(path, debug=True, save_path=f"../../../data/example_dataset")
    list_test_ids = None
    make_dataset_and_save_it_to_disk(path,
                                    debug=True,
                                    save_path=f"../../../data/OpenProteinSet_uniclust30-filtered_rag-esm",
                                    list_test_ids=list_test_ids)

In [None]:
# load dataset
import datasets

ds = datasets.load_from_disk("../../../data/OpenProteinSet_uniclust30-filtered_rag-esm")["train"]