## DockGen Structure Prediction Results Plotting

#### Import packages

In [None]:
import os

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import seaborn as sns

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "vina_p2rank",
    "diffdock",
    "dynamicbind",
    "rfaa",
    "alphafold3",
    "chai-lab",
    "neuralplexer",
    "flowdock_hp",
    "flowdock_aft",
    "flowdock_pft",
    "flowdock_esmfold",
    "flowdock",
]
max_num_repeats_per_method = 3

# Filepaths for each baseline method
globals()["vina_output_dir"] = os.path.join("..", "forks", "Vina", "inference")
globals()["diffdock_output_dir"] = os.path.join("..", "forks", "DiffDock", "inference")
globals()["dynamicbind_output_dir"] = os.path.join(
    "..", "forks", "DynamicBind", "inference", "outputs", "results"
)
globals()["rfaa_output_dir"] = os.path.join("..", "forks", "RoseTTAFold-All-Atom", "inference")
globals()["alphafold3_output_dir"] = os.path.join("..", "forks", "alphafold3", "inference")
globals()["chai-lab_output_dir"] = os.path.join("..", "forks", "chai-lab", "inference")
globals()["neuralplexer_output_dir"] = os.path.join("..", "forks", "NeuralPLexer", "inference")
globals()["flowdock_hp_output_dir"] = os.path.join("..", "forks", "FlowDock", "hp_inference")
globals()["flowdock_aft_output_dir"] = os.path.join("..", "forks", "FlowDock", "aft_inference")
globals()["flowdock_pft_output_dir"] = os.path.join("..", "forks", "FlowDock", "pft_inference")
globals()["flowdock_esmfold_output_dir"] = os.path.join(
    "..", "forks", "FlowDock", "esmfold_inference"
)
globals()["flowdock_output_dir"] = os.path.join("..", "forks", "FlowDock", "inference")

for repeat_index in range(1, max_num_repeats_per_method + 1):
    # P2Rank-Vina results
    globals()[f"vina_p2rank_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["vina_output_dir"],
        f"vina_p2rank_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"vina_p2rank_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["vina_output_dir"],
            f"vina_p2rank_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # DiffDock results
    globals()[f"diffdock_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["diffdock_output_dir"],
        f"diffdock_dockgen_output_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"diffdock_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["diffdock_output_dir"],
        f"diffdock_dockgen_output_{repeat_index}_relaxed",
        "bust_results.csv",
    )

    # DynamicBind results
    globals()[f"dynamicbind_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["dynamicbind_output_dir"], f"dockgen_{repeat_index}", "bust_results.csv"
    )
    globals()[f"dynamicbind_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["dynamicbind_output_dir"],
            f"dockgen_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # RoseTTAFold-All-Atom results
    globals()[f"rfaa_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["rfaa_output_dir"],
        f"rfaa_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"rfaa_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["rfaa_output_dir"],
        f"rfaa_dockgen_outputs_{repeat_index}_relaxed",
        "bust_results.csv",
    )

    # AlphaFold 3 (Single-Seq) results
    globals()[f"alphafold3_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["alphafold3_output_dir"],
        f"alphafold3_ss_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"alphafold3_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["alphafold3_output_dir"],
            f"alphafold3_ss_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # Chai-1 (Single-Seq) results
    globals()[f"chai-lab_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["chai-lab_output_dir"],
        f"chai-lab_ss_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"chai-lab_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["chai-lab_output_dir"],
        f"chai-lab_ss_dockgen_outputs_{repeat_index}_relaxed",
        "bust_results.csv",
    )

    # NeuralPLexer results
    globals()[f"neuralplexer_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["neuralplexer_output_dir"],
        f"neuralplexer_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"neuralplexer_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["neuralplexer_output_dir"],
            f"neuralplexer_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # FlowDock-HP results
    globals()[f"flowdock_hp_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["flowdock_hp_output_dir"],
        f"flowdock_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"flowdock_hp_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["flowdock_hp_output_dir"],
            f"flowdock_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # FlowDock-AFT results
    globals()[f"flowdock_aft_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["flowdock_aft_output_dir"],
        f"flowdock_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"flowdock_aft_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["flowdock_aft_output_dir"],
            f"flowdock_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # FlowDock-PFT results
    globals()[f"flowdock_pft_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["flowdock_pft_output_dir"],
        f"flowdock_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"flowdock_pft_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["flowdock_pft_output_dir"],
            f"flowdock_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # FlowDock-ESMFold results
    globals()[f"flowdock_esmfold_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["flowdock_esmfold_output_dir"],
        f"flowdock_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"flowdock_esmfold_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
        os.path.join(
            globals()["flowdock_esmfold_output_dir"],
            f"flowdock_dockgen_outputs_{repeat_index}_relaxed",
            "bust_results.csv",
        )
    )

    # FlowDock results
    globals()[f"flowdock_dockgen_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["flowdock_output_dir"],
        f"flowdock_dockgen_outputs_{repeat_index}",
        "bust_results.csv",
    )
    globals()[f"flowdock_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
        globals()["flowdock_output_dir"],
        f"flowdock_dockgen_outputs_{repeat_index}_relaxed",
        "bust_results.csv",
    )

# Mappings
method_mapping = {
    "vina_p2rank": "P2Rank-Vina",
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "rfaa": "RoseTTAFold-AA",
    "alphafold3": "AF3-Single-Seq",
    "chai-lab": "Chai-1-Single-Seq",
    "neuralplexer": "NeuralPLexer",
    "flowdock_hp": "FlowDock-HP",
    "flowdock_aft": "FlowDock-AFT",
    "flowdock_pft": "FlowDock-PFT",
    "flowdock_esmfold": "FlowDock-ESMFold",
    "flowdock": "FlowDock-AF3",
}

method_category_mapping = {
    "vina_p2rank": "Conventional blind",
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "rfaa": "DL-based blind",
    "alphafold3": "DL-based blind",
    "chai-lab": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "flowdock_hp": "DL-based blind",
    "flowdock_aft": "DL-based blind",
    "flowdock_pft": "DL-based blind",
    "flowdock_esmfold": "DL-based blind",
    "flowdock": "DL-based blind",
}

# Metrics
BUST_TEST_COLUMNS = [
    # accuracy #
    "rmsd_≤_2å",
    # chemical validity and consistency #
    "mol_pred_loaded",
    "mol_true_loaded",
    "mol_cond_loaded",
    "sanitization",
    "molecular_formula",
    "molecular_bonds",
    "tetrahedral_chirality",
    "double_bond_stereochemistry",
    # intramolecular validity #
    "bond_lengths",
    "bond_angles",
    "internal_steric_clash",
    "aromatic_ring_flatness",
    "double_bond_flatness",
    "internal_energy",
    # intermolecular validity #
    "minimum_distance_to_protein",
    "minimum_distance_to_organic_cofactors",
    "minimum_distance_to_inorganic_cofactors",
    "volume_overlap_with_protein",
    "volume_overlap_with_organic_cofactors",
    "volume_overlap_with_inorganic_cofactors",
]

#### Report test results for each baseline method

In [None]:
# load and report test results for each baseline method
for config in ["", "_relaxed"]:
    for method in baseline_methods:
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            method_title = method_mapping[method]

            if not os.path.exists(
                globals()[f"{method}_dockgen{config}_bust_results_csv_filepath_{repeat_index}"]
            ):
                continue

            globals()[f"{method}_dockgen{config}_bust_results_{repeat_index}"] = pd.read_csv(
                globals()[f"{method}_dockgen{config}_bust_results_csv_filepath_{repeat_index}"]
            )
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"] = globals()[
                f"{method}_dockgen{config}_bust_results_{repeat_index}"
            ][BUST_TEST_COLUMNS + ["rmsd"]]
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "pb_valid"
            ] = (
                globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"]
                .iloc[:, 1:-1]
                .all(axis=1)
            )

            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "method"
            ] = method
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "post-processing"
            ] = ("energy minimization" if config == "_relaxed" else "none")
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "dataset"
            ] = "dockgen"
            globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"].loc[
                :, "docked_ligand_successfully_loaded"
            ] = globals()[f"{method}_dockgen{config}_bust_results_table_{repeat_index}"][
                ["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]
            ].all(
                axis=1
            )

            globals()[f"{method}{config}_bust_results_table_{repeat_index}"] = globals()[
                f"{method}_dockgen{config}_bust_results_table_{repeat_index}"
            ]

            print(
                f"{method_title}{config}_{repeat_index} DockGen set `rmsd_≤_2å`: {globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}']['rmsd_≤_2å'].mean()}"
            )
            print(
                f"{method_title}{config}_{repeat_index} DockGen set `rmsd_≤_2å and pb_valid`: {globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}'][globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}']['pb_valid']]['rmsd_≤_2å'].sum() / len(globals()[f'{method}_dockgen{config}_bust_results_table_{repeat_index}'])}"
            )

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "Misc")

#### Standardize metrics

In [None]:
# load and organize the DockGen results CSVs
for repeat_index in range(1, max_num_repeats_per_method + 1):
    globals()[f"results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}{config}_bust_results_table_{repeat_index}"]
            for method in baseline_methods
            for config in ["", "_relaxed"]
            if f"{method}{config}_bust_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"results_table_{repeat_index}"].loc[:, "method_assignment_index"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(assign_method_index)
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_≤_2å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold_and_pb_valid"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"]
    ) & (globals()[f"results_table_{repeat_index}"].loc[:, "pb_valid"].fillna(False))
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å & PB-Valid"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold_and_pb_valid"]
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "dataset"]
        .map({"dockgen": "DockGen set"})
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Make plots

In [None]:
# RMSD Violin Plot of DockGen Set Results (Relaxed vs. Unrelaxed) #

# prepare data for the violin plots to plot
colors = ["#FB8072", "#BEBADA"]

# combine results for each dataset across all three repeats
combined_data_list = []
for repeat_index in range(1, max_num_repeats_per_method + 1):
    ad_relaxed_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "energy minimization")
    ]
    ad_unrelaxed_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "none")
    ]
    combined_data_list.append(pd.concat([ad_relaxed_results_table, ad_unrelaxed_results_table]))
combined_relaxed_data = (
    pd.concat(combined_data_list).sort_values("method_assignment_index").reset_index(drop=True)
)

# define font properties
plt.rcParams["font.size"] = 12
plt.rcParams["axes.labelsize"] = 14

# set the size of the figure
plt.figure(figsize=(12, 6))

# create a violin plot
sns.violinplot(
    x="method",
    y="rmsd",
    hue="post-processing",
    # ignore outliers for better readability
    data=combined_relaxed_data[
        (combined_relaxed_data["rmsd"] < 10) & ~(combined_relaxed_data["rmsd"].isna())
    ],
    split=True,
    inner="quartile",
    palette=colors,
    cut=0,
)

# set labels and title
plt.xlabel("")
plt.ylabel("RMSD")

# rotate x-axis labels for better readability
plt.xticks(rotation=45, ha="right")

# display legend outside the plot
plt.legend(title="Post-processing", bbox_to_anchor=(1.05, 1), loc="best")

# display the plots
plt.tight_layout()
plt.savefig("dockgen_structure_prediction_relaxed_rmsd_violin_plot.png", dpi=300)
plt.show()

In [None]:
# RMSD ≤ 2 Å Bar Chart of DockGen Set Results (Relaxed vs. Unrelaxed) #

# prepare data for the bar charts to plot
colors = ["#FB8072", "#BEBADA"]

bar_width = 0.5
r1 = [item - 0.25 for item in range(2, 26, 2)]
r2 = [x + bar_width for x in r1]

(
    dockgen_rmsd_lt_2_data_list,
    dockgen_relaxed_rmsd_lt_2_data_list,
    dockgen_rmsd_lt_2_and_pb_valid_data_list,
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list,
) = ([], [], [], [])
for repeat_index in range(1, max_num_repeats_per_method + 1):
    dockgen_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "none")
    ].sort_values(by="method_assignment_index")
    dockgen_relaxed_results_table = globals()[f"results_table_{repeat_index}"][
        (globals()[f"results_table_{repeat_index}"]["dataset"] == "DockGen set")
        & (globals()[f"results_table_{repeat_index}"]["post-processing"] == "energy minimization")
    ].sort_values(by="method_assignment_index")

    dockgen_labels = dockgen_results_table["method"].unique()
    num_methods = len(dockgen_labels)

    num_dockgen_data_points = max(
        len(
            dockgen_results_table[
                (dockgen_results_table["method"] == method)
                & ~(dockgen_results_table["rmsd"].isna())
            ]
        )
        for method in dockgen_labels
    )
    num_dockgen_relaxed_data_points = max(
        len(
            dockgen_relaxed_results_table[
                (dockgen_relaxed_results_table["method"] == method)
                & ~(dockgen_relaxed_results_table["rmsd"].isna())
            ]
        )
        for method in dockgen_labels
    )

    # PoseBusters Benchmark (unrelaxed) results
    dockgen_rmsd_lt_2_data = (
        dockgen_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
        dockgen_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_dockgen_data_points * 100
    )
    dockgen_rmsd_lt_2_data_list.append(
        dockgen_rmsd_lt_2_data.sort_values("method_assignment_index")
    )

    # PoseBusters Benchmark (relaxed) results
    dockgen_relaxed_rmsd_lt_2_data = (
        dockgen_relaxed_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_relaxed_rmsd_lt_2_data["RMSD ≤ 2 Å"] = (
        dockgen_relaxed_rmsd_lt_2_data["RMSD ≤ 2 Å"] / num_dockgen_relaxed_data_points * 100
    )
    dockgen_relaxed_rmsd_lt_2_data_list.append(
        dockgen_relaxed_rmsd_lt_2_data.sort_values("method_assignment_index")
    )

    # PoseBusters Benchmark (unrelaxed and PB-Valid) results
    dockgen_rmsd_lt_2_and_pb_valid_data = (
        dockgen_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å & PB-Valid": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] = (
        dockgen_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"]
        / num_dockgen_data_points
        * 100
    )
    dockgen_rmsd_lt_2_and_pb_valid_data_list.append(
        dockgen_rmsd_lt_2_and_pb_valid_data.sort_values("method_assignment_index")
    )

    # PoseBusters Benchmark (relaxed and PB-Valid) results
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data = (
        dockgen_relaxed_results_table.groupby("method")
        .agg({"RMSD ≤ 2 Å & PB-Valid": "sum", "method_assignment_index": "first"})
        .reset_index()
    )
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"] = (
        dockgen_relaxed_rmsd_lt_2_and_pb_valid_data["RMSD ≤ 2 Å & PB-Valid"]
        / num_dockgen_relaxed_data_points
        * 100
    )
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list.append(
        dockgen_relaxed_rmsd_lt_2_and_pb_valid_data.sort_values("method_assignment_index")
    )

# calculate means and standard deviations
dockgen_rmsd_lt_2_data_mean = (
    pd.concat([df for df in dockgen_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)
dockgen_rmsd_lt_2_data_std = (
    pd.concat([df for df in dockgen_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)

dockgen_relaxed_rmsd_lt_2_data_mean = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)
dockgen_relaxed_rmsd_lt_2_data_std = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å"]
)

dockgen_rmsd_lt_2_and_pb_valid_data_mean = (
    pd.concat([df for df in dockgen_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)
dockgen_rmsd_lt_2_and_pb_valid_data_std = (
    pd.concat([df for df in dockgen_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)

dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_mean = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .mean()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)
dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_std = (
    pd.concat([df for df in dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_list])
    .groupby(
        [
            "method",
            "method_assignment_index",
        ]
    )
    .std()
    .sort_values(["method_assignment_index"])["RMSD ≤ 2 Å & PB-Valid"]
)

dockgen_rmsd_lt_2_data_std.fillna(0, inplace=True)
dockgen_relaxed_rmsd_lt_2_data_std.fillna(0, inplace=True)
dockgen_rmsd_lt_2_and_pb_valid_data_std.fillna(0, inplace=True)
dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_std.fillna(0, inplace=True)

# define font properties
plt.rcParams["font.size"] = 22
plt.rcParams["axes.labelsize"] = 24

# create the figure and a list of axes
fig, axis = plt.subplots(figsize=(34, 14))
axis.spines["top"].set_visible(False)
axis.spines["right"].set_visible(False)
axis.spines["bottom"].set_visible(False)
axis.spines["left"].set_visible(False)

# plot (unrelaxed) data for the DockGen set
dockgen_rmsd_lt_2_and_pb_valid_bar = axis.bar(
    r1,
    dockgen_rmsd_lt_2_and_pb_valid_data_mean,
    yerr=dockgen_rmsd_lt_2_and_pb_valid_data_std,
    label="RMSD ≤ 2Å & PB-Valid",
    color=colors[0],
    width=bar_width,
)
dockgen_rmsd_lt_2_bar = axis.bar(
    r1,
    dockgen_rmsd_lt_2_data_mean,
    yerr=dockgen_rmsd_lt_2_data_std,
    label="RMSD ≤ 2Å",
    color="none",
    edgecolor=colors[0],
    hatch="\\\\\\",
    width=bar_width,
)

# plot (relaxed) data for the PoseBusters Benchmark set
dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar = axis.bar(
    r2,
    dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_mean,
    yerr=dockgen_relaxed_rmsd_lt_2_and_pb_valid_data_std,
    label="RMSD ≤ 2Å & PB-Valid",
    color=colors[1],
    width=bar_width,
)
dockgen_relaxed_rmsd_lt_2_bar = axis.bar(
    r2,
    dockgen_relaxed_rmsd_lt_2_data_mean,
    yerr=dockgen_relaxed_rmsd_lt_2_data_std,
    label="RMSD ≤ 2Å",
    color="none",
    edgecolor=colors[1],
    hatch="\\\\\\",
    width=bar_width,
)

# add labels, titles, ticks, etc.
axis.set_ylabel("Percentage of predictions")
axis.set_xlim(1, 25 + 0.1)
axis.set_ylim(0, 125)

assert len(dockgen_rmsd_lt_2_bar) == len(dockgen_rmsd_lt_2_and_pb_valid_bar), (
    f"Length of dockgen_rmsd_lt_2_bar ({len(dockgen_rmsd_lt_2_bar)}) "
    f"and dockgen_rmsd_lt_2_and_pb_valid_bar ({len(dockgen_rmsd_lt_2_and_pb_valid_bar)}) "
    "do not match."
)
assert len(dockgen_relaxed_rmsd_lt_2_bar) == len(dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar), (
    f"Length of dockgen_relaxed_rmsd_lt_2_bar ({len(dockgen_relaxed_rmsd_lt_2_bar)}) "
    f"and dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar ({len(dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar)}) "
    "do not match."
)
for bar, pb_valid_bar in zip(dockgen_rmsd_lt_2_bar, dockgen_rmsd_lt_2_and_pb_valid_bar):
    height = bar.get_height()
    pb_valid_height = pb_valid_bar.get_height()
    axis.annotate(
        f"{height:.1f}",
        (
            bar.get_x() + bar.get_width() / 2.5,
            max(height + 5, pb_valid_height) + 2,
        ),  # Offset to prevent overlap
        ha="center",
        va="bottom",
        fontsize=24,
    )
    axis.annotate(
        f"{pb_valid_height:.1f}",
        (pb_valid_bar.get_x() + pb_valid_bar.get_width() / 2.5, max(height, pb_valid_height) + 2),
        ha="center",
        va="bottom",
        fontsize=24,
    )
for bar, pb_valid_bar in zip(
    dockgen_relaxed_rmsd_lt_2_bar, dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar
):
    height = bar.get_height()
    pb_valid_height = pb_valid_bar.get_height()
    axis.annotate(
        f"{height:.1f}",
        (
            bar.get_x() + bar.get_width() / 1.75,
            max(height + 5, pb_valid_height) + 2,
        ),  # Offset to prevent overlap
        ha="center",
        va="bottom",
        fontsize=24,
    )
    axis.annotate(
        f"{pb_valid_height:.1f}",
        (pb_valid_bar.get_x() + pb_valid_bar.get_width() / 1.75, max(height, pb_valid_height) + 2),
        ha="center",
        va="bottom",
        fontsize=24,
    )

axis.yaxis.set_major_formatter(mtick.PercentFormatter())

axis.set_yticks([0, 20, 40, 60, 80, 100])
axis.axhline(y=0, color="#EAEFF8")
axis.grid(axis="y", color="#EAEFF8")
axis.set_axisbelow(True)

axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 10, 12, 14, 14 + 1e-3, 16, 18, 20, 22, 24])
axis.set_xticks([1 + 0.1], minor=True)
axis.set_xticklabels(
    [
        "P2Rank-Vina",
        "Conventional blind",
        "DiffDock-L",
        "DynamicBind",
        "RoseTTAFold-AA",
        "AF3-Single-Seq",
        "Chai-1-Single-Seq",
        "NeuralPLexer",
        "DL-based blind",
        "FlowDock-HP",
        "FlowDock-AFT",
        "FlowDock-PFT",
        "FlowDock-ESMFold",
        "FlowDock-AF3",
    ]
)

axis.grid("off", axis="x", color="#EAEFF8")
axis.grid("off", axis="x", which="minor", color="#EAEFF8")

axis.tick_params(axis="x", which="minor", direction="out", length=30, color="#EAEFF8")
axis.tick_params(axis="x", which="major", bottom="off", top="off", color="#EAEFF8")
axis.tick_params(axis="y", which="major", left="off", right="on", color="#EAEFF8")

# vertical alignment of xtick labels
vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0]
for tick, y in zip(axis.get_xticklabels(), vert_alignments):
    tick.set_y(y)

# add legends
legend_0 = fig.legend(
    [dockgen_rmsd_lt_2_bar, dockgen_rmsd_lt_2_and_pb_valid_bar],
    ["RMSD ≤ 2Å", "RMSD ≤ 2Å & PB-Valid"],
    loc="upper right",
    title="No post-processing",
    bbox_to_anchor=(1, 1, -0.20, -0.05),
)
legend_1 = fig.legend(
    [dockgen_relaxed_rmsd_lt_2_bar, dockgen_relaxed_rmsd_lt_2_and_pb_valid_bar],
    ["RMSD ≤ 2Å", "RMSD ≤ 2Å & PB-Valid"],
    loc="upper right",
    title="With relaxation",
    bbox_to_anchor=(1, 1, -0.01, -0.05),
)
legend_0.get_frame().set_alpha(0)
legend_1.get_frame().set_alpha(0)

# display the plots
plt.tight_layout()
plt.savefig("dockgen_structure_prediction_relaxed_bar_chart.png", dpi=300)
plt.show()