## Structure Prediction Results Plotting

#### Import packages

In [None]:
import os

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import seaborn as sns

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "diffdock",
    "dynamicbind",
    "neuralplexer",
    "af3",
    "flowdock",
]

# DiffDock results
diffdock_output_dir = os.path.join("..", "forks", "DiffDock", "inference")
diffdock_pdbbind_bust_results_csv_filepath = os.path.join(
    diffdock_output_dir, "diffdock_pdbbind_output_1", "bust_results.csv"
)
diffdock_dockgen_bust_results_csv_filepath = os.path.join(
    diffdock_output_dir, "diffdock_moad_output_1", "bust_results.csv"
)

# DynamicBind results
dynamicbind_output_dir = os.path.join(
    "..", "forks", "DynamicBind", "inference", "output", "results"
)
dynamicbind_pdbbind_bust_results_csv_filepath = os.path.join(
    dynamicbind_output_dir, "dynamicbind_pdbbind_1", "bust_results.csv"
)
dynamicbind_dockgen_bust_results_csv_filepath = os.path.join(
    dynamicbind_output_dir, "dynamicbind_moad_1", "bust_results.csv"
)

# NeuralPLexer results
neuralplexer_output_dir = os.path.join("..", "forks", "NeuralPLexer", "inference")
neuralplexer_pdbbind_bust_results_csv_filepath = os.path.join(
    neuralplexer_output_dir,
    "neuralplexer_pdbbind_outputs_1",
    "bust_results.csv",
)
neuralplexer_dockgen_bust_results_csv_filepath = os.path.join(
    neuralplexer_output_dir,
    "neuralplexer_moad_outputs_1",
    "bust_results.csv",
)

# AlphaFold 3 results
af3_output_dir = os.path.join("..", "forks", "AlphaFold3", "inference")
af3_pdbbind_bust_results_csv_filepath = os.path.join(
    af3_output_dir, "af3_pdbbind_outputs_1", "bust_results.csv"
)
af3_dockgen_bust_results_csv_filepath = os.path.join(
    af3_output_dir, "af3_moad_outputs_1", "bust_results.csv"
)

# FlowDock results
flowdock_output_dir = os.path.join("..", "forks", "FlowDock", "inference")
flowdock_pdbbind_bust_results_csv_filepath = os.path.join(
    flowdock_output_dir,
    "flowdock_pdbbind_outputs_1",
    "bust_results.csv",
)
flowdock_dockgen_bust_results_csv_filepath = os.path.join(
    flowdock_output_dir,
    "flowdock_moad_outputs_1",
    "bust_results.csv",
)

# Mappings
method_mapping = {
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "af3": "AlphaFold 3",
    "flowdock": "FlowDock",
}

method_category_mapping = {
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "af3": "DL-based blind",
    "flowdock": "DL-based blind",
}

# Metrics
BUST_TEST_COLUMNS = [
    # accuracy
    "rmsd_≤_2å",
    # chemical validity and consistency
    "mol_pred_loaded",
    "mol_true_loaded",
    "mol_cond_loaded",
    "sanitization",
    "molecular_formula",
    "molecular_bonds",
    "tetrahedral_chirality",
    "double_bond_stereochemistry",
    # intramolecular validity
    "bond_lengths",
    "bond_angles",
    "internal_steric_clash",
    "aromatic_ring_flatness",
    "double_bond_flatness",
    "internal_energy",
    # intermolecular validity
    "minimum_distance_to_protein",
    "minimum_distance_to_organic_cofactors",
    "minimum_distance_to_inorganic_cofactors",
    "volume_overlap_with_protein",
    "volume_overlap_with_organic_cofactors",
    "volume_overlap_with_inorganic_cofactors",
]

#### Report test results for new methods

In [None]:
# load and report test results for each baseline method
for method in baseline_methods:
    method_title = method_mapping[method]

    globals()[f"{method}_pdbbind_bust_results"] = pd.read_csv(
        globals()[f"{method}_pdbbind_bust_results_csv_filepath"]
    )
    globals()[f"{method}_dockgen_bust_results"] = pd.read_csv(
        globals()[f"{method}_dockgen_bust_results_csv_filepath"]
    )
    globals()[f"{method}_pdbbind_tests_table"] = globals()[f"{method}_pdbbind_bust_results"][
        BUST_TEST_COLUMNS + ["rmsd"]
    ]
    globals()[f"{method}_dockgen_tests_table"] = globals()[f"{method}_dockgen_bust_results"][
        BUST_TEST_COLUMNS + ["rmsd"]
    ]
    globals()[f"{method}_pdbbind_tests_table"].loc[:, "pb_valid"] = (
        globals()[f"{method}_pdbbind_tests_table"].iloc[:, 1:].all(axis=1)
    )
    globals()[f"{method}_dockgen_tests_table"].loc[:, "pb_valid"] = (
        globals()[f"{method}_dockgen_tests_table"].iloc[:, 1:].all(axis=1)
    )

    globals()[f"{method}_pdbbind_tests_table"].loc[:, "method"] = method
    globals()[f"{method}_pdbbind_tests_table"].loc[:, "post-processing"] = "none"
    globals()[f"{method}_pdbbind_tests_table"].loc[:, "dataset"] = "pdbbind"
    globals()[f"{method}_pdbbind_tests_table"].loc[
        :, "docked_ligand_successfully_loaded"
    ] = globals()[f"{method}_pdbbind_tests_table"][
        ["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]
    ].all(
        axis=1
    )

    globals()[f"{method}_dockgen_tests_table"].loc[:, "method"] = method
    globals()[f"{method}_dockgen_tests_table"].loc[:, "post-processing"] = "none"
    globals()[f"{method}_dockgen_tests_table"].loc[:, "dataset"] = "dockgen"
    globals()[f"{method}_dockgen_tests_table"].loc[
        :, "docked_ligand_successfully_loaded"
    ] = globals()[f"{method}_dockgen_tests_table"][
        ["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]
    ].all(
        axis=1
    )

    globals()[f"{method}_tests_table"] = pd.concat(
        [globals()[f"{method}_pdbbind_tests_table"], globals()[f"{method}_dockgen_tests_table"]]
    )

    print(
        f"{method_title} PDBBind set `rmsd_≤_2å`: {globals()[f'{method}_pdbbind_tests_table']['rmsd_≤_2å'].mean()}"
    )
    print(
        f"{method_title} PDBBind set `rmsd_≤_2å and pb_valid`: {globals()[f'{method}_pdbbind_tests_table'][globals()[f'{method}_pdbbind_tests_table']['pb_valid']]['rmsd_≤_2å'].sum() / len(globals()[f'{method}_pdbbind_tests_table'])}"
    )

    print(
        f"\n{method_title} DockGen set `rmsd_≤_2å`: {globals()[f'{method}_dockgen_tests_table']['rmsd_≤_2å'].mean()}"
    )
    print(
        f"{method_title} DockGen set `rmsd_≤_2å and pb_valid`: {globals()[f'{method}_dockgen_tests_table'][globals()[f'{method}_dockgen_tests_table']['pb_valid']]['rmsd_≤_2å'].sum() / len(globals()[f'{method}_dockgen_tests_table'])}\n"
    )

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "Misc")

#### Standardize metrics

In [None]:
# load and organize the paper results CSV
tests_table = pd.concat([globals()[f"{method}_tests_table"] for method in baseline_methods])
tests_table.loc[:, "method_category"] = tests_table["method"].apply(categorize_method)
tests_table.loc[:, "method_assignment_index"] = tests_table["method"].apply(assign_method_index)
tests_table.loc[:, "rmsd_within_threshold"] = tests_table.loc[:, "rmsd_≤_2å"].fillna(False)
tests_table.loc[:, "rmsd_within_threshold_and_pb_valid"] = (
    tests_table.loc[:, "rmsd_within_threshold"]
) & (tests_table.loc[:, "pb_valid"].fillna(False))
tests_table.loc[:, "RMSD ≤ 2 Å & PB-Valid"] = tests_table.loc[
    :, "rmsd_within_threshold_and_pb_valid"
].astype(int)
tests_table.loc[:, "RMSD ≤ 2 Å"] = (
    tests_table.loc[:, "rmsd_within_threshold"].fillna(False).astype(int)
)
tests_table.loc[:, "dataset"] = tests_table.loc[:, "dataset"].map(
    {"pdbbind": "PDBBind set", "dockgen": "DockGen set"}
)
tests_table.loc[:, "method"] = tests_table.loc[:, "method"].map(method_mapping)

#### Make plots

In [None]:
# Bar Chart of PDBBind & DockGen Set Results #

# prepare data for the bar charts to plot
colors = ["#8DD3C7", "#FB8072"]

bar_width = 0.5
r1 = [item - 0.25 for item in [2, 4, 6, 8, 10]]
r2 = [x + bar_width for x in r1]

tests_table = tests_table.sort_values(by="method_assignment_index")
pdbbind_results_table = tests_table[
    (tests_table["dataset"] == "PDBBind set") & (tests_table["post-processing"] == "none")
]
dockgen_results_table = tests_table[
    (tests_table["dataset"] == "DockGen set") & (tests_table["post-processing"] == "none")
]

pdbbind_labels = pdbbind_results_table["method"].unique()
dockgen_labels = dockgen_results_table["method"].unique()
num_methods = len(pdbbind_labels)

num_pdbbind_data_points = max(
    len(pdbbind_results_table[(pdbbind_results_table["method"] == method)])
    for method in pdbbind_labels
)
num_dockgen_data_points = max(
    len(dockgen_results_table[(dockgen_results_table["method"] == method)])
    for method in dockgen_labels
)

# PDBBind results
pdbbind_rmsd_lt_2_data = (
    pdbbind_results_table.groupby("method")
    .agg({"RMSD ≤ 2 Å": "sum", "method_assignment_index": "first"})
    .reset_index()
)
pdbbind_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
    pdbbind_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_pdbbind_data_points * 100
)
pdbbind_rmsd_lt_2_data = pdbbind_rmsd_lt_2_data.sort_values("method_assignment_index")[
    "RMSD ≤ 2 Å"
]

# DockGen results
dockgen_rmsd_lt_2_data = (
    dockgen_results_table.groupby("method")
    .agg({"RMSD ≤ 2 Å": "sum", "method_assignment_index": "first"})
    .reset_index()
)
dockgen_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
    dockgen_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_dockgen_data_points * 100
)
dockgen_rmsd_lt_2_data = dockgen_rmsd_lt_2_data.sort_values("method_assignment_index")[
    "RMSD ≤ 2 Å"
]

# PDBBind (PB-Valid) results
pdbbind_rmsd_lt_2_and_pb_valid_data = (
    pdbbind_results_table.groupby("method")
    .agg({"RMSD ≤ 2 Å & PB-Valid": "sum", "method_assignment_index": "first"})
    .reset_index()
)
pdbbind_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] = (
    pdbbind_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] / num_pdbbind_data_points * 100
)
pdbbind_rmsd_lt_2_and_pb_valid_data = pdbbind_rmsd_lt_2_and_pb_valid_data.sort_values(
    "method_assignment_index"
)["RMSD ≤ 2 Å & PB-Valid"]

# DockGen (PB-Valid) results
dockgen_rmsd_lt_2_and_pb_valid_data = (
    dockgen_results_table.groupby("method")
    .agg({"RMSD ≤ 2 Å & PB-Valid": "sum", "method_assignment_index": "first"})
    .reset_index()
)
dockgen_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] = (
    dockgen_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] / num_dockgen_data_points * 100
)
dockgen_rmsd_lt_2_and_pb_valid_data = dockgen_rmsd_lt_2_and_pb_valid_data.sort_values(
    "method_assignment_index"
)["RMSD ≤ 2 Å & PB-Valid"]

# create the figure and a list of axes
fig, axis = plt.subplots(figsize=(12, 6))
axis.spines["top"].set_visible(False)
axis.spines["right"].set_visible(False)
axis.spines["bottom"].set_visible(False)
axis.spines["left"].set_visible(False)

# define font properties
plt.rcParams["font.size"] = 11
plt.rcParams["axes.labelsize"] = 13

# plot data for the PDBBind set
pdbbind_rmsd_lt_2_and_pb_valid_bar = axis.bar(
    r1,
    pdbbind_rmsd_lt_2_and_pb_valid_data,
    label="RMSD ≤ 2Å & PB-Valid",
    color=colors[0],
    width=bar_width,
)
pdbbind_rmsd_lt_2_bar = axis.bar(
    r1,
    pdbbind_rmsd_lt_2_data,
    label="RMSD ≤ 2Å",
    color="none",
    edgecolor=colors[0],
    hatch="\\\\\\",
    width=bar_width,
)

# plot data for the DockGen set
dockgen_rmsd_lt_2_and_pb_valid_bar = axis.bar(
    r2,
    dockgen_rmsd_lt_2_and_pb_valid_data,
    label="RMSD ≤ 2Å & PB-Valid",
    color=colors[1],
    width=bar_width,
)
dockgen_rmsd_lt_2_bar = axis.bar(
    r2,
    dockgen_rmsd_lt_2_data,
    label="RMSD ≤ 2Å",
    color="none",
    edgecolor=colors[1],
    hatch="\\\\\\",
    width=bar_width,
)

# add labels, titles, ticks, etc.
axis.set_ylabel("Percentage of predictions")
axis.set_xlim(1, 11 + 0.1)
axis.set_ylim(0, 100)

axis.bar_label(pdbbind_rmsd_lt_2_bar, fmt="{:,.1f}%", label_type="edge")
axis.bar_label(pdbbind_rmsd_lt_2_and_pb_valid_bar, fmt="{:,.1f}%", label_type="center", padding=5)
axis.bar_label(dockgen_rmsd_lt_2_bar, fmt="{:,.1f}%", label_type="edge")
axis.bar_label(dockgen_rmsd_lt_2_and_pb_valid_bar, fmt="{:,.1f}%", label_type="center", padding=5)

axis.yaxis.set_major_formatter(mtick.PercentFormatter())

axis.set_yticks([0, 20, 40, 60, 80, 100])
axis.axhline(y=0, color="#EAEFF8")
axis.grid(axis="y", color="#EAEFF8")
axis.set_axisbelow(True)

axis.set_xticks([2, 4, 6, 6 + 1e-3, 8, 10])
axis.set_xticks([1 + 0.1], minor=True)
axis.set_xticklabels(
    [
        "DiffDock-L",
        "DynamicBind",
        "NeuralPLexer",
        "DL-based blind",
        "AlphaFold 3",
        "FlowDock",
    ]
)

axis.grid("off", axis="x", color="#EAEFF8")
axis.grid("off", axis="x", which="minor", color="#EAEFF8")

axis.tick_params(axis="x", which="minor", direction="out", length=30, color="#EAEFF8")
axis.tick_params(axis="x", which="major", bottom="off", top="off", color="#EAEFF8")
axis.tick_params(axis="y", which="major", left="off", right="on", color="#EAEFF8")

# vertical alignment of xtick labels
vert_alignments = [0.0, 0.0, 0.0, -0.1, 0.0, 0.0]
for tick, y in zip(axis.get_xticklabels(), vert_alignments):
    tick.set_y(y)

# add legends
legend_0 = fig.legend(
    [pdbbind_rmsd_lt_2_bar, pdbbind_rmsd_lt_2_and_pb_valid_bar],
    ["RMSD ≤ 2Å", "RMSD ≤ 2Å & PB-Valid"],
    loc="upper right",
    title="PDBBind set",
    bbox_to_anchor=(1, 1, -0.20, -0.05),
)
legend_1 = fig.legend(
    [dockgen_rmsd_lt_2_bar, dockgen_rmsd_lt_2_and_pb_valid_bar],
    ["RMSD ≤ 2Å", "RMSD ≤ 2Å & PB-Valid"],
    loc="upper right",
    title="DockGen set",
    bbox_to_anchor=(1, 1, -0.01, -0.05),
)
legend_0.get_frame().set_alpha(0)
legend_1.get_frame().set_alpha(0)

# display the plots
plt.tight_layout()
plt.savefig("pdbbind_dockgen_bar_chart.png", dpi=300)
plt.show()

In [None]:
# Violin Plot of PDBBind & DockGen Set Results #

# prepare data for the violin plots to plot
colors = ["#8DD3C7", "#FB8072"]

# combine data for both PDBBind and DockGen sets
combined_results_table = pd.concat([pdbbind_results_table, dockgen_results_table]).reset_index(
    drop=True
)

# set the size of the figure
plt.figure(figsize=(12, 6))

# create a violin plot
sns.violinplot(
    x="method",
    y="rmsd",
    hue="dataset",
    data=combined_results_table,
    split=True,
    inner="quartile",
    palette=colors,
    cut=0,
)

# set labels and title
plt.xlabel("DL-based blind")
plt.ylabel("RMSD")

# rotate x-axis labels for better readability
plt.xticks(rotation=45, ha="right")

# display legend outside the plot
plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc="upper left")

# display the plots
plt.tight_layout()
plt.savefig("pdbbind_dockgen_rmsd_violin_plot.png", dpi=300)
plt.show()