# Train Seq2Fitness models for Alpha Amylase

This notebook shows how to train Seq2fitness models with a protein fitness dataset. Here we use alpha-amylase. 

In [5]:
# %run setup_environment.py
import badass.models.seq2fitness_models as models
import badass.data.datasets as datasets
import badass.training.seq2fitness_traintools as traintools
import badass.training.seq2fitness_train as train
import badass.utils.sequence_utils as sequence_utils
import torch.nn as nn
import time

import torch

In [6]:
# Reference sequence
ref_seq_amylase = """LTAPSIKSGTILHAWNWSFNTLKHNMKDIHDAGYTAIQTSPINQVKEGNQGDKSMSNWYWLYQPTSYQIGNRYLGTEQEFKEMCAAAEEYGIKVIVDAVINHTTSDYAAIS
NEVKSIPNWTHGNTPIKNWSDRWDVTQNSLSGLYDWNTQNTQVQSYLKRFLDRALNDGADGFRFDAAKHIELPDDGSYGSQFWPNITNTSAEFQYGEILQDSVSRDAAYANY
MDVTASNYGHSIRSALKNRNLGVSNISHYAVDVSADKLVTWVESHDTYANDDEESTWMSDDDIRLGWAVIASRSGSTPLFFSRPEGGGNGVRFPGKSQIGDRGSALFEDQAI
TAVNRFHNVMAGQPEELSNPNGNNQIFMNQRGSHGVVLANAGSSSVSINTATKLPDGRYDNKAGAGSFQVNDGKLTGTINARSVAVLYPD""".replace('\n','')

In [7]:
# Criteria are labels corresponding to columns in the spreadsheet containing the dataset
# Multi-task training is supported
# We use task weights to correspond to the relative sizes of the number of samples per task.

criteria = {
    'fitness_dp3_activity': {'loss': torch.nn.MSELoss(), 'weight': 2.0}, # only single mutants from NNK. Used weight of 2.0 first
    'fitness_stain_removal': {'loss': torch.nn.MSELoss(), 'weight': 1.0} # 1-8 mutations
}
print(criteria)


{'fitness_dp3_activity': {'loss': MSELoss(), 'weight': 2.0}, 'fitness_stain_removal': {'loss': MSELoss(), 'weight': 1.0}}


In [10]:
# Load precomputed ESM2 3B Scores
# You can compute ESM2-3B scores with the notebook (Compute ESM2-3B Scores.ipynb)
static_score_path = "../data/aAmyl_esm2_3B_scores.xlsx" # Precomputed scores
amylase_ESM23B_scores = models.initialize_static_esm_scores(static_score_path, verbose=False)

In [11]:
# Define model hyperparameters
# These are default hyperparameters used throughout the paper
model_params = {
    'task_criteria': criteria,
    'k1': 32, # Filter of 1st conv.
    'k2': 32, # Filter of 2nd. conv.
    'dropout': 0.20,
    'quantiles': [0.01, 0.025, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 0.975, 0.99], # quantiles used to process embeddings
    'task_stats': {},
    'm1': 27, 
    'm2': 15, 
    'esm_scores_dim': 3, # total number of scores: mutant, wildtype, and from 3B model currently
    'ref_seq': ref_seq_amylase,
    'static_logit': amylase_ESM23B_scores,
    'esm_modelname': 'esm2_t33_650M_UR50D',
    'use_rel_embeddings': True,
    'use_rel_prob_matrices': False
}

In [12]:
# Define training parameters
training_params = {
    'dataset_path': '../data/AlphaAmylase_both_merged.csv',
    'batch_size': 800, 
    'epochs': 80, 
    'seed': 19,
    'dropout': 0.2,
    'lr': 1e-2,  
    'split_ratio': 0.8,
    'save_path': "../trained_models",
    'model_filename': f"Seq2Fitness_AAmylsase_",
    'weight_decay': 1.2e-3,
    'file_name': "Seq2Fitness_AAmylsase_" # for storing plots
}

In [None]:
start_time = time.time()
trainer = train.main(model_params, training_params, model_class="ProteinFunctionPredictor_with_probmatrix")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

Creating model of class ProteinFunctionPredictor_with_probmatrix.
Total number of trainable parameters in the model: 125470
Number of trainable parameters in the model excluding ESM: 125470
Found 1 wildtype sequences in dataset.


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  dataframe['sequence'].replace(np.nan, "NA", inplace=True) # Make wt NA for now.


train has 8578 unique sequences out of 10722.
After splitting, trainset has 8578 sequences, and test has 2144.


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  dataframe['sequence'].replace("NA", "", inplace=True) # Now make WT empty string


Number of data points with all NaNs after normalization: 0
Number of data points with all NaNs after normalization: 0
Task stats used for normalization are : {'task_means': {'fitness_dp3_activity': 0.743588245116046, 'fitness_stain_removal': 0.054606376970788045}, 'task_stds': {'fitness_dp3_activity': 0.6666848597149097, 'fitness_stain_removal': 0.05645892648933254}}.
We'll train for 80 epochs.
We'll use 2 GPUs through DataParallel.
Will save models to ../trained_models
Spearman correlation for fitness_dp3_activity: -0.3779
Spearman correlation for fitness_stain_removal: 0.5884
Updated model checkpoint - val_loss: 0.05579641353002818, epoch: 1
Epoch 1: train_Loss=2.2709, val_loss=0.7443, corr=-0.0558, l_rate=1.0e-02
Spearman correlation for fitness_dp3_activity: 0.2669
Spearman correlation for fitness_stain_removal: 0.5439
Updated model checkpoint - val_loss: -0.3592427250719978, epoch: 2
Epoch 2: train_Loss=1.1440, val_loss=0.7574, corr=0.3592, l_rate=1.0e-02
Spearman correlation for 