# Compute ESM2-3B Scores for NucB

#### Prepare environment and import modules

In [1]:
# Load packages
# %run setup_environment.py
import badass.models.seq2fitness_models as models
import badass.data.datasets as datasets
import badass.training.seq2fitness_traintools as traintools
import badass.training.seq2fitness_train as train
import badass.utils.sequence_utils as sequence_utils
import torch.nn as nn
from pprint import pprint

import torch
import esm
import pandas as pd

#### Get ESM2-3B masked marginal scores

In [2]:
seq = """LTAPSIKSGTILHAWNWSFNTLKHNMKDIHDAGYTAIQTSPINQVKEGNQGDKSMSNWYWLYQPTSYQIGNRYLGTEQEFKEMCAAAEEYGIKVIVDAVINHTTSDYAAIS
NEVKSIPNWTHGNTPIKNWSDRWDVTQNSLSGLYDWNTQNTQVQSYLKRFLDRALNDGADGFRFDAAKHIELPDDGSYGSQFWPNITNTSAEFQYGEILQDSVSRDAAYANY
MDVTASNYGHSIRSALKNRNLGVSNISHYAVDVSADKLVTWVESHDTYANDDEESTWMSDDDIRLGWAVIASRSGSTPLFFSRPEGGGNGVRFPGKSQIGDRGSALFEDQAI
TAVNRFHNVMAGQPEELSNPNGNNQIFMNQRGSHGVVLANAGSSSVSINTATKLPDGRYDNKAGAGSFQVNDGKLTGTINARSVAVLYPD""".replace('\n','')

In [11]:
seq

'LTAPSIKSGTILHAWNWSFNTLKHNMKDIHDAGYTAIQTSPINQVKEGNQGDKSMSNWYWLYQPTSYQIGNRYLGTEQEFKEMCAAAEEYGIKVIVDAVINHTTSDYAAISNEVKSIPNWTHGNTPIKNWSDRWDVTQNSLSGLYDWNTQNTQVQSYLKRFLDRALNDGADGFRFDAAKHIELPDDGSYGSQFWPNITNTSAEFQYGEILQDSVSRDAAYANYMDVTASNYGHSIRSALKNRNLGVSNISHYAVDVSADKLVTWVESHDTYANDDEESTWMSDDDIRLGWAVIASRSGSTPLFFSRPEGGGNGVRFPGKSQIGDRGSALFEDQAITAVNRFHNVMAGQPEELSNPNGNNQIFMNQRGSHGVVLANAGSSSVSINTATKLPDGRYDNKAGAGSFQVNDGKLTGTINARSVAVLYPD'

In [3]:
# Prepare CSV of mutants
AMINOS = "ACDEFGHIKLMNPQRSTVWY"
mutants = []
for i,aa in enumerate(seq):
    for amino in AMINOS:
        mutant = f"{aa}{i+1}{amino}"
        mutants.append(mutant)
positions = [int(mut[1:-1]) for mut in mutants]
df = pd.DataFrame({'mutant': mutants, 'position': positions})
df.to_csv('esm2_scores.csv')


In [2]:
!git clone https://github.com/facebookresearch/esm.git

Cloning into 'esm'...
remote: Enumerating objects: 1511, done.[K
remote: Counting objects: 100% (725/725), done.[K
remote: Compressing objects: 100% (194/194), done.[K
remote: Total 1511 (delta 567), reused 531 (delta 531), pack-reused 786 (from 1)[K
Receiving objects: 100% (1511/1511), 12.87 MiB | 18.10 MiB/s, done.
Resolving deltas: 100% (952/952), done.


In [9]:
cd esm/examples/variant-prediction

/kfs2/projects/proteinml/repos/BADASS/notebooks/esm/examples/variant-prediction


In [12]:
!python predict.py \
    --model-location esm2_t36_3B_UR50D \
    --sequence LTAPSIKSGTILHAWNWSFNTLKHNMKDIHDAGYTAIQTSPINQVKEGNQGDKSMSNWYWLYQPTSYQIGNRYLGTEQEFKEMCAAAEEYGIKVIVDAVINHTTSDYAAISNEVKSIPNWTHGNTPIKNWSDRWDVTQNSLSGLYDWNTQNTQVQSYLKRFLDRALNDGADGFRFDAAKHIELPDDGSYGSQFWPNITNTSAEFQYGEILQDSVSRDAAYANYMDVTASNYGHSIRSALKNRNLGVSNISHYAVDVSADKLVTWVESHDTYANDDEESTWMSDDDIRLGWAVIASRSGSTPLFFSRPEGGGNGVRFPGKSQIGDRGSALFEDQAITAVNRFHNVMAGQPEELSNPNGNNQIFMNQRGSHGVVLANAGSSSVSINTATKLPDGRYDNKAGAGSFQVNDGKLTGTINARSVAVLYPD\
    --dms-input ../../../esm2_scores.csv \
    --mutation-col mutant \
    --dms-output ../../../esm2_scores.csv \
    --offset-idx 1 \
    --scoring-strategy masked-marginals

Transferred model to GPU
100%|█████████████████████████████████████████| 427/427 [01:03<00:00,  6.69it/s]


In [13]:
cd ../../..

/kfs2/projects/proteinml/repos/BADASS/notebooks


In [14]:
# Convert scores from 1D to 2D
df = pd.read_csv('esm2_scores.csv', index_col=0)
loc = df.columns[0] 
del df[loc] # Delete added index column
df

Unnamed: 0,mutant,position,esm2_t36_3B_UR50D
0,L1A,1,0.153941
1,L1C,1,-4.131199
2,L1D,1,-0.277244
3,L1E,1,0.286405
4,L1F,1,-0.625426
...,...,...,...
8495,D425S,425,1.037780
8496,D425T,425,-1.477381
8497,D425V,425,0.225569
8498,D425W,425,-3.162145


In [15]:
# Scores as 2D dataframe
scores = {}
for i in set(df['position']):
    dfsel = df[df['position']==i]
    scores[i] = dfsel.iloc[:,-1].values
df = pd.DataFrame(scores).transpose()
df.columns = list(AMINOS)
df.insert(0, 'WT', list(seq))

# Save
df.to_excel('../data/aAmyl_esm2_3B_scores.xlsx', sheet_name='matrix')
df

Unnamed: 0,WT,A,C,D,E,F,G,H,I,K,...,M,N,P,Q,R,S,T,V,W,Y
1,L,0.153941,-4.131199,-0.277244,0.286405,-0.625426,-0.033034,-0.125943,-0.984842,-0.751996,...,4.004670,-0.003591,0.649219,-0.154374,-0.703769,0.360060,0.146645,-0.080936,-0.778397,0.110521
2,T,1.691495,-3.897074,-1.062125,-0.307547,-3.135472,-0.257015,-1.129305,-1.396681,-0.653588,...,-2.639576,-0.689998,1.153053,0.393923,-0.732564,0.029048,0.000000,0.535200,-4.401953,-2.143408
3,A,0.000000,-5.661877,-1.283266,-1.134010,-4.272923,-1.844071,-1.146866,-3.514070,-0.660671,...,-4.143487,-0.790061,-1.725499,0.168853,-0.931946,-0.672568,-1.125988,-2.170161,-5.925178,-3.358454
4,P,-0.183073,-5.624362,0.061390,-0.663584,-4.701787,-0.203064,-0.983418,-4.472972,1.077263,...,-3.928177,0.036490,0.000000,0.414096,0.141788,-0.206762,-1.061157,-3.030909,-5.927115,-3.515624
5,S,-1.215753,-7.251094,0.166699,-0.498575,-5.963136,-2.294290,-1.173683,-5.582670,-1.524374,...,-5.736083,0.085196,-2.411382,-0.889735,-2.073480,0.000000,-0.746483,-3.793533,-7.740762,-5.008163
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
421,V,-4.058813,-7.332922,-10.182964,-8.513586,-7.020294,-9.237148,-10.045355,-2.623017,-9.835143,...,-4.413274,-8.677698,-11.158614,-9.796848,-9.171410,-6.073474,-3.782107,0.000000,-9.674072,-7.398786
422,L,-5.637181,-5.981638,-11.032974,-11.042208,0.813973,-8.571558,-9.702011,-1.514849,-10.643763,...,-5.068220,-10.230425,-9.769588,-11.391006,-9.471382,-8.468260,-7.170893,-1.628333,-5.842010,-3.588126
423,Y,-2.280747,-5.458008,-4.039402,-3.053553,-2.122924,-4.869119,-2.420832,-2.781030,-3.895570,...,-3.336308,-2.942133,-6.395665,-3.059319,-3.708502,-2.237686,-2.470805,-1.565651,-4.033664,0.000000
424,P,2.208050,-3.421869,-0.386514,0.409701,-0.683241,0.568616,0.360910,-0.186933,2.258799,...,-0.606777,2.171574,0.000000,1.554703,1.952792,1.855488,1.344000,0.824582,-2.391808,-0.030002
