# Novozyme Enzyme Stability Prediction

This notebook contains model training and evaluation to predict the thermal stability (as measured via melting point) of enzymes based on their amino acid sequence.

Competition details are available [here](https://www.kaggle.com/competitions/novozymes-enzyme-stability-prediction/overview).

Prepared for SCS3546 - Deep Learning

<pre> Christopher Eeles </pre>

<pre> X361483 </pre>

Please note that the dependencies for this notebook are available in a Conda
environment file on GitHub under `ChristopherEeles/enzyme_thermal_stability_prediction/env`

## Dataset Download

Retrieve the dataset from Kaggle via the Kaggle API utility

In [1]:
from pathlib import Path
from shlib import Cmd
import zipfile as zip

In [2]:
# Path constants
DATA_DIR = Path("rawdata")
METADATA_DIR = Path("metadata")
LOG_DIR = Path("logs")
RESULT_DIR = Path("results")

# Kaggle constants
COMPETITION_NAME = "novozymes-enzyme-stability-prediction"

In [3]:
# Initialize project directories
for d in (DATA_DIR, METADATA_DIR, LOG_DIR, RESULT_DIR):
    d.mkdir(parents=True, exist_ok=True)

In [4]:
# Download competition data
download_competition_files = Cmd(["kaggle", "competitions", "download", "-c", 
    COMPETITION_NAME, "-p", DATA_DIR])
download_competition_files.run()
dataset_file = sorted(DATA_DIR.glob(f"{COMPETITION_NAME}.*"))
dataset_file

Downloading novozymes-enzyme-stability-prediction.zip to rawdata



100%|██████████| 7.06M/7.06M [00:00<00:00, 49.4MB/s]


[PosixPath('rawdata/novozymes-enzyme-stability-prediction.zip')]

In [5]:
with zip.ZipFile(dataset_file[0].resolve()) as z:
    z.extractall(path=DATA_DIR)
    dataset_file[0].unlink()

In [6]:
dataset_files = sorted(DATA_DIR.glob("*"))
dataset_files

[PosixPath('rawdata/sample_submission.csv'),
 PosixPath('rawdata/test.csv'),
 PosixPath('rawdata/train.csv'),
 PosixPath('rawdata/train_updates_20220929.csv'),
 PosixPath('rawdata/wildtype_structure_prediction_af2.pdb')]

## Data Exploration

Before we begin modelling we will have a look at the files available for the Novozyme competition to see what kind of features are available to help with our task.

In [7]:
from biopandas.pdb import PandasPdb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [8]:
TRAIN_PATH = dataset_files[2]
TEST_PATH = dataset_files[1]
SAMPLE_SUBMISSION = dataset_files[0]
TRAIN_UPDATE_PATH = dataset_files[3]
TRAIN_PDB_PATH = dataset_files[4]

In [9]:
# Load available csv files
train_df = pd.read_csv(TRAIN_PATH)
test_df = pd.read_csv(TEST_PATH)
sample_submission_df = pd.read_csv(SAMPLE_SUBMISSION)
train_update_df = pd.read_csv(TRAIN_UPDATE_PATH)

In [10]:
# Use Biopythons Biopandas to load the PDB protein structure file
pdb_df = PandasPdb().read_pdb(str(TRAIN_PDB_PATH))

In [11]:
# Get it PDB file into a Python native format
protein_struct_df_dict = pdb_df.df
protein_struct_df_dict.keys()

dict_keys(['ATOM', 'HETATM', 'ANISOU', 'OTHERS'])

In [12]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 31390 entries, 0 to 31389
Data columns (total 5 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   seq_id            31390 non-null  int64  
 1   protein_sequence  31390 non-null  object 
 2   pH                31104 non-null  float64
 3   data_source       28043 non-null  object 
 4   tm                31390 non-null  float64
dtypes: float64(2), int64(1), object(2)
memory usage: 1.2+ MB


In [13]:
## NOTE: tm column is melting point in Celsius (C)
train_df.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,0,AAAAKAAALALLGEAPEVVDIWLPAGWRQPFRVFRLERKGDGVLVG...,7.0,doi.org/10.1038/s41592-020-0801-4,75.7
1,1,AAADGEPLHNEEERAGAGQVGRSLPQESEEQRTGSRPRRRRDLGSR...,7.0,doi.org/10.1038/s41592-020-0801-4,50.5
2,2,AAAFSTPRATSYRILSSAGSGSTRADAPQVRRLHTTRDLLAKDYYA...,7.0,doi.org/10.1038/s41592-020-0801-4,40.5
3,3,AAASGLRTAIPAQPLRHLLQPAPRPCLRPFGLLSVRAGSARRSGLL...,7.0,doi.org/10.1038/s41592-020-0801-4,47.2
4,4,AAATKSGPRRQSQGASVRTFTPFYFLVEPVDTLSVRGSSVILNCSA...,7.0,doi.org/10.1038/s41592-020-0801-4,49.5


In [14]:
train_update_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2434 entries, 0 to 2433
Data columns (total 5 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   seq_id            2434 non-null   int64  
 1   protein_sequence  25 non-null     object 
 2   pH                25 non-null     float64
 3   data_source       0 non-null      float64
 4   tm                25 non-null     float64
dtypes: float64(3), int64(1), object(1)
memory usage: 95.2+ KB


In [15]:
# There were some data quality issues, need to drop NaN rows and update some pH and tm values
# See: https://www.kaggle.com/competitions/novozymes-enzyme-stability-prediction/discussion/356251
train_update_df.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,69,,,,
1,70,,,,
2,71,,,,
3,72,,,,
4,73,,,,


In [16]:
# Extract rows which need updating in train data
bad_seq_ids = train_update_df.seq_id.values
bad_seq_ids

array([   69,    70,    71, ..., 30740, 30741, 30742])

In [17]:
# Drop those rows from the training data and append the updated rows
train_df_fix = train_df.loc[~train_df.seq_id.isin(bad_seq_ids), :]
train_df_fix = (pd.concat([train_df_fix, train_update_df])
    .sort_values(by="seq_id"))

In [18]:
train_df_fix.head()

Unnamed: 0,seq_id,protein_sequence,pH,data_source,tm
0,0,AAAAKAAALALLGEAPEVVDIWLPAGWRQPFRVFRLERKGDGVLVG...,7.0,doi.org/10.1038/s41592-020-0801-4,75.7
1,1,AAADGEPLHNEEERAGAGQVGRSLPQESEEQRTGSRPRRRRDLGSR...,7.0,doi.org/10.1038/s41592-020-0801-4,50.5
2,2,AAAFSTPRATSYRILSSAGSGSTRADAPQVRRLHTTRDLLAKDYYA...,7.0,doi.org/10.1038/s41592-020-0801-4,40.5
3,3,AAASGLRTAIPAQPLRHLLQPAPRPCLRPFGLLSVRAGSARRSGLL...,7.0,doi.org/10.1038/s41592-020-0801-4,47.2
4,4,AAATKSGPRRQSQGASVRTFTPFYFLVEPVDTLSVRGSSVILNCSA...,7.0,doi.org/10.1038/s41592-020-0801-4,49.5


In [19]:
# Sanity check the non NaN columns got updated correctly
assert all(train_df_fix.iloc[bad_seq_ids].pH.dropna() == train_update_df.pH.dropna())

In [20]:
# Drop columns with NaN in the tm column, since that is our target in modelling
train_df2 = train_df_fix.loc[~train_df_fix.tm.isna(), ]

In [21]:
train_df2.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 28981 entries, 0 to 31389
Data columns (total 5 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   seq_id            28981 non-null  int64  
 1   protein_sequence  28981 non-null  object 
 2   pH                28695 non-null  float64
 3   data_source       28001 non-null  object 
 4   tm                28981 non-null  float64
dtypes: float64(2), int64(1), object(2)
memory usage: 1.3+ MB


In [22]:
# Due to memory constraints need to only work with shorter proteins!
train_df2 = train_df2.loc[train_df2.protein_sequence.apply(lambda x: len(x)) < 1000, :]

In [23]:
train_df2.shape

(27029, 5)

In [24]:
train_df2.protein_sequence.apply(lambda x: len(x)).max()  # Check longest sequence

999

In [25]:
# Write train/test data as a FASTA sequence file, makes it easier to make DataLoader for PyTorch
TRAIN_FASTA = RESULT_DIR / "train.fasta"
with TRAIN_FASTA.open("w", encoding="utf-8") as f:
    for i in range(train_df2.shape[0]):
        row = train_df2.iloc[i, :]
        f.write(f">{row.seq_id}|{row.pH}|{row.tm}\n")
        f.write(f"{row.protein_sequence}\n")

In [26]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2413 entries, 0 to 2412
Data columns (total 4 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   seq_id            2413 non-null   int64 
 1   protein_sequence  2413 non-null   object
 2   pH                2413 non-null   int64 
 3   data_source       2413 non-null   object
dtypes: int64(2), object(2)
memory usage: 75.5+ KB


In [27]:
# Check for duplicated proteins in test set!
assert len(test_df.protein_sequence.unique()) == test_df.shape[0]

In [28]:
TEST_FASTA = RESULT_DIR / "test.fasta"
with TEST_FASTA.open("w", encoding="utf-8") as f:
    for i in range(test_df.shape[0]):
        row = test_df.iloc[i, :]
        f.write(f">{row.seq_id}|{row.pH}|\n")
        f.write(f"{row.protein_sequence}\n")

## Modelling

Given the relatively small training set of proteins for this competition, it
is likely optimal to leverage an existing model via transfer learning to ensure
we can extract sufficient biochemical context from the protein sequences to
usefully rank the proteins by thermal stability.

Based on a brief review of the literature, I identifier the 
Evolutionary Scale Modelling (ESM) tranformer architecture from Facebook Research
as a potential candidate for transfer learning to the thermal stability task.
Pre-trained versions of their model are available via their PyPI package
`fair-esm`. The model is implemented in PyTorch and given my limited exposure
to this deep learning framework I spent a long time in trail and error
before I could succesfully export the latent space embeddings of the protein
sequences I would use for downstream modelling.



In [29]:
# PyTorch
import torch
from torch import nn, optim, utils, Tensor
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.data.sampler import SubsetRandomSampler

# Facebook Research Evolutionary Scale Modelling (ESM) package
from esm import FastaBatchedDataset
import esm

# Utilities
import numpy as np
import inspect
import time
import copy
import math

In [30]:
model, alpha = esm.pretrained.esm2_t30_150M_UR50D()

In [31]:
class ESM2Regressor(nn.Module):
    """
    MultiLayerPerceptron model for regression appended to an ESM2 model for
    feature extraction for protein sequence data.
    """
    def __init__(self, esm2 = esm.pretrained.esm2_t30_150M_UR50D, 
        freeze_esm = True, layers = [64, 32, 16], return_contacts=False, 
        repr_layers=30
    ):
        super(ESM2Regressor, self).__init__()
        # Initialize our pretrained ESM2 model and alphabet
        esm_model, esm_alphabet = esm2()
        # Freeze model weights if not doing fine tuning
        if freeze_esm:
            for p in esm_model.parameters():
                p.requires_grad = False
        self.freeze_esm = freeze_esm
        self.esm_alphabet = esm_alphabet
        self.feature_extractor = esm_model
        self.return_contacts = return_contacts
        self.repr_layers = repr_layers
        # Get the output shape from our ESM2 feature extractor
        self.input_dim = (list(self.feature_extractor.children())[-1]
            .dense.out_features)
        # Add our regression MLP, parameterizing the layers
        self.fc = nn.Sequential()
        previous_l = self.input_dim
        for l in layers:
            #self.fc.append(nn.BatchNorm1d(previous_l))
            self.fc.append(nn.Linear(previous_l, l))
            self.fc.append(nn.ReLU())
            previous_l = l
        self.fc.append(
            nn.Linear(previous_l, 1)
        )

    def forward(self, input):
        features = self.feature_extractor(input, repr_layers=[self.repr_layers],
            return_contacts=self.return_contacts
        )
        # Sum along token dimension to get representation per sequence
        x = features["representations"][self.repr_layers].sum(1)
        x = self.fc(x)
        return x


In [32]:
# Model configuration
REPR_LAYERS = 30
BATCH_SIZE = 32
RETURN_CONTACTS = False  # To save on memory usage, even though contacts may be useful for stability prediction
VAL_SPLIT = 0.2
TRAIN_SPLIT = 1 - VAL_SPLIT

In [33]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [34]:
model = ESM2Regressor()
model

ESM2Regressor(
  (feature_extractor): ESM2(
    (embed_tokens): Embedding(33, 640, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=640, out_features=640, bias=True)
          (v_proj): Linear(in_features=640, out_features=640, bias=True)
          (q_proj): Linear(in_features=640, out_features=640, bias=True)
          (out_proj): Linear(in_features=640, out_features=640, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=640, out_features=2560, bias=True)
        (fc2): Linear(in_features=2560, out_features=640, bias=True)
        (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=640, out_features=640, bias=True)
      

In [35]:
# Get batch converter to preprocess our data from the dataloader
batch_converter = model.esm_alphabet.get_batch_converter()

In [36]:
# Load in our entire training set
dataset = FastaBatchedDataset.from_file(TRAIN_FASTA)
# Determine size for our train-test split
data_size = len(dataset)
TRAIN_SIZE = int(np.floor(data_size * TRAIN_SPLIT))
VAL_SIZE = data_size - TRAIN_SIZE
# Set seeds
torch.manual_seed(1990 * 42)
np.random.seed(1990 * 42)
# Shuffle data indices for subsetting
indices = list(range(data_size))
np.random.shuffle(indices)  # shuffles by reference
# Configure traning and validation samplers
train_idx, val_idx = indices[:TRAIN_SIZE], indices[TRAIN_SIZE:]

In [37]:
# Configure training and validation samplers
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

In [38]:
train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, collate_fn=batch_converter, sampler=val_sampler
)

In [39]:
val_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, collate_fn=batch_converter, sampler=val_sampler
)

In [40]:
dataloaders = {
    "train": train_dataloader,
    "val": val_dataloader
}

In [41]:
test_dataset = FastaBatchedDataset.from_file(TEST_FASTA)
test_batches = test_dataset.get_batch_indices(BATCH_SIZE, extra_toks_per_seq=1)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, collate_fn=batch_converter, batch_sampler=test_batches, shuffle=False
)

In [42]:
# Adapted from PyTorch docs: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
# TensorBoard code adapted from docs: https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
def train_model(model, criterion, optimizer, scheduler, dataloaders, device, 
    num_epochs=100
):
    """
    Function to train our ESMRegressor model with validation.
    """
    ## TODO:: Add tensorboard logging!
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 0.0

    datasizes = {k: len(v) for k, v in dataloaders.items()}

    model = model.to(device)
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 10)

        # training and validation phase for each epoch
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
                # freeze the ESM2 model parameters, very expensive to train
                if model.freeze_esm:
                    for p in model.feature_extractor.parameters():
                        p.require_grad = False
            else:
                model.eval()

            running_loss = 0.0

            for labels, _, inputs in dataloaders[phase]:
                inputs = inputs.to(device)
                # last value in string is melting point for each protein
                labels = Tensor([float(s.split("|")[-1]) for s in labels])
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward pass
                # track history only for training
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels.reshape([len(labels), 1]))

                    # backward pass + optimization for training only
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # compute running statistics
                running_loss += loss.item() * inputs.size(0)
        
        if phase == "train":
            # Adjust learning rate schedule per epoch
            scheduler.step()

        epoch_loss = running_loss / datasizes[phase]
        print(f"{phase} Loss: {epoch_loss:.4f}")

        if phase == "val" and epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
    print()  # Add newline before next epoch

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val loss: {best_loss:4f}')

    # Load best model
    model.load_state_dict(best_model_wts)
    return model


In [43]:
# Loss function
criterion = nn.MSELoss()

# Select an optimizer, only optimzie fully connected layer paramters to speed up training
optimizer = optim.Adam(model.fc.parameters(), lr=0.01, weight_decay=0.001)

# Use learning rate scheduler, reduce lr by order of magniture per 10 epochs
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [44]:
model = train_model(model, criterion, optimizer, scheduler, dataloaders, device, 
    num_epochs=100)

Epoch 1/100
----------
val Loss: 2775.4218
Epoch 2/100
----------


KeyboardInterrupt: 

In [None]:
labels, strs, toks = next(iter(train_dataloader))
print(toks.shape)