## Failure Modes Analysis Plotting

#### Import packages

In [None]:
import glob
import os
import shutil
import subprocess
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pypdb
import seaborn as sns

from posebench.analysis.inference_analysis import BUST_TEST_COLUMNS
from posebench.utils.data_utils import parse_fasta

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = False

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "vina_p2rank",
    "diffdock",
    "dynamicbind",
    "neuralplexer",
    "rfaa",
    "chai-lab_ss",
    "chai-lab",
    "alphafold3_ss",
    "alphafold3",
]
max_num_repeats_per_method = 3
method_max_training_cutoff_date = "2021-09-30"

datasets = ["astex_diverse", "posebusters_benchmark", "dockgen", "casp15"]

# PoseBusters Benchmark deposition dates
pb_deposition_dates_filepath = "posebusters_benchmark_complex_pdb_deposition_dates.csv"
assert os.path.exists(
    pb_deposition_dates_filepath
), "Please prepare the PoseBusters Benchmark complex PDB deposition dates CSV file via later steps in `failure_modes_analysis_plotting.ipynb` before proceeding."

pb_pdb_id_deposition_date_mapping_df = pd.read_csv(pb_deposition_dates_filepath)
pb_pdb_id_deposition_date_mapping_df["Deposition Date"] = pd.to_datetime(
    pb_pdb_id_deposition_date_mapping_df["Deposition Date"]
)
pb_pdb_id_deposition_date_mapping_df = pb_pdb_id_deposition_date_mapping_df[
    pb_pdb_id_deposition_date_mapping_df["Deposition Date"] > method_max_training_cutoff_date
]
pb_pdb_id_deposition_date_mapping = dict(
    zip(
        pb_pdb_id_deposition_date_mapping_df["PDB ID"],
        pb_pdb_id_deposition_date_mapping_df["Deposition Date"].astype(str),
    )
)

# Filepaths for each baseline method
globals()["vina_output_dir"] = os.path.join("..", "forks", "Vina", "inference")
globals()["diffdock_output_dir"] = os.path.join("..", "forks", "DiffDock", "inference")
globals()["dynamicbind_output_dir"] = os.path.join(
    "..", "forks", "DynamicBind", "inference", "outputs", "results"
)
globals()["neuralplexer_output_dir"] = os.path.join("..", "forks", "NeuralPLexer", "inference")
globals()["rfaa_output_dir"] = os.path.join("..", "forks", "RoseTTAFold-All-Atom", "inference")
globals()["chai-lab_output_dir"] = os.path.join("..", "forks", "chai-lab", "inference")
globals()["alphafold3_output_dir"] = os.path.join("..", "forks", "alphafold3", "inference")
globals()["casp15_output_dir"] = os.path.join("..", "data", "test_cases", "casp15")
for config in ["", "_relaxed"]:
    for dataset in datasets:
        for repeat_index in range(1, max_num_repeats_per_method + 1):
            # P2Rank-Vina results
            globals()[
                f"vina_p2rank_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["vina_output_dir"]
                ),
                (
                    f"top_vina_p2rank_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"vina_p2rank_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # DiffDock results
            globals()[f"diffdock_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["diffdock_output_dir"]
                    ),
                    (
                        f"top_diffdock_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"diffdock_{dataset}_output_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # DynamicBind results
            globals()[
                f"dynamicbind_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["dynamicbind_output_dir"]
                ),
                (
                    f"top_dynamicbind_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"{dataset}_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # NeuralPLexer results
            globals()[
                f"neuralplexer_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["neuralplexer_output_dir"]
                ),
                (
                    f"top_neuralplexer_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"neuralplexer_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # RoseTTAFold-All-Atom results
            globals()[f"rfaa_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["rfaa_output_dir"]
                    ),
                    (
                        f"top_rfaa_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"rfaa_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # Chai-1 (Single-Seq) results
            globals()[
                f"chai-lab_ss_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["chai-lab_output_dir"]
                ),
                (
                    f"top_chai-lab_ss_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"chai-lab_ss_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # Chai-1 results
            globals()[f"chai-lab_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["chai-lab_output_dir"]
                    ),
                    (
                        f"top_chai-lab_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"chai-lab_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

            # AlphaFold 3 (Single-Seq) results
            globals()[
                f"alphafold3_ss_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
            ] = os.path.join(
                (
                    globals()["casp15_output_dir"] + config
                    if dataset == "casp15"
                    else globals()["alphafold3_output_dir"]
                ),
                (
                    f"top_alphafold3_ss_ensemble_predictions_{repeat_index}"
                    if dataset == "casp15"
                    else f"alphafold3_ss_{dataset}_outputs_{repeat_index}{config}"
                ),
                "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
            )

            # AlphaFold 3 results
            globals()[f"alphafold3_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"] = (
                os.path.join(
                    (
                        globals()["casp15_output_dir"] + config
                        if dataset == "casp15"
                        else globals()["alphafold3_output_dir"]
                    ),
                    (
                        f"top_alphafold3_ensemble_predictions_{repeat_index}"
                        if dataset == "casp15"
                        else f"alphafold3_{dataset}_outputs_{repeat_index}{config}"
                    ),
                    "scoring_results.csv" if dataset == "casp15" else "bust_results.csv",
                )
            )

# Mappings
method_mapping = {
    "vina_p2rank": "P2Rank-Vina",
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "rfaa": "RFAA",
    "chai-lab_ss": "Chai-1-Single-Seq",
    "chai-lab": "Chai-1",
    "alphafold3_ss": "AF3-Single-Seq",
    "alphafold3": "AF3",
}

method_category_mapping = {
    "vina_p2rank": "Conventional blind",
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "rfaa": "DL-based blind",
    "chai-lab_ss": "DL-based blind",
    "chai-lab": "DL-based blind",
    "alphafold3_ss": "DL-based blind",
    "alphafold3": "DL-based blind",
}

dataset_mapping = {
    "astex_diverse": "Astex Diverse set",
    "posebusters_benchmark": "Posebusters Benchmark set",
    "dockgen": "DockGen set",
    "casp15": "CASP15 set",
}

casp15_target_pdb_id_mapping = {
    # NOTE: `?` indicates that the target's crystal structure is not publicly available
    "H1135": "7z8y",
    "H1171v1": "7pbl",
    "H1171v2": "7pbl",
    "H1172v1": "7pbp",
    "H1172v2": "7pbp",
    "H1172v3": "7pbp",
    "H1172v4": "7pbp",
    "T1124": "7ux8",
    "T1127v2": "?",
    "T1146": "?",
    "T1152": "7r1l",
    "T1158v1": "8sx8",
    "T1158v2": "8sxb",
    "T1158v3": "8sx7",
    "T1158v4": "8swn",
    "T1170": "7pbr",
    "T1181": "?",
    "T1186": "?",
    "T1187": "8ad2",
    "T1188": "8c6z",
}

#### Load test results for each baseline method

In [None]:
# load and report test results for each baseline method
for config in [""]:
    for dataset in datasets:
        for method in baseline_methods:
            for repeat_index in range(1, max_num_repeats_per_method + 1):
                method_title = method_mapping[method]

                if not os.path.exists(
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
                    ]
                ):
                    continue

                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"] = (
                    pd.read_csv(
                        globals()[
                            f"{method}_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
                        ]
                    )
                )

                if dataset == "casp15":
                    # count the number of ligands in each target complex, and assign these corresponding numbers to the ligands (rows) of each complex
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "num_target_ligands"
                    ] = (
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
                        .groupby(["target", "mdl"])["pose"]
                        .transform("count")
                    )

                    # filter out non-relevant ligand predictions, and for all methods select only their first model for each ligand
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        np.where(
                            (
                                globals()[
                                    f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                                ].relevant
                            ),
                            True,
                            False,
                        )
                        & (
                            globals()[
                                f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                            ].mdl
                            == 1
                        )
                    ]

                    # finalize bust (i.e., scoring) results for CASP15, using dummy values for `pb_valid` and `crmsd_≤_1å`
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "rmsd_≤_2å"
                    ] = (
                        globals()[
                            f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                        ].loc[:, "rmsd"]
                        <= 2
                    )
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "pdb_valid"
                    ] = True
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "crmsd_≤_1å"
                    ] = True

                else:
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        BUST_TEST_COLUMNS + ["rmsd", "centroid_distance", "mol_id"]
                    ]
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "pb_valid"
                    ] = (
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
                        .iloc[:, 1:-3]
                        .all(axis=1)
                    )
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "crmsd_≤_1å"
                    ] = (
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                            "centroid_distance"
                        ]
                        < 1
                    )
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                        :, "pdb_id"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "mol_id"
                    ]

                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "method"
                ] = method
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "post-processing"
                ] = ("energy minimization" if config == "_relaxed" else "none")
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "dataset"
                ] = dataset
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "docked_ligand_successfully_loaded"
                ] = (
                    True
                    if dataset == "casp15"
                    else globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ][["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]].all(axis=1)
                )

                if dataset == "posebusters_benchmark":
                    # keep only the results for complexes deposited in the PDB after the maximum cutoff date for any method's training data
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "pdb_id"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "mol_id"
                    ].map(
                        lambda x: x.lower().split("_")[0]
                    )
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                    ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                            "pdb_id"
                        ].isin(pb_pdb_id_deposition_date_mapping.keys())
                    ]

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "Misc")


def subset_counter(
    subset_counter: Counter, superset_counter: Counter, normalize_subset: bool = False
) -> Counter:
    """
    Subset a superset counter by a subset counter.

    :param subset_counter: Subset counter.
    :param superset_counter: Superset counter.
    :param normalize_subset: Normalize subset counter by superset counter.
    :return: Subsetted counter.
    """
    subsetted_counter = Counter()
    for key in subset_counter:
        if key in superset_counter and superset_counter[key] != 0:
            subsetted_counter[key] = (
                subset_counter[key] / superset_counter[key]
                if normalize_subset
                else superset_counter[key]
            )
        else:
            subsetted_counter[key] = 0  # or handle as needed
    return subsetted_counter

#### Standardize metrics

In [None]:
# load and organize the results CSVs
for repeat_index in range(1, max_num_repeats_per_method + 1):
    globals()[f"results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
            for dataset in datasets
            for method in baseline_methods
            for config in [""]
            if f"{method}_{dataset}{config}_bust_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"results_table_{repeat_index}"].loc[:, "method_assignment_index"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(assign_method_index)
    globals()[f"results_table_{repeat_index}"].loc[:, "crmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "crmsd_≤_1å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_≤_2å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold_and_pb_valid"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"]
    ) & (globals()[f"results_table_{repeat_index}"].loc[:, "pb_valid"].fillna(False))
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å & PB-Valid"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold_and_pb_valid"]
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "cRMSD ≤ 1 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "crmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "dataset"].map(dataset_mapping)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Collect metadata across all datasets

In [None]:
# find PDB IDs of complexes across all datasets
for dataset in datasets:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
            == dataset_mapping[dataset]
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_{repeat_index}"] = set(
            dataset_results_table.loc[:, "pdb_id"].unique()
        )

#### Plot distribution of complex types across all datasets

In [None]:
# plot functional keyword statistics of the complexes across all datasets
pdb_info_cache = dict()

for repeat_index in [1]:  # NOTE: we only consider the first repeat
    all_complexes_df = []
    for dataset in datasets:
        complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_{repeat_index}"], columns=["pdb_id"]
        )
        complexes_df["dataset"] = dataset_mapping[dataset]
        all_complexes_df.append(complexes_df)
    all_complexes_df = pd.concat(all_complexes_df, ignore_index=True)

    if all_complexes_df.empty:
        print("No complexes for any dataset.")
        continue

    complex_function_annotations = []
    for pdb_id in set(all_complexes_df["pdb_id"]):
        pdb_id = pdb_id.lower().split("_")[0]
        if pdb_id == "?":
            continue
        if pdb_id in pdb_info_cache:
            pdb_id_info = pdb_info_cache[pdb_id]
        else:
            pdb_id_info = pypdb.get_all_info(pdb_id)
            pdb_info_cache[pdb_id] = pdb_id_info
        if not pdb_id_info:
            continue
        complex_function_annotations.append(
            # NOTE: these represent functional keywords
            pdb_id_info["struct_keywords"]["pdbx_keywords"]
            .lower()
            .split(", ")[0]
        )

    complex_function_annotation_counts = Counter(complex_function_annotations)
    df = pd.DataFrame(
        complex_function_annotation_counts.items(),
        columns=["Keyword", "Frequency"],
    )
    df["Frequency"] = df["Frequency"].astype(int)
    df = df.sort_values(by="Frequency", ascending=False)

    plt.figure(figsize=(20, 10))
    sns.barplot(data=df, x="Frequency", y="Keyword", palette="viridis")

    max_freq = df["Frequency"].max()
    plt.xticks(ticks=range(0, max_freq + 1), labels=range(0, max_freq + 1), rotation=60)

    plt.xlabel("Frequency")
    plt.ylabel("Complex Annotation")

    plt.tight_layout()
    plt.savefig(f"complexes_functional_keywords_{repeat_index}.png", bbox_inches="tight")
    plt.show()

    plt.close("all")

    print(f"{len(complex_function_annotations)} complex annotations across all datasets.")

#### Plot distribution of the PoseBusters Benchmark dataset's complex deposition dates

In [None]:
# # report the PDB deposition date of each PoseBusters Benchmark complex
# posebusters_complex_deposition_dates = dict()
# for pdb_id in set(
#     all_complexes_df[all_complexes_df["dataset"] == "Posebusters Benchmark set"]["pdb_id"]
# ):
#     pdb_id = pdb_id.lower().split("_")[0]
#     if pdb_id == "?":
#         continue
#     if pdb_id in pdb_info_cache:
#         pdb_id_info = pdb_info_cache[pdb_id]
#     else:
#         pdb_id_info = pypdb.get_all_info(pdb_id)
#         pdb_info_cache[pdb_id] = pdb_id_info
#     if not pdb_id_info:
#         continue
#     posebusters_complex_deposition_dates[pdb_id] = pdb_id_info["rcsb_accession_info"][
#         "deposit_date"
#     ]

# # analyze and plot statistics of the PoseBusters complexes' deposition dates
# posebusters_complex_pdb_deposition_dates_df = pd.DataFrame(
#     {
#         "PDB ID": posebusters_complex_deposition_dates.keys(),
#         "Deposition Date": posebusters_complex_deposition_dates.values(),
#     }
# )
# posebusters_complex_pdb_deposition_dates_df["Deposition Date"] = pd.to_datetime(
#     posebusters_complex_pdb_deposition_dates_df["Deposition Date"]
# )
# posebusters_complex_pdb_deposition_dates_df.to_csv(
#     "posebusters_benchmark_complex_pdb_deposition_dates.csv", index=False
# )

# posebusters_pre_cutoff_complexes = posebusters_complex_pdb_deposition_dates_df[
#     posebusters_complex_pdb_deposition_dates_df["Deposition Date"]
#     <= method_max_training_cutoff_date
# ]
# posebusters_post_cutoff_complexes = posebusters_complex_pdb_deposition_dates_df[
#     posebusters_complex_pdb_deposition_dates_df["Deposition Date"]
#     > method_max_training_cutoff_date
# ]
# print(
#     f"{len(posebusters_pre_cutoff_complexes)}/{len(posebusters_complex_pdb_deposition_dates_df)} PoseBusters Benchmark complexes deposited before maximum cutoff of {method_max_training_cutoff_date}."
# )
# print(
#     f"{len(posebusters_post_cutoff_complexes)}/{len(posebusters_complex_pdb_deposition_dates_df)} PoseBusters Benchmark complexes deposited after maximum cutoff of {method_max_training_cutoff_date}."
# )

# sns.histplot(posebusters_complex_pdb_deposition_dates_df["Deposition Date"].values, bins=25)
# plt.xlabel("Deposition Date")
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.savefig("posebusters_benchmark_complex_pdb_deposition_dates.png")
# plt.show()
# plt.close("all")

#### Identify failure modes across all datasets

In [None]:
# find PDB IDs of complexes for which the correct (e.g., RMSD ≤ 2 Å & PB-Valid) binding conformation was not found by any method
for dataset in datasets:
    docking_success_column = "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
            == dataset_mapping[dataset]
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_docked_by_any_method_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "pdb_id",
            ].unique()
        )
        globals()[f"{dataset}_complexes_not_docked_by_any_method_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "pdb_id"].isin(
                    globals()[f"{dataset}_complexes_docked_by_any_method_{repeat_index}"]
                ),
                "pdb_id",
            ].unique()
        )

#### Find commonalities among the failure modes of each dataset

In [None]:
# plot functional keyword statistics of the failed complexes across all datasets
for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
    all_failed_complexes_df = []
    for dataset in datasets:
        failed_complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_not_docked_by_any_method_{repeat_index}"],
            columns=["pdb_id"],
        )
        failed_complexes_df["dataset"] = dataset_mapping[dataset]
        all_failed_complexes_df.append(failed_complexes_df)
    all_failed_complexes_df = pd.concat(all_failed_complexes_df, ignore_index=True)

    if all_failed_complexes_df.empty:
        print("No failed complexes for any dataset.")
        continue

    failed_complex_function_annotations = []
    for pdb_id in set(all_failed_complexes_df["pdb_id"]):
        pdb_id = pdb_id.lower().split("_")[0]
        if pdb_id == "?":
            continue
        if pdb_id in pdb_info_cache:
            pdb_id_info = pdb_info_cache[pdb_id]
        else:
            pdb_id_info = pypdb.get_all_info(pdb_id)
            pdb_info_cache[pdb_id] = pdb_id_info
        if not pdb_id_info:
            continue
        failed_complex_function_annotations.append(
            # NOTE: these represent functional keywords
            pdb_id_info["struct_keywords"]["pdbx_keywords"]
            .lower()
            .split(", ")[0]
        )

    failed_complex_function_annotation_counts = subset_counter(
        Counter(failed_complex_function_annotations),
        complex_function_annotation_counts,
        normalize_subset=True,
    )
    df = pd.DataFrame(
        failed_complex_function_annotation_counts.items(),
        columns=["Keyword", "Failed Ratio"],
    )
    df["Frequency"] = df["Keyword"].map(complex_function_annotation_counts)
    df.sort_values(
        by=["Failed Ratio", "Frequency"], ascending=False, inplace=True, ignore_index=True
    )

    plt.figure(figsize=(10, 6))
    sns.barplot(data=df, x="Failed Ratio", y="Keyword", palette="viridis")

    plt.xlabel("Failed Ratio")
    plt.ylabel("Complex Annotation")

    plt.xlim(0, 1.1)

    # annotate bars with the frequency of each keyword
    for index, row in df.iterrows():
        plt.text(
            x=row["Failed Ratio"] + 0.01,
            y=index,
            s=f"{row['Failed Ratio']:.2f} ({row['Frequency']})",
            va="center",
        )

    plt.tight_layout()
    plt.savefig(f"failed_complexes_functional_keywords_{repeat_index}.png")
    plt.show()

    plt.close("all")

    print(f"{len(failed_complex_function_annotations)} complex annotations across all datasets.")

#### Identify AlphaFold 3's failure modes

In [None]:
# find complexes that AlphaFold 3 failed to correctly predict
for dataset in datasets:
    docking_success_column = "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            (
                globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
                == dataset_mapping[dataset]
            )
            & (globals()[f"results_table_{repeat_index}"].loc[:, "method"] == "AF3")
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "pdb_id",
            ].unique()
        )
        globals()[f"{dataset}_complexes_not_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "pdb_id"].isin(
                    globals()[f"{dataset}_complexes_docked_by_af3_{repeat_index}"]
                ),
                "pdb_id",
            ].unique()
        )

#### Record and plot AlphaFold 3's failure mode metadata

In [None]:
# plot functional keyword statistics of AlphaFold 3's failed complexes across all datasets
all_failed_af3_complexes_df = []
for dataset in datasets:
    for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
        failed_af3_complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_not_docked_by_af3_{repeat_index}"],
            columns=["pdb_id"],
        )
        failed_af3_complexes_df["dataset"] = dataset_mapping[dataset]
        failed_af3_complexes_df["repeat_index"] = repeat_index
        all_failed_af3_complexes_df.append(failed_af3_complexes_df)
all_failed_af3_complexes_df = pd.concat(all_failed_af3_complexes_df, ignore_index=True)

failed_af3_complex_function_annotations = []
for pdb_id in set(all_failed_af3_complexes_df["pdb_id"]):
    pdb_id = pdb_id.lower().split("_")[0]
    if pdb_id == "?":
        continue
    if pdb_id in pdb_info_cache:
        pdb_id_info = pdb_info_cache[pdb_id]
    else:
        pdb_id_info = pypdb.get_all_info(pdb_id)
        pdb_info_cache[pdb_id] = pdb_id_info
    if not pdb_id_info:
        continue
    failed_af3_complex_function_annotations.append(
        # NOTE: these represent functional keywords
        pdb_id_info["struct_keywords"]["pdbx_keywords"]
        .lower()
        .split(", ")[0]
    )

failed_af3_complex_function_annotation_counts = subset_counter(
    Counter(failed_af3_complex_function_annotations),
    complex_function_annotation_counts,
    normalize_subset=True,
)
df = pd.DataFrame(
    failed_af3_complex_function_annotation_counts.items(),
    columns=["Keyword", "Failed Ratio"],
)
df["Frequency"] = df["Keyword"].map(complex_function_annotation_counts)
df.sort_values(by=["Failed Ratio", "Frequency"], ascending=False, inplace=True, ignore_index=True)

plt.figure(figsize=(12, 8))
sns.barplot(data=df, x="Failed Ratio", y="Keyword", palette="viridis")

plt.xlabel("Failed Ratio")
plt.ylabel("Complex Annotation")

plt.xlim(0, 1.09)

# annotate bars with the frequency of each keyword
for index, row in df.iterrows():
    plt.text(
        x=row["Failed Ratio"] + 0.01,
        y=index,
        s=f"{row['Failed Ratio']:.2f} ({row['Frequency']})",
        va="center",
    )

plt.tight_layout()
plt.savefig("failed_af3_complexes_functional_keywords.png")
plt.show()

plt.close("all")

print(f"{len(failed_af3_complex_function_annotations)} complex annotations across all datasets.")

#### Study PDB statistics of different types of complexes

In [None]:
# combine all CSV files from a custom PDB report
report_types = os.listdir("pdb_reports")

for report_type in report_types:
    pdb_report_dir = os.path.join("pdb_reports", report_type)
    pdb_report_files = [
        os.path.join(pdb_report_dir, f) for f in os.listdir(pdb_report_dir) if f.endswith(".csv")
    ]

    pdb_report_dfs = []
    for pdb_report_file in pdb_report_files:
        pdb_report_dfs.append(pd.read_csv(pdb_report_file, skiprows=1))
    pdb_report_df = pd.concat(pdb_report_dfs, ignore_index=True)

    # analyze and plot statistics of the custom PDB report
    pdb_report_df["Refinement Resolution (Å)"] = pdb_report_df["Refinement Resolution (Å)"].astype(
        str
    )
    pdb_report_df["Refinement Resolution (Å)"] = pd.to_numeric(
        pdb_report_df["Refinement Resolution (Å)"].str.replace(",", ""),
        errors="coerce",
    )
    pdb_report_df["Deposition Date"] = pd.to_datetime(pdb_report_df["Deposition Date"])

    pdb_report_df.to_csv(f"{report_type}_pdb_report.csv", index=False)

    print(f"{len(pdb_report_df)} PDB entries in the custom {report_type} report.")

    sns.histplot(pdb_report_df["Refinement Resolution (Å)"].values)
    plt.xlim(0, 10)
    plt.xlabel("Refinement Resolution (Å)")
    plt.tight_layout()
    plt.savefig(f"{report_type}_pdb_report_resolution.png")
    plt.show()

    plt.close("all")

    sns.histplot(pdb_report_df["Deposition Date"].values)
    plt.xlabel("Deposition Date")
    plt.tight_layout()
    plt.savefig(f"{report_type}_pdb_report_deposition_date.png")
    plt.show()

    plt.close("all")

#### Study AlphaFold 3's relationship between training-test set sequence overlap and structure prediction performance

In [None]:
# find PoseBusters Benchmark set and CASP15 complexes that AlphaFold 3 failed to correctly predict
for dataset in ["posebusters_benchmark", "casp15"]:
    docking_success_column = "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            (
                globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
                == dataset_mapping[dataset]
            )
            & (globals()[f"results_table_{repeat_index}"].loc[:, "method"] == "AF3")
        ]

        if dataset == "casp15":
            dataset_results_table.loc[:, "pdb_id"] = dataset_results_table.loc[:, "target"].map(
                casp15_target_pdb_id_mapping
            )

        globals()[f"{dataset}_complexes_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "pdb_id",
            ].unique()
        )
        globals()[f"{dataset}_complexes_not_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "pdb_id"].isin(
                    globals()[f"{dataset}_complexes_docked_by_af3_{repeat_index}"]
                ),
                "pdb_id",
            ].unique()
        )

#### Plot deposition dates of proteins most similar to AlphaFold 3's failed PoseBusters Benchmark set complexes

In [None]:
# plot max sequence overlap deposition dates of AlphaFold 3's failed PoseBusters Benchmark set and CASP15 complexes
af3_overlap_datasets = ["posebusters_benchmark", "casp15"]

all_af3_failed_complexes_df = []
for dataset in af3_overlap_datasets:
    for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
        failed_complexes_df = pd.DataFrame(
            globals()[f"{dataset}_complexes_not_docked_by_af3_{repeat_index}"],
            columns=["pdb_id"],
        )
        failed_complexes_df["dataset"] = dataset_mapping[dataset]
        failed_complexes_df["repeat_index"] = repeat_index
        all_af3_failed_complexes_df.append(failed_complexes_df)
all_af3_failed_complexes_df = pd.concat(all_af3_failed_complexes_df, ignore_index=True)
all_af3_failed_complex_pdb_ids = list(all_af3_failed_complexes_df.loc[:, "pdb_id"].unique())

# prepare target and database FASTA files
target_fasta = "target.fasta"
database_fasta = os.path.join("..", "data", "pdb_data", "pdb_seqres.txt")

# parse PDB sequences
pdb_sequences = parse_fasta(
    database_fasta,
    only_mols=["protein"],
    collate_by_pdb_id=True,
)

with open(target_fasta, "w") as target_f:
    for pdb_id in all_af3_failed_complex_pdb_ids:
        pdb_id_ = pdb_id.lower().split("_")[0]
        if pdb_id_ in pdb_sequences:
            for seq in pdb_sequences[pdb_id_]:
                target_f.write(f">{pdb_id_}_{seq[0]}\n{seq[-1]}\n")

# run MMseqs2 to find the best match for each PDB chain
result_file = "result.m8"
tmp_dir = "mmseqs_tmp"
os.makedirs(tmp_dir, exist_ok=True)
subprocess.run(
    [
        "mmseqs",
        "easy-search",
        target_fasta,
        database_fasta,
        result_file,
        tmp_dir,
        "--format-output",
        "query,target,pident,qcov,tcov,evalue,bits",
    ]
)

# parse MMseqs2 top match for each PDB chain
query_top_match_pdb_id_mappings = dict()
if os.path.exists(result_file):
    with open(result_file, "r") as f:
        for line in f:
            query_pdb_id = line.strip().split("\t")[0]
            top_match_pdb_id = line.strip().split("\t")[1]
            if (
                query_pdb_id not in query_top_match_pdb_id_mappings
                and query_pdb_id.split("_")[0] != top_match_pdb_id.split("_")[0]
            ):
                query_top_match_pdb_id_mappings[query_pdb_id] = top_match_pdb_id

    all_af3_failed_complex_pdb_ids = list(query_top_match_pdb_id_mappings.keys())
    top_match_pdb_ids = list(query_top_match_pdb_id_mappings.values())

else:
    raise ValueError(
        "No results found. Ensure MMseqs2 is correctly installed and the input sequences are valid."
    )

os.remove(target_fasta)
os.remove(result_file)
shutil.rmtree(tmp_dir)

# find the deposition dates of failed complexes and their top matches
failed_complexes = []
failed_complex_indices = set()
failed_complexes_after_cutoff_deposition_date = []
failed_complexes_after_cutoff_deposition_date_indices = set()
for failed_complex_index, failed_complex_pdb_id in enumerate(all_af3_failed_complex_pdb_ids):
    failed_complex_pdb_id_ = failed_complex_pdb_id.lower().split("_")[0]
    if failed_complex_pdb_id_ == "?":
        continue
    if failed_complex_pdb_id_ in pdb_info_cache:
        failed_complex_pdb_info = pdb_info_cache[failed_complex_pdb_id_]
    else:
        failed_complex_pdb_info = pypdb.get_all_info(failed_complex_pdb_id_)
        pdb_info_cache[failed_complex_pdb_id_] = failed_complex_pdb_info
    if not failed_complex_pdb_info:
        continue
    deposition_date = failed_complex_pdb_info["rcsb_accession_info"]["deposit_date"]
    failed_complex_indices.add(failed_complex_index)
    failed_complexes.append((failed_complex_pdb_id_, deposition_date))
    if deposition_date > method_max_training_cutoff_date:
        failed_complexes_after_cutoff_deposition_date_indices.add(failed_complex_index)
        failed_complexes_after_cutoff_deposition_date.append(
            (failed_complex_pdb_id_, deposition_date)
        )

top_match_complexes = []
top_match_complexes_after_cutoff_deposition_date = []
for top_match_index, top_match_pdb_id in enumerate(top_match_pdb_ids):
    top_match_pdb_id_ = top_match_pdb_id.lower().split("_")[0]
    if top_match_pdb_id_ == "?":
        continue
    if top_match_index not in failed_complex_indices:
        continue
    if top_match_pdb_id_ in pdb_info_cache:
        top_match_pdb_info = pdb_info_cache[top_match_pdb_id_]
    else:
        top_match_pdb_info = pypdb.get_all_info(top_match_pdb_id_)
        pdb_info_cache[top_match_pdb_id_] = top_match_pdb_info
    if not top_match_pdb_info:
        continue
    deposition_date = top_match_pdb_info["rcsb_accession_info"]["deposit_date"]
    top_match_complexes.append((top_match_pdb_id_, deposition_date))
    if top_match_index in failed_complexes_after_cutoff_deposition_date_indices:
        top_match_complexes_after_cutoff_deposition_date.append(
            (top_match_pdb_id_, deposition_date)
        )

assert len(failed_complexes) == len(
    top_match_complexes
), "Expected equal number of failed complexes and top matches."
assert len(failed_complexes_after_cutoff_deposition_date) == len(
    top_match_complexes_after_cutoff_deposition_date
), "Expected equal number of failed complexes and top matches after the cutoff deposition date."

# analyze and plot statistics of the failed complexes and their top matches' deposition dates
failed_complex_pdb_deposition_dates_df = pd.DataFrame(
    {
        "Deposition Date": [com[1] for com in list(dict.fromkeys(failed_complexes))]
    }  # remove chain duplicates in-order
)
failed_complex_pdb_deposition_dates_df["Deposition Date"] = pd.to_datetime(
    failed_complex_pdb_deposition_dates_df["Deposition Date"]
)
failed_complex_pdb_deposition_dates_df.to_csv(
    "af3_failed_complex_pdb_deposition_dates.csv", index=False
)
sns.histplot(failed_complex_pdb_deposition_dates_df["Deposition Date"].values, bins=25)
plt.xlabel("Deposition Date")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig("af3_failed_complex_pdb_deposition_dates.png")
plt.show()
plt.close("all")

print(
    f"{len(failed_complex_pdb_deposition_dates_df)} date annotations across {af3_overlap_datasets}."
)

top_match_pdb_deposition_dates_df = pd.DataFrame(
    {
        "Deposition Date": [com[1] for com in list(dict.fromkeys(top_match_complexes))]
    }  # remove chain duplicates in-order
)
top_match_pdb_deposition_dates_df["Deposition Date"] = pd.to_datetime(
    top_match_pdb_deposition_dates_df["Deposition Date"]
)
top_match_pdb_deposition_dates_df.to_csv(
    "af3_failed_complex_top_match_pdb_deposition_dates.csv", index=False
)
sns.histplot(top_match_pdb_deposition_dates_df["Deposition Date"].values, bins=30)
plt.xlabel("Deposition Date")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig("af3_failed_complex_top_match_pdb_deposition_dates.png")
plt.show()
plt.close("all")

print(
    f"{len(top_match_pdb_deposition_dates_df)} top match date annotations across {af3_overlap_datasets}."
)

#### Identify specific complexes AlphaFold 3 failed to predict that are worth studying further

In [None]:
# filter for failed complexes and their top matches that were both deposited after the cutoff deposition date
collated_complexes = defaultdict(list)
for failed_complex, top_match_complex in zip(
    failed_complexes_after_cutoff_deposition_date, top_match_complexes_after_cutoff_deposition_date
):
    complex_pdb_id = failed_complex[0]
    top_match_pdb_id = top_match_complex[0]
    complex_deposition_date = failed_complex[1]
    top_match_complex_deposition_date = top_match_complex[1]
    collated_complex = (
        (complex_pdb_id, complex_deposition_date),
        (top_match_pdb_id, top_match_complex_deposition_date),
    )
    collated_complexes[complex_pdb_id].append(collated_complex)

failed_complexes_to_study_further = []
failed_complex_pdb_ids_to_study_further = []
failed_complex_function_annotations_to_study_further = []
for complex_pdb_id in collated_complexes:
    collated_complex_chains = collated_complexes[complex_pdb_id]
    all_chains_after_cutoff_deposition_date = True
    for collated_complex_chain in collated_complex_chains:
        (complex_pdb_id, complex_deposition_date), (
            top_match_pdb_id,
            top_match_complex_deposition_date,
        ) = collated_complex_chain
        if (
            complex_deposition_date <= method_max_training_cutoff_date
            or top_match_complex_deposition_date <= method_max_training_cutoff_date
        ):
            all_chains_after_cutoff_deposition_date = False
            break
    if all_chains_after_cutoff_deposition_date:
        complex_pdb_id_ = complex_pdb_id.lower().split("_")[0]
        if complex_pdb_id_ == "?":
            continue
        if complex_pdb_id_ in pdb_info_cache:
            complex_pdb_info = pdb_info_cache[complex_pdb_id_]
        else:
            complex_pdb_info = pypdb.get_all_info(complex_pdb_id_)
            pdb_info_cache[complex_pdb_id_] = complex_pdb_info
        if not complex_pdb_info:
            continue
        complex_function_annotation = (
            # NOTE: these represent functional keywords
            complex_pdb_info["struct_keywords"]["pdbx_keywords"]
            .lower()
            .split(", ")[0]
        )
        failed_complexes_to_study_further.append(collated_complex_chains)
        failed_complex_pdb_ids_to_study_further.append(complex_pdb_id)
        failed_complex_function_annotations_to_study_further.append(complex_function_annotation)
        print(f"{complex_pdb_id}, {complex_function_annotation} : {collated_complex_chains}")

for complex_index, (complex_pdb_id, complex_function_annotation) in enumerate(
    zip(
        failed_complex_pdb_ids_to_study_further,
        failed_complex_function_annotations_to_study_further,
    )
):
    if complex_index == 0:
        print()
    pb_path = os.path.join(
        "..", "data", "posebusters_benchmark_set", f"{complex_pdb_id.upper()}_*", "*.pdb"
    )
    if glob.glob(pb_path):
        pb_path = glob.glob(pb_path)[0]
        print(
            f"Posebusters Benchmark target: {complex_pdb_id}, {complex_function_annotation} -> {pb_path}"
        )
    else:
        casp_matches = [k for k, v in casp15_target_pdb_id_mapping.items() if v == complex_pdb_id]
        print(f"CASP15 target: {complex_pdb_id}, {complex_function_annotation} -> {casp_matches}")

print(
    f"{len(failed_complexes_to_study_further)} ((failed complex, deposition date), (top match, deposition date)) novel (multi-chain) protein-ligand PDB complexes AlphaFold 3 failed to predict that are worth studying further."
)