This notebook walks through all the steps required to run TranceptEVE on a deep mutational scan (DMS)
To run this notebook you need: 
* a csv with all the mutant sequences in the same format as the other proteingym datasets (i.e. they need a mutated_sequence column with all the mutated sequences and a DMS_score column with the experimental values)
* A multiple sequence alignment for the target protein of the DMS
* A copy of the Tranception checkpoint (small, medium or large) that you want to use
* To update the reference file with an additional row describing the DMS (required columns to fill in are "DMS_id","DMS_filename","MSA_filename","MSA_theta","MSA_start","MSA_end","weight_file_name", and "target_seq")

# Step 0: Alignment Generation
If you don't have an alignment for your target protein, there are several methods for generating one:
* [EVCouplings](evcouplings.org): You can use the online webserver at [evcouplings.org](evcouplings.org) or download the software and run it locally from [https://github.com/debbiemarkslab/EVcouplings](https://github.com/debbiemarkslab/EVcouplings)
* [ColabFold](https://github.com/sokrypton/ColabFold) also includes an MSA generation pipeline.
* [BLAST](https://blast.ncbi.nlm.nih.gov/Blast.cgi), the Basic Local Alignment Search Tool, also has a web server for generating alignments
* [MUSCLE](https://www.drive5.com/muscle5/) is another commonly used library for alignment generation

All the DMS alignments in the original ProteinGym paper were generated using EVCouplings

# Step 1: Training EVE models on wild type MSA 
The script at proteingym/baselines/EVE/train_VAE.py can be used to train EVE models on an alignment. Below is essentially the same code to run train_VAE.py that is present in scripts/scoring_DMS_zero_shot/training_EVE_models.sh

In [None]:
# This is the index of the DMS you want to run in the reference file csv.
DMS_index=0 # change this to whatever row your new DMS is on 
# You can train multiple EVE model with different seeds for initialization and then pass them all to TranceptEVE. 
# The log prior in TranceptEVE will then be the ensemble of all those models 
random_seeds = [0,1,2,3,4]

model_parameters_location='../proteingym/baselines/EVE/EVE/default_model_params.json'
training_logs_location='../proteingym/baselines/EVE/logs/'
DMS_reference_file_path="../reference_files/DMS_substitutions.csv"

# replace the below with the locations of the MSAs and assay csvs on your machine 

DMS_MSA_data_folder="Folder containing multiple sequence alignments in a2m format" 
# This is where the EVE models will be written out. The filenames are in the format input-msa-name_seed
DMS_EVE_model_folder="Folder where EVE models will be saved"
# if you don't already have weights here for the MSA, the EVE training script will generate them 
DMS_MSA_weights_folder="Folder where MSA weights will be saved" 

# Note that these models generally take a few hours to train, so it is likely easier to run the training_EVE_models.sh script mentioned above and 
# train several in parallel than to do them sequentially in a notebook here. 
for seed in random_seeds:
    command = f"../proteingym/baselines/EVE/train_VAE.py \
            --MSA_data_folder {DMS_MSA_data_folder} \
            --DMS_reference_file_path {DMS_reference_file_path} \
            --protein_index {DMS_index} \
            --MSA_weights_location {DMS_MSA_weights_folder} \
            --VAE_checkpoint_location {DMS_EVE_model_folder} \
            --model_parameters_location {model_parameters_location} \
            --training_logs_location {training_logs_location} \
            --threshold_focus_cols_frac_gaps 1 \
            --seed {seed} \
            --skip_existing \
            --experimental_stream_data"
    !python $command


# Step 2: Scoring with TranceptEVE
Now that the EVE models are trained, we can use them in conjunction with the downloaded Tranception checkpoint to run TranceptEVE. The below code is essentially the same as the script at scripts/scoring_DMS_zero_shot/scoring_TranceptEVE_substitutions.sh. The code for scoring indels is essentially the same, with two extra parameters (A flag for setting Tranception to indel mode, 
and a path to an installation of Clustal Omega which is an alignment tool used as part of the retrieval process with indels). 

In [None]:
# These values need to match those in the prior steps, so that the script finds the correct EVE models 
DMS_index=0 
random_seeds = [0,1,2,3,4]
DMS_MSA_data_folder="Folder containing MSA files in a2m format"
model_parameters_location='../proteingym/baselines/EVE/EVE/default_model_params.json'
DMS_EVE_model_folder="Folder containing EVE models"
DMS_MSA_weights_folder="Folder containing MSA weights"
DMS_reference_file_path="../reference_files/DMS_substitutions.csv"

# These are new for trancepteve 

inference_time_retrieval_type="TranceptEVE"
# This is the number of samples taken from each EVE model to generate the log prior. This is done at the start of the script and then cached
# so that later runs with the same EVE models don't have to recompute it. 
EVE_num_samples_log_proba=200000
# For long proteins, "sliding" rather than "optimal" may be ideal for this parameter. 
scoring_window="optimal" 

# These can be changed based on where the Tranception checkpoint and DMS data files are stored and where you want the output scores to write to 
DMS_data_folder="Folder containing DMS assay csvs"
checkpoint = "Tranception model checkpoint path"
output_scores_folder="Path to folder where scores will be saved" 

command = f"../proteingym/baselines/trancepteve/score_trancepteve.py \
                --checkpoint {checkpoint} \
                --DMS_reference_file_path {DMS_reference_file_path} \
                --DMS_data_folder {DMS_data_folder} \
                --DMS_index {DMS_index} \
                --output_scores_folder {output_scores_folder} \
                --inference_time_retrieval_type {inference_time_retrieval_type} \
                --MSA_folder {DMS_MSA_data_folder} \
                --MSA_weights_folder {DMS_MSA_weights_folder} \
                --EVE_num_samples_log_proba {EVE_num_samples_log_proba} \
                --EVE_model_parameters_location {model_parameters_location} \
                --EVE_model_folder {DMS_EVE_model_folder} \
                --scoring_window {scoring_window} \
                --EVE_seeds {" ".join(random_seeds)} \
                --EVE_recalibrate_probas"
!python $command