In [23]:
import esm
import sys, os
import torch
import pandas as pd
import numpy as np
import pickle
import random

### get embedding of wild type sequence

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval() 
esm_model = esm_model.to(device)
batch_converter = alphabet.get_batch_converter()

In [17]:
pro = pd.read_csv('../data/412pro_info.csv')

In [19]:
embed = {}

for i in pro.index:
    name, seq = pro.loc[i, ['WT_name', 'aa_seq']]
    data = [(name, seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
    token_representations = results["representations"][33]
    embed[name] = token_representations.cpu()

In [20]:
with open("../data/embedding_412pro.pkl", "wb") as f:
    pickle.dump(embed, f)

### train/validation/test mutation split

In [None]:
df = pd.read_csv('/path/to/412pros_ddG_ML.csv') # please download this file from zenodo: https://zenodo.org/records/17488191
df = df[~df['mut_type'].str.contains(':')].reset_index(drop=True) # keep single mutations only

In [27]:
df = df[df['mut_type'] != 'wt']
df['split'] = 'train'
df.loc[df['WT_name'].isin(pro[pro['split'] == 'test']['WT_name']), 'split'] = 'test'

In [28]:
for i in pro[pro['split'] == 'train']['WT_name']:
    idx = df[df['WT_name'] == i].index
    eval_idx = random.sample(list(idx), int(len(idx)*0.3))
    df.loc[eval_idx, 'split'] = 'eval'

### get training/validation/test label

In [29]:
mt_aa = alphabet.all_toks[4:24]
label = {}

for i in pro.index:
    name, seq = pro.loc[i, ['WT_name', 'aa_seq']]
    df_train = df[(df['WT_name'] == name) & (df['split'] == 'train')]
    df_eval = df[(df['WT_name'] == name) & (df['split'] == 'eval')]
    df_test = df[(df['WT_name'] == name) & (df['split'] == 'test')]
    
    all_mt = []
    for i, wt_aa in enumerate(seq):
        all_mt += [f"{wt_aa}{i+1}{mt_aa}" for mt_aa in mt_aa]
    
    if len(df_train) > 0:
        label[f'{name}_train'] = torch.tensor(pd.merge(pd.DataFrame(all_mt), df_train, left_on=0, right_on='mut_type', how='left')['ddG_ML'])
    if len(df_eval) > 0:
        label[f'{name}_eval'] = torch.tensor(pd.merge(pd.DataFrame(all_mt), df_eval, left_on=0, right_on='mut_type', how='left')['ddG_ML'])
    if len(df_test) > 0:
        label[f'{name}_test'] = torch.tensor(pd.merge(pd.DataFrame(all_mt), df_test, left_on=0, right_on='mut_type', how='left')['ddG_ML'])

In [32]:
with open("../data/mutation_label.pkl", "wb") as f:
    pickle.dump(label, f)