# Example script for Hackathon

Within each cycle of active learning, you can:

1. Collect training data (original training data + your query data).

2. Train a prediction model to predict the DMS_score for each mutant (e.g., M0A).

3. Use the trained model to predict the score for all mutant in the test set.

4. Select query mutants for next round based on certain criteria. You may want to make sure you don't query the same mutant twice as you only have a limited chances of making queries in total.

In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
import random
from copy import deepcopy
import pandas as pd
from scipy.stats import spearmanr
import argparse
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

## 1. collect training data

Upload `sequence.fasta`, `train.csv`, and `test.csv` to the current runtime:

1. click the folder icon on the left

2. click the upload icon and upload the files to the current directory

In [39]:
with open('sequence.fasta', 'r') as f:
  data = f.readlines()

sequence_wt = data[1].strip()
sequence_wt

'MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLREKMRRRLESGDKWFSLEFFPPRTAEGAVNLISRFDRMAAGGPLYIDVTWHPAGDPGSDKETSSMMIASTAVNYCGLETILHMTCCRQRLEEITGHLHKAKQLGLKNIMALRGDPIGDQWEEEEGGFNYAVDLVKHIRSEFGDYFDICVAGYPKGHPEAGSFEADLKHLKEKVSAGADFIITQLFFEADTFFRFVKACTDMGITCPIVPGIFPIQGYHSLRQLVKLSKLEVPQEIKDVIEPIKDNDAAIRNYGIELAVSLCQELLASGLVPGLHFYTLNREMATTEVLKRLGMWTEDPRRPLPWALSAHPKRREEDVRPIFWASRPKSYIYRTQEWDEFPNGRWGNSSSPAFGELKDYYLFYLKSKSPKEELLKMWGEELTSEESVFEVFVLYLSGEPNRNGHKVTCLPWNDEPLAAETSLLKEELLRVNRQGILTINSQPNINGKPSSDPIVGWGPSGGYVFQKAYLEFFTSRETAEALLQVLKKYELRVNYHLVNVKGENITNAPELQPNAVTWGIFPGREIIQPTVVDPVSFMFWKDEAFALWIERWGKLYEEESPSRTIIQYIHDNYFLVNLVDNDFPLDNCLWQVVEDTLELLNRPTQNARETEAP'

In [40]:
len(sequence_wt)

656

In [41]:
def get_mutated_sequence(mut, sequence_wt):
  wt, pos, mt = mut[0], int(mut[1:-1]), mut[-1]

  sequence = deepcopy(sequence_wt)

  return sequence[:pos]+mt+sequence[pos+1:]

In [42]:
df_train = pd.read_csv('train.csv')
df_train['sequence'] = df_train.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train

Unnamed: 0,mutant,DMS_score,sequence
0,M0Y,0.2730,YVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,M0W,0.2857,WVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,M0V,0.2153,VVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,M0T,0.3122,TVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,M0S,0.2180,SVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...,...
1135,P347D,0.3876,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1136,P347C,0.1837,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1137,P347A,0.4611,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1138,P347M,0.2412,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [43]:
df_test = pd.read_csv('test.csv')
df_test['sequence'] = df_test.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_test

Unnamed: 0,mutant,sequence
0,V1D,MDNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,V1Y,MYNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,V1C,MCNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,V1A,MANEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,V1E,MENEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...
11319,P655S,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11320,P655T,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11321,P655V,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11322,P655A,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [44]:
# TODO: integrate the query data that you acquired each round into df_train
import os

if os.path.exists('queried_data.csv'):
    #Gotta figure out which data to query.
    df_query = pd.read_csv('queried_data.csv')
    df_query['sequence'] = df_query.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
    df_train = pd.concat([df_train, df_query]).reset_index(drop=True)
    print("Integrated queried data; new training data shape:", df_train.shape)


## 2. Train a prediction model

Here, we provided a linear regression model and used one-hot encoding to encode each variant. You would need to build your own model to achieve better performances.

Hint: you can perform cross-validation on the training set to evaluate your predictor before making predictions on the test set.

In [73]:
import os
from tqdm import tqdm  # for progress display
import torch
from torch.utils.data import Dataset

class ProteinDatasetESM(Dataset):
    def __init__(self, df, seq2name, emb_dir, istrain=True, layer=33):
        """
        Args:
            df (pd.DataFrame): DataFrame containing at least a 'sequence' column.
            seq2name (dict): Mapping from sequence to a unique name corresponding to the embedding file.
            emb_dir (str): Directory where the embedding files are stored.
            istrain (bool): Whether the dataset contains target labels (DMS_score).
            layer (int): Which layer's representation to use from the embedding file.
        """
        self.df = df
        self.seq2name = seq2name
        self.emb_dir = emb_dir
        self.layer = layer
        self.has_target = 'DMS_score' in df.columns
        self.num_samples = len(self.df)

        # If targets are available, store them
        if self.has_target:
            self.targets = self.df['DMS_score'].values.astype(np.float32)

        # Pre-load embeddings for all sequences
        self.embeddings = []
        for seq in tqdm(self.df['sequence'], desc='Loading ESM embeddings'):
            name = self.seq2name[seq]
            emb_file = os.path.join(self.emb_dir, f'{name}.pt')
            # Load the embedding file and extract the specified layer's representation
            emb = torch.load(emb_file)['mean_representations'][self.layer]
            self.embeddings.append(emb)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        emb = self.embeddings[idx]
        if self.has_target:
            label = torch.tensor(self.targets[idx], dtype=torch.float)
            return emb, label
        else:
            return emb

#ESM

train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(df_train['sequence'])}
test_seq2name  = {seq: f'seq_{i}' for i, seq in enumerate(df_test['sequence'])}

train_dataset = ProteinDatasetESM(df_train, train_seq2name, emb_dir='./esm_embeddings_train', istrain=True)
test_dataset  = ProteinDatasetESM(df_test, test_seq2name, emb_dir='./esm_embeddings_test', istrain=False)

seed = 0
val_ratio = 0.2
indices = list(range(len(train_dataset)))
split = int(np.floor(val_ratio * len(train_dataset)))
np.random.seed(seed)
np.random.shuffle(indices)
train_idx, val_idx = indices[split:], indices[:split]

from torch.utils.data import Subset
train_subset = Subset(train_dataset, train_idx)
val_subset   = Subset(train_dataset, val_idx)


train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_subset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)


  emb = torch.load(emb_file)['mean_representations'][self.layer]
Loading ESM embeddings: 100%|██████████| 1140/1140 [00:00<00:00, 5161.76it/s]
Loading ESM embeddings: 100%|██████████| 11324/11324 [00:02<00:00, 5158.22it/s]


In [75]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

# Assuming ESM_MLP is defined as follows:
class ESM_MLP(nn.Module):
    def __init__(self, input_dim=1280, hidden_dim=512, dropout_rate=0.3):
        super(ESM_MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x.squeeze(-1)

# Hyperparameters
ensemble_size = 5
num_epochs = 300
batch_size = 32
learning_rate = 0.001
weight_decay = 1e-4

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# List to hold our ensemble models.
ensemble_models = []

# Train each ensemble member.
for i in range(ensemble_size):
    print(f"\nTraining model {i+1}/{ensemble_size}...")
    
    # Instantiate a new model.
    model = ESM_MLP(input_dim=1280).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.MSELoss()
    
    # Training loop.
    for epoch in tqdm(range(num_epochs)):
        model.train()
        epoch_loss = 0.0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * x_batch.size(0)
        epoch_loss /= len(train_loader.dataset)
        # Uncomment the next line to print epoch loss for each model.
        # print(f"Model {i+1}, Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
    ensemble_models.append(model)

# Function to compute ensemble predictions and uncertainty.
def ensemble_predict(models, dataloader):
    all_preds = []
    for model in models:
        model.eval()
        preds = []
        with torch.no_grad():
            for x_batch in dataloader:
                x_batch = x_batch.to(device)
                pred = model(x_batch)
                preds.append(pred.cpu().numpy())
        preds = np.concatenate(preds)
        all_preds.append(preds)
    all_preds = np.array(all_preds)  # Shape: (ensemble_size, num_samples)
    mean_preds = np.mean(all_preds, axis=0)
    std_preds = np.std(all_preds, axis=0)
    return mean_preds, std_preds

# Create a DataLoader for the test set.
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Get ensemble predictions and uncertainty estimates.
mean_predictions, uncertainty = ensemble_predict(ensemble_models, test_loader)

# For demonstration, print the first 10 mean predictions and uncertainties.
for i in range(10):
    print(f"Sample {i}: Predicted fitness = {mean_predictions[i]:.4f}, Uncertainty (std) = {uncertainty[i]:.4f}")

# You can use these uncertainty estimates to choose which test points to query next.
# For example, you might select points with high uncertainty or a combination of high predicted fitness and high uncertainty.


Using device: cpu

Training model 1/5...


100%|██████████| 300/300 [00:28<00:00, 10.53it/s]



Training model 2/5...


100%|██████████| 300/300 [00:29<00:00, 10.07it/s]



Training model 3/5...


100%|██████████| 300/300 [00:28<00:00, 10.44it/s]



Training model 4/5...


100%|██████████| 300/300 [00:29<00:00, 10.15it/s]



Training model 5/5...


100%|██████████| 300/300 [00:29<00:00, 10.12it/s]


Sample 0: Predicted fitness = 0.5249, Uncertainty (std) = 0.7561
Sample 1: Predicted fitness = 0.5671, Uncertainty (std) = 0.7085
Sample 2: Predicted fitness = 0.4722, Uncertainty (std) = 0.7541
Sample 3: Predicted fitness = 0.6419, Uncertainty (std) = 0.7897
Sample 4: Predicted fitness = 0.5355, Uncertainty (std) = 0.8575
Sample 5: Predicted fitness = 0.4896, Uncertainty (std) = 0.7995
Sample 6: Predicted fitness = 0.6162, Uncertainty (std) = 0.7605
Sample 7: Predicted fitness = 0.3535, Uncertainty (std) = 0.7873
Sample 8: Predicted fitness = 0.4370, Uncertainty (std) = 0.7634
Sample 9: Predicted fitness = 0.6064, Uncertainty (std) = 0.7174


In [76]:
df_describe = pd.DataFrame(uncertainty)
df_describe.describe()

Unnamed: 0,0
count,11324.0
mean,0.725133
std,0.057139
min,0.493538
25%,0.688131
50%,0.725233
75%,0.76159
max,1.023016


In [65]:
mean_predictions.shape

(11324,)

In [77]:
df_test['DMS_score_predicted'] = mean_predictions
df_test[['mutant', 'DMS_score_predicted']].to_csv('predictions.csv', index=False)


In [19]:
def mc_dropout_predictions(model, dataloader, T=20):
    """
    Perform T stochastic forward passes (with dropout enabled) to estimate uncertainty.
    Returns:
        means: Mean predictions over T passes.
        stds: Standard deviation (uncertainty) of predictions.
    """
    model.train()  # keep dropout active
    all_preds = []
    for _ in tqdm(range(T)):
        preds = []
        with torch.no_grad():
            for x_batch, l in dataloader:
                x_batch = x_batch
                pred = model(x_batch)
                preds.append(pred.cpu().numpy())
        all_preds.append(np.concatenate(preds))
    all_preds = np.array(all_preds)
    means = np.mean(all_preds, axis=0)
    stds = np.std(all_preds, axis=0)
    return means, stds

means , stds = mc_dropout_predictions(model_cv, train_loader, T=20)

100%|██████████| 20/20 [01:45<00:00,  5.30s/it]


## 3. Select query for next round

In [None]:
df_test.sort_values('DMS_score_predicted', ascending=False).head(100)

Unnamed: 0,mutant,sequence,DMS_score_predicted
11323,P655W,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
0,V1D,MDNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
1,V1Y,MYNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
2,V1C,MCNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
3,V1A,MANEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
...,...,...,...
119,N7R,MVNEARGRSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
120,N7Q,MVNEARGQSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
121,N7P,MVNEARGPSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278
122,N7M,MVNEARGMSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...,-0.049278


In [None]:
# Example: randomly select 100 test variants to be queried.
# Note: random selection may not be a good strategy
# TODO: select query mutants for the next round based on your own criteria

querys = df_test.sort_values('DMS_score_predicted', ascending=False).head(100)['mutant'].values


In [None]:
with open('query.txt', 'w') as f:
    for mutant in querys:
        f.write(mutant + '\n')

print("Query file 'query.txt' created with the following mutants:")
print(querys)

#Preprocessing

In [None]:
!pip install fair-esm
!pip install biopython

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [None]:
import esm
from tqdm.auto import tqdm
import os
import numpy as np
import torch

def gen_emb_from_df(df, sequence_col='sequence', id_col=None, out_dir='esm_embeddings', device='cuda:0'):
    """
    Generate ESM-2 embeddings from sequences stored in a DataFrame.

    Args:
        df (pd.DataFrame): DataFrame containing sequences.
        sequence_col (str): Name of the column containing sequences.
        id_col (str or None): Name of the column containing sequence IDs.
                              If None, default IDs will be generated.
        out_dir (str): Directory to save embedding files.
        device (str): Device to use for inference.
    """
    os.makedirs(out_dir, exist_ok=True)

    # Get sequence IDs: either from a specified column or generate default ones.
    if id_col is None:
        names = [f'seq_{i}' for i in range(len(df))]
    else:
        names = df[id_col].tolist()

    sequences = df[sequence_col].tolist()
    print(f'Number of sequences: {len(sequences)}')

    data = [(name, seq) for name, seq in zip(names, sequences)]

    # Load ESM-2 model (esm2_t33_650M_UR50D) and batch converter.
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()

    model.to(device)
    model.eval()  # disables dropout for deterministic results

    batch_size = 64  # Adjust if you run out of CUDA memory.
    num_batches = int(np.ceil(len(data) / batch_size))

    for i in tqdm(range(num_batches), desc="Processing batches"):
        batch = data[i * batch_size:(i + 1) * batch_size]
        names_batch, seqs_batch = zip(*batch)
        batch_labels, batch_strs, batch_tokens = batch_converter(batch)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
        batch_tokens = batch_tokens.to(device)

        # Inference: extract per-residue representations from layer 33.
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)

        # Get per-residue representations from the specified layer.
        token_representations = results['representations'][33]

        # Generate per-sequence representations via averaging.
        for k, tokens_len in enumerate(batch_lens):
            seq_name = names_batch[k]
            seq_tokens = token_representations[k, :tokens_len]
            seq_mean = seq_tokens.mean(0)
            save = {'mean_representations': {33: seq_mean.cpu()}}
            torch.save(save, os.path.join(out_dir, f'{seq_name}.pt'))

# Assuming `df` is a DataFrame with a column 'sequence' (and optionally an 'id' column):
gen_emb_from_df(df_test, sequence_col='sequence', out_dir='esm_embeddings_test')
gen_emb_from_df(df_train, sequence_col='sequence', out_dir='esm_embeddings_train')


Number of sequences: 11324


Processing batches:   0%|          | 0/177 [00:00<?, ?it/s]

KeyboardInterrupt: 