# Comparison of GNN results with REF15

1. Compare results of best model from hyperparameter tuning with REF15 for AbAg-affinity test set and AB-benchmark

2. Check how robust GNN ist on 4-Fold Cross Validation

In [1]:
import pandas as pd
import os
from collections import defaultdict
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
from abag_affinity.utils.config import read_config, get_data_paths
from pathlib import Path
from scipy.stats import norm
import seaborn as sns

In [2]:
project_root = "../.."
plot_path = os.path.join(project_root, "results", "experiments", "GNN_TF-GNN_comparison")
Path(plot_path).mkdir(exist_ok=True, parents=True)

gnn_abag_test_result_path = os.path.join(project_root, "results/predictions/CV_experiment/abag_affinity_test_cv1.csv")
gnn_ab_benchmark_result_path = os.path.join(project_root, "results/predictions/CV_experiment/benchmark_cv1.csv")
col_mapping = {
    "prediction": "gnn_prediction",
    "labels": "-log(Kd)_labels"
}
gnn_abag_test_df = pd.read_csv(gnn_abag_test_result_path, index_col=0).set_index("pdb").rename(col_mapping, axis=1)
gnn_ab_benchmark_df = pd.read_csv(gnn_ab_benchmark_result_path, index_col=0).set_index("pdb").rename(col_mapping, axis=1)

In [3]:
tfl_abag_test_result_path = os.path.join(project_root, "results/predictions/CV_TF_experiment/abag_affinity_test_cv1.csv")
tfl_ab_benchmark_result_path = os.path.join(project_root, "results/predictions/CV_TF_experiment/benchmark_cv1.csv")
col_mapping = {
    "prediction": "gnn_prediction",
    "labels": "-log(Kd)_labels"
}
tfl_abag_test_df = pd.read_csv(tfl_abag_test_result_path, index_col=0).set_index("pdb").rename(col_mapping, axis=1)
tfl_ab_benchmark_df = pd.read_csv(tfl_ab_benchmark_result_path, index_col=0).set_index("pdb").rename(col_mapping, axis=1)

FileNotFoundError: [Errno 2] No such file or directory: '../../results/predictions/CV_TF_experiment/abag_affinity_test_cv1.csv'

In [None]:
config = read_config("../config.yaml")
force_field_results_folder = config["force_field_results"]
assumed_temp = 298.15
gas_constant =  8.31446261815324 

def get_scores(path):
    # get pdb_ids
    with open(os.path.join(path, "data_points.txt")) as f:
        pdbs = f.readlines()
    pdbs = [ pdb.split(",")[0].strip().lower() for pdb in pdbs]

    # get delta g scores
    with open(os.path.join(path, "results.txt")) as f:
        scores = f.readlines()

    results = defaultdict(dict)
    pdb_idx = 0
    for score in scores:
        if score == 'Antibody\n':
            score_type = "Antibody"
        if score == 'Antigen\n':
            score_type = "Antigen"
        if score == 'Complex\n':
            score_type = "Complex"
        if score in ["\n", 'Antibody\n', 'Antigen\n', 'Complex\n', 'fa_atr\n', 'score\n']:
            pdb_idx = 0
            continue
            
        results[pdbs[pdb_idx]][score_type] = score.strip()
        pdb_idx += 1

    rosetta_scores = pd.DataFrame(results.values())
    rosetta_scores.index = results.keys()
    rosetta_scores = rosetta_scores.astype(float)
    rosetta_scores["rosetta_delta_g"] = rosetta_scores["Complex"] - rosetta_scores["Antibody"] - rosetta_scores["Antigen"]
    
    return rosetta_scores

In [None]:
num_bootstrap_repeats = 10000

def calculate_and_plot_bootstraped_error_differences(df):
    # based on https://stats.stackexchange.com/questions/518773/statistical-test-for-comparing-performance-metrics-of-two-regression-models-on-a
    bootstrapped_mean_diff = [ df["error_diff"].sample(n=len(benchmark_results), replace=True).mean() for _ in range(num_bootstrap_repeats)]
    ci = norm(*norm.fit(bootstrapped_mean_diff)).interval(0.95)  # fit a normal distribution and get 95% c.i.
    sns.histplot(bootstrapped_mean_diff, bins=50)
    plt.axvline(ci[0], color="red")
    plt.axvline(ci[1], color="red")
    
    plt.xlabel("Mean of difference of absolute prediction errors")
    plt.title(f"Histogram of mean differences of absolute prediction erros - {num_bootstrap_repeats} bootstrapped samples")
    return ci


def calculate_and_plot_bootstraped_metrics(df):
    # based on https://stats.stackexchange.com/questions/518773/statistical-test-for-comparing-performance-metrics-of-two-regression-models-on-a
    fig, ax = plt.subplots(1,2, figsize=(15,5))

    bootstrapped_dfs = [ df[["delta_g_labels", "-log(Kd)_labels", "ref15_prediction", "rosetta_delta_g", "gnn_prediction"]].sample(
        n=len(benchmark_results), replace=True) for _ in range(num_bootstrap_repeats)]
    
    
    # calculate and plot rsme
    rsme = [ np.sqrt(np.mean((sample_df["gnn_prediction"]-sample_df["-log(Kd)_labels"])**2)) -
              np.sqrt(np.mean((sample_df["ref15_prediction"]-sample_df["-log(Kd)_labels"])**2))
            for sample_df in bootstrapped_dfs]

    ci_rsme = norm(*norm.fit(rsme)).interval(0.9)  # fit a normal distribution and get 95% c.i.
    sns.histplot(rsme, bins=50, ax=ax[0])
    #ax[0].axvline(ci_rsme[0], color="red")
    ax[0].axvline(ci_rsme[1], color="red")
    ax[0].set_xlabel("Difference of root-mean-squared-errors")

    
    # calculate and plot pearson    
    pearson = [ stats.pearsonr(x=sample_df["gnn_prediction"], y=sample_df["-log(Kd)_labels"])[0] - 
              stats.pearsonr(x=sample_df["rosetta_delta_g"], y=sample_df["delta_g_labels"])[0]
            for sample_df in bootstrapped_dfs]
    
    ci_pearson = norm(*norm.fit(pearson)).interval(0.95)  # fit a normal distribution and get 95% c.i.
    sns.histplot(pearson, bins=50, ax=ax[1])
    ax[1].axvline(ci_pearson[0], color="red")
    ax[1].axvline(ci_pearson[1], color="red")
    ax[1].set_xlabel("Difference of pearson correlations")
    plt.suptitle(f"Histogram of RSME and Pearson's R differences - {num_bootstrap_repeats} bootstrapped samples")
    return ci_rsme, ci_pearson

In [None]:
# fit linear transform for REF 15 values based on abag_affinity dataset
from sklearn.linear_model import LinearRegression

abag_rosetta_scores = get_scores(os.path.join(force_field_results_folder, "guest_REF15", "abag_affinity_dataset"))

abag_summary_path = os.path.join(config["DATASETS"]["path"], config["DATASETS"]["abag_affinity"]["folder_path"], config["DATASETS"]["abag_affinity"]["summary"])
abag_df = pd.read_csv(abag_summary_path, index_col=0)
abag_df = abag_df[~abag_df["test"] & ~abag_df["delta_g"].isna()]

overlap_df = abag_df.join(abag_rosetta_scores)

X = overlap_df["rosetta_delta_g"].values.reshape(-1, 1)
y = overlap_df["delta_g"].values.reshape(-1, 1)

rosetta_fit = LinearRegression().fit(X, y)

In [None]:
def calc_kd(delta_g):
    delta_g = rosetta_fit.predict(np.array(delta_g).reshape(1,-1)).item() # scale value
    
    delta_g = delta_g * 4184 # convert to cal  
    
    kd = 1 / np.exp(-delta_g / ( gas_constant * assumed_temp))
    return kd

def calc_delta_g(kd):
    delta_g = -1 * gas_constant * assumed_temp * np.log(1 / kd)
    return delta_g / 4184 # convert to kcal

## AB-benchmark

### Calculate difference to predictions

In [None]:
benchmark_results = gnn_ab_benchmark_df.join(benchmark_ref15_scores)[["gnn_prediction", "-log(Kd)_labels", "ref15_prediction", "rosetta_delta_g"]]
benchmark_results["delta_g_labels"] = benchmark_results["-log(Kd)_labels"].apply(lambda x: calc_delta_g(10**(-x)))

benchmark_results["gnn_diff"] = benchmark_results["gnn_prediction"] - benchmark_results["-log(Kd)_labels"]
benchmark_results["gnn_error"] = np.abs(benchmark_results["gnn_diff"])

benchmark_results["ref15_diff"] = benchmark_results["ref15_prediction"] - benchmark_results["-log(Kd)_labels"]
benchmark_results["ref15_error"] = np.abs(benchmark_results["ref15_diff"])

benchmark_results["error_diff"] = benchmark_results["gnn_error"] - benchmark_results["ref15_error"]

In [None]:
benchmark_results.describe()

### Plot statistics of differences

In [None]:
fig, ax = plt.subplots(2,2, figsize=(15,10))

benchmark_results["gnn_diff"].plot.hist(ax=ax[0,0])
ax[0,0].set_title("GNN prediction error histogram")
ax[0,0].set_xlabel("GNN prediction error [-log(Kd)]")

benchmark_results.plot.scatter("gnn_prediction", "gnn_diff", ax=ax[0,1])
ax[0,1].set_title("GNN prediction error vs. GNN prediction")

benchmark_results["ref15_diff"].plot.hist(ax=ax[1,0])
ax[1,0].set_title("RE15 prediction error histogram")
ax[1,0].set_xlabel("RE15 prediction error [-log(Kd)]")

benchmark_results.plot.scatter("ref15_prediction", "ref15_diff", ax=ax[1,1])
ax[1,1].set_title("RE15 prediction error vs. RE15 prediction")

plt.savefig(os.path.join(plot_path, "ab_benchmark_error_distributions.png"))
plt.show()
plt.close()

In [None]:
benchmark_results["error_diff"].plot.hist()
plt.title("Difference in absolute prediction errors: GNN Error - REF15 Error")
plt.xlabel("Difference in absolute error [-log(Kd)]")
plt.savefig(os.path.join(plot_path, "ab_benchmark_error_difference_distributions.png"))
plt.show()
plt.close()

### Bootstrapping Method to generate multiple means of differences

In [None]:
calculate_and_plot_bootstraped_error_differences(benchmark_results)
plt.savefig(os.path.join(plot_path, "ab_benchmark_error_difference_distributions_bootstrapped.png"))
plt.show()
plt.close()

In [None]:
ci_rsme, ci_pearson = calculate_and_plot_bootstraped_metrics(benchmark_results)
plt.savefig(os.path.join(plot_path, "ab_benchmark_metric_difference_bootstrapped.png"))
plt.show()
plt.close()

In [None]:
print(ci_rsme)
print(ci_pearson)

### Calculate paired 

In [None]:
stats.wilcoxon(benchmark_results["gnn_error"], benchmark_results["ref15_error"], alternative="less")

## AbAg-Affinity test set

### Calculate difference to predictions

In [None]:
abag_test_results = gnn_abag_test_df.join(abag_rosetta_scores, how="inner")[["gnn_prediction", "-log(Kd)_labels", "ref15_prediction", "rosetta_delta_g"]]
abag_test_results["delta_g_labels"] = abag_test_results["-log(Kd)_labels"].apply(lambda x: calc_delta_g(10**(-x)))


abag_test_results["gnn_diff"] = abag_test_results["gnn_prediction"] - abag_test_results["-log(Kd)_labels"]
abag_test_results["gnn_error"] = np.abs(abag_test_results["gnn_diff"])

abag_test_results["ref15_diff"] = abag_test_results["ref15_prediction"] - abag_test_results["-log(Kd)_labels"]
abag_test_results["ref15_error"] = np.abs(abag_test_results["ref15_diff"])

abag_test_results["error_diff"] = abag_test_results["gnn_error"] - abag_test_results["ref15_error"]

In [None]:
abag_test_results.describe()

### Plot statistics of differences

In [None]:
fig, ax = plt.subplots(2,2, figsize=(15,10))

abag_test_results["gnn_diff"].plot.hist(ax=ax[0,0])
ax[0,0].set_title("GNN prediction error histogram")
ax[0,0].set_xlabel("GNN prediction error [-log(Kd)]")

abag_test_results.plot.scatter("gnn_prediction", "gnn_diff", ax=ax[0,1])
ax[0,1].set_title("GNN prediction error vs. GNN prediction")

abag_test_results["ref15_diff"].plot.hist(ax=ax[1,0])
ax[1,0].set_title("RE15 prediction error histogram")
ax[1,0].set_xlabel("REF15 prediction error [-log(Kd)]")

abag_test_results[abag_test_results["rosetta_delta_g"] < 200].plot.scatter("ref15_prediction", "ref15_diff", ax=ax[1,1])
ax[1,1].set_title("RE15 prediction error vs. RE15 prediction")

plt.savefig(os.path.join(plot_path, "abag_testset_error_distributions.png"))
plt.show()
plt.close()

In [None]:
abag_test_results[abag_test_results["rosetta_delta_g"] < 200]["error_diff"].plot.hist()
plt.title("Difference in absolute prediction errors: GNN Error - REF15 Error")
plt.xlabel("Difference in absolute error [-log(Kd)]")
plt.savefig(os.path.join(plot_path, "abag_test_error_difference_distributions.png"))
plt.show()
plt.close()

### Bootstrapping Method to generate multiple means of differences

In [None]:
calculate_and_plot_bootstraped_error_differences(abag_test_results[abag_test_results["rosetta_delta_g"] < 200])
plt.savefig(os.path.join(plot_path, "abag_test_error_difference_distributions_bootstrapped.png"))
plt.show()
plt.close()

In [None]:
ci_rsme, ci_pearson = calculate_and_plot_bootstraped_metrics(abag_test_results[abag_test_results["rosetta_delta_g"] < 200])
plt.savefig(os.path.join(plot_path, "abag_test_metric_difference_bootstrapped.png"))
plt.show()
plt.close()

In [None]:
print(ci_rsme)
print(ci_pearson)

### Calculate paired T-Test

In [None]:
stats.wilcoxon(abag_test_results["gnn_error"], abag_test_results["ref15_error"], alternative="less")

## Combination of both datasets

In [None]:
full_results = pd.concat([abag_test_results, benchmark_results])
full_results.describe()

In [None]:
fig, ax = plt.subplots(2,2, figsize=(15,10))

full_results["gnn_diff"].plot.hist(ax=ax[0,0])
ax[0,0].set_title("GNN prediction error histogram")
ax[0,0].set_xlabel("GNN prediction error [-log(Kd)]")


full_results.plot.scatter("gnn_prediction", "gnn_diff", ax=ax[0,1])
ax[0,1].set_title("GNN prediction error vs. GNN prediction")

full_results["ref15_diff"].plot.hist(ax=ax[1,0])
ax[1,0].set_title("RE15 prediction error histogram")
ax[1,0].set_xlabel("RE15 prediction error [-log(Kd)]")


full_results[full_results["ref15_prediction"] > 3].plot.scatter("ref15_prediction", "ref15_diff", ax=ax[1,1])
#full_results.plot.scatter("ref15_prediction", "ref15_diff", ax=ax[1,1])

ax[1,1].set_title("RE15 prediction error vs. RE15 prediction")


plt.savefig(os.path.join(plot_path, "full_testset_error_distributions.png"))
plt.show()
plt.close()

In [None]:
full_results[full_results["ref15_prediction"] > 3]["error_diff"].plot.hist()
plt.title("Difference in absolute prediction errors: GNN Error - REF15 Error")
plt.xlabel("Difference in absolute error [-log(Kd)]")
plt.savefig(os.path.join(plot_path, "full_testset_error_difference_distributions.png"))
plt.show()
plt.close()

### Bootstrapping Method to generate multiple means of differences

In [None]:
calculate_and_plot_bootstraped_error_differences(full_results[full_results["ref15_prediction"] > 3])
plt.savefig(os.path.join(plot_path, "full_testset_error_difference_distributions_bootstrapped.png"))
plt.show()
plt.close()

In [None]:
ci_rsme, ci_pearson = calculate_and_plot_bootstraped_metrics(full_results[full_results["ref15_prediction"] > 3])
plt.savefig(os.path.join(plot_path, "full_testset_metric_difference_bootstrapped.png"))
plt.show()
plt.close()

In [None]:
print(ci_rsme)
print(ci_pearson)

### Calculate paired T-Test

In [None]:
stats.wilcoxon(full_results["gnn_error"], full_results["ref15_error"], alternative="less")