In [1]:
import esm
import sys, os
import torch
import pandas as pd
import numpy as np
import pickle

### get embedding of wild type sequence

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval()  # disables dropout for deterministic results
esm_model = esm_model.to(device)
batch_converter = alphabet.get_batch_converter()

In [2]:
os.chdir('/nfs/user/Users/ch3849/esm_sae/gate_mut/')
df = pd.read_csv('412pros_528kmuts_single_double_ddG_ML.csv')
wt = df[df['mut_type'] == 'wt']

In [10]:
embed = {}

for i in wt.index:
    name, seq = wt.loc[i, ['WT_name', 'aa_seq']]
    data = [(name, seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
    token_representations = results["representations"][33]
    # Generate per-residue representations (on CPU)
    embedding = token_representations.cpu()
    embed[name] = embedding

In [13]:
embed['EA|run2_0325_0005.pdb'].shape

torch.Size([1, 49, 1280])

In [None]:
with open("embedding_412wt.pkl", "wb") as f:
    pickle.dump(embed, f)

### train/test protein split

In [15]:
with open('seq/412wt_seq.fasta', 'w') as f:
    for i in wt.index:
        name, seq = wt.loc[i, ['WT_name', 'aa_seq']]
        f.write(f'>{name}\n{seq}\n')

In [None]:
### add 412 wt cluster info
clu = pd.read_table('seq/clusterRes_cluster.tsv', header=None, names=['cluster', 'WT_name'])
wt = pd.merge(wt, clu, on='WT_name')

In [None]:
### add ESM-2 zero-shot performance
performance = pd.read_csv('/nfs/user/Users/ch3849/ProDance/mutation/stability_cdna/zero_shot/esm2/412pros_528kmuts_esm2_spearmans.csv', index_col=0)
performance = performance[performance['mut_type'] == 'single'][['pro', 'esm_650_wt']]
wt = pd.merge(wt, performance, left_on='WT_name', right_on='pro')

In [None]:
### add UniRef50 homolog count info
uniref_homolog = pd.read_table('/nfs/user/Users/ch3849/ProDance/mutation/stability_cdna/seq_new/mmseq_search_uniref50/rDB.tsv', header=None)
uniref_homolog = uniref_homolog[[0,1,3,5,6,7]]
uniref_homolog.columns = ['name', 'train', 'identity', 'qstart', 'qend', 'qlen']
# the coverage on evaluation sequence
uniref_homolog['coverage'] = (uniref_homolog['qend']-uniref_homolog['qstart']+1) / uniref_homolog['qlen']
# define the threshold
uniref_homolog = uniref_homolog[(uniref_homolog['identity'] >= 0.2) & (uniref_homolog['coverage'] >= 0.5)]

In [42]:
uniref_homolog_count = pd.DataFrame(uniref_homolog['name'].value_counts())

In [44]:
wt = pd.merge(wt, uniref_homolog_count, left_on='WT_name', right_index=True, how='left')
wt['count'] = wt['count'].fillna(0)

In [74]:
import random
seed = 0
while True:
    seed += 1
    random.seed(seed)
    clusters = list(set(wt['cluster']))
    sampled = random.sample(clusters, int(len(clusters)*0.3))

    wt_test = wt[wt['cluster'].isin(sampled)]
    wt_train = wt[~wt['cluster'].isin(sampled)]
    if len(wt_test) / len(wt_train) < 0.45 and len(wt_test) / len(wt_train) > 0.4 and abs(wt_test['esm_650_wt'].median() - wt_train['esm_650_wt'].median()) < 0.01 and \
    len(wt_test[wt_test['count'] == 0]) / len(wt_train[wt_train['count'] == 0]) < 0.45 and len(wt_test[wt_test['count'] == 0]) / len(wt_train[wt_train['count'] == 0]) > 0.4:
        print('good split')
        break

good split


In [77]:
wt['split'] = 'train'
wt.loc[wt['cluster'].isin(sampled), 'split'] = 'test'

In [78]:
wt.to_csv('412wt_info.csv', index=False)

### train/test mutation split

In [96]:
df = df[df['mut_type'] != 'wt']
df['split'] = 'train'
df.loc[df['WT_name'].isin(wt[wt['split'] == 'test']['WT_name']), 'split'] = 'test'
df.loc[df[df['mut_type'].str.contains(':')].index, 'split'] = 'double_test'

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['split'] = 'train'


In [97]:
for i in wt[wt['split'] == 'train']['WT_name']:
    idx = df[df['WT_name'] == i].index
    eval_idx = random.sample(list(idx), int(len(idx)*0.3))
    df.loc[eval_idx, 'split'] = 'eval'

In [102]:
df.to_csv('ddG_ML_split.csv', index=False)

### get training label

In [3]:
wt = pd.read_csv('412wt_info.csv')
df = pd.read_csv('ddG_ML_split.csv')

In [4]:
wt

Unnamed: 0,aa_seq,mut_type,WT_name,WT_cluster,ddG_ML,cluster,pro,esm_650_wt,count,split
0,DEVTIHLGDKTIRVDGLDKELLEILKELARRGADEEELRKEIERWER,wt,EA|run2_0325_0005.pdb,EEHH,-0.136745,EA|run2_0325_0005.pdb,EA|run2_0325_0005.pdb,0.265157,0.0,train
1,KVTIVVENIKVFGEDGKLTDEARRLLEKALEEAKRFPGKEVEIVLLP,wt,r10_490_TrROS_Hall.pdb,hall,0.047965,r10_490_TrROS_Hall.pdb,r10_490_TrROS_Hall.pdb,0.129520,0.0,test
2,TVKFKYKGEEKQVDISKIKKVWRVGKMISFTYDEGGGKTGRGAVSE...,wt,1JIC.pdb,109,-0.014232,1JIC.pdb,1JIC.pdb,0.320662,18.0,test
3,KDPKFEAAYDFPGSGSSSELPLKKGDIVFISRDEPSGWSLAKLLDG...,wt,2BTT.pdb,156,0.084476,2BTT.pdb,2BTT.pdb,0.496714,182.0,train
4,NKFNKELGWATWEIFNLPNLNGVQVKAFIDSLRDDPSQSANLLAEA...,wt,2B88.pdb,71,0.039239,v2_2M5A.pdb,2B88.pdb,0.187662,42.0,test
...,...,...,...,...,...,...,...,...,...,...
407,DHWEIRVGDITIHLKDVDEEIIRWVEEALRNGDDLEEIKRWVEEVLR,wt,XX|run1_0455_0002.pdb,EEHH,0.007841,EA|run2_0325_0005.pdb,XX|run1_0455_0002.pdb,0.219816,0.0,train
408,PEDLERKVRELQKNGVSPEQIEKILRRDGVDEREVQELVKKVS,wt,HHH_rd4_0613.pdb,HHH,-0.004826,HHH_rd4_0124.pdb,HHH_rd4_0613.pdb,0.592606,0.0,test
409,ENVVSAPMPGKVLRVLVRVGDRVRVGQGLLVLEAMKMENEIPSPRD...,wt,5GU9.pdb_L66K,226,0.005261,5GU9.pdb_V60D,5GU9.pdb_L66K,0.519919,1000.0,train
410,KELVSALYDYQEKSPREVTMKKGDILTLLNSTNKDWWKVEVNGRQG...,wt,1BK2.pdb_L5S,15,-0.002634,2KXD.pdb,1BK2.pdb_L5S,0.482433,1000.0,train


In [7]:
mt_aa = alphabet.all_toks[4:24]
label = {}

for i in wt.index:
    name, seq = wt.loc[i, ['WT_name', 'aa_seq']]
    df_train = df[(df['WT_name'] == name) & (df['split'] == 'train')]
    df_eval = df[(df['WT_name'] == name) & (df['split'] == 'eval')]
    df_test = df[(df['WT_name'] == name) & (df['split'] == 'test')]
    
    all_mt = []
    for i, wt_aa in enumerate(seq):
        all_mt += [f"{wt_aa}{i+1}{mt_aa}" for mt_aa in mt_aa]
    
    if len(df_train) > 0:
        label[f'{name}_train'] = torch.tensor(pd.merge(pd.DataFrame(all_mt), df_train, left_on=0, right_on='mut_type', how='left')['ddG_ML'])
    if len(df_eval) > 0:
        label[f'{name}_eval'] = torch.tensor(pd.merge(pd.DataFrame(all_mt), df_eval, left_on=0, right_on='mut_type', how='left')['ddG_ML'])
    if len(df_test) > 0:
        label[f'{name}_test'] = torch.tensor(pd.merge(pd.DataFrame(all_mt), df_test, left_on=0, right_on='mut_type', how='left')['ddG_ML'])

In [10]:
with open("training_label.pkl", "wb") as f:
    pickle.dump(label, f)

In [126]:
len(label.keys())

285