## DockGen Inference Results Plotting

#### Import packages

In [None]:
import os

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import seaborn as sns

from posebench.analysis.inference_analysis import BUST_TEST_COLUMNS

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "vina_p2rank",
    "diffdock",
    "dynamicbind",
    "neuralplexer",
    "rfaa",
    "chai-lab_ss",
    "chai-lab",
    "alphafold3_ss",
    "alphafold3",
]
max_num_repeats_per_method = 3

# Filepaths for each baseline method
globals()["vina_output_dir"] = os.path.join("..", "forks", "Vina", "inference")
globals()["diffdock_output_dir"] = os.path.join("..", "forks", "DiffDock", "inference")
globals()["dynamicbind_output_dir"] = os.path.join(
    "..", "forks", "DynamicBind", "inference", "outputs", "results"
)
globals()["neuralplexer_output_dir"] = os.path.join("..", "forks", "NeuralPLexer", "inference")
globals()["rfaa_output_dir"] = os.path.join("..", "forks", "RoseTTAFold-All-Atom", "inference")
globals()["chai-lab_output_dir"] = os.path.join("..", "forks", "chai-lab", "inference")
globals()["alphafold3_output_dir"] = os.path.join("..", "forks", "alphafold3", "inference")
for repeat_index in range(1, max_num_repeats_per_method + 1):
    # PLIF metrics
    globals()[f"dockgen_plif_metrics_csv_filepath_{repeat_index}"] = "dockgen_plif_metrics.csv"

    # P2Rank-Vina results
    globals()[f"vina_p2rank_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["vina_output_dir"],
        f"vina_p2rank_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"vina_p2rank_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["vina_output_dir"],
            f"vina_p2rank_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # DiffDock results
    globals()[f"diffdock_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["diffdock_output_dir"],
        f"diffdock_dockgen_output_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"diffdock_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["diffdock_output_dir"],
        f"diffdock_dockgen_output_{repeat_index}_relaxed",
        "bust_results.csv",
    )

    # DynamicBind results
    globals()[f"dynamicbind_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["dynamicbind_output_dir"],
        f"dockgen_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"dynamicbind_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["dynamicbind_output_dir"],
            f"dockgen_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # NeuralPLexer results
    globals()[f"neuralplexer_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["neuralplexer_output_dir"],
        f"neuralplexer_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"neuralplexer_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["neuralplexer_output_dir"],
            f"neuralplexer_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # RoseTTAFold-All-Atom results
    globals()[f"rfaa_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["rfaa_output_dir"],
        f"rfaa_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"rfaa_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["rfaa_output_dir"],
        f"rfaa_dockgen_outputs_{repeat_index}_relaxed",
        "bust_results.csv",
    )

    # Chai-1 (Single-Seq) results
    globals()[f"chai-lab_ss_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["chai-lab_output_dir"],
        f"chai-lab_ss_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"chai-lab_ss_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["chai-lab_output_dir"],
            f"chai-lab_ss_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # Chai-1 results
    globals()[f"chai-lab_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["chai-lab_output_dir"],
        f"chai-lab_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"chai-lab_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["chai-lab_output_dir"],
        f"chai-lab_dockgen_outputs_{repeat_index}_relaxed",
        "bust_results.csv",
    )

    # AlphaFold 3 (Single-Seq) results
    globals()[f"alphafold3_ss_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["alphafold3_output_dir"],
        f"alphafold3_ss_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"alphafold3_ss_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["alphafold3_output_dir"],
            f"alphafold3_ss_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # AlphaFold 3 results
    globals()[f"alphafold3_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["alphafold3_output_dir"],
        f"alphafold3_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"alphafold3_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["alphafold3_output_dir"],
            f"alphafold3_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

# Mappings
method_mapping = {
    "vina_p2rank": "P2Rank-Vina",
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "rfaa": "RoseTTAFold-AA",
    "chai-lab_ss": "Chai-1-Single-Seq",
    "chai-lab": "Chai-1",
    "alphafold3_ss": "AF3-Single-Seq",
    "alphafold3": "AF3",
}

method_category_mapping = {
    "vina_p2rank": "Conventional blind",
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "rfaa": "DL-based blind",
    "chai-lab_ss": "DL-based blind",
    "chai-lab": "DL-based blind",
    "alphafold3_ss": "DL-based blind",
    "alphafold3": "DL-based blind",
}

#### Report test results for each baseline method

In [None]:
# load and report test results for each baseline method
for config in ["", "_relaxed"]:
    for method in baseline_methods:
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            method_title = method_mapping[method]

            if not os.path.exists(
                globals()[f"{method}_dockgen{config}_bust_results_csv_filepath_{repeat_index}"]
            ):
                continue

            globals()[f"{method}_dockgen{config}_bust_results_{repeat_index}"] = pd.read_csv(
                globals()[f"{method}_dockgen{config}_bust_results_csv_filepath_{repeat_index}"]
            )
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"] = globals()[
                f"{method}_dockgen{config}_bust_results_{repeat_index}"
            ][BUST_TEST_COLUMNS + ["rmsd"]]
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "pb_valid"
            ] = (
                globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"]
                .iloc[:, 1:-1]
                .all(axis=1)
            )

            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "method"
            ] = method
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "post-processing"
            ] = ("energy minimization" if config == "_relaxed" else "none")
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "dataset"
            ] = "dockgen"
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "docked_ligand_successfully_loaded"
            ] = globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"][
                ["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]
            ].all(
                axis=1
            )

            globals()[f"{method}{config}_bust_results_table_{repeat_index}"] = globals()[
                f"{method}_dockgen{config}_bust_results_table_{repeat_index}"
            ]

            print(
                f"\n{method_title}{config}_{repeat_index} DockGen set `rmsd_≤_2å`: {globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}']['rmsd_≤_2å'].mean()}"
            )
            print(
                f"{method_title}{config}_{repeat_index} DockGen set `rmsd_≤_2å and pb_valid`: {globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}'][globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}']['pb_valid']]['rmsd_≤_2å'].sum() / len(globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}'])}\n"
            )

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def assign_category_index(category: str) -> str:
    """
    Assign category index for plotting.

    :param category: Category name.
    :return: Category index.
    """
    return list(method_mapping.values()).index(category)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "Misc")

#### Standardize metrics

In [None]:
# load and organize the DockGen results CSVs
for repeat_index in range(1, max_num_repeats_per_method + 1):
    globals()[f"dockgen_plif_metrics_table_{repeat_index}"] = pd.read_csv(
        globals()[f"dockgen_plif_metrics_csv_filepath_{repeat_index}"]
    )

    globals()[f"results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}{config}_bust_results_table_{repeat_index}"]
            for method in baseline_methods
            for config in ["", "_relaxed"]
            if f"{method}{config}_bust_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"results_table_{repeat_index}"].loc[:, "method_assignment_index"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(assign_method_index)
    globals()[f"dockgen_plif_metrics_table_{repeat_index}"].loc[
        :, "category_assignment_index"
    ] = globals()[f"dockgen_plif_metrics_table_{repeat_index}"]["Category"].apply(
        assign_category_index
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_≤_2å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold_and_pb_valid"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"]
    ) & (globals()[f"results_table_{repeat_index}"].loc[:, "pb_valid"].fillna(False))
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å & PB-Valid"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold_and_pb_valid"]
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "dataset"]
        .map({"dockgen": "DockGen set"})
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Make plots

In [None]:
# RMSD Violin Plot of DockGen Set Results (Relaxed vs. Unrelaxed) #

# prepare data for the violin plots to plot
colors = ["#FB8072", "#BEBADA"]

# combine results across all three repeats
combined_data_list = []
for repeat_index in range(1, max_num_repeats_per_method + 1):
    pb_relaxed_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "energy minimization")
    ]
    pb_unrelaxed_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "none")
    ]
    combined_data_list.append(pd.concat([pb_relaxed_results_table, pb_unrelaxed_results_table]))
combined_relaxed_data = pd.concat(combined_data_list).sort_values("method_assignment_index")

# define font properties
plt.rcParams["font.size"] = 12
plt.rcParams["axes.labelsize"] = 14

# set the size of the figure
plt.figure(figsize=(12, 6))

# create a violin plot
sns.violinplot(
    x="method",
    y="rmsd",
    hue="post-processing",
    data=combined_relaxed_data[combined_relaxed_data["rmsd"] <= 20],
    split=True,
    inner="quartile",
    palette=colors,
    cut=0,
)

# set labels and title
plt.xlabel("Primary-ligand docking")
plt.ylabel("RMSD")

# rotate x-axis labels for better readability
plt.xticks(rotation=45, ha="right")

# display legend outside the plot
plt.legend(title="Post-processing", bbox_to_anchor=(1.05, 1), loc="best")

# display the plots
plt.tight_layout()
plt.savefig("dockgen_primary_ligand_relaxed_rmsd_violin_plot.png", dpi=300)
plt.show()

In [None]:
# RMSD ≤ 2 Å Bar Chart of DockGen Set Results (Relaxed vs. Unrelaxed) #

# prepare data for the bar charts to plot
colors = ["#FB8072", "#BEBADA", "#FCCDE5"]

bar_width = 0.5
r1 = [item - 0.5 for item in range(2, 20, 2)]
r2 = [x + bar_width for x in r1]
r3 = [x + bar_width for x in r2]

(
    dockgen_rmsd_lt_2_data_list,
    dockgen_relaxed_rmsd_lt_2_data_list,
    dockgen_rmsd_lt_2_and_pb_valid_data_list,
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list,
    dockgen_plif_wm_data_list,
) = ([], [], [], [], [])
for repeat_index in range(1, max_num_repeats_per_method + 1):
    dockgen_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "none")
    ].sort_values(by="method_assignment_index")
    dockgen_relaxed_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "energy minimization")
    ].sort_values(by="method_assignment_index")

    dockgen_labels = dockgen_results_table["method"].unique()
    num_methods = len(dockgen_labels)

    num_dockgen_data_points = max(
        len(dockgen_results_table[(dockgen_results_table["method"] == method)])
        for method in dockgen_labels
    )
    num_dockgen_relaxed_data_points = max(
        len(dockgen_relaxed_results_table[(dockgen_relaxed_results_table["method"] == method)])
        for method in dockgen_labels
    )

    # DockGen (unrelaxed) results
    dockgen_rmsd_lt_2_data = (
        dockgen_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
        dockgen_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_dockgen_data_points * 100
    )
    dockgen_rmsd_lt_2_data_list.append(
        dockgen_rmsd_lt_2_data.sort_values("method_assignment_index")
    )

    # DockGen (relaxed) results
    dockgen_relaxed_rmsd_lt_2_data = (
        dockgen_relaxed_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_relaxed_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
        dockgen_relaxed_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_dockgen_relaxed_data_points * 100
    )
    dockgen_relaxed_rmsd_lt_2_data_list.append(
        dockgen_relaxed_rmsd_lt_2_data.sort_values("method_assignment_index")
    )

    # DockGen (unrelaxed and PB-Valid) results
    dockgen_rmsd_lt_2_and_pb_valid_data = (
        dockgen_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å & PB-Valid": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] = (
        dockgen_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"]
        / num_dockgen_data_points
        * 100
    )
    dockgen_rmsd_lt_2_and_pb_valid_data_list.append(
        dockgen_rmsd_lt_2_and_pb_valid_data.sort_values("method_assignment_index")
    )

    # DockGen (relaxed and PB-Valid) results
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data = (
        dockgen_relaxed_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å & PB-Valid": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] = (
        dockgen_relaxed_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"]
        / num_dockgen_relaxed_data_points
        * 100
    )
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list.append(
        dockgen_relaxed_rmsd_lt_2_and_pb_valid_data.sort_values("method_assignment_index")
    )

    # DockGen PLIF-WM results
    dockgen_plif_wm_data = (
        globals()[f"dockgen_plif_metrics_table_{repeat_index}"]
        .groupby("Category")
        .agg({"WM": "mean", "category_assignment_index": "first"})
    )
    dockgen_plif_wm_data = dockgen_plif_wm_data.sort_values("category_assignment_index")
    dockgen_plif_wm_data_list.append(dockgen_plif_wm_data)

# calculate means and standard deviations
dockgen_rmsd_lt_2_data_mean = (
    pd.concat([df for df in dockgen_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)
dockgen_rmsd_lt_2_data_std = (
    pd.concat([df for df in dockgen_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)

dockgen_relaxed_rmsd_lt_2_data_mean = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)
dockgen_relaxed_rmsd_lt_2_data_std = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)

dockgen_rmsd_lt_2_and_pb_valid_data_mean = (
    pd.concat([df for df in dockgen_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)
dockgen_rmsd_lt_2_and_pb_valid_data_std = (
    pd.concat([df for df in dockgen_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)

dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_mean = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)
dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_std = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)

# convert PLIF-WM scores to percentages
dockgen_plif_wm_data_mean = (
    pd.concat([df for df in dockgen_plif_wm_data_list])
    .groupby(
        [
            "Category",
            "category_assignment_index",
        ]
    )
    .mean()
    .sort_values(["category_assignment_index"])["WM"]
    * 100.0
)
dockgen_plif_wm_data_std = (
    pd.concat([df for df in dockgen_plif_wm_data_list])
    .groupby(
        [
            "Category",
            "category_assignment_index",
        ]
    )
    .std()
    .sort_values(["category_assignment_index"])["WM"]
    * 100.0
)

dockgen_rmsd_lt_2_data_std.fillna(0, inplace=True)
dockgen_relaxed_rmsd_lt_2_data_std.fillna(0, inplace=True)
dockgen_rmsd_lt_2_and_pb_valid_data_std.fillna(0, inplace=True)
dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_std.fillna(0, inplace=True)
dockgen_plif_wm_data_std.fillna(0, inplace=True)

# define font properties
plt.rcParams["font.size"] = 22
plt.rcParams["axes.labelsize"] = 24

# create the figure and a list of axes
fig, axis = plt.subplots(figsize=(34, 14))
axis.spines["top"].set_visible(False)
axis.spines["right"].set_visible(False)
axis.spines["bottom"].set_visible(False)
axis.spines["left"].set_visible(False)

# plot (unrelaxed) data for the DockGen set
dockgen_rmsd_lt_2_and_pb_valid_bar = axis.bar(
    r1,
    dockgen_rmsd_lt_2_and_pb_valid_data_mean,
    yerr=dockgen_rmsd_lt_2_and_pb_valid_data_std,
    label="RMSD ≤ 2Å & PB-Valid",
    color=colors[0],
    width=bar_width,
)
dockgen_rmsd_lt_2_bar = axis.bar(
    r1,
    dockgen_rmsd_lt_2_data_mean,
    yerr=dockgen_rmsd_lt_2_data_std,
    label="RMSD ≤ 2Å",
    color="none",
    edgecolor=colors[0],
    hatch="\\\\\\",
    width=bar_width,
)

# plot (relaxed) data for the DockGen set
dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar = axis.bar(
    r2,
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_mean,
    yerr=dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_std,
    label="RMSD ≤ 2Å & PB-Valid",
    color=colors[1],
    width=bar_width,
)
dockgen_relaxed_rmsd_lt_2_bar = axis.bar(
    r2,
    dockgen_relaxed_rmsd_lt_2_data_mean,
    yerr=dockgen_relaxed_rmsd_lt_2_data_std,
    label="RMSD ≤ 2Å",
    color="none",
    edgecolor=colors[1],
    hatch="\\\\\\",
    width=bar_width,
)

# plot PLIF-WM data for the DockGen set
dockgen_plif_wm_bar = axis.bar(
    r3,
    dockgen_plif_wm_data_mean,
    yerr=dockgen_plif_wm_data_std,
    label="PLIF-WM",
    color=colors[2],
    hatch="\\\\\\",
    width=bar_width,
)

# add labels, titles, ticks, etc.
axis.set_ylabel("Percentage of predictions")
axis.set_xlim(1, 19 + 0.1)
axis.set_ylim(0, 125)

axis.bar_label(dockgen_rmsd_lt_2_bar, fmt="{:,.1f}%", label_type="edge")
axis.bar_label(dockgen_rmsd_lt_2_and_pb_valid_bar, fmt="{:,.1f}%", label_type="center", padding=5)
axis.bar_label(dockgen_relaxed_rmsd_lt_2_bar, fmt="{:,.1f}%", label_type="edge")
axis.bar_label(
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar, fmt="{:,.1f}%", label_type="center", padding=5
)
axis.bar_label(dockgen_plif_wm_bar, fmt="{:,.1f}%", label_type="edge")

axis.yaxis.set_major_formatter(mtick.PercentFormatter())

axis.set_yticks([0, 20, 40, 60, 80, 100])
axis.axhline(y=0, color="#EAEFF8")
axis.grid(axis="y", color="#EAEFF8")
axis.set_axisbelow(True)

axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 10, 11, 12, 14, 16, 18])
axis.set_xticks([1 + 0.1], minor=True)
axis.set_xticklabels(
    [
        "P2Rank-Vina",
        "Conventional blind",
        "DiffDock-L",
        "DynamicBind",
        "NeuralPLexer",
        "RoseTTAFold-AA",
        "DL-based blind",
        "Chai-1-Single-Seq",
        "Chai-1",
        "AF3-Single-Seq",
        "AF3",
    ]
)

axis.grid("off", axis="x", color="#EAEFF8")
axis.grid("off", axis="x", which="minor", color="#EAEFF8")

axis.tick_params(axis="x", which="minor", direction="out", length=30, color="#EAEFF8")
axis.tick_params(axis="x", which="major", bottom="off", top="off", color="#EAEFF8")
axis.tick_params(axis="y", which="major", left="off", right="on", color="#EAEFF8")

# vertical alignment of xtick labels
vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0]
for tick, y in zip(axis.get_xticklabels(), vert_alignments):
    tick.set_y(y)

# add legends
legend_0 = fig.legend(
    [dockgen_rmsd_lt_2_bar, dockgen_rmsd_lt_2_and_pb_valid_bar],
    ["RMSD ≤ 2Å", "RMSD ≤ 2Å & PB-Valid"],
    loc="upper right",
    title="No post-processing",
    bbox_to_anchor=(1, 1, -0.40, -0.05),
)
legend_1 = fig.legend(
    [dockgen_relaxed_rmsd_lt_2_bar, dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar],
    ["RMSD ≤ 2Å", "RMSD ≤ 2Å & PB-Valid"],
    loc="upper right",
    title="With relaxation",
    bbox_to_anchor=(1, 1, -0.20, -0.05),
)
legend_2 = fig.legend(
    [dockgen_plif_wm_bar],
    ["PLIF-WM"],
    loc="upper right",
    title="Protein-ligand interactions\n    (no post-processing)",
    bbox_to_anchor=(1, 1, -0.01, -0.05),
)
legend_0.get_frame().set_alpha(0)
legend_1.get_frame().set_alpha(0)
legend_2.get_frame().set_alpha(0)

# display the plots
plt.tight_layout()
plt.savefig("dockgen_primary_ligand_relaxed_bar_chart.png", dpi=300)
plt.show()