# How to split a dataset
The dataset is splits based on the transitive sequence identities. This means that sequences in the test set will have a low identity compared to the sequences in the training set. The amount of identity tolerated between both set is controlled by the `threshold` parameter. The mecanism works by building a graph where each node is a sequence and an edge is drawn between two nodes if their sequence identity is above the threshold. The leiden algorithm is used to detect communities in the graph. From this, we can extract clusters of sequences. The test set is built by sampling whole clusters. To obtain a perfect split, each cluster should be a connected component disconnected from the rest of the graph. This is usually not the case since a dataset usually contains a huge connected component and many small components. This is why we need the community detection algorithm. However, it is possible to obtain a near perfect split by removing bridge sequences. These sequences are identified as sequences that could be part of multiple clusters. We can find them and remove them. This is done automatically with the `post_filtering` argument.

There are multiple ways to build the test clusters:
- Random: Randomly sample clusters until the desired test size is reached. This method makes a test set in-distribution with respect to the training set. *This was found empirically with a linear model based on the ESM2's embeddings.*
- Maximize: Sample smallest cluster first to maximize the number of clusters in the test set, thus maximizing the diversity of the test set. Note that this method lead to a test set that is out-of-distribution with respect to the training set. *This was found empirically with a linear model based on the ESM2's embeddings.*
- Probabilistic: Sample cluster based on a probability distribution that favors smaller clusters but still allows larger clusters to be sampled. This is a smooth version of both previous methods. You can controll the smoothness of the distribution with the `temperature` parameter. A temperature of 0 will give the same results as the maximize method, while a temperature of `inf` will give the same results as the random method.



Note:
The split is not perfect as we rely on a deep learning model to evaluate the identity between sequences. The model is not perfect thus the split is not perfect either. However, it is usually good enough to train a model that can generalize well on unseen sequences.

In [None]:
from qmap.toolkit.split import train_test_split

# Imports for the example
import json


# Step 1: Load the training dataset.
# For this example, we will load the DBAASP dataset that is supposed to be already downloaded in the ../data/build folder
with open('../../../data/build/dataset.json', 'r') as f:
    dataset = json.load(f)
    # Filter out sequences that are too long because the aligner support sequences up to 100 amino acids long
    dataset = [sample for sample in dataset if len(sample["Sequence"]) < 100]

# Step 2: Split the dataset into train and validation sets.
sequences = [sample['Sequence'] for sample in dataset]
train_sequences, val_sequences, train_samples, val_samples = train_test_split(sequences, dataset,
                                                                              test_size=0.15, post_filtering=True)

In [None]:
# If you only want to split the sequences, you can also do this by only passing the sequences to the function:
train_sequences, val_sequences = train_test_split(sequences, test_size=0.15, post_filtering=True)