In [1]:
import os
import pandas as pd
import numpy as np

import Bio
from Bio import SeqIO

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
import scipy.stats
from sklearn.metrics import mean_absolute_error
import torch

import sys


def perf_metrix(predictions, true_values):
    
    pearson, pearson_p = scipy.stats.pearsonr(predictions, true_values)
    spearman, spearman_p = scipy.stats.spearmanr(predictions, true_values)
    rmse = mean_absolute_error(predictions, true_values)
    
    pearson = float("{0:.3g}".format(pearson))
    pearson_p = float("{0:.3g}".format(pearson_p))
    spearman = float("{0:.3g}".format(spearman))
    spearman_p = float("{0:.3g}".format(spearman_p))
    rmse = float("{0:.3g}".format(rmse))
    
    return pearson, pearson_p, spearman, spearman_p, rmse


In [71]:
fp_fig_save = "figures/figureS10.png"

fontsize = 36

fp_database = 'data/database_640.pkl'
database = pd.read_pickle(fp_database)


fp_data_csv = "data/dataset_640.csv"
data_csv = pd.read_csv(fp_data_csv)
print(data_csv.columns)

fp_pred = "data/predictions_exp_blosum_640_mae_no_mean/ens_no_mean_spearman_div_emb.csv"
df_pred = pd.read_csv(fp_pred)
model_pred = [np.mean([float(p) for p in pred.split('_')]) for pred in df_pred['ridge_tm_4'].to_list()]
data_csv['pred'] = model_pred

Index(['id', 'name', 'source', 'exp_method', 'WT', 'tm', 'seq'], dtype='object')


In [89]:
# 3 sequences were added to the database post benchmarking, so remove them here
database = database[~database['id'].isin(['sdAb.4875_84.2', '7KGJ_61.6','NbThermo.00469_68.0'])]
tms = database['tm'].to_list()


['6IBB_72.12', '5MP2_70.27', '6XW6_66.07', '7BC6_63.0', '5MY6_74.28', '6SC5_65.24', '6RTW_64.21', '1ZVY_65.1', '7DSS_67.53', '7AZB_59.91', '5IP4_63.24', '3K81_64.1', '3QXV_67.24', '6RL7_72.09', '5L21_66.86', '6DBD_65.6', '7NQA_65.87', '6UI1_67.83', '6X05_66.67', '7B27_65.74', '1BZQ/NbThermo.00491_70.11', '1JTO/NbThermo.00112_67.01', '1KXT_57.85', '1KXV_56.99', '1QD0_66.56', '1RI8_78.14', '2BSE_75.23', '2X6M_67.13', '3DWT/NbThermo.00113/NbThermo.00133_77.72', '3EBA_86.21', '3G9A_66.55', '3J69_66.03', '3J6A_69.02', '3P9W_60.0', '3SN6_60.26', '4B5E_61.66', '4C57_63.78', '4I0C_75.79', '4I13_65.39', '4IDL_53.76', '4LJP_71.72', '4LJS_72.49', '4QGY_74.65', '4YGA_73.41', '4Z9K/NbThermo.00231_73.49', '5C1M_61.35', '5FOJ_70.81', '5J56_73.74', '5KU2_55.16', '5M2M_74.98', '5NM0_72.73', '5O03_69.2', '5OCL_71.69', '5OVW_76.47', '5TD8_66.55', '5TJW_66.25', '5U64_73.39', '5VLV_69.62', '5VM4_72.93', '5VXK_65.75', '6CK8_66.66', '6EQI_62.47', '6FYU_64.58', '6GWN_64.42', '6HEQ_60.61', '6HHU_70.29', '6I2G_

In [97]:
## FOLDX
sns.set_theme(style='white')
fig, axs = plt.subplots(2,2,figsize=(20, 20))

# FoldX
pdb_directory = 'data/benchmark/foldx_preds'

def fetch_foldxpred(pdb_directory):
    id_list, total_energy_list = [],[]
    tm_list = []
    pred_list = []
    for filename in os.listdir(pdb_directory):
        if filename.endswith('.fxout'):
            file_path = os.path.join(pdb_directory, filename)
            with open(file_path, 'r') as file:
                lines = file.read().split()
                if len(lines) >1:
                    id = filename.split('_')[0]
                    if id.upper() in data_csv['name'].to_list():
                        id_list.append(id)
                        tm = data_csv.loc[data_csv['name'].str.contains(id.upper()),'tm'].values[0]
                        tm_list.append(tm)
                        pred = data_csv.loc[data_csv['name'].str.contains(id.upper()),'pred'].values[0]
                        pred_list.append(pred)
                        total_energy = lines[1]
                        total_energy_list.append(float(total_energy))
    return id_list, total_energy_list, tm_list, pred_list

id_list, total_energy_list, tm_list, pred_list = fetch_foldxpred(pdb_directory)

print(id_list)
print(len(id_list))

pred = pred_list
target = tm_list
foldx_pearson, foldx_pearson_p, foldx_spearman, foldx_spearman_p, foldx_rmse = perf_metrix(total_energy_list, target)
print("FoldX: ", foldx_pearson, foldx_pearson_p, foldx_spearman, foldx_spearman_p, foldx_rmse)

df = pd.DataFrame({#'id':ids_plot,
                   'tm':tm_list,'energy':total_energy_list})

sns.scatterplot(df, x='tm', y='energy',ax=axs[0,0])

unit = r"kcal mol$^{-1}$"
pearson_latex = r"$\ r$"
spearman_latex = r"$\rho$"
foldx_bold = r'$\mathbf{\ %s}$' % "FoldX"
esm_bold = r'$\mathbf{\ %s}$' % "ESM-2"
antiberty_bold = r'$\mathbf{\ %s}$' % "AntiBERTy"
deepstabp_bold = r'$\mathbf{\ %s}$' % "DeepSTABp"

axs[0,0].set_xlim(20,100)
axs[0,0].set_title(f"{foldx_bold} \n{pearson_latex} = {foldx_pearson}, {spearman_latex} = {foldx_spearman}", y=1.03, fontsize = fontsize)
axs[0,0].set_xlabel('Measured melting temperature / °C', fontsize=fontsize)
axs[0,0].set_ylabel(f'Free energy of unfolding / {unit}', fontsize=fontsize)
axs[0,0].set_xticks([20,40,60,80,100])






# ESM
fp_esm_aar = "data/benchmark/esm_aar.csv"
df_esm = pd.read_csv(fp_esm_aar)
aar = df_esm['esm_aar'].to_list()

esm_pearson, esm_pearson_p, esm_spearman, esm_spearman_p, esm_rmse = perf_metrix(aar, tms)

esm_pearson_p_sci_latex = "{:.2e}".format(esm_pearson_p)
esm_pearson_p = float(esm_pearson_p_sci_latex.split('e')[0])
esm_pearson_p_pow = int(esm_pearson_p_sci_latex.split('e')[1])
esm_pearson_p = r"${%s} \times 10^{%s}$" % (esm_pearson_p, esm_pearson_p_pow)

esm_spearman_p_sci_latex = "{:.2e}".format(esm_spearman_p)
esm_spearman_p = float(esm_spearman_p_sci_latex.split('e')[0])
esm_spearman_p_pow = int(esm_spearman_p_sci_latex.split('e')[1])
esm_spearman_p = r"${%s} \times 10^{%s}$" % (esm_spearman_p, esm_spearman_p_pow)



print("Correlation between aar and tm: ", esm_pearson)
print("Correlation between aar and tm: ", esm_spearman)

axs[0,1].scatter(tms, aar)
axs[0,1].set_xlim(20,100)
axs[0,1].set_ylabel("Amino acid recovery rate",fontsize=fontsize)
axs[0,1].set_xlabel("Measured melting temperature / °C",fontsize=fontsize)
axs[0,1].set_title(f"{esm_bold} \n{pearson_latex} = {esm_pearson}, {spearman_latex} = {esm_spearman}", y=1.03, fontsize = fontsize)
axs[0,1].set_xticks([20,40,60,80,100])
axs[0,1].set_yticks([0.875,0.9,0.925,0.95,0.975])

# AntiBERTy
fp_antiberty_to_save = "data/benchmark/antiberty.csv"
antiberty = pd.read_csv(fp_antiberty_to_save)
antiberty_scores = antiberty['AntiBERTy'].to_list()

antiberty_pearson, antiberty_pearson_p, antiberty_spearman, antiberty_spearman_p, antiberty_rmse = perf_metrix(antiberty_scores, tms)

antiberty_pearson_p_sci_latex = "{:.2e}".format(antiberty_pearson_p)
antiberty_pearson_p = float(antiberty_pearson_p_sci_latex.split('e')[0])
antiberty_pearson_p_pow = int(antiberty_pearson_p_sci_latex.split('e')[1])
antiberty_pearson_p = r"${%s} \times 10^{%s}$" % (antiberty_pearson_p, antiberty_pearson_p_pow)

antiberty_spearman_p_sci_latex = "{:.2e}".format(antiberty_spearman_p)
antiberty_spearman_p = float(antiberty_spearman_p_sci_latex.split('e')[0])
antiberty_spearman_p_pow = int(antiberty_spearman_p_sci_latex.split('e')[1])
antiberty_spearman_p = r"${%s} \times 10^{%s}$" % (antiberty_spearman_p, antiberty_spearman_p_pow)


axs[1,0].scatter(tms, antiberty_scores)
axs[1,0].set_title(f"{antiberty_bold} \n{pearson_latex} = {antiberty_pearson}, {spearman_latex} = {antiberty_spearman}", y=1.03, fontsize = fontsize)
axs[1,0].set_ylabel(f'Pseudo log-likelihood by AntiBERTy', fontsize=fontsize)
axs[1,0].set_xlabel(f'Measured melting temperature / °C', fontsize=fontsize)
axs[1,0].set_ylim(-2.0,-0.4)
axs[1,0].set_xlim(20,100)
axs[1,0].set_xticks([20,40,60,80,100])
axs[1,0].set_yticks([-2.0,-1.5,-1.0,-0.5])


# DeepStabP - 637
fp_deepstabp_pred = 'data/benchmark/DeepStabP.csv'
deepstabp_pred = pd.read_csv(fp_deepstabp_pred)
deepstabp_pred['tm'] = tms
pred = deepstabp_pred['predictions'].values
target = deepstabp_pred['tm'].values

print(np.std(pred)/np.std(target))

deepstabp_pearson, deepstabp_pearson_p, deepstabp_spearman, deepstabp_spearman_p, deepstabp_rmse = perf_metrix(pred, target)

deepstabp_pearson_p_sci_latex = "{:.2e}".format(deepstabp_pearson_p)
deepstabp_pearson_p = float(deepstabp_pearson_p_sci_latex.split('e')[0])
deepstabp_pearson_p_pow = int(deepstabp_pearson_p_sci_latex.split('e')[1])
deepstabp_pearson_p = r"${%s} \times 10^{%s}$" % (deepstabp_pearson_p, deepstabp_pearson_p_pow)

deepstabp_spearman_p_sci_latex = "{:.2e}".format(deepstabp_spearman_p)
deepstabp_spearman_p = float(deepstabp_spearman_p_sci_latex.split('e')[0])
deepstabp_spearman_p_pow = int(deepstabp_spearman_p_sci_latex.split('e')[1])
deepstabp_spearman_p = r"${%s} \times 10^{%s}$" % (deepstabp_spearman_p, deepstabp_spearman_p_pow)


sns.scatterplot(deepstabp_pred, x='tm', y='predictions', ax=axs[1,1])
axs[1,1].set_ylabel('Predicted melting temperature / °C', fontsize=fontsize)
axs[1,1].set_xlabel('Measured melting temperature / °C', fontsize=fontsize)
axs[1,1].set_xlim(20,100)
axs[1,1].set_ylim(20,100)
axs[1,1].set_title(f"{deepstabp_bold} \n{pearson_latex} = {deepstabp_pearson}, {spearman_latex} = {deepstabp_spearman}", y=1.03, fontsize = fontsize)
axs[1,1].set_xticks([20,40,60,80,100])
axs[1,1].set_yticks([20,40,60,80,100])


# Set font size
axs[0,0].tick_params(labelsize=fontsize)
axs[0,1].tick_params(labelsize=fontsize)
axs[1,0].tick_params(labelsize=fontsize)
axs[1,1].tick_params(labelsize=fontsize)


# Show ticks on all sides
axs[0,0].tick_params(axis='both', which='both', direction='in', length=10, width=2)
axs[0,1].tick_params(axis='both', which='both', direction='in', length=10, width=2)
axs[1,0].tick_params(axis='both', which='both', direction='in', length=10, width=2)
axs[1,1].tick_params(axis='both', which='both', direction='in', length=10, width=2)


plt.tight_layout()
plt.subplots_adjust(hspace=.7) # Separate two rows of plots
plt.savefig('figures/figureS11.png')
plt.close()


['6lr7', '6knm', '7mjh', '1jto', '6i2g', '1kxt', '6u53', '5nm0', '5ocl', '6hhu', '7kbi', '6u50', '4lgs', '6ck8', '7nqk', '5ovw', '4lgp', '3qxv', '6obc', '5u64', '7bc6', '6rtw', '2x6m', '5j56', '5vm4', '6xxn', '3g9a', '1bzq', '3p9w', '5ip4', '3jbc', '5l21', '4u05', '6itp', '6raf', '2bse', '6fyu', '7azb', '5vlv', '5td8', '5foj', '6qv1', '3jbd', '7nqa', '6i8h', '4idl']
46
FoldX:  0.033 0.828 -6.17e-05 1.0 39.2
Correlation between aar and tm:  0.337
Correlation between aar and tm:  0.233
0.2756155432158137
