## Failure Modes Analysis Plotting

#### Import packages

In [None]:
import copy
import glob
import os
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pypdb
import seaborn as sns
from scipy.stats import pearsonr, spearmanr

from posebench.analysis.inference_analysis import BUST_TEST_COLUMNS

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = False

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "vina_p2rank",
    "diffdock",
    "dynamicbind",
    "neuralplexer",
    "rfaa",
    "chai-lab_ss",
    "chai-lab",
    # "boltz_ss",
    # "boltz",
    "alphafold3_ss",
    "alphafold3",
]
max_num_repeats_per_method = 3
method_max_training_cutoff_date = "2021-09-30"

datasets = ["astex_diverse", "posebusters_benchmark", "dockgen", "casp15"]

# PoseBusters Benchmark deposition dates
pb_deposition_dates_filepath = "posebusters_benchmark_complex_pdb_deposition_dates.csv"
assert os.path.exists(
    pb_deposition_dates_filepath
), "Please prepare the PoseBusters Benchmark complex PDB deposition dates CSV file via later steps in `failure_modes_analysis_plotting_plinder.ipynb` before proceeding."

pb_pdb_id_deposition_date_mapping_df = pd.read_csv(pb_deposition_dates_filepath)
pb_pdb_id_deposition_date_mapping_df["Deposition Date"] = pd.to_datetime(
    pb_pdb_id_deposition_date_mapping_df["Deposition Date"]
)
pb_pdb_id_deposition_date_mapping_df = pb_pdb_id_deposition_date_mapping_df[
    pb_pdb_id_deposition_date_mapping_df["Deposition Date"] > method_max_training_cutoff_date
]
pb_pdb_id_deposition_date_mapping = dict(
    zip(
        pb_pdb_id_deposition_date_mapping_df["PDB ID"],
        pb_pdb_id_deposition_date_mapping_df["Deposition Date"].astype(str),
    )
)

# Filepaths for each baseline method
globals()["vina_output_dir"] = os.path.join("..", "forks", "Vina", "inference")
globals()["diffdock_output_dir"] = os.path.join("..", "forks", "DiffDock", "inference")
globals()["dynamicbind_output_dir"] = os.path.join(
    "..", "forks", "DynamicBind", "inference", "outputs", "results"
)
globals()["neuralplexer_output_dir"] = os.path.join("..", "forks", "NeuralPLexer", "inference")
globals()["rfaa_output_dir"] = os.path.join("..", "forks", "RoseTTAFold-All-Atom", "inference")
globals()["chai-lab_output_dir"] = os.path.join("..", "forks", "chai-lab", "inference")
globals()["boltz_output_dir"] = os.path.join("..", "forks", "boltz", "inference")
globals()["alphafold3_output_dir"] = os.path.join("..", "forks", "alphafold3", "inference")
globals()["casp15_output_dir"] = os.path.join("..", "data", "test_cases", "casp15")
for config in ["", "_relaxed"]:
    for dataset in datasets:
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            # P2Rank-Vina results
            globals()[
                f"vina_p2rank_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["vina_output_dir"]
                ),
                (
                    f"top_vina_p2rank_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"vina_p2rank_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # DiffDock results
            globals()[f"diffdock_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["diffdock_output_dir"]
                    ),
                    (
                        f"top_diffdock_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"diffdock_{dataset}_output_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # DynamicBind results
            globals()[
                f"dynamicbind_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["dynamicbind_output_dir"]
                ),
                (
                    f"top_dynamicbind_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"{dataset}_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # NeuralPLexer results
            globals()[
                f"neuralplexer_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["neuralplexer_output_dir"]
                ),
                (
                    f"top_neuralplexer_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"neuralplexer_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # RoseTTAFold-All-Atom results
            globals()[f"rfaa_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["rfaa_output_dir"]
                    ),
                    (
                        f"top_rfaa_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"rfaa_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # Chai-1 (Single-Seq) results
            globals()[
                f"chai-lab_ss_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["chai-lab_output_dir"]
                ),
                (
                    f"top_chai-lab_ss_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"chai-lab_ss_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # Chai-1 results
            globals()[f"chai-lab_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["chai-lab_output_dir"]
                    ),
                    (
                        f"top_chai-lab_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"chai-lab_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # Boltz (Single-Seq) results
            globals()[f"boltz_ss_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["boltz_output_dir"]
                    ),
                    (
                        f"top_boltz_ss_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"boltz_ss_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # Boltz results
            globals()[f"boltz_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["boltz_output_dir"]
                    ),
                    (
                        f"top_boltz_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"boltz_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # AlphaFold 3 (Single-Seq) results
            globals()[
                f"alphafold3_ss_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["alphafold3_output_dir"]
                ),
                (
                    f"top_alphafold3_ss_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"alphafold3_ss_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # AlphaFold 3 results
            globals()[f"alphafold3_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["alphafold3_output_dir"]
                    ),
                    (
                        f"top_alphafold3_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"alphafold3_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

# Mappings
method_mapping = {
    "vina_p2rank": "P2Rank-Vina",
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "rfaa": "RFAA",
    "chai-lab_ss": "Chai-1-Single-Seq",
    "chai-lab": "Chai-1",
    "boltz_ss": "Boltz-1-Single-Seq",
    "boltz": "Boltz-1",
    "alphafold3_ss": "AF3-Single-Seq",
    "alphafold3": "AF3",
}

method_category_mapping = {
    "vina_p2rank": "Conventional blind",
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "rfaa": "DL-based blind",
    "chai-lab_ss": "DL-based blind",
    "chai-lab": "DL-based blind",
    "boltz_ss": "DL-based blind",
    "boltz": "DL-based blind",
    "alphafold3_ss": "DL-based blind",
    "alphafold3": "DL-based blind",
}

dataset_mapping = {
    "astex_diverse": "Astex Diverse set",
    "posebusters_benchmark": "Posebusters Benchmark set",
    "dockgen": "DockGen set",
    "casp15": "CASP15 set",
}

casp15_target_pdb_id_mapping = {
    # NOTE: `?` indicates that the target's crystal structure is not publicly available
    "H1135": "7z8y",
    "H1171v1": "7pbl",
    "H1171v2": "7pbl",
    "H1172v1": "7pbp",
    "H1172v2": "7pbp",
    "H1172v3": "7pbp",
    "H1172v4": "7pbp",
    "T1124": "7ux8",
    "T1127v2": "?",
    "T1146": "?",
    "T1152": "7r1l",
    "T1158v1": "8sx8",
    "T1158v2": "8sxb",
    "T1158v3": "8sx7",
    "T1158v4": "8swn",
    "T1170": "7pbr",
    "T1181": "?",
    "T1186": "?",
    "T1187": "8ad2",
    "T1188": "8c6z",
}

#### Load test results for each baseline method

In [None]:
# load and report test results for each baseline method
for config in [""]:
    for dataset in datasets:
        for method in baseline_methods:
            for repeat_index in range(1, max_num_repeats_per_method + 1):
                method_title = method_mapping[method]

                if not os.path.exists(
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
                    ]
                ):
                    continue

                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"] = (
                    pd.read_csv(
                        globals()[
                            f"{method}_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
                        ]
                    )
                )

                if dataset == "casp15":
                    # count the number of ligands in each target complex, and assign these corresponding numbers to the ligands (rows) of each complex
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "num_target_ligands"
                    ] = (
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
                        .groupby(["target", "mdl"])["pose"]
                        .transform("count")
                    )

                    # filter out non-relevant ligand predictions, and for all methods select only their first model for each ligand
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        np.where(
                            (
                                globals()[
                                    f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                                ].relevant
                            ),
                            True,
                            False,
                        )
                        & (
                            globals()[
                                f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                            ].mdl
                            == 1
                        )
                    ]

                    # finalize bust (i.e., scoring) results for CASP15, using dummy values for `pb_valid` and `crmsd_≤_1å`
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "rmsd_≤_2å"
                    ] = (
                        globals()[
                            f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                        ].loc[:, "rmsd"]
                        <= 2
                    )
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "pdb_valid"
                    ] = True
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "crmsd_≤_1å"
                    ] = True

                else:
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        BUST_TEST_COLUMNS + ["rmsd", "centroid_distance", "mol_id"]
                    ]
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "pb_valid"
                    ] = (
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
                        .iloc[:, 1:-3]
                        .all(axis=1)
                    )
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "crmsd_≤_1å"
                    ] = (
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                            "centroid_distance"
                        ]
                        < 1
                    )
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "pdb_id"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "mol_id"
                    ]

                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "method"
                ] = method
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "post-processing"
                ] = ("energy minimization" if config == "_relaxed" else "none")
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "dataset"
                ] = dataset
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "docked_ligand_successfully_loaded"
                ] = (
                    True
                    if dataset == "casp15"
                    else globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ][["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]].all(axis=1)
                )

                if dataset == "posebusters_benchmark":
                    # keep only the results for complexes deposited in the PDB after the maximum cutoff date for any method's training data
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "pdb_id"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "mol_id"
                    ].map(
                        lambda x: x.lower().split("_")[0]
                    )
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                            "pdb_id"
                        ].isin(pb_pdb_id_deposition_date_mapping.keys())
                    ]

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "Misc")


def subset_counter(
    subset_counter: Counter, superset_counter: Counter, normalize_subset: bool = False
) -> Counter:
    """
    Subset a superset counter by a subset counter.

    :param subset_counter: Subset counter.
    :param superset_counter: Superset counter.
    :param normalize_subset: Normalize subset counter by superset counter.
    :return: Subsetted counter.
    """
    subsetted_counter = Counter()
    for key in subset_counter:
        if key in superset_counter and superset_counter[key] != 0:
            subsetted_counter[key] = (
                subset_counter[key] / superset_counter[key]
                if normalize_subset
                else superset_counter[key]
            )
        else:
            subsetted_counter[key] = 0  # or handle as needed
    return subsetted_counter

#### Standardize metrics

In [None]:
# load and organize the results CSVs
for repeat_index in range(1, max_num_repeats_per_method + 1):
    globals()[f"results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
            for dataset in datasets
            for method in baseline_methods
            for config in [""]
            if f"{method}_{dataset}{config}_bust_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"results_table_{repeat_index}"].loc[:, "method_assignment_index"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(assign_method_index)
    globals()[f"results_table_{repeat_index}"].loc[:, "crmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "crmsd_≤_1å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_≤_2å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold_and_pb_valid"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"]
    ) & (globals()[f"results_table_{repeat_index}"].loc[:, "pb_valid"].fillna(False))
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å & PB-Valid"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold_and_pb_valid"]
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "cRMSD ≤ 1 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "crmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "dataset"].map(dataset_mapping)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Collect metadata across all datasets

In [None]:
# find PDB IDs of complexes across all datasets
for dataset in datasets:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
            == dataset_mapping[dataset]
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_{repeat_index}"] = set(
            dataset_results_table.loc[:, "pdb_id"].unique()
        )

#### Plot distribution of complex types across all datasets

In [None]:
# plot functional keyword statistics of the complexes across all datasets
pdb_info_cache = dict()

for repeat_index in [1]:  # NOTE: we only consider the first repeat
    all_complexes_df = []
    for dataset in datasets:
        complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_{repeat_index}"], columns=["pdb_id"]
        )
        complexes_df["dataset"] = dataset_mapping[dataset]
        all_complexes_df.append(complexes_df)
    all_complexes_df = pd.concat(all_complexes_df, ignore_index=True)

    if all_complexes_df.empty:
        print("No complexes for any dataset.")
        continue

    complex_function_annotations = []
    for pdb_id in set(all_complexes_df["pdb_id"]):
        pdb_id = pdb_id.lower().split("_")[0]
        if pdb_id == "?":
            continue
        if pdb_id in pdb_info_cache:
            pdb_id_info = pdb_info_cache[pdb_id]
        else:
            pdb_id_info = pypdb.get_all_info(pdb_id)
            pdb_info_cache[pdb_id] = pdb_id_info
        if not pdb_id_info:
            continue
        complex_function_annotations.append(
            # NOTE: these represent functional keywords
            pdb_id_info["struct_keywords"]["pdbx_keywords"]
            .lower()
            .split(", ")[0]
        )

    complex_function_annotation_counts = Counter(complex_function_annotations)
    df = pd.DataFrame(
        complex_function_annotation_counts.items(),
        columns=["Keyword", "Frequency"],
    )
    df["Frequency"] = df["Frequency"].astype(int)
    df = df.sort_values(by="Frequency", ascending=False)

    plt.figure(figsize=(20, 10))
    sns.barplot(data=df, x="Frequency", y="Keyword", palette="viridis")

    max_freq = df["Frequency"].max()
    plt.xticks(ticks=range(0, max_freq + 1), labels=range(0, max_freq + 1), rotation=60)

    plt.xlabel("Frequency")
    plt.ylabel("Complex Annotation")

    plt.tight_layout()
    plt.savefig(f"complexes_functional_keywords_{repeat_index}.png", bbox_inches="tight")
    plt.show()

    plt.close("all")

    print(f"{len(complex_function_annotations)} complex annotations across all datasets.")

#### Plot distribution of the PoseBusters Benchmark dataset's complex deposition dates

In [None]:
# # report the PDB deposition date of each PoseBusters Benchmark complex
# posebusters_complex_deposition_dates = dict()
# for pdb_id in set(
#     all_complexes_df[all_complexes_df["dataset"] == "Posebusters Benchmark set"]["pdb_id"]
# ):
#     pdb_id = pdb_id.lower().split("_")[0]
#     if pdb_id == "?":
#         continue
#     if pdb_id in pdb_info_cache:
#         pdb_id_info = pdb_info_cache[pdb_id]
#     else:
#         pdb_id_info = pypdb.get_all_info(pdb_id)
#         pdb_info_cache[pdb_id] = pdb_id_info
#     if not pdb_id_info:
#         continue
#     posebusters_complex_deposition_dates[pdb_id] = pdb_id_info["rcsb_accession_info"][
#         "deposit_date"
#     ]

# # analyze and plot statistics of the PoseBusters complexes' deposition dates
# posebusters_complex_pdb_deposition_dates_df = pd.DataFrame(
#     {
#         "PDB ID": posebusters_complex_deposition_dates.keys(),
#         "Deposition Date": posebusters_complex_deposition_dates.values(),
#     }
# )
# posebusters_complex_pdb_deposition_dates_df["Deposition Date"] = pd.to_datetime(
#     posebusters_complex_pdb_deposition_dates_df["Deposition Date"]
# )
# posebusters_complex_pdb_deposition_dates_df.to_csv(
#     "posebusters_benchmark_complex_pdb_deposition_dates.csv", index=False
# )

# posebusters_pre_cutoff_complexes = posebusters_complex_pdb_deposition_dates_df[
#     posebusters_complex_pdb_deposition_dates_df["Deposition Date"]
#     <= method_max_training_cutoff_date
# ]
# posebusters_post_cutoff_complexes = posebusters_complex_pdb_deposition_dates_df[
#     posebusters_complex_pdb_deposition_dates_df["Deposition Date"]
#     > method_max_training_cutoff_date
# ]
# print(
#     f"{len(posebusters_pre_cutoff_complexes)}/{len(posebusters_complex_pdb_deposition_dates_df)} PoseBusters Benchmark complexes deposited before maximum cutoff of {method_max_training_cutoff_date}."
# )
# print(
#     f"{len(posebusters_post_cutoff_complexes)}/{len(posebusters_complex_pdb_deposition_dates_df)} PoseBusters Benchmark complexes deposited after maximum cutoff of {method_max_training_cutoff_date}."
# )

# sns.histplot(posebusters_complex_pdb_deposition_dates_df["Deposition Date"].values, bins=25)
# plt.xlabel("Deposition Date")
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.savefig("posebusters_benchmark_complex_pdb_deposition_dates.png")
# plt.show()
# plt.close("all")

#### Identify failure modes across all datasets

In [None]:
# find PDB IDs of complexes for which the correct (e.g., RMSD ≤ 2 Å & PB-Valid) binding conformation was not found by any method
for dataset in datasets:
    docking_success_column = "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
            == dataset_mapping[dataset]
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_docked_by_any_method_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "pdb_id",
            ].unique()
        )
        globals()[f"{dataset}_complexes_not_docked_by_any_method_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "pdb_id"].isin(
                    globals()[f"{dataset}_complexes_docked_by_any_method_{repeat_index}"]
                ),
                "pdb_id",
            ].unique()
        )

#### Find commonalities among the failure modes of each dataset

In [None]:
# plot functional keyword statistics of the failed complexes across all datasets
for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
    all_failed_complexes_df = []
    for dataset in datasets:
        predicted_complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_not_docked_by_any_method_{repeat_index}"],
            columns=["pdb_id"],
        )
        predicted_complexes_df["dataset"] = dataset_mapping[dataset]
        all_failed_complexes_df.append(predicted_complexes_df)
    all_failed_complexes_df = pd.concat(all_failed_complexes_df, ignore_index=True)

    if all_failed_complexes_df.empty:
        print("No failed complexes for any dataset.")
        continue

    failed_complex_function_annotations = []
    for pdb_id in set(all_failed_complexes_df["pdb_id"]):
        pdb_id = pdb_id.lower().split("_")[0]
        if pdb_id == "?":
            continue
        if pdb_id in pdb_info_cache:
            pdb_id_info = pdb_info_cache[pdb_id]
        else:
            pdb_id_info = pypdb.get_all_info(pdb_id)
            pdb_info_cache[pdb_id] = pdb_id_info
        if not pdb_id_info:
            continue
        failed_complex_function_annotations.append(
            # NOTE: these represent functional keywords
            pdb_id_info["struct_keywords"]["pdbx_keywords"]
            .lower()
            .split(", ")[0]
        )

    failed_complex_function_annotation_counts = subset_counter(
        Counter(failed_complex_function_annotations),
        complex_function_annotation_counts,
        normalize_subset=True,
    )
    df = pd.DataFrame(
        failed_complex_function_annotation_counts.items(),
        columns=["Keyword", "Failed Ratio"],
    )
    df["Frequency"] = df["Keyword"].map(complex_function_annotation_counts)
    df.sort_values(
        by=["Failed Ratio", "Frequency"], ascending=False, inplace=True, ignore_index=True
    )

    plt.figure(figsize=(10, 6))
    sns.barplot(data=df, x="Failed Ratio", y="Keyword", palette="viridis")

    plt.xlabel("Failed Ratio")
    plt.ylabel("Complex Annotation")

    plt.xlim(0, 1.1)

    # annotate bars with the frequency of each keyword
    for index, row in df.iterrows():
        plt.text(
            x=row["Failed Ratio"] + 0.01,
            y=index,
            s=f"{row['Failed Ratio']:.2f} ({row['Frequency']})",
            va="center",
        )

    plt.tight_layout()
    plt.savefig(f"failed_complexes_functional_keywords_{repeat_index}.png")
    plt.show()

    plt.close("all")

    print(f"{len(failed_complex_function_annotations)} complex annotations across all datasets.")

#### Identify AlphaFold 3's failure modes

In [None]:
# find complexes that AlphaFold 3 failed to correctly predict
for dataset in datasets:
    docking_success_column = "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            (
                globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
                == dataset_mapping[dataset]
            )
            & (globals()[f"results_table_{repeat_index}"].loc[:, "method"] == "AF3")
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "pdb_id",
            ].unique()
        )
        globals()[f"{dataset}_complexes_not_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "pdb_id"].isin(
                    globals()[f"{dataset}_complexes_docked_by_af3_{repeat_index}"]
                ),
                "pdb_id",
            ].unique()
        )

#### Record and plot AlphaFold 3's failure mode metadata

In [None]:
# plot functional keyword statistics of AlphaFold 3's failed complexes across all datasets
all_failed_af3_complexes_df = []
for dataset in datasets:
    for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
        failed_af3_complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_not_docked_by_af3_{repeat_index}"],
            columns=["pdb_id"],
        )
        failed_af3_complexes_df["dataset"] = dataset_mapping[dataset]
        failed_af3_complexes_df["repeat_index"] = repeat_index
        all_failed_af3_complexes_df.append(failed_af3_complexes_df)
all_failed_af3_complexes_df = pd.concat(all_failed_af3_complexes_df, ignore_index=True)

failed_af3_complex_function_annotations = []
for pdb_id in set(all_failed_af3_complexes_df["pdb_id"]):
    pdb_id = pdb_id.lower().split("_")[0]
    if pdb_id == "?":
        continue
    if pdb_id in pdb_info_cache:
        pdb_id_info = pdb_info_cache[pdb_id]
    else:
        pdb_id_info = pypdb.get_all_info(pdb_id)
        pdb_info_cache[pdb_id] = pdb_id_info
    if not pdb_id_info:
        continue
    failed_af3_complex_function_annotations.append(
        # NOTE: these represent functional keywords
        pdb_id_info["struct_keywords"]["pdbx_keywords"]
        .lower()
        .split(", ")[0]
    )

failed_af3_complex_function_annotation_counts = subset_counter(
    Counter(failed_af3_complex_function_annotations),
    complex_function_annotation_counts,
    normalize_subset=True,
)
df = pd.DataFrame(
    failed_af3_complex_function_annotation_counts.items(),
    columns=["Keyword", "Failed Ratio"],
)
df["Frequency"] = df["Keyword"].map(complex_function_annotation_counts)
df.sort_values(by=["Failed Ratio", "Frequency"], ascending=False, inplace=True, ignore_index=True)

plt.figure(figsize=(12, 8))
sns.barplot(data=df, x="Failed Ratio", y="Keyword", palette="viridis")

plt.xlabel("Failed Ratio")
plt.ylabel("Complex Annotation")

plt.xlim(0, 1.09)

# annotate bars with the frequency of each keyword
for index, row in df.iterrows():
    plt.text(
        x=row["Failed Ratio"] + 0.01,
        y=index,
        s=f"{row['Failed Ratio']:.2f} ({row['Frequency']})",
        va="center",
    )

plt.tight_layout()
plt.savefig("failed_af3_complexes_functional_keywords.png")
plt.show()

plt.close("all")

print(f"{len(failed_af3_complex_function_annotations)} complex annotations across all datasets.")

#### Study PDB statistics of different types of complexes

In [None]:
# combine all CSV files from a custom PDB report
report_types = os.listdir("pdb_reports")

for report_type in report_types:
    pdb_report_dir = os.path.join("pdb_reports", report_type)
    pdb_report_files = [
        os.path.join(pdb_report_dir, f) for f in os.listdir(pdb_report_dir) if f.endswith(".csv")
    ]

    pdb_report_dfs = []
    for pdb_report_file in pdb_report_files:
        pdb_report_dfs.append(pd.read_csv(pdb_report_file, skiprows=1))
    pdb_report_df = pd.concat(pdb_report_dfs, ignore_index=True)

    # analyze and plot statistics of the custom PDB report
    pdb_report_df["Refinement Resolution (Å)"] = pdb_report_df["Refinement Resolution (Å)"].astype(
        str
    )
    pdb_report_df["Refinement Resolution (Å)"] = pd.to_numeric(
        pdb_report_df["Refinement Resolution (Å)"].str.replace(",", ""),
        errors="coerce",
    )
    pdb_report_df["Deposition Date"] = pd.to_datetime(pdb_report_df["Deposition Date"])

    pdb_report_df.to_csv(f"{report_type}_pdb_report.csv", index=False)

    print(f"{len(pdb_report_df)} PDB entries in the custom {report_type} report.")

    sns.histplot(pdb_report_df["Refinement Resolution (Å)"].values)
    plt.xlim(0, 10)
    plt.xlabel("Refinement Resolution (Å)")
    plt.tight_layout()
    plt.savefig(f"{report_type}_pdb_report_resolution.png")
    plt.show()

    plt.close("all")

    sns.histplot(pdb_report_df["Deposition Date"].values)
    plt.xlabel("Deposition Date")
    plt.tight_layout()
    plt.savefig(f"{report_type}_pdb_report_deposition_date.png")
    plt.show()

    plt.close("all")

<!-- #### Study AlphaFold 3's relationship between training-test set ligand-binding pocket structural overlap and structure prediction performance -->

#### Plot correlation between performance of method predictions for PoseBusters Benchmark set complexes and their maximum similarity to the PDB

In [None]:
# compare max (training set) ligand-binding pocket structural overlap and (test set) RMSD of each method's predicted PoseBusters Benchmark set complexes
af3_overlap_datasets = ["posebusters_benchmark"]

all_predicted_complexes_df = []
for dataset in af3_overlap_datasets:
    for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
        predicted_complexes_df = copy.deepcopy(
            globals()[f"results_table_{repeat_index}"].loc[
                (
                    globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
                    == dataset_mapping[dataset]
                )
            ]
        )
        predicted_complexes_df["dataset"] = dataset_mapping[dataset]
        predicted_complexes_df["repeat_index"] = repeat_index
        all_predicted_complexes_df.append(predicted_complexes_df)
all_predicted_complexes_df = pd.concat(all_predicted_complexes_df, ignore_index=True)
all_predicted_complex_pdb_ids = list(all_predicted_complexes_df.loc[:, "pdb_id"].unique())

# analyze and plot statistics of each method's predictions and the novelty of the target binding modes
pb_pdb_id_ccd_code_mapping = {
    pdb_ccd_code.split("_")[0]: pdb_ccd_code.split("_")[1]
    for pdb_ccd_code in os.listdir(os.path.join("..", "data", "posebusters_benchmark_set"))
    if os.path.isdir(os.path.join("..", "data", "posebusters_benchmark_set", pdb_ccd_code))
    if pdb_ccd_code.split("_")[0].lower() in all_predicted_complex_pdb_ids
    and not any(s in pdb_ccd_code for s in ["plots", "msas", "structures"])
}

# load annotations from CSV
annotated_df = pd.read_csv(os.path.join("..", "data", "plinder", "annotations.csv"))
annotated_df["target_release_date"] = pd.to_datetime(annotated_df["target_release_date"])

# filter annotated_df to rows where (entry_pdb_id, ligand_ccd_code) matches pb_pdb_id_ccd_code_mapping
annotated_df = annotated_df[
    annotated_df.apply(
        lambda row: (
            (row["entry_pdb_id"].upper(), row["ligand_ccd_code"])
            in pb_pdb_id_ccd_code_mapping.items()
        )
        and not pd.isna(row["target_system"]),
        axis=1,
    )
]

# ensure the key columns match in format
predicted_complexes_df["pdb_id"] = predicted_complexes_df["pdb_id"].str.upper()
annotated_df["entry_pdb_id"] = annotated_df["entry_pdb_id"].str.upper()

# merge RMSD values into annotated_df
annotated_df = annotated_df.merge(
    predicted_complexes_df[["pdb_id", "method", "rmsd"]],
    left_on="entry_pdb_id",
    right_on="pdb_id",
    how="left",
)


def remove_rmsd_outliers(df: pd.DataFrame, col: str = "rmsd", factor: float = 1.5) -> pd.DataFrame:
    """Remove outliers from a DataFrame column."""
    q1 = df[col].quantile(0.25)
    q3 = df[col].quantile(0.75)
    iqr = q3 - q1
    lower_bound = q1 - factor * iqr
    upper_bound = q3 + factor * iqr
    return df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]


# drop rows with missing or outlier values
method_generalization_df = annotated_df[["sucos_shape_pocket_qcov", "method", "rmsd"]].dropna()
method_generalization_df = remove_rmsd_outliers(method_generalization_df)

# --- compute per-method correlations (use all points for stats) ---
method_corrs = []
for method_name, df_method in method_generalization_df.groupby("method"):
    if len(df_method) < 3:
        continue
    pr, pp = pearsonr(df_method["sucos_shape_pocket_qcov"], df_method["rmsd"])
    sr, sp = spearmanr(df_method["sucos_shape_pocket_qcov"], df_method["rmsd"])
    method_corrs.append(
        {
            "method": method_name,
            "pearson_r": pr,
            "pearson_p": pp,
            "spearman_r": sr,
            "spearman_p": sp,
            "n": len(df_method),
        }
    )

corr_df = pd.DataFrame(method_corrs)
if corr_df.empty:
    raise ValueError("No methods have >= 3 points; nothing to plot.")

# --- split positive vs negative Pearson r ---
# pos_df = corr_df[corr_df["pearson_r"] >= 0].sort_values("pearson_r", ascending=False)
# neg_df = corr_df[corr_df["pearson_r"] < 0].sort_values("pearson_r")

neg_df = corr_df.sort_values("pearson_r")

# pos_methods = pos_df["method"].tolist()
neg_methods = neg_df["method"].tolist()

# --- determine clipping to ignore extreme RMSD values for axis scaling ---
ymax = np.percentile(method_generalization_df["rmsd"], 95)
ymin = max(0.0, method_generalization_df["rmsd"].min())

# fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# ax_pos, ax_neg = axes

fig, axes = plt.subplots(1, 1, figsize=(10, 6), sharey=True)
ax_neg = axes


def plot_group(ax, methods, corr_subdf):
    if not methods:
        ax.text(0.5, 0.5, "No methods in this group", ha="center", va="center")
        ax.set_xlabel("SuCOS-pocket similarity")
        ax.set_ylim(ymin, ymax)
        return

    palette = sns.color_palette(n_colors=len(methods))
    for i, method in enumerate(methods):
        dfm = method_generalization_df[method_generalization_df["method"] == method]
        color = palette[i]
        # scatter points
        sns.scatterplot(
            data=dfm,
            x="sucos_shape_pocket_qcov",
            y="rmsd",
            ax=ax,
            label=method,
            color=color,
            alpha=0.5,
            s=40,
        )
        # regression line (no scatter here)
        sns.regplot(
            data=dfm,
            x="sucos_shape_pocket_qcov",
            y="rmsd",
            ax=ax,
            scatter=False,
            ci=None,
            color=color,
            line_kws={"lw": 2},
        )
        # mark outliers that were clipped for axis scaling
        out = dfm[dfm["rmsd"] > ymax]
        if not out.empty:
            ax.scatter(
                out["sucos_shape_pocket_qcov"],
                [ymax] * len(out),
                marker="^",
                s=60,
                edgecolor="k",
                linewidth=0.5,
                color=color,
                alpha=0.85,
            )

    ax.set_xlabel("SuCOS-pocket similarity")
    ax.set_ylim(ymin, ymax)
    ax.legend(fontsize=8, loc="upper right")

    # correlation table for this panel
    text_lines = [
        f"{r['method']}: r={r['pearson_r']:.2f} (p={r['pearson_p']:.1g}), ρ={r['spearman_r']:.2f}"
        for _, r in corr_subdf.iterrows()
    ]
    ax.text(
        0.02,
        0.95 * ymax,
        "\n".join(text_lines),
        fontsize=9,
        bbox=dict(facecolor="white", edgecolor="gray", boxstyle="round,pad=0.3"),
    )


# plot_group(ax_pos, pos_methods, pos_df)
# # ax_pos.set_title("Methods with positive Pearson r")
# ax_pos.set_ylabel("RMSD")


plot_group(ax_neg, neg_methods, neg_df)
# ax_neg.set_title("Methods with negative Pearson r")
ax_neg.set_ylabel("")  # shared y-label on left

plt.tight_layout()
plt.savefig("posebusters_benchmark_methods_pos_vs_neg_correlation.png", dpi=300)
plt.show()

# print sorted summary to console
print(corr_df.sort_values("pearson_r", ascending=False).to_string(index=False))