## CASP15 Inference Results Plotting

#### Import packages

In [None]:
import os

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import seaborn as sns

from posebench.analysis.inference_analysis_casp import (
    CASP_BUST_TEST_COLUMNS,
    NUM_SCOREABLE_CASP15_TARGETS,
    PUBLIC_CASP15_MULTI_LIGAND_TARGETS,
    PUBLIC_CASP15_SINGLE_LIGAND_TARGETS,
    All_CASP15_MULTI_LIGAND_TARGETS,
    All_CASP15_SINGLE_LIGAND_TARGETS,
)

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "vina_p2rank",
    "diffdock",
    "dynamicbind",
    "neuralplexer",
    "rfaa",
    "chai-lab_ss",
    "chai-lab",
    "alphafold3_ss",
    "alphafold3",
]
max_num_repeats_per_method = 3

# Mappings
method_mapping = {
    "vina_p2rank": "P2Rank-Vina",
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "rfaa": "RoseTTAFold-AA",
    "chai-lab_ss": "Chai-1-Single-Seq",
    "chai-lab": "Chai-1",
    "alphafold3_ss": "AF3-Single-Seq",
    "alphafold3": "AF3",
}

method_category_mapping = {
    "vina_p2rank": "Conventional blind",
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "rfaa": "DL-based blind",
    "chai-lab_ss": "DL-based blind",
    "chai-lab": "DL-based blind",
    "alphafold3_ss": "DL-based blind",
    "alphafold3": "DL-based blind",
}

#### Report test results for each baseline method

In [None]:
# load and report test results for each baseline method
for config in ["", "_relaxed"]:
    for method in baseline_methods:
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            method_title = method_mapping[method]

            globals()[f"{method}_output_dir_{repeat_index}"] = os.path.join(
                "..",
                "data",
                "test_cases",
                "casp15",
                f"top_{method}{'' if 'ensemble' in method else '_ensemble'}_predictions_{repeat_index}",
            )
            globals()[f"{method}{config}_scoring_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    globals()[f"{method}_output_dir_{repeat_index}"] + config,
                    "scoring_results.csv",
                )
            )
            globals()[f"{method}{config}_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
                globals()[f"{method}_output_dir_{repeat_index}"] + config,
                "bust_results.csv",
            )

            if not os.path.exists(
                globals()[f"{method}{config}_scoring_results_csv_filepath_{repeat_index}"]
            ):
                continue

            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"] = pd.read_csv(
                globals()[f"{method}{config}_scoring_results_csv_filepath_{repeat_index}"]
            )
            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].loc[
                :, "num_target_ligands"
            ] = (
                # count the number of ligands in each target complex, and assign these corresponding numbers to the ligands (rows) of each complex
                globals()[f"{method}{config}_scoring_results_table_{repeat_index}"]
                .groupby(["target", "mdl"])["pose"]
                .transform("count")
            )
            globals()[f"{method}{config}_bust_results_table_{repeat_index}"] = (
                pd.read_csv(
                    globals()[f"{method}{config}_bust_results_csv_filepath_{repeat_index}"]
                )
                if os.path.exists(
                    globals()[f"{method}{config}_bust_results_csv_filepath_{repeat_index}"]
                )
                else None
            )
            # filter out non-relevant ligand predictions, and for all methods select only their first model for each ligand
            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"] = globals()[
                f"{method}{config}_scoring_results_table_{repeat_index}"
            ][
                np.where(
                    (globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].relevant),
                    True,
                    False,
                )
                & (globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].mdl == 1)
            ]

            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].loc[
                :, "method"
            ] = method
            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].loc[
                :, "post-processing"
            ] = ("energy minimization" if config == "_relaxed" else "none")
            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].loc[
                :, "dataset"
            ] = "casp15"

            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].loc[
                :, "rmsd_≤_2å"
            ] = (
                globals()[f"{method}{config}_scoring_results_table_{repeat_index}"].loc[:, "rmsd"]
                <= 2
            )

            grouped_num_target_ligands = (
                globals()[f"{method}{config}_scoring_results_table_{repeat_index}"]
                .groupby(["target", "mdl"])["num_target_ligands"]
                .first()
            )
            num_ligands_per_complex = grouped_num_target_ligands.loc[(slice(None), 1)].tolist()

            print(
                f"{method_title}{config}_{repeat_index} CASP15 set average `lddt_pli`: {globals()[f'{method}{config}_scoring_results_table_{repeat_index}']['lddt_pli'].mean()}"
            )
            print(
                f"{method_title}{config}_{repeat_index} CASP15 set average `rmsd`: {globals()[f'{method}{config}_scoring_results_table_{repeat_index}']['rmsd'].mean()}"
            )
            print(
                f"{method_title}{config}_{repeat_index} CASP15 set average `rmsd_≤_2å`: {globals()[f'{method}{config}_scoring_results_table_{repeat_index}']['rmsd_≤_2å'].mean()}"
            )
            if globals()[f"{method}{config}_bust_results_table_{repeat_index}"] is not None:
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"] = globals()[
                    f"{method}{config}_bust_results_table_{repeat_index}"
                ][CASP_BUST_TEST_COLUMNS + ["target"]]
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "pb_valid"
                ] = globals()[f"{method}{config}_bust_results_table_{repeat_index}"].all(axis=1)
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "method"
                ] = method
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "post-processing"
                ] = ("energy minimization" if config == "_relaxed" else "none")
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "dataset"
                ] = "casp15"
                # filter bust results to only those for targets that were scoreable using the CASP scoring pipeline
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"] = globals()[
                    f"{method}{config}_bust_results_table_{repeat_index}"
                ][
                    globals()[f"{method}{config}_bust_results_table_{repeat_index}"].target.isin(
                        globals()[
                            f"{method}{config}_scoring_results_table_{repeat_index}"
                        ].target.unique()
                    )
                ]
                globals()[f"{method}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "num_target_ligands"
                ] = num_ligands_per_complex
                print(
                    f"{method_title}{config}_{repeat_index} CASP15 set average complexes `pb_valid`: {globals()[f'{method}{config}_bust_results_table_{repeat_index}']['pb_valid'].mean()}"
                )
                if (
                    len(globals()[f"{method}{config}_bust_results_table_{repeat_index}"])
                    < NUM_SCOREABLE_CASP15_TARGETS
                ):
                    print(
                        f"Warning: Found {len(globals()[f'{method}{config}_bust_results_table_{repeat_index}'])} scoreable CASP15 targets for {method_title}{config}_{repeat_index} out of the full {NUM_SCOREABLE_CASP15_TARGETS}"
                    )
                print()
            else:
                print()

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def assign_category_index(category: str) -> str:
    """
    Assign category index for plotting.

    :param category: Category name.
    :return: Category index.
    """
    return list(method_mapping.values()).index(category)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "DL-based blind")

#### Standardize metrics

In [None]:
# load and organize the CASP15 results CSV
for repeat_index in range(1, max_num_repeats_per_method + 1):
    # PLIF metrics
    globals()[f"casp15_plif_metrics_csv_filepath_{repeat_index}"] = "casp15_plif_metrics.csv"
    globals()[f"casp15_plif_metrics_table_{repeat_index}"] = pd.read_csv(
        globals()[f"casp15_plif_metrics_csv_filepath_{repeat_index}"]
    )

    globals()[f"scoring_results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}{config}_scoring_results_table_{repeat_index}"]
            for method in baseline_methods
            for config in ["", "_relaxed"]
            if f"{method}{config}_scoring_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"scoring_results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"scoring_results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"scoring_results_table_{repeat_index}"].loc[
        :, "method_assignment_index"
    ] = globals()[f"scoring_results_table_{repeat_index}"]["method"].apply(assign_method_index)
    globals()[f"casp15_plif_metrics_table_{repeat_index}"].loc[
        :, "category_assignment_index"
    ] = globals()[f"casp15_plif_metrics_table_{repeat_index}"]["Category"].apply(
        assign_category_index
    )
    globals()[f"scoring_results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å"] = (
        globals()[f"scoring_results_table_{repeat_index}"]
        .loc[:, "rmsd_≤_2å"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"scoring_results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"scoring_results_table_{repeat_index}"]
        .loc[:, "dataset"]
        .map({"casp15": "CASP15 set"})
    )
    globals()[f"scoring_results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"scoring_results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Make plots

In [None]:
# lDDT-PLI Violin Plot of CASP15 Set (Relaxed vs. Unrelaxed) Results #

# prepare data for the bar charts to plot
colors = ["#FB8072", "#BEBADA"]

combined_data_list = []
for repeat_index in range(1, max_num_repeats_per_method + 1):
    casp15_results_table = globals()[f"scoring_results_table_{repeat_index}"][
        (globals()[f"scoring_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
        & (globals()[f"scoring_results_table_{repeat_index}"]["post-processing"] == "none")
    ]
    casp15_relaxed_results_table = globals()[f"scoring_results_table_{repeat_index}"][
        (globals()[f"scoring_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
        & (
            globals()[f"scoring_results_table_{repeat_index}"]["post-processing"]
            == "energy minimization"
        )
    ]
    combined_data_list.append(pd.concat([casp15_results_table, casp15_relaxed_results_table]))
combined_data = pd.concat(combined_data_list).sort_values("method_assignment_index")

for complex_type in ["single", "multi"]:
    for complex_license in ["all", "public"]:
        # define font properties
        plt.rcParams["font.size"] = 12
        plt.rcParams["axes.labelsize"] = 14

        # set the size of the figure
        plt.figure(figsize=(12, 6))

        # create a violin plot
        sns.violinplot(
            x="method",
            y="lddt_pli",
            hue="post-processing",
            data=combined_data[
                # filter the data based on the complex type and license
                (
                    combined_data["target"].isin(
                        (
                            PUBLIC_CASP15_SINGLE_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_SINGLE_LIGAND_TARGETS
                        )
                    )
                    if complex_type == "single"
                    else combined_data["target"].isin(
                        (
                            PUBLIC_CASP15_MULTI_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_MULTI_LIGAND_TARGETS
                        )
                    )
                )
            ],
            split=True,
            inner="quartile",
            palette=colors,
            cut=0,
        )

        # set labels and title
        plt.xlabel(f"{complex_type.title()}-ligand blind docking ({complex_license})")
        plt.ylabel("lDDT-PLI")

        # rotate x-axis labels for better readability
        plt.xticks(rotation=45, ha="right")

        # display legend outside the plot
        plt.legend(title="Post-processing", bbox_to_anchor=(1.05, 1), loc="best")

        # display the plots
        plt.tight_layout()
        plt.savefig(
            f"casp15_{complex_license}_{complex_type}_ligand_relaxed_lddt_pli_violin_plot.png",
            dpi=300,
        )
        plt.show()

In [None]:
# RMSD Violin Plot of CASP15 Set (Relaxed vs. Unrelaxed) Results #

# prepare data for the bar charts to plot
colors = ["#FB8072", "#BEBADA"]

combined_data_list = []
for repeat_index in range(1, max_num_repeats_per_method + 1):
    casp15_results_table = globals()[f"scoring_results_table_{repeat_index}"][
        (globals()[f"scoring_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
        & (globals()[f"scoring_results_table_{repeat_index}"]["post-processing"] == "none")
    ]
    casp15_relaxed_results_table = globals()[f"scoring_results_table_{repeat_index}"][
        (globals()[f"scoring_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
        & (
            globals()[f"scoring_results_table_{repeat_index}"]["post-processing"]
            == "energy minimization"
        )
    ]
    combined_data_list.append(pd.concat([casp15_results_table, casp15_relaxed_results_table]))
combined_data = pd.concat(combined_data_list).sort_values("method_assignment_index")

for complex_type in ["single", "multi"]:
    for complex_license in ["all", "public"]:
        # define font properties
        plt.rcParams["font.size"] = 12
        plt.rcParams["axes.labelsize"] = 14

        # set the size of the figure
        plt.figure(figsize=(12, 6))

        # create a violin plot
        sns.violinplot(
            x="method",
            y="rmsd",
            hue="post-processing",
            data=combined_data[
                # ignore outliers
                (combined_data["rmsd"] < 200)
                & (
                    # filter the data based on the complex type and license
                    combined_data["target"].isin(
                        (
                            PUBLIC_CASP15_SINGLE_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_SINGLE_LIGAND_TARGETS
                        )
                    )
                    if complex_type == "single"
                    else combined_data["target"].isin(
                        (
                            PUBLIC_CASP15_MULTI_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_MULTI_LIGAND_TARGETS
                        )
                    )
                )
            ],
            split=True,
            inner="quartile",
            palette=colors,
            cut=0,
        )

        # set labels and title
        plt.xlabel(f"{complex_type.title()}-ligand blind docking ({complex_license})")
        plt.ylabel("RMSD")

        # rotate x-axis labels for better readability
        plt.xticks(rotation=45, ha="right")

        # display legend outside the plot
        plt.legend(title="Post-processing", bbox_to_anchor=(1.05, 1), loc="best")

        # display the plots
        plt.tight_layout()
        plt.savefig(
            f"casp15_{complex_license}_{complex_type}_ligand_relaxed_rmsd_violin_plot.png", dpi=300
        )
        plt.show()

In [None]:
# RMSD ≤ 2 Å Bar Chart of CASP15 Set (Relaxed vs. Unrelaxed) Results #

# prepare data for the bar charts to plot
colors = ["#FB8072", "#BEBADA", "#FCCDE5"]

bar_width = 0.5
r1 = [item - 0.5 for item in range(2, 20, 2)]
r2 = [x + bar_width for x in r1]
r3 = [x + bar_width for x in r2]

for complex_type in ["single", "multi"]:
    for complex_license in ["all", "public"]:
        (
            casp15_rmsd_lt_2_data_list,
            casp15_relaxed_rmsd_lt_2_data_list,
            casp15_plif_wm_data_list,
        ) = ([], [], [])
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            # filter the data based on the complex type and license
            casp15_results_table = globals()[f"scoring_results_table_{repeat_index}"][
                (globals()[f"scoring_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
                & (globals()[f"scoring_results_table_{repeat_index}"]["post-processing"] == "none")
                & (
                    globals()[f"scoring_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_SINGLE_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_SINGLE_LIGAND_TARGETS
                        )
                    )
                    if complex_type == "single"
                    else globals()[f"scoring_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_MULTI_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_MULTI_LIGAND_TARGETS
                        )
                    )
                )
            ].sort_values(by="method_assignment_index")
            casp15_relaxed_results_table = globals()[f"scoring_results_table_{repeat_index}"][
                (globals()[f"scoring_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
                & (
                    globals()[f"scoring_results_table_{repeat_index}"]["post-processing"]
                    == "energy minimization"
                )
                & (
                    globals()[f"scoring_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_SINGLE_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_SINGLE_LIGAND_TARGETS
                        )
                    )
                    if complex_type == "single"
                    else globals()[f"scoring_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_MULTI_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_MULTI_LIGAND_TARGETS
                        )
                    )
                )
            ].sort_values(by="method_assignment_index")

            casp15_labels = casp15_results_table["method"].unique()
            num_methods = len(casp15_labels)

            num_casp15_data_points = max(
                len(casp15_results_table[(casp15_results_table["method"] == method)])
                for method in casp15_labels
            )
            num_casp15_relaxed_data_points = max(
                len(
                    casp15_relaxed_results_table[
                        (casp15_relaxed_results_table["method"] == method)
                    ]
                )
                for method in casp15_labels
            )

            # CASP15 set (unrelaxed) results
            casp15_rmsd_lt_2_data = (
                casp15_results_table.groupby(["method"])
                .agg(
                    {
                        "RMSD ≤ 2 Å": "sum",
                        "method_assignment_index": "first",
                    }
                )
                .reset_index()
            )
            casp15_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
                casp15_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_casp15_data_points * 100
            )
            casp15_rmsd_lt_2_data = casp15_rmsd_lt_2_data.sort_values(["method_assignment_index"])
            casp15_rmsd_lt_2_data_list.append(casp15_rmsd_lt_2_data)

            # CASP15 set (relaxed) results
            casp15_relaxed_rmsd_lt_2_data = (
                casp15_relaxed_results_table.groupby(["method"])
                .agg(
                    {
                        "RMSD ≤ 2 Å": "sum",
                        "method_assignment_index": "first",
                    }
                )
                .reset_index()
            )
            casp15_relaxed_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
                casp15_relaxed_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_casp15_relaxed_data_points * 100
            )
            casp15_relaxed_rmsd_lt_2_data = casp15_relaxed_rmsd_lt_2_data.sort_values(
                ["method_assignment_index"]
            )
            casp15_relaxed_rmsd_lt_2_data_list.append(casp15_relaxed_rmsd_lt_2_data)

            # CASP15 PLIF-WM results
            casp15_plif_wm_data = (
                globals()[f"casp15_plif_metrics_table_{repeat_index}"][
                    globals()[f"casp15_plif_metrics_table_{repeat_index}"]["Target"].isin(
                        casp15_results_table["target"].unique()
                    )
                ]
                .groupby("Category")
                .agg({"WM": "mean", "category_assignment_index": "first"})
            )
            casp15_plif_wm_data = casp15_plif_wm_data.sort_values("category_assignment_index")
            casp15_plif_wm_data_list.append(casp15_plif_wm_data)

        # calculate means and standard deviations
        casp15_rmsd_lt_2_data_mean = (
            pd.concat([df for df in casp15_rmsd_lt_2_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .mean()
            .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
        )
        casp15_rmsd_lt_2_data_std = (
            pd.concat([df for df in casp15_rmsd_lt_2_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .std()
            .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
        )

        casp15_relaxed_rmsd_lt_2_data_mean = (
            pd.concat([df for df in casp15_relaxed_rmsd_lt_2_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .mean()
            .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
        )
        casp15_relaxed_rmsd_lt_2_data_std = (
            pd.concat([df for df in casp15_relaxed_rmsd_lt_2_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .std()
            .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
        )

        casp15_plif_wm_data_mean = (
            pd.concat([df for df in casp15_plif_wm_data_list])
            .groupby(
                [
                    "Category",
                    "category_assignment_index",
                ]
            )
            .mean()
            .sort_values(["category_assignment_index"])["WM"]
        ) * 100.0
        casp15_plif_wm_data_std = (
            pd.concat([df for df in casp15_plif_wm_data_list])
            .groupby(
                [
                    "Category",
                    "category_assignment_index",
                ]
            )
            .std()
            .sort_values(["category_assignment_index"])["WM"]
        ) * 100.0

        casp15_rmsd_lt_2_data_std.fillna(0, inplace=True)
        casp15_relaxed_rmsd_lt_2_data_std.fillna(0, inplace=True)
        casp15_plif_wm_data_std.fillna(0, inplace=True)

        # define font properties
        plt.rcParams["font.size"] = 22
        plt.rcParams["axes.labelsize"] = 24

        # create the figure and a list of axes
        fig, axis = plt.subplots(figsize=(34, 14))
        axis.spines["top"].set_visible(False)
        axis.spines["right"].set_visible(False)
        axis.spines["bottom"].set_visible(False)
        axis.spines["left"].set_visible(False)

        # plot (unrelaxed) data for the CASP15 set
        casp15_rmsd_lt2_bar = axis.bar(
            r1,
            casp15_rmsd_lt_2_data_mean,
            yerr=casp15_rmsd_lt_2_data_std,
            label="RMSD ≤ 2Å",
            color="none",
            edgecolor=colors[0],
            hatch="\\\\\\",
            width=bar_width,
        )

        # plot (relaxed) data for the CASP15 set
        casp15_relaxed_rmsd_lt_2_bar = axis.bar(
            r2,
            casp15_relaxed_rmsd_lt_2_data_mean,
            yerr=casp15_relaxed_rmsd_lt_2_data_std,
            label="RMSD ≤ 2Å",
            color="none",
            edgecolor=colors[1],
            hatch="\\\\\\",
            width=bar_width,
        )

        # plot PLIF-WM data for the CASP15 set
        casp15_plif_wm_bar = axis.bar(
            r3,
            casp15_plif_wm_data_mean,
            yerr=casp15_plif_wm_data_std,
            label="PLIF-WM",
            color=colors[2],
            hatch="\\\\\\",
            width=bar_width,
        )

        # add labels, titles, ticks, etc.
        axis.set_xlabel(f"{complex_type.title()}-ligand blind docking ({complex_license})")
        axis.set_ylabel("Percentage of predictions")
        axis.set_xlim(1, 19 + 0.1)
        axis.set_ylim(0, 125)

        axis.bar_label(casp15_rmsd_lt2_bar, fmt="{:,.1f}", label_type="center")
        axis.bar_label(casp15_relaxed_rmsd_lt_2_bar, fmt="{:,.1f}", label_type="center")
        axis.bar_label(casp15_plif_wm_bar, fmt="{:,.1f}", label_type="center")

        axis.yaxis.set_major_formatter(mtick.PercentFormatter())

        axis.set_yticks([0, 20, 40, 60, 80, 100])
        axis.axhline(y=0, color="#EAEFF8")
        axis.grid(axis="y", color="#EAEFF8")
        axis.set_axisbelow(True)

        axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 10, 11, 12, 14, 16, 18])
        axis.set_xticks([1 + 0.1], minor=True)
        axis.set_xticklabels(
            [
                "P2Rank-Vina",
                "Conventional blind",
                "DiffDock-L",
                "DynamicBind",
                "NeuralPLexer",
                "RoseTTAFold-AA",
                "DL-based blind",
                "Chai-1-Single-Seq",
                "Chai-1",
                "AF3-Single-Seq",
                "AF3",
            ]
        )

        axis.grid("off", axis="x", color="#EAEFF8")
        axis.grid("off", axis="x", which="minor", color="#EAEFF8")

        axis.tick_params(axis="x", which="minor", direction="out", length=30, color="#EAEFF8")
        axis.tick_params(axis="x", which="major", bottom="off", top="off", color="#EAEFF8")
        axis.tick_params(axis="y", which="major", left="off", right="on", color="#EAEFF8")

        # vertical alignment of xtick labels
        vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0]
        for tick, y in zip(axis.get_xticklabels(), vert_alignments):
            tick.set_y(y)

        # add legends
        legend_0 = fig.legend(
            [casp15_rmsd_lt2_bar],
            ["RMSD ≤ 2Å"],
            loc="upper right",
            title="No post-processing",
            bbox_to_anchor=(1, 1, -0.40, -0.05),
        )
        legend_1 = fig.legend(
            [casp15_relaxed_rmsd_lt_2_bar],
            ["RMSD ≤ 2Å"],
            loc="upper right",
            title="With relaxation",
            bbox_to_anchor=(1, 1, -0.2, -0.05),
        )
        legend_2 = fig.legend(
            [casp15_plif_wm_bar],
            ["PLIF-WM"],
            loc="upper right",
            title="Protein-ligand interactions\n    (no post-processing)",
            bbox_to_anchor=(1, 1, -0.01, -0.05),
        )
        legend_0.get_frame().set_alpha(0)
        legend_1.get_frame().set_alpha(0)
        legend_2.get_frame().set_alpha(0)

        # display the plots
        plt.tight_layout()
        plt.savefig(
            f"casp15_{complex_license}_{complex_type}_ligand_relaxed_rmsd_lt2_bar_chart.png",
            dpi=300,
        )
        plt.show()

#### Standardize PoseBusters validity metrics

In [None]:
# load and organize the CASP15 PoseBusters validity results CSV
for repeat_index in range(1, max_num_repeats_per_method + 1):
    globals()[f"bust_results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}{config}_bust_results_table_{repeat_index}"]
            for method in baseline_methods
            for config in ["", "_relaxed"]
            if f"{method}{config}_bust_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"bust_results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"bust_results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"bust_results_table_{repeat_index}"].loc[:, "method_assignment_index"] = globals()[
        f"bust_results_table_{repeat_index}"
    ]["method"].apply(assign_method_index)
    globals()[f"bust_results_table_{repeat_index}"].loc[:, "PB-Valid"] = (
        globals()[f"bust_results_table_{repeat_index}"]
        .loc[:, "pb_valid"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"bust_results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"bust_results_table_{repeat_index}"]
        .loc[:, "dataset"]
        .map({"casp15": "CASP15 set"})
    )
    globals()[f"bust_results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"bust_results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Make PoseBusters validity plot

In [None]:
# PB-Valid Bar Chart of CASP15 Set (Relaxed vs. Unrelaxed) Results #

# prepare data for the bar charts to plot
colors = ["#FB8072", "#BEBADA"]

bar_width = 0.75
r1 = [item - 0.25 for item in range(2, 20, 2)]
r2 = [x + bar_width for x in r1]

for complex_type in ["single", "multi"]:
    for complex_license in ["all", "public"]:
        casp15_pb_valid_data_list, casp15_relaxed_pb_valid_data_list = [], []
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            # filter the data based on the complex type and license
            casp15_results_table = globals()[f"bust_results_table_{repeat_index}"][
                (globals()[f"bust_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
                & (globals()[f"bust_results_table_{repeat_index}"]["post-processing"] == "none")
                & (
                    globals()[f"bust_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_SINGLE_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_SINGLE_LIGAND_TARGETS
                        )
                    )
                    if complex_type == "single"
                    else globals()[f"bust_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_MULTI_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_MULTI_LIGAND_TARGETS
                        )
                    )
                )
            ].sort_values(by="method_assignment_index")
            casp15_relaxed_results_table = globals()[f"bust_results_table_{repeat_index}"][
                (globals()[f"bust_results_table_{repeat_index}"]["dataset"] == "CASP15 set")
                & (
                    globals()[f"bust_results_table_{repeat_index}"]["post-processing"]
                    == "energy minimization"
                )
                & (
                    globals()[f"bust_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_SINGLE_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_SINGLE_LIGAND_TARGETS
                        )
                    )
                    if complex_type == "single"
                    else globals()[f"bust_results_table_{repeat_index}"]["target"].isin(
                        (
                            PUBLIC_CASP15_MULTI_LIGAND_TARGETS
                            if complex_license == "public"
                            else All_CASP15_MULTI_LIGAND_TARGETS
                        )
                    )
                )
            ].sort_values(by="method_assignment_index")

            casp15_labels = casp15_results_table["method"].unique()
            num_methods = len(casp15_labels)

            num_casp15_data_points = max(
                len(casp15_results_table[(casp15_results_table["method"] == method)])
                for method in casp15_labels
            )
            num_casp15_relaxed_data_points = max(
                len(
                    casp15_relaxed_results_table[
                        (casp15_relaxed_results_table["method"] == method)
                    ]
                )
                for method in casp15_labels
            )

            # CASP15 set (unrelaxed) results
            casp15_pb_valid_data = (
                casp15_results_table.groupby("method")
                .agg({"PB-Valid": "sum", "method_assignment_index": "first"})
                .reset_index()
            )
            casp15_pb_valid_data["PB-Valid"] = (
                casp15_pb_valid_data["PB-Valid"] / num_casp15_data_points * 100
            )
            casp15_pb_valid_data_list.append(
                casp15_pb_valid_data.sort_values("method_assignment_index")
            )

            # CASP15 set (relaxed) results
            casp15_relaxed_pb_valid_data = (
                casp15_relaxed_results_table.groupby("method")
                .agg({"PB-Valid": "sum", "method_assignment_index": "first"})
                .reset_index()
            )
            casp15_relaxed_pb_valid_data["PB-Valid"] = (
                casp15_relaxed_pb_valid_data["PB-Valid"] / num_casp15_relaxed_data_points * 100
            )
            casp15_relaxed_pb_valid_data_list.append(
                casp15_relaxed_pb_valid_data.sort_values("method_assignment_index")
            )

        # calculate means and standard deviations
        casp15_pb_valid_data_mean = (
            pd.concat([df for df in casp15_pb_valid_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .mean()
            .sort_values(["method_assignment_index"])["PB-Valid"]
        )
        casp15_pb_valid_data_std = (
            pd.concat([df for df in casp15_pb_valid_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .std()
            .sort_values(["method_assignment_index"])["PB-Valid"]
        )

        casp15_relaxed_pb_valid_data_mean = (
            pd.concat([df for df in casp15_relaxed_pb_valid_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .mean()
            .sort_values(["method_assignment_index"])["PB-Valid"]
        )
        casp15_relaxed_pb_valid_data_std = (
            pd.concat([df for df in casp15_relaxed_pb_valid_data_list])
            .groupby(
                [
                    "method",
                    "method_assignment_index",
                ]
            )
            .std()
            .sort_values(["method_assignment_index"])["PB-Valid"]
        )
        casp15_pb_valid_data_std.fillna(0, inplace=True)
        casp15_relaxed_pb_valid_data_std.fillna(0, inplace=True)

        # define font properties
        plt.rcParams["font.size"] = 22
        plt.rcParams["axes.labelsize"] = 24

        # create the figure and a list of axes
        fig, axis = plt.subplots(figsize=(34, 14))
        axis.spines["top"].set_visible(False)
        axis.spines["right"].set_visible(False)
        axis.spines["bottom"].set_visible(False)
        axis.spines["left"].set_visible(False)

        # plot (unrelaxed) data for the CASP15 set
        casp15_pb_valid_bar = axis.bar(
            r1,
            casp15_pb_valid_data_mean,
            yerr=casp15_pb_valid_data_std,
            label="PB-Valid",
            color="none",
            edgecolor=colors[0],
            hatch="\\\\\\",
            width=bar_width,
        )

        # plot (relaxed) data for the CASP15 set
        casp15_relaxed_pb_valid_bar = axis.bar(
            r2,
            casp15_relaxed_pb_valid_data_mean,
            yerr=casp15_relaxed_pb_valid_data_std,
            label="PB-Valid",
            color="none",
            edgecolor=colors[1],
            hatch="\\\\\\",
            width=bar_width,
        )

        # add labels, titles, ticks, etc.
        axis.set_xlabel(f"{complex_type.title()}-ligand blind docking ({complex_license})")
        axis.set_ylabel("Percentage of complex predictions")
        axis.set_xlim(1, 19 + 0.1)
        axis.set_ylim(0, 100)

        axis.bar_label(casp15_pb_valid_bar, fmt="{:,.1f}", label_type="center")
        axis.bar_label(casp15_relaxed_pb_valid_bar, fmt="{:,.1f}", label_type="center")

        axis.yaxis.set_major_formatter(mtick.PercentFormatter())

        axis.set_yticks([0, 20, 40, 60, 80, 100])
        axis.axhline(y=0, color="#EAEFF8")
        axis.grid(axis="y", color="#EAEFF8")
        axis.set_axisbelow(True)

        axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 10, 11, 12, 14, 16, 18])
        axis.set_xticks([1 + 0.1], minor=True)
        axis.set_xticklabels(
            [
                "P2Rank-Vina",
                "Conventional blind",
                "DiffDock-L",
                "DynamicBind",
                "NeuralPLexer",
                "RoseTTAFold-AA",
                "DL-based blind",
                "Chai-1-Single-Seq",
                "Chai-1",
                "AF3-Single-Seq",
                "AF3",
            ]
        )

        axis.grid("off", axis="x", color="#EAEFF8")
        axis.grid("off", axis="x", which="minor", color="#EAEFF8")

        axis.tick_params(axis="x", which="minor", direction="out", length=30, color="#EAEFF8")
        axis.tick_params(axis="x", which="major", bottom="off", top="off", color="#EAEFF8")
        axis.tick_params(axis="y", which="major", left="off", right="on", color="#EAEFF8")

        # vertical alignment of xtick labels
        vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0]
        for tick, y in zip(axis.get_xticklabels(), vert_alignments):
            tick.set_y(y)

        # add legends
        legend_0 = fig.legend(
            [casp15_pb_valid_bar],
            ["PB-Valid"],
            loc="upper right",
            title="No post-processing",
            bbox_to_anchor=(1, 1, -0.12, -0.05),
        )
        legend_1 = fig.legend(
            [casp15_relaxed_rmsd_lt_2_bar],
            ["PB-Valid"],
            loc="upper right",
            title="With relaxation",
            bbox_to_anchor=(1, 1, -0.01, -0.05),
        )
        legend_0.get_frame().set_alpha(0)
        legend_1.get_frame().set_alpha(0)

        # display the plots
        plt.tight_layout()
        plt.savefig(
            f"casp15_{complex_license}_{complex_type}_ligand_relaxed_pb_valid_bar_chart.png",
            dpi=300,
        )
        plt.show()