In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
import json
from pathlib import Path
from tqdm import tqdm

import torch
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, roc_auc_score, average_precision_score

from mutils.definitions import MUTILS_DATA_DIR
from mutils.data import load_SKEMPI2
from ppiref.split import read_fold
from ppiref.utils.ppipath import path_to_pdb_id, path_to_ppi_id
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
from ppiformer.tasks.node import DDGPPIformer
from ppiformer.utils.api import predict_ddg

## Load models from checkpoint

In [3]:
TEST_MODEL_NAME = 'ppiformer'

checkpoints_dir = PPIFORMER_WEIGHTS_DIR / 'ddg_regression'
checkpoint_paths = list(checkpoints_dir.glob('*.ckpt'))
device = 'cpu'
models = [DDGPPIformer.load_from_checkpoint(PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt', map_location=torch.device(device)).eval() for i in range(3)]

  rank_zero_warn(
  rank_zero_warn(


## SKEMPI v2.0 test set

### Predict

In [4]:
ppis_paths_test = read_fold('skempi2_iclr24_split', 'test')
df_ddg = load_SKEMPI2()[0]

df_test = []
for ppi_path in tqdm(ppis_paths_test, desc='Making predictions for all PPIs'):
    pdb_id = path_to_pdb_id(ppi_path)
    ppi_id = path_to_ppi_id(ppi_path)
    df_ppi = df_ddg[(df_ddg['PDB Id'] == pdb_id)]
    ppi_id = df_ppi['#Pdb'].iloc[0]
    muts = df_ppi['Mutation(s)_cleaned'].to_list()
    ddg = df_ppi['ddG'].to_list()
    ddg_pred = predict_ddg(models, ppi_path, muts, impute=True)
    for m, d, d_pred in zip(muts, ddg, ddg_pred):
        df_test.append({
            'complex': ppi_id,
            'mutstr': m,
            'ddG': d,
            'ddG_pred': d_pred.item()
        })
df_test = pd.DataFrame(df_test)
df_test.to_csv(MUTILS_DATA_DIR / f'SKEMPI2/predictions_test/results_{TEST_MODEL_NAME}.csv', index=False)

Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 11.00it/s]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=3)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 11.70it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=3)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 10.55it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=94)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 11.85it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=1)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00,  2.57it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=47)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 11.54it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=3)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 11.01it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=2)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 11.61it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=1)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 10.51it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=2)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00,  6.72it/s]/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=5)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 21.71it/s]s/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=46)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 12.15it/s]s/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=1)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00,  7.85it/s]s/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=5)


Process 71138 preparing data: 100%|██████████| 1/1 [00:00<00:00, 10.55it/s]s/it]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=4)


Making predictions for all PPIs: 100%|██████████| 14/14 [05:01<00:00, 21.50s/it]


### Benchmark against other methods

In [10]:
IMPUTE_VAL = 0.691834179286864


def test_skempi(pred_dir: Path):
    # Read test 'Protein 1' -> PDB codes mapping
    with open(MUTILS_DATA_DIR / 'SKEMPI2/test_protein1_to_pdbs.json') as file:
        p1_to_pdbs = json.load(file)

    # Calculate per-PPI performance for all methods
    all_dfs_ppi = []
    for p1, pdbs in p1_to_pdbs.items():
        df_ppi = []
        for path in pred_dir.glob('*.csv'):
            name = path.stem

            # Read PPI df for a method
            df = pd.read_csv(path)
            df = df[df['complex'].apply(lambda c: c in pdbs)]
            df = df.fillna(IMPUTE_VAL)
            df['Method'] = name
            df['Protein 1'] = p1

            # Add metrics
            pred = df['ddG_pred'] < 0
            real = df['ddG'] < 0
            metrics =  {
                'Method': name,
                'Spearman': df['ddG'].corr(df['ddG_pred'], method='spearman'),
                'Pearson': df['ddG'].corr(df['ddG_pred'], method='pearson'),
                'Precision': precision_score(real, pred, zero_division=0),
                'Recall': recall_score(real, pred, zero_division=0),
                'ROC AUC': roc_auc_score(real, -df['ddG_pred']) if len(df) and real.nunique() > 1 else np.nan,
                'PR AUC': average_precision_score(real, -df['ddG_pred']) if len(df) and real.nunique() > 1 else np.nan,
                'MAE': (df['ddG'] - df['ddG_pred']).abs().mean(),
                'RMSE': math.sqrt((df['ddG'] - df['ddG_pred']).pow(2).mean())
            }
            if metrics is not None:
                df_ppi.append(metrics)
        
        # Print PPI performance
        df_ppi = pd.DataFrame(df_ppi).set_index('Method')
        print(f'Protein 1: {p1} ({pdbs})')
        display((df_ppi[['Spearman', 'Precision', 'Recall']]))
        all_dfs_ppi.append(df_ppi)

    # Print overall performance
    print('Overall')
    display(pd.concat(all_dfs_ppi).round(2).reset_index().groupby(by='Method').mean())


test_skempi(MUTILS_DATA_DIR / 'SKEMPI2/predictions_test')

Protein 1: Barnase (['1X1W_A_D', '1X1X_A_D', '1B3S_A_D', '1B2U_A_D', '1B2S_A_D', '1BRS_A_D'])


Unnamed: 0_level_0,Spearman,Precision,Recall
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
MSA_Transformer,0.429936,0.875,0.5
ppiformer_seed_3,0.630663,0.434783,0.714286
ppiformer_seed_2,0.598698,0.454545,0.714286
ppiformer,0.600988,0.384615,0.714286
ppiformer_seed_1,0.601036,0.384615,0.714286
results_Flex_ddG,0.824193,0.428571,0.428571
gemme,0.39592,1.0,0.642857
results_RDE,0.578092,0.421053,0.571429
ESM-IF,0.176488,0.411765,0.5
results_ppiformer,0.600988,0.384615,0.714286


Protein 1: C. thermophilum YTM1 (['5CXB_A_B', '5CYK_A_B'])


Unnamed: 0_level_0,Spearman,Precision,Recall
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
MSA_Transformer,0.054879,0.666667,0.4
ppiformer_seed_3,0.091465,0.666667,0.4
ppiformer_seed_2,0.548791,0.75,0.6
ppiformer,0.34147,0.6,0.6
ppiformer_seed_1,0.34147,0.6,0.6
results_Flex_ddG,0.975628,1.0,1.0
gemme,0.792698,1.0,0.8
results_RDE,0.146344,0.5,0.4
ESM-IF,0.085367,0.0,0.0
results_ppiformer,0.34147,0.6,0.6


Protein 1: Complement C3d (['2NOJ_A_B', '2GOX_A_B', '3D5R_A_C', '3D5S_A_C'])


Unnamed: 0_level_0,Spearman,Precision,Recall
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
MSA_Transformer,0.610257,1.0,1.0
ppiformer_seed_3,0.74587,1.0,1.0
ppiformer_seed_2,0.678064,1.0,1.0
ppiformer,0.74587,1.0,1.0
ppiformer_seed_1,0.74587,1.0,1.0
results_Flex_ddG,0.678064,1.0,0.5
gemme,0.610257,1.0,1.0
results_RDE,0.678064,1.0,1.0
ESM-IF,0.339032,0.5,0.5
results_ppiformer,0.74587,1.0,1.0


Protein 1: E6AP (['1C4Z_ABC_D'])


Unnamed: 0_level_0,Spearman,Precision,Recall
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
MSA_Transformer,0.371497,0.0,0.0
ppiformer_seed_3,0.588346,0.5,0.25
ppiformer_seed_2,0.514331,0.4,0.166667
ppiformer,0.426176,0.4,0.166667
ppiformer_seed_1,0.426176,0.4,0.166667
results_Flex_ddG,0.294217,0.444444,0.333333
gemme,0.19902,0.0,0.0
results_RDE,0.209229,0.5,0.333333
ESM-IF,0.21321,0.307692,0.333333
results_ppiformer,0.426176,0.4,0.166667


Protein 1: dHP1 Chromodomain  (['1KNE_A_P'])


Unnamed: 0_level_0,Spearman,Precision,Recall
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
MSA_Transformer,0.08786,0.0,0.0
ppiformer_seed_3,0.060587,0.45,0.642857
ppiformer_seed_2,0.033379,0.409091,0.642857
ppiformer,-0.004566,0.533333,0.571429
ppiformer_seed_1,-0.004566,0.533333,0.571429
results_Flex_ddG,-0.045163,0.292683,0.857143
gemme,-0.100384,0.0,0.0
results_RDE,-0.395607,0.302326,0.928571
ESM-IF,0.104578,0.454545,0.714286
results_ppiformer,-0.004566,0.533333,0.571429


Overall


Unnamed: 0_level_0,Spearman,Pearson,Precision,Recall,ROC AUC,PR AUC,MAE,RMSE
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
ESM-IF,0.184,0.176,0.334,0.408,0.68,0.526,1.868,2.152
MSA_Transformer,0.31,0.364,0.51,0.38,0.702,0.6,6.128,6.93
gemme,0.38,0.41,0.6,0.488,0.744,0.664,2.162,2.808
ppiformer,0.424,0.456,0.582,0.61,0.768,0.642,1.644,1.936
ppiformer_seed_1,0.424,0.456,0.582,0.61,0.768,0.642,1.644,1.936
ppiformer_seed_2,0.474,0.494,0.602,0.624,0.802,0.702,1.624,1.936
ppiformer_seed_3,0.424,0.436,0.61,0.6,0.774,0.652,1.64,1.944
results_Flex_ddG,0.544,0.574,0.632,0.624,0.838,0.69,1.594,2.0
results_RDE,0.244,0.304,0.544,0.646,0.666,0.574,1.702,2.02
results_ppiformer,0.424,0.456,0.582,0.61,0.768,0.642,1.644,1.936


# SARS-CoV-2 test set

### Predict

In [15]:
ppi_path = MUTILS_DATA_DIR / '7FAE/7FAE-RBD-Fv_A_H_L.pdb'
df_label = pd.read_csv(MUTILS_DATA_DIR / '7FAE/shan2022_covid_SKEMPI_format.csv')
muts = df_label['Mutation(s)_cleaned'].tolist()
ddg_pred = predict_ddg(models, ppi_path, muts, impute=True)
df_test = pd.DataFrame({
    'complex': len(muts) * ['7FAE'],
    'mutstr': muts,
    'ddG_pred': ddg_pred.tolist(),
    'ddG': df_label['ddG'].tolist()
})
df_test.to_csv(MUTILS_DATA_DIR / f'7FAE/predictions_test/results_{TEST_MODEL_NAME}.csv', index=False)

### Benchmark against other methods

In [6]:
def precision_at_k(ranks, classes, k):
    # claess are bool
    df = pd.DataFrame({'ranks': ranks, 'classes': classes})
    df = df.nsmallest(k, 'ranks')
    return df['classes'].mean()


def test_shan2022(pred_dir: Path):
    stabilizing_muts = ['TH31W', 'AH53F', 'NH57L', 'RH103M', 'LH104F']

    # Calculate performance of each method
    df_overall = []
    df_test = []
    for path in pred_dir.glob('*.csv'):
        df = pd.read_csv(path)
        if 'rank' not in df.columns:
            df['rank'] = df['ddG_pred'].rank() / len(df)
        metrics = {}
        metrics['name'] = path.stem

        for k in [1, 25, 49]:
            metrics[f'P@{k}'] = precision_at_k(df['rank'], df['ddG'] < 0, k)
        metrics['Mean rank'] = df[df['mutstr'].isin(stabilizing_muts)]['rank'].mean()

        dct = df[df['mutstr'].isin(stabilizing_muts)][['mutstr', 'rank']].set_index('mutstr').T.to_dict()
        dct = df[df['mutstr'].isin(stabilizing_muts)][['mutstr', 'rank']].set_index('mutstr').T.to_dict()
        dct = {k: v['rank'] for k, v in dct.items()}
        dct['Method'] = path.stem
        df_overall.append(dct)
        df_test.append(metrics)

    # Calculate overall performance
    df_overall = pd.DataFrame(df_overall).set_index('Method')
    display((100*df_overall).round(2))
    df_test = pd.DataFrame(df_test).set_index('name')
    display((100*df_test).round(3))


test_shan2022(MUTILS_DATA_DIR / '7FAE/predictions_test')

Unnamed: 0_level_0,AH53F,LH104F,NH57L,RH103M,TH31W
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ppiformer,0.2,10.93,7.69,21.46,18.02
results_MSA_Transformer,42.11,18.83,63.56,49.19,56.88
results_Flex_ddG,70.24,17.61,55.87,77.33,2.83
results_ESM-IF,17.61,48.58,17.0,51.42,49.39
results_RDE,2.02,5.47,20.65,61.54,1.62
results_ppiformer,0.2,10.93,7.69,21.46,18.02


Unnamed: 0_level_0,P@1,P@25,P@49,Mean rank
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ppiformer,100.0,4.0,4.082,11.66
results_MSA_Transformer,0.0,0.0,0.0,46.113
results_Flex_ddG,0.0,4.0,2.041,44.777
results_ESM-IF,0.0,0.0,0.0,36.802
results_RDE,0.0,8.0,6.122,18.259
results_ppiformer,100.0,4.0,4.082,11.66


## Staphylokinase (SAK) thrombolytic test set

### Predict

In [4]:
ppi_path = MUTILS_DATA_DIR / 'SAK/1bui_A_C.pdb'
df_label = pd.read_csv(MUTILS_DATA_DIR / 'SAK/laroche2000_sak_SKEMPI_format.csv')
muts = df_label['Mutation(s)'].tolist()
ddg_pred = predict_ddg(models, ppi_path, muts, impute=True)
df_test = pd.DataFrame({
    'complex': len(muts) * ['1bui_A_C'],
    'mutstr': muts,
    'ddG_pred': ddg_pred.tolist(),
    'ddG': df_label['Activity'].tolist()
})
df_test.to_csv(MUTILS_DATA_DIR / f'SAK/predictions_test/results_{TEST_MODEL_NAME}.csv', index=False)

Process 14462 preparing data: 100%|██████████| 1/1 [00:00<00:00,  5.94it/s]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=80)


### Benchmark against other methods

In [12]:
def test_sak(pred_dir: Path):
    stabilizing_muts = df_label[df_label['Activity enhancement']]['Mutation(s)'].tolist()
    sel_stabilizing_muts = df_label[df_label['2x activity enhancement']]['Mutation(s)'].tolist()

    # Calculate performance of each method
    df_overall = []
    df_test = []
    for path in pred_dir.glob('*.csv'):
        df = pd.read_csv(path)
        if 'rank' not in df.columns:
            df['rank'] = df['ddG_pred'].rank() / len(df)
        metrics = {}
        metrics['name'] = path.stem

        for k in [1, 4, 8]:
            metrics[f'P@{k}'] = precision_at_k(df['rank'], df['mutstr'].isin(stabilizing_muts), k)
        metrics['Mean rank'] = df[df['mutstr'].isin(stabilizing_muts)]['rank'].mean()

        dct = df[df['mutstr'].isin(sel_stabilizing_muts)][['mutstr', 'rank']].set_index('mutstr').T.to_dict()
        dct = df[df['mutstr'].isin(sel_stabilizing_muts)][['mutstr', 'rank']].set_index('mutstr').T.to_dict()
        dct = {k: v['rank'] for k, v in dct.items()}
        dct['Method'] = path.stem
        df_overall.append(dct)
        df_test.append(metrics)

    # Calculate overall performance
    df_overall = pd.DataFrame(df_overall).set_index('Method')
    display((100*df_overall).round(2))
    df_test = pd.DataFrame(df_test).set_index('name')
    display((100*df_test).round(3))


test_sak(MUTILS_DATA_DIR / 'SAK/predictions_test')

Unnamed: 0_level_0,KC130A,KC135A,KC130T,"KC130T,KC135R","KC74R,KC130T,KC135R","KC74Q,KC130E,KC135R"
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
results_MSA_Transformer,52.5,40.0,32.5,55.0,78.75,70.0
results_ppiformer,66.25,52.5,15.0,2.5,1.25,33.75


Unnamed: 0_level_0,P@1,P@4,P@8,Mean rank
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
results_MSA_Transformer,100.0,50.0,37.5,49.813
results_ppiformer,100.0,75.0,87.5,28.563
