# Introduction

This notebook will discuss how to use the various splitters in the ```multievolve``` package.

In [13]:
from multievolve.splitters import *

## Setting up

First, define the following variables, including a structure of the protein of interest:

- ```protein_name```: the name of the protein

- ```wt_file```: the path to the wildtype sequence

- ```training_dataset_fname```: the path to the training dataset

- ```structure_file```: the path to the structure file, either .pdb or .cif

In [14]:
protein_name = "example_protein"
wt_file = "../../data/example_protein/apex.fasta"
training_dataset_fname = '../../data/example_protein/example_dataset.csv'
structure_file = "../../data/example_protein/apex.cif"

## Refresher

As previously mentioned, each splitter has the following parameters:

- ```protein_name```: the name of the protein

- ```training_dataset_fname```: the path to the training dataset

- ```wt_file```: the path to the wildtype sequence

- ```csv_has_header```: whether the CSV has a header

- ```use_cache```: whether to cache the processed dataset for later use (default: ```False```)

- ```y_scaling```: whether to scale the property values between 0 and 1 (default: ```False```)

- ```val_split```: the proportion of the dataset to include in the validation set (default: ```None```). The validation set is only used for when training neural network models.

### KFoldProteinSplitter

```KFoldProteinSplitter```: Performs k-fold cross-validation by randomly splitting data into k folds.

In [15]:
kfold_splitter = KFoldProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

Unlike the other Splitters in which we run ```split_data()``` method, we obtain the processed datasets by running ```kfold_splitter.generate_splits(n_splits=5)```, where ```n_splits``` is the number of folds, in this case we perform 5-fold cross-validation. 

This returns a list of ```n_splits``` splitter objects, each with using a different fold for the test set.

In [16]:
splits = kfold_splitter.generate_splits(n_splits=5)

Again, if you check the ```splits``` attribute of one of the splitter objects, then you will see that the dataset has been split into training, validation, and test sets in the form of a dictionary.

In [None]:
splits[0].splits.keys()

### RoundProteinSplitter

```RoundProteinSplitter```: Splits data based on evolution rounds, allowing training on early rounds and testing on later rounds.

In [18]:
round_splitter = RoundProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

For ```RoundProteinSplitter```,```split_data()``` has the following arguments:
- ```max_train_round```: the maximum round number to include in the training set
- ```min_test_round```: the minimum round number to include in the test set

In [19]:
round_splitter.split_data(max_train_round=0, min_test_round=1)

### RandomProteinSplitter

```RandomProteinSplitter```: Randomly splits data into training and test sets with a specified test size.

In [20]:
random_splitter = RandomProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

For ```RandomProteinSplitter```,```split_data()``` has the following arguments:
- ```test_size```: the proportion of the dataset to include in the test set

In [21]:
random_splitter.split_data(test_size=0.2)

### PositionProteinSplitter

```PositionProteinSplitter```: Splits based on mutation positions - variants with mutations at certain positions go to test set.

In [22]:
position_splitter = PositionProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

For ```PositionProteinSplitter```,```split_data()``` has the following arguments:
- ```test_size_sample```: the proportion of the dataset to sample to get mutation positions to exclude out of the training set
- ```iter```: the number of iterations to perform to get a test set size between ```test_size_min``` and ```test_size_max```
- ```test_size_min```: the minimum test size set desired
- ```test_size_max```: the maximum test size set allowed

When splitting the data, the splitter will sample random mutations to get the mutation positions to exclude out of the training set. The splitter will attempt to get a test set size between a specified range, and will repeat sampling for specified number of iterations if the test set size is not within the desired range.

In [None]:
position_splitter.split_data(test_size_sample=0.2, iter=3, test_size_min=0.2, test_size_max=0.3)

### RegionProteinSplitter

```RegionProteinSplitter```: Splits based on protein regions - variants with mutations in specified regions go to test set.

In [24]:
region_splitter = RegionProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

For ```RegionProteinSplitter```,```split_data()``` has the following arguments:
- ```region```: a list of two numbers defining the minimum and maximum positions to include in the test set (e.g. [1, 60])

In [25]:
region_splitter.split_data(region=[1, 60])

### PropertyProteinSplitter

```PropertyProteinSplitter```: Splits based on property values - can separate high/low performing variants.

In [26]:
property_splitter = PropertyProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

For ```PropertyProteinSplitter```,```split_data()``` has the following arguments:
- ```property```: the value of the property split on
- ```above_or_below```: 'above' or 'below', values to leave out into the test set based on the given property value

In [27]:
property_splitter.split_data(property=1, above_or_below='above')

### MutLoadProteinSplitter

```MutLoadProteinSplitter```: Splits based on number of mutations - can train on low mutation count variants and test on higher ones.


In [28]:
mutload_splitter = MutLoadProteinSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None)

For ```MutLoadProteinSplitter```,```split_data()``` has the following arguments:
- ```max_train_muts```: the maximum mutation load to include in the training set
- ```min_test_muts```: the minimum mutation load to include in the test set

In [29]:
mutload_splitter.split_data(max_train_muts=1, min_test_muts=2)

### ResidueDistanceSplitter

```ResidueDistanceSplitter```: Splits based on 3D distances between mutations using protein structure.

When initializing ```ResidueDistanceSplitter```, we need to specify the additional arguments:
- ```pdb_file```: the path to the PDB/CIF structure file
- ```chain_ids```: the chain IDs of the protein of interest

In [30]:
residue_distance_splitter = ResidueDistanceSplitter(protein_name, training_dataset_fname, wt_file, csv_has_header=True, use_cache=False, y_scaling=False, val_split=None,
                                                    pdb_file=structure_file, chain_ids=['A'])

For ```ResidueDistanceSplitter```,```split_data()``` has the following arguments:
- ```percentile_threshold```: the percentile threshold for the distance to include in the training set
- ```min_test_muts```: the minimum number of mutations to include in the test set
- ```max_train_muts```: the maximum number of mutations to include in the training set
- ```randomized_control```: whether to randomize the distance dictionary as a control (default: ```False```)

When splitting the data, the splitter will calculate the distance percentile for each variant within its mutational load group. The splitter will consider variants with mutational load less than or equal to the ```max_train_muts``` for the training set. The splitter will then split the data based on the distance percentile, with variants with distances less than or equal to the percentile threshold going to the training set and variants with distances greater than the percentile threshold going to the test set. All variants will a mutational load higher than or equal to ```max_train_muts``` will go to the test set.

In [31]:
residue_distance_splitter.split_data(
        percentile_threshold=50,  
        min_test_muts=5,         
        max_train_muts=2,
        randomized_control=False
    )

# Multi-chain proteins

If you are working with multi-chain proteins such as antibodies that have a heavy variable domain and light variable domain, you can use the Splitters to accept both chains.



## Setting up

With multi-chain proteins, you need to:
- Specify the wild-type sequences for each chain in the ```wt_files``` argument as a list.

In [32]:
protein_name_multichain = "example_multichain_protein"
wt_files = ['../../data/example_multichain_protein/vh_chain1.fasta', '../../data/example_multichain_protein/vl_chain2.fasta']
training_dataset_fname_multichain = '../../data/example_multichain_protein/example_dataset.csv'
structure_file_multichain = 'multichain_protein.cif'
chain_ids = ['A', 'B']

### Formatting datasets for multi-chain proteins
For multi-chain datasets, the mutation strings for each chain are separated by a colon (e.g. ```F32Y:E61Y```). If a variant is wild-type for both chains, then the mutation string should be ```WT:WT```. If a variant is wild-type for one chain and has a mutation for the other chain, then the mutation string should be ```WT:F32Y```.

In [None]:
df = pd.read_csv(training_dataset_fname_multichain)
df.head()

### Example with ResidueDistanceSplitter

For the ```ResidueDistanceSplitter```, the arguments should be as follows:
- Specify the chain IDs in the ```chain_ids``` argument as a list.
- Specify the structure file in the ```pdb_file``` argument. This should be one structure containing all chains.

In [None]:
split = ResidueDistanceSplitter(protein_name_multichain, training_dataset_fname_multichain, wt_files, csv_has_header=True, use_cache=False, 
                 y_scaling=False,
                 val_split=None,
                 pdb_file=structure_file_multichain,
                 chain_ids=chain_ids,
                 random_state=0)

For multi-chain proteins, the Splitter will concatenate the sequences of each chain to get the full sequence. It will then automatically adjust the mutation positions for each chain to match the positions in the full concatenatedsequence.

In [None]:
split.data

In [33]:
residue_distance_splitter.split_data(
        percentile_threshold=50,  
        min_test_muts=5,         
        max_train_muts=2,
        randomized_control=False
    )