# Guide to train-test splits in EUGENe
David Laub (last updated: *09/20/2023*)
***
**Description:**
This notebook is meant to serve as a guide to performin train-test splits in EUGENe. We currently offer 3 different ways to split your data into train and test sets.
- **Random split:** This is the simplest way to split your data. You simply specify the percentage of your data you want to use for training and the rest will be used for testing.
- **Chromosome split:** This method splits your data by chromosome. You specify the chromosomes you want to use for training and testing.
- **Sequence homology split:** This method splits your data by sequence homology. We use the graph-part package for this.

# Set-up

In [None]:
import os
import sys
import numpy as np
import subprocess
import seqdatasets
import seqdata as sd
from seqdatasets._utils import get_download_path, try_download_urls
from eugene.preprocess import train_test_random_split, train_test_chrom_split, train_test_homology_split

# Load data

In [None]:
sdata = seqdatasets.kopp21(task="binary", split="test")
sdata

# Random split

In [None]:
train_test_random_split(
    sdata=sdata, 
    dim='_sequence', 
    train_var="random_split",
    test_size=0.1,
    random_state=13
)

In [None]:
np.unique(sdata["random_split"].values, return_counts=True)

# Splitting on chromosomes

In [None]:
train_test_chrom_split(
    sdata=sdata,
    test_chroms=["chr3", "chr4"],
    train_var="chrom_split",
)

In [None]:
np.unique(sdata["chrom_split"].values, return_counts=True)

# Splitting on sequence homology

We will grab the first 10,000 sequences in this tutorial for the sake of time. Splitting by homology with 10,000 sequences takes ~6 minutes.

In [None]:
sdata_10k = sdata.isel(_sequence=slice(0, int(1e4)))

In [None]:
train_test_homology_split(
    sdata=sdata_10k,
    seq_var="seq",
    train_var="homology_split",
    test_size=0.1,
    nucleotide=True
)

In [None]:
np.unique(sdata_10k["homology_split"].values, return_counts=True)

# DONE!

---

# Scratch

In [None]:
import seqdata as sd
import xarray as xr
import numpy as np
import dask.array as da
import dask_ml as dml
from graph_part import train_test_validation_split

In [None]:
def train_test_split_chrom(sdata: xr.Dataset, test_chroms: list[str], train_var=None):
    test_mask = sdata.chrom.isin(test_chroms).compute()
    return sdata.assign()
    return sdata.sel(_sequence=~test_mask), sdata.sel(_sequence=test_mask)

In [None]:
splits = train_test_split_chrom(sdata, ['chr2', 'chr3'])
[s.sizes['_sequence'] for s in splits]

In [None]:
def train_test_split_random(sdata: xr.Dataset, dim: str, groups=None, test_size=0.1, random_state=None):
    splitter = dml.model_selection.ShuffleSplit(
        n_splits=1,
        test_size=test_size,
        random_state=random_state
    )
    train_idx, test_idx = next(splitter.split(da.arange(sdata.sizes[dim]), groups=groups))
    return sdata.isel({dim: train_idx}), sdata.isel({dim: test_idx})

In [None]:
def train_test_split_homology(sdata: xr.Dataset, seq_var: str, test_size=0.1, nucleotide=True):
    seq_length = sdata.sizes[sdata.attrs['length_dim']]
    outs = train_test_validation_split(
        sdata[seq_var].to_numpy().view(f'S{seq_length}').squeeze().astype('U').astype(object),
        test_size=test_size,
        initialization_mode='fast-nn',
        nucleotide=nucleotide,
        prefilter=True,
        denominator='shortest'
    )
    train_idx, test_idx = map(np.array, outs)
    return sdata.isel({sdata.attrs['sequence_dim']: train_idx}), sdata.isel({sdata.attrs['sequence_dim']: test_idx})

In [None]:
train, test = train_test_split_homology(sdata.isel(_sequence=slice(0, int(1e4))), 'seq')

Note: we don't demonstrate this for more than 100,000 sequences since this takes ~3.5 hours to run. The homology graph partitioning algorithm needs approximately $O(n^2)$ time so increasing the amount of sequences beyond this would be intractable for the purposes of a tutorial.

In [None]:
train, test = train_test_split_homology(sdata.isel(_sequence=slice(0, int(1e5))), 'seq')