This notebook reproduces the test results of ESM-IF in the [PPIformer paper](https://arxiv.org/pdf/2310.18515.pdf).

In [1]:
!pip install git+https://github.com/anton-bushuiev/mutils.git

Collecting git+https://github.com/anton-bushuiev/mutils.git
  Cloning https://github.com/anton-bushuiev/mutils.git to /private/var/folders/yw/q5k8tqgn3tq8_y9lbqhwm_8m0000gn/T/pip-req-build-d67l1p26
  Running command git clone --filter=blob:none --quiet https://github.com/anton-bushuiev/mutils.git /private/var/folders/yw/q5k8tqgn3tq8_y9lbqhwm_8m0000gn/T/pip-req-build-d67l1p26
  Resolved https://github.com/anton-bushuiev/mutils.git to commit e0d0d7dfcfe00f03a516d720eb0329226769da85
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: mutils
  Building wheel for mutils (setup.py) ... [?25ldone
[?25h  Created wheel for mutils: filename=mutils-0.1-py3-none-any.whl size=177679 sha256=c4ea2fda54f6152398456edc347193b7f6b820962762f3660475c390531929cb
  Stored in directory: /private/var/folders/yw/q5k8tqgn3tq8_y9lbqhwm_8m0000gn/T/pip-ephem-wheel-cache-btj_axnp/wheels/a5/7d/39/2785c9e54d28c296cac8cceaf579728977e8b7538b9ece023c
Successfully built mutils
Installi

In [8]:
import copy
from math import sqrt

import esm
import biotite
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, roc_auc_score
 
from mutils.data import load_SKEMPI2
from mutils.pdb import get_sequences
from mutils.definitions import MUTILS_SKEMPI2_DIR

tqdm.pandas()

In [2]:
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()



In [3]:
def predict_ddg(esm_model, pdb_path, mutation, chain_ids=None, chain_offsets=None):
    """
    See Appendix B in https://arxiv.org/pdf/2310.18515.pdf
    """
    # Load structure and wild-type sequences
    pdb_path = str(pdb_path)
    if chain_ids is None:
        chain_ids = list(get_sequences(pdb_path).keys())
    structure = esm.inverse_folding.util.load_structure(pdb_path, chain_ids)
    structure = biotite.structure.array([atom for atom in structure if not atom.hetero])
    coords, native_seqs = esm.inverse_folding.multichain_util.extract_coords_from_complex(structure)

    # Create mutant sequences
    mutated_seqs = copy.deepcopy(native_seqs)
    for point_mut in mutation.split(','):
        wt, chain, pos, mut = point_mut[0], point_mut[1], int(point_mut[2:-1]), point_mut[-1]
        pos -= 1  # 0-based indexing
        if chain_offsets is not None and chain in chain_offsets:
            pos -= chain_offsets[chain]
        seq_wt = native_seqs[chain]
        # assert seq_wt[pos] == wt, f'Wild-type sequence does not match the provided mutation: {seq_wt[pos]} != {wt}'
        if seq_wt[pos] != wt:
            print(f'Wild-type sequence does not match the provided mutation {mutation}: {seq_wt[pos]} != {wt}')
        seq_mut = mutated_seqs[chain]
        mutated_seqs[chain] = seq_mut[:pos] + mut + seq_mut[pos+1:]

    # Calculate average log likelihood for wild-type and mutant complexes
    ll_wt, ll_mut = [], []
    for chain in native_seqs.keys():
        seq_wt = native_seqs[chain]
        seq_mut = mutated_seqs[chain]
        if seq_wt != seq_mut:
            ll_wt_chain, _ = esm.inverse_folding.multichain_util.score_sequence_in_complex(
                esm_model, alphabet, coords, chain, seq_wt)
            ll_mut_chain, _ = esm.inverse_folding.multichain_util.score_sequence_in_complex(
                esm_model, alphabet, coords, chain, seq_mut)
            ll_wt.append(ll_wt_chain)
            ll_mut.append(ll_mut_chain)
    ll_wt = np.mean(ll_wt)
    ll_wt = ll_wt.round(2)
    ll_mut = np.mean(ll_mut)

    # Calculate predicted ddG
    ddg_pred = ll_wt - ll_mut
    return ddg_pred
    

# To have SKEMPI2 .pdb files in `MUTILS_SKEMPI2_DIR / 'PDBs'` directory,
# clone mutils from github (https://github.com/anton-bushuiev/mutils) and install in editable mode (pip install -e mutils).
# Alternatively, download the files from https://life.bsc.es/pid/skempi2/database/index
# predict_ddg(model, MUTILS_SKEMPI2_DIR / 'PDBs' / '1C4Z.pdb', 'ED90R')
# predict_ddg(model, MUTILS_SKEMPI2_DIR / 'PDBs' / '1KNE.pdb', 'TP2K')
# predict_ddg(model, MUTILS_SKEMPI2_DIR / 'PDBs' / '1KNE.pdb', 'DA40T,TP2L')
predict_ddg(model, MUTILS_SKEMPI2_DIR / 'PDBs' / '2NOJ.pdb', 'RB24A,NB31A')



0.17954417787749177

# SKEMPI v2.0 test set

In [4]:
# Read ids for test complexes
df_ppiformer = pd.read_csv('https://raw.githubusercontent.com/anton-bushuiev/mutils/main/mutils/datasets/SKEMPI2/predictions_test/ppiformer.csv')
test_complexes = df_ppiformer['complex'].unique()
test_complexes

array(['1KNE_A_P', '1C4Z_ABC_D', '5CXB_A_B', '5CYK_A_B', '1BRS_A_D',
       '1B2U_A_D', '1B2S_A_D', '1B3S_A_D', '1X1W_A_D', '1X1X_A_D',
       '2GOX_A_B', '3D5S_A_C', '3D5R_A_C', '2NOJ_A_B'], dtype=object)

In [5]:
# Read dataframe for SKEMPI2 test set
from mutils.data import load_SKEMPI2
df_s2 = load_SKEMPI2()[0]
df_s2_test = df_s2[df_s2['#Pdb'].isin(test_complexes)]
df_s2_test

Unnamed: 0,#Pdb,Mutation(s)_PDB,Mutation(s)_cleaned,iMutation_Location(s),Hold_out_type,Hold_out_proteins,Affinity_mut (M),Affinity_mut_parsed,Affinity_wt (M),Affinity_wt_parsed,...,dS_wt (cal mol^(-1) K^(-1)),Notes,Method,SKEMPI version,dG_mut,dG_wt,ddG,PDB Id,Partner 1,Partner 2
104,1BRS_A_D,KA27A,KA25A,COR,Other,"1BRS_A_D,1B2U_A_D,1B2S_A_D,1B3S_A_D,1X1W_A_D,1...",8.800000e-11,8.800000e-11,1.000000e-14,1.000000e-14,...,-1.01,"Thermodynamic data from 9126847.,,",ITC,1,-13.717446,-19.098395,5.380949,1BRS,A,D
105,1BRS_A_D,RA59A,RA57A,COR,Other,"1BRS_A_D,1B2U_A_D,1B2S_A_D,1B3S_A_D,1X1W_A_D,1...",7.000000e-11,7.000000e-11,1.000000e-14,1.000000e-14,...,-1.01,"Thermodynamic data from 9126847.,,",ITC,1,-13.853024,-19.098395,5.245372,1BRS,A,D
106,1BRS_A_D,RA83Q,RA81Q,COR,Other,"1BRS_A_D,1B2U_A_D,1B2S_A_D,1B3S_A_D,1X1W_A_D,1...",9.400000e-11,9.400000e-11,1.000000e-14,1.000000e-14,...,,,SFFL,1,-13.678369,-19.098395,5.420026,1BRS,A,D
107,1BRS_A_D,RA87A,RA85A,SUP,Other,"1BRS_A_D,1B2U_A_D,1B2S_A_D,1B3S_A_D,1X1W_A_D,1...",1.200000e-10,1.200000e-10,1.000000e-14,1.000000e-14,...,-1.01,"Thermodynamic data from 9126847.,,",ITC,1,-13.533694,-19.098395,5.564701,1BRS,A,D
108,1BRS_A_D,HA102A,HA100A,COR,Other,"1BRS_A_D,1B2U_A_D,1B2S_A_D,1B3S_A_D,1X1W_A_D,1...",3.200000e-10,3.200000e-10,1.000000e-14,1.000000e-14,...,-1.01,"Thermodynamic data from 9126847.,,",ITC,1,-12.952600,-19.098395,6.145795,1BRS,A,D
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6241,5CYK_A_B,EB486R,EB52R,COR,Other,"5CXB_A_B,5CYK_A_B",2.700000e-09,2.700000e-09,2.490000e-07,2.490000e-07,...,,Crystal structure is one of the mutants in the...,BI,2,-11.689086,-9.008714,-2.680372,5CYK,A,B
6242,5CYK_A_B,"EB486R,EB481D","EB52R,EB47D","COR,COR",Other,"5CXB_A_B,5CYK_A_B",5.000000e-09,5.000000e-09,2.490000e-07,2.490000e-07,...,,Crystal structure is one of the mutants in the...,BI,2,-11.324025,-9.008714,-2.315311,5CYK,A,B
6243,5CYK_A_B,"EB486R,TB484Q","EB52R,TB50Q","COR,COR",Other,"5CXB_A_B,5CYK_A_B",3.000000e-09,3.000000e-09,2.490000e-07,2.490000e-07,...,,Crystal structure is one of the mutants in the...,BI,2,-11.626665,-9.008714,-2.617951,5CYK,A,B
6244,5CYK_A_B,EB486A,EB52A,COR,Other,"5CXB_A_B,5CYK_A_B",7.000000e-09,7.000000e-09,2.490000e-07,2.490000e-07,...,,Crystal structure is one of the mutants in the...,BI,2,-11.124682,-9.008714,-2.115968,5CYK,A,B


In [6]:
# To have SKEMPI2 .pdb files in `MUTILS_SKEMPI2_DIR / 'PDBs'` directory,
# clone mutils from github (https://github.com/anton-bushuiev/mutils) and install in editable mode (pip install -e mutils).
# Alternatively, download the files from https://life.bsc.es/pid/skempi2/database/index
df_s2_test['ddG_pred'] = df_s2_test.progress_apply(
    lambda row: predict_ddg(
        esm_model=model,
        pdb_path=MUTILS_SKEMPI2_DIR / 'PDBs' / f'{row["PDB Id"]}.pdb',
        mutation=row['Mutation(s)_cleaned'],
        chain_ids=list(''.join(row['#Pdb'].split('_')[1:]))
    ),
    axis=1
)

100%|██████████| 219/219 [33:09<00:00,  9.08s/it]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_s2_test['ddG_pred'] = df_s2_test.progress_apply(


In [21]:
def metrics_from_df(df):
    pred = df['ddG_pred'] < 0
    real = df['ddG'] < 0
    return {
        'Spearman': df['ddG'].corr(df['ddG_pred'], method='spearman'),
        'Pearson': df['ddG'].corr(df['ddG_pred'], method='pearson'),
        'Precision': precision_score(real, pred),
        'Recall': recall_score(real, pred),
        'ROC AUC': roc_auc_score(real, -df['ddG_pred']) if len(df) else np.nan,
        'MAE': (df['ddG'] - df['ddG_pred']).abs().mean(),
        'RMSE': sqrt((df['ddG'] - df['ddG_pred']).pow(2).mean())
    }


res = df_s2_test.groupby('Protein 1').apply(metrics_from_df)
df_s2_test_agg = pd.DataFrame(res.tolist())
df_s2_test_agg = df_s2_test_agg.set_index(res.index)
display(df_s2_test_agg[['Spearman', 'Precision', 'Recall']])

df_s2_test_agg.mean()

  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0_level_0,Spearman,Precision,Recall
Protein 1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Barnase,0.176488,0.411765,0.5
C. thermophilum YTM1,0.085367,0.0,0.0
Complement C3d,0.339032,0.5,0.5
E6AP,0.21321,0.307692,0.333333
dHP1 Chromodomain,0.104578,0.454545,0.714286


Spearman     0.183735
Pearson      0.175749
Precision    0.334800
Recall       0.409524
ROC AUC      0.680099
MAE          1.868550
RMSE         2.151956
dtype: float64

# SAK (staphylokinase) test set

In [27]:
df_sak = pd.read_csv('./ppiformer_test_data/sak.csv')
df_sak

Unnamed: 0,#Pdb,Mutation(s),Activity,Activity enhancement,2x activity enhancement
0,1BUI_A_B_C,"EC38A,EC75A",66,False,False
1,1BUI_A_B_C,KC74A,100,False,False
2,1BUI_A_B_C,EC75A,140,True,False
3,1BUI_A_B_C,SC16A,160,True,False
4,1BUI_A_B_C,"YC17A,FC18A",30,False,False
...,...,...,...,...,...
75,1BUI_A_B_C,"GC36R,HC43R,KC74R,KC130T,KC135R",160,True,False
76,1BUI_A_B_C,"KC74Q,KC86A,KC130T,KC135R",130,False,False
77,1BUI_A_B_C,"KC74Q,KC130A,KC135R",240,True,False
78,1BUI_A_B_C,"KC74Q,KC130E,KC135R",300,True,True


In [28]:
df_sak['ddG_pred'] = df_sak.progress_apply(
    lambda row: predict_ddg(
        esm_model=model,
        pdb_path='./ppiformer_test_data/1BUI_A_B_C.pdb',
        mutation=row['Mutation(s)'],
        chain_ids=['A', 'B', 'C'],
        chain_offsets={'C': 14}  # C starts with 15 in 1bui
    ),
    axis=1
)

100%|██████████| 80/80 [15:02<00:00, 11.28s/it]


In [29]:
df_sak['rank'] = df_sak['ddG_pred'].rank() / len(df_sak)
df_sak[df_sak['2x activity enhancement']][['Mutation(s)', 'rank']]

Unnamed: 0,Mutation(s),rank
42,KC130A,0.45
44,KC135A,0.25
65,KC130T,0.3375
69,"KC130T,KC135R",0.4625
71,"KC74R,KC130T,KC135R",0.5875
78,"KC74Q,KC130E,KC135R",0.425


# COVID benchmark from (Shan 2022)

In [30]:
df_cov = pd.read_csv('./ppiformer_test_data/shan2022_sars-cov-2.csv')
df_cov

Unnamed: 0,#Pdb,Mutation(s),Stabilizing
0,7FAE-RBD-Fv_A_H_L,AH53C,False
1,7FAE-RBD-Fv_A_H_L,AH53D,False
2,7FAE-RBD-Fv_A_H_L,AH53E,False
3,7FAE-RBD-Fv_A_H_L,AH53F,True
4,7FAE-RBD-Fv_A_H_L,AH53G,False
...,...,...,...
489,7FAE-RBD-Fv_A_H_L,YH32R,False
490,7FAE-RBD-Fv_A_H_L,YH32S,False
491,7FAE-RBD-Fv_A_H_L,YH32T,False
492,7FAE-RBD-Fv_A_H_L,YH32V,False


In [33]:
df_cov['ddG_pred'] = df_cov.progress_apply(
    lambda row: predict_ddg(
        esm_model=model,
        pdb_path='./ppiformer_test_data/7FAE_RBD_Fv.pdb',
        mutation=row['Mutation(s)'],
        chain_ids=['A', 'H', 'L']
    ),
    axis=1
)

100%|██████████| 494/494 [50:26<00:00,  6.13s/it]


In [34]:
df_cov['rank'] = df_cov['ddG_pred'].rank() / len(df_cov)
df_cov[df_cov['Stabilizing']][['Mutation(s)', 'rank']]

Unnamed: 0,Mutation(s),rank
3,AH53F,0.176113
232,LH104F,0.48583
294,NH57L,0.17004
352,RH103M,0.51417
416,TH31W,0.493927
