# Train Seq2Fitness models for Alpha Amylase

This notebook shows how to train Seq2fitness models with a protein fitness dataset. Here we use alpha-amylase. 

In [None]:
%run setup_environment.py
import seq2fitness_models as models
import datasets
import seq2fitness_traintools as traintools
import seq2fitness_train as train
import sequence_utils
import torch.nn as nn
import time

In [None]:
# Reference sequence
ref_seq_amylase = """LTAPSIKSGTILHAWNWSFNTLKHNMKDIHDAGYTAIQTSPINQVKEGNQGDKSMSNWYWLYQPTSYQIGNRYLGTEQEFKEMCAAAEEYGIKVIVDAVINHTTSDYAAIS
NEVKSIPNWTHGNTPIKNWSDRWDVTQNSLSGLYDWNTQNTQVQSYLKRFLDRALNDGADGFRFDAAKHIELPDDGSYGSQFWPNITNTSAEFQYGEILQDSVSRDAAYANY
MDVTASNYGHSIRSALKNRNLGVSNISHYAVDVSADKLVTWVESHDTYANDDEESTWMSDDDIRLGWAVIASRSGSTPLFFSRPEGGGNGVRFPGKSQIGDRGSALFEDQAI
TAVNRFHNVMAGQPEELSNPNGNNQIFMNQRGSHGVVLANAGSSSVSINTATKLPDGRYDNKAGAGSFQVNDGKLTGTINARSVAVLYPD""".replace('\n','')

In [None]:
# Criteria are labels corresponding to columns in the spreadsheet containing the dataset
# Multi-task training is supported
# We use task weights to correspond to the relative sizes of the number of samples per task.

criteria = {
    'fitness_dp3_activity': {'loss': torch.nn.MSELoss(), 'weight': 2.0}, # only single mutants from NNK. Used weight of 2.0 first
    'fitness_stain_removal': {'loss': torch.nn.MSELoss(), 'weight': 1.0} # 1-8 mutations
}
print(criteria)


In [None]:
# Load precomputed ESM2 3B Scores
# You can compute ESM2-3B scores with the notebook (compute_esm2_3b_scores.ipynb)
static_score_path = "../data/alpha_amylase_esm2_3b_scores.xlsx" # Precomputed scores
amylase_ESM23B_scores = models.initialize_static_esm_scores(static_score_path, verbose=False)

In [None]:
# Define model hyperparameters
# These are default hyperparameters used throughout the paper
model_params = {
    'task_criteria': criteria,
    'k1': 32, # Filter of 1st conv.
    'k2': 32, # Filter of 2nd. conv.
    'dropout': 0.20,
    'quantiles': [0.01, 0.025, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 0.975, 0.99], # quantiles used to process embeddings
    'task_stats': {},
    'm1': 27, 
    'm2': 15, 
    'esm_scores_dim': 3, # total number of scores: mutant, wildtype, and from 3B model currently
    'ref_seq': ref_seq_amylase,
    'static_logit': amylase_ESM23B_scores,
    'esm_modelname': 'esm2_t33_650M_UR50D',
    'use_rel_embeddings': True,
    'use_rel_prob_matrices': False
}

In [None]:
# Define training parameters
training_params = {
    'dataset_path': '../data/AlphaAmylase_both_merged.csv',
    'batch_size': 800, 
    'epochs': 80, 
    'seed': 19,
    'dropout': 0.2,
    'lr': 1e-2,  
    'split_ratio': 0.8,
    'save_path': "../trained_models",
    'model_filename': f"Seq2Fitness_AAmylsase_",
    'weight_decay': 1.2e-3,
    'file_name': "Seq2Fitness_AAmylsase_" # for storing plots
}

In [None]:
start_time = time.time()
trainer = train.main(model_params, training_params, model_class="ProteinFunctionPredictor_with_probmatrix")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")