## Failure Modes Analysis Plotting

#### Import packages

In [None]:
import json
import os
from collections import Counter

import matplotlib.pyplot as plt
import pandas as pd
import pypdb
import seaborn as sns
from beartype.typing import Any, Dict, List, Optional, Tuple
from pdbeccdutils.core import ccd_reader
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.DataStructs import TanimotoSimilarity
from tqdm import tqdm

from posebench.analysis.inference_analysis import BUST_TEST_COLUMNS

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "vina_p2rank",
    "diffdock",
    "dynamicbind",
    "neuralplexer",
    "rfaa",
    "chai-lab_ss",
    "chai-lab",
    "alphafold3_ss",
    "alphafold3",
]
max_num_repeats_per_method = 3
method_max_training_cutoff_date = "2021-09-30"

datasets = ["astex_diverse", "posebusters_benchmark", "dockgen"]

# Filepaths for each baseline method
globals()["vina_output_dir"] = os.path.join("..", "forks", "Vina", "inference")
globals()["diffdock_output_dir"] = os.path.join("..", "forks", "DiffDock", "inference")
globals()["dynamicbind_output_dir"] = os.path.join(
    "..", "forks", "DynamicBind", "inference", "outputs", "results"
)
globals()["neuralplexer_output_dir"] = os.path.join("..", "forks", "NeuralPLexer", "inference")
globals()["rfaa_output_dir"] = os.path.join("..", "forks", "RoseTTAFold-All-Atom", "inference")
globals()["chai-lab_output_dir"] = os.path.join("..", "forks", "chai-lab", "inference")
globals()["alphafold3_output_dir"] = os.path.join("..", "forks", "alphafold3", "inference")
for dataset in datasets:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        # P2Rank-Vina results
        globals()[f"vina_p2rank_{dataset}_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["vina_output_dir"],
                f"vina_p2rank_{dataset}_outputs_{repeat_index}",
                "bust_results.csv",
            )
        )
        globals()[f"vina_p2rank_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["vina_output_dir"],
                f"vina_p2rank_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # DiffDock results
        globals()[f"diffdock_{dataset}_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
            globals()["diffdock_output_dir"],
            f"diffdock_{dataset}_output_{repeat_index}",
            "bust_results.csv",
        )
        globals()[f"diffdock_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["diffdock_output_dir"],
                f"diffdock_{dataset}_output_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # DynamicBind results
        globals()[f"dynamicbind_{dataset}_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["dynamicbind_output_dir"],
                f"{dataset}_{repeat_index}",
                "bust_results.csv",
            )
        )
        globals()[f"dynamicbind_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["dynamicbind_output_dir"],
                f"{dataset}_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # NeuralPLexer results
        globals()[f"neuralplexer_{dataset}_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["neuralplexer_output_dir"],
                f"neuralplexer_{dataset}_outputs_{repeat_index}",
                "bust_results.csv",
            )
        )
        globals()[f"neuralplexer_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["neuralplexer_output_dir"],
                f"neuralplexer_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # RoseTTAFold-All-Atom results
        globals()[f"rfaa_{dataset}_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
            globals()["rfaa_output_dir"],
            f"rfaa_{dataset}_outputs_{repeat_index}",
            "bust_results.csv",
        )
        globals()[f"rfaa_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["rfaa_output_dir"],
                f"rfaa_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # Chai-1 (Single-Seq) results
        globals()[f"chai-lab_ss_{dataset}_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["chai-lab_output_dir"],
                f"chai-lab_ss_{dataset}_outputs_{repeat_index}",
                "bust_results.csv",
            )
        )
        globals()[f"chai-lab_ss_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["chai-lab_output_dir"],
                f"chai-lab_ss_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # Chai-1 results
        globals()[f"chai-lab_{dataset}_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
            globals()["chai-lab_output_dir"],
            f"chai-lab_{dataset}_outputs_{repeat_index}",
            "bust_results.csv",
        )
        globals()[f"chai-lab_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["chai-lab_output_dir"],
                f"chai-lab_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # AlphaFold 3 (Single-Seq) results
        globals()[f"alphafold3_ss_{dataset}_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["alphafold3_output_dir"],
                f"alphafold3_ss_{dataset}_outputs_{repeat_index}",
                "bust_results.csv",
            )
        )
        globals()[f"alphafold3_ss_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["alphafold3_output_dir"],
                f"alphafold3_ss_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

        # AlphaFold 3 results
        globals()[f"alphafold3_{dataset}_bust_results_csv_filepath_{repeat_index}"] = os.path.join(
            globals()["alphafold3_output_dir"],
            f"alphafold3_{dataset}_outputs_{repeat_index}",
            "bust_results.csv",
        )
        globals()[f"alphafold3_{dataset}_relaxed_bust_results_csv_filepath_{repeat_index}"] = (
            os.path.join(
                globals()["alphafold3_output_dir"],
                f"alphafold3_{dataset}_outputs_{repeat_index}_relaxed",
                "bust_results.csv",
            )
        )

# Mappings
method_mapping = {
    "vina_p2rank": "P2Rank-Vina",
    "diffdock": "DiffDock-L",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "rfaa": "RoseTTAFold-AA",
    "chai-lab_ss": "Chai-1-Single-Seq",
    "chai-lab": "Chai-1",
    "alphafold3_ss": "AF3-Single-Seq",
    "alphafold3": "AF3",
}

method_category_mapping = {
    "vina_p2rank": "Conventional blind",
    "diffdock": "DL-based blind",
    "dynamicbind": "DL-based blind",
    "neuralplexer": "DL-based blind",
    "rfaa": "DL-based blind",
    "chai-lab_ss": "DL-based blind",
    "chai-lab": "DL-based blind",
    "alphafold3_ss": "DL-based blind",
    "alphafold3": "DL-based blind",
}

dataset_mapping = {
    "astex_diverse": "Astex Diverse set",
    "posebusters_benchmark": "Posebusters Benchmark set",
    "dockgen": "DockGen set",
    "casp15": "CASP15 set",
}

ligand_prediction_methods = set(method_mapping.values()) - {
    # NOTE: we exclude Vina in this analysis since it often is missing predictions due to timeouts
    "P2Rank-Vina",
}

#### Load test results for each baseline method

In [None]:
# load and report test results for each baseline method
for config in [""]:
    for dataset in datasets:
        for method in baseline_methods:
            for repeat_index in range(1, max_num_repeats_per_method + 1):
                method_title = method_mapping[method]

                if not os.path.exists(
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
                    ]
                ):
                    continue

                globals()[f"{method}_{dataset}{config}_bust_results_{repeat_index}"] = pd.read_csv(
                    globals()[
                        f"{method}_{dataset}{config}_bust_results_csv_filepath_{repeat_index}"
                    ]
                )
                globals()[
                    f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"
                ] = globals()[f"{method}_{dataset}{config}_bust_results_{repeat_index}"][
                    BUST_TEST_COLUMNS + ["rmsd", "centroid_distance", "inchi_crystal"]
                ]
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "pb_valid"
                ] = (
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
                    .iloc[:, 1:-3]
                    .all(axis=1)
                )
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "crmsd_≤_1å"
                ] = (
                    globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                        "centroid_distance"
                    ]
                    < 1
                )

                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "method"
                ] = method
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "post-processing"
                ] = ("energy minimization" if config == "_relaxed" else "none")
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "dataset"
                ] = dataset
                globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"].loc[
                    :, "docked_ligand_successfully_loaded"
                ] = globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"][
                    ["mol_pred_loaded", "mol_true_loaded", "mol_cond_loaded"]
                ].all(
                    axis=1
                )

#### Define helper functions

In [None]:
def assign_method_index(method: str) -> str:
    """
    Assign method index for plotting.

    :param method: Method name.
    :return: Method index.
    """
    return list(method_mapping.keys()).index(method)


def categorize_method(method: str) -> str:
    """
    Categorize method for plotting.

    :param method: Method name.
    :return: Method category.
    """
    return method_category_mapping.get(method, "Misc")


def find_closest_inchi(
    target_inchi: str, inchi_list: List[str], candidate_fp_cache: Optional[Dict[str, Any]] = None
) -> Tuple[str, float]:
    """
    Find the closest InChI string to the target InChI string from a list of InChI strings.

    :param target_inchi: Target InChI string.
    :param inchi_list: List of InChI strings.
    :param candidate_fp_cache: Optional cache of candidate fingerprints.
    :return: Closest InChI string and its Tanimoto similarity.
    """
    target_mol = Chem.MolFromInchi(target_inchi)
    target_fp = AllChem.GetMorganFingerprintAsBitVect(target_mol, radius=2, nBits=2048)

    best_match = None
    highest_similarity = 0

    for candidate_inchi in tqdm(inchi_list, desc="Finding closest InChI"):
        if candidate_inchi in candidate_fp_cache:
            candidate_fp = candidate_fp_cache[candidate_inchi]
        else:
            candidate_mol = Chem.MolFromInchi(candidate_inchi)
            if candidate_mol is None:
                continue
            candidate_fp = AllChem.GetMorganFingerprintAsBitVect(
                candidate_mol, radius=2, nBits=2048
            )
            candidate_fp_cache[candidate_inchi] = candidate_fp

        similarity = TanimotoSimilarity(target_fp, candidate_fp)

        if similarity == 1.0:
            return candidate_inchi, similarity
        elif similarity > highest_similarity:
            highest_similarity = similarity
            best_match = candidate_inchi

    return best_match, highest_similarity

#### Standardize metrics

In [None]:
# load and organize the results CSVs
for repeat_index in range(1, max_num_repeats_per_method + 1):
    globals()[f"results_table_{repeat_index}"] = pd.concat(
        [
            globals()[f"{method}_{dataset}{config}_bust_results_table_{repeat_index}"]
            for dataset in datasets
            for method in baseline_methods
            for config in [""]
            if f"{method}_{dataset}{config}_bust_results_table_{repeat_index}" in globals()
        ]
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method_category"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(categorize_method)
    globals()[f"results_table_{repeat_index}"].loc[:, "method_assignment_index"] = globals()[
        f"results_table_{repeat_index}"
    ]["method"].apply(assign_method_index)
    globals()[f"results_table_{repeat_index}"].loc[:, "crmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "crmsd_≤_1å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_≤_2å"].fillna(False)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold_and_pb_valid"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "rmsd_within_threshold"]
    ) & (globals()[f"results_table_{repeat_index}"].loc[:, "pb_valid"].fillna(False))
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å & PB-Valid"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold_and_pb_valid"]
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "cRMSD ≤ 1 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "crmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "RMSD ≤ 2 Å"] = (
        globals()[f"results_table_{repeat_index}"]
        .loc[:, "rmsd_within_threshold"]
        .fillna(False)
        .astype(int)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "dataset"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "dataset"].map(dataset_mapping)
    )
    globals()[f"results_table_{repeat_index}"].loc[:, "method"] = (
        globals()[f"results_table_{repeat_index}"].loc[:, "method"].map(method_mapping)
    )

#### Load Ligand Expo data

In [None]:
CCD_COMPONENTS_FILEPATH = os.path.join("..", "data", "ccd_data", "components.cif")
CCD_COMPONENTS_INCHI_FILEPATH = os.path.join("..", "data", "ccd_data", "components_inchi.json")

# load all InChI strings in the PDB Chemical Component Dictionary (CCD)

CCD_COMPONENTS_INCHI = None

if os.path.exists(CCD_COMPONENTS_INCHI_FILEPATH):
    print(f"Loading CCD component InChI strings from {CCD_COMPONENTS_INCHI_FILEPATH}.")
    with open(CCD_COMPONENTS_INCHI_FILEPATH) as f:
        CCD_COMPONENTS_INCHI = json.load(f)
elif os.path.exists(CCD_COMPONENTS_FILEPATH):
    print(
        f"Loading CCD components from {CCD_COMPONENTS_FILEPATH} to extract all available InChI strings (~3 minutes, one-time only)."
    )
    CCD_COMPONENTS = ccd_reader.read_pdb_components_file(
        CCD_COMPONENTS_FILEPATH,
        sanitize=False,  # Reduce loading time
    )
    print(
        f"Saving CCD component InChI strings to {CCD_COMPONENTS_INCHI_FILEPATH} (one-time only)."
    )
    with open(CCD_COMPONENTS_INCHI_FILEPATH, "w") as f:
        CCD_COMPONENTS_INCHI = {
            CCD_COMPONENTS[ccd_code].component.inchi: {
                "atoms_ids": CCD_COMPONENTS[ccd_code].component.atoms_ids,
                "formula": CCD_COMPONENTS[ccd_code].component.formula,
                "id": CCD_COMPONENTS[ccd_code].component.id,
                "inchikey": CCD_COMPONENTS[ccd_code].component.inchikey,
                "modified_date": str(CCD_COMPONENTS[ccd_code].component.modified_date),
                "name": CCD_COMPONENTS[ccd_code].component.name,
                "number_atoms": CCD_COMPONENTS[ccd_code].component.number_atoms,
                "pdb_id": CCD_COMPONENTS[ccd_code].component.ccd_cif_block.find(
                    "_chem_comp.", ["pdbx_model_coordinates_db_code"]
                )[0][0],
                "released": CCD_COMPONENTS[ccd_code].component.released,
                "smiles": Chem.MolToSmiles(CCD_COMPONENTS[ccd_code].component.mol_no_h),
                "weight": CCD_COMPONENTS[ccd_code].component._cif_properties.weight,
            }
            for ccd_code in CCD_COMPONENTS
        }
        json.dump(CCD_COMPONENTS_INCHI, f)

#### Identify failure modes

In [None]:
# find InChI keys of ligands for which the correct (e.g., RMSD ≤ 2 Å & PB-Valid) binding conformation was not found by any method
ccd_components_inchi_fp_cache = dict()
ccd_components_inchi_keys = list(CCD_COMPONENTS_INCHI)

for dataset in datasets:
    # NOTE: for DockGen, we consider centroid RMSD (cRMSD) ≤ 1 Å as a surrogate docking success criterion
    docking_success_column = "cRMSD ≤ 1 Å" if dataset == "dockgen" else "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
            == dataset_mapping[dataset]
        ]

        globals()[f"ligands_docked_by_any_method_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "inchi_crystal",
            ].unique()
        )
        globals()[f"ligands_not_docked_by_any_method_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "inchi_crystal"].isin(
                    globals()[f"ligands_docked_by_any_method_{repeat_index}"]
                ),
                "inchi_crystal",
            ].unique()
        )

        # collect metadata of ligands for which the correct (e.g., RMSD ≤ 2 Å & PB-Valid) binding conformation was not found by all methods
        globals()[f"{dataset}_failed_ligands_{repeat_index}"] = []
        for ligand in tqdm(
            globals()[f"ligands_not_docked_by_any_method_{repeat_index}"],
            desc=f"Processing {dataset} failed ligands",
        ):
            ligand_results = dataset_results_table.loc[
                dataset_results_table.loc[:, "inchi_crystal"] == ligand
            ]
            ligand_result_methods = set(ligand_results.loc[:, "method"].unique())
            if ligand_result_methods >= ligand_prediction_methods:
                row = ligand_results.iloc[0]
                row_inchi = CCD_COMPONENTS_INCHI.get(row.inchi_crystal)
                if not row_inchi:
                    # (slowly) find the closest matching CCD component InChI string if necessary (e.g., for DockGen)
                    closest_inchi, closest_inchi_similarity = find_closest_inchi(
                        row.inchi_crystal,
                        ccd_components_inchi_keys,
                        candidate_fp_cache=ccd_components_inchi_fp_cache,
                    )
                    row_inchi = CCD_COMPONENTS_INCHI[closest_inchi]
                    print(
                        f"Found closest CCD component InChI string (similarity={closest_inchi_similarity}) for: {ligand}"
                    )
                if row_inchi:
                    # if row_inchi and row_inchi["modified_date"] > method_max_training_cutoff_date:
                    print(f"Failed {dataset_mapping[dataset]} docking case for: {ligand}")
                    print(f"CCD component InChI metadata: {row_inchi}\n")
                    globals()[f"{dataset}_failed_ligands_{repeat_index}"].append(row_inchi)

#### Record and plot failure mode metadata

In [None]:
# plot statistics of the failed ligands
pd.options.mode.copy_on_write = False

for dataset in datasets:
    for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
        failed_ligands_df = pd.DataFrame(globals()[f"{dataset}_failed_ligands_{repeat_index}"])
        if failed_ligands_df.empty:
            print(f"No failed ligands for {dataset}.")
            continue

        failed_ligands_df["weight"] = failed_ligands_df["weight"].astype(float)
        failed_ligands_df["number_atoms"] = failed_ligands_df["number_atoms"].astype(int)

        failed_ligands_df.to_csv(f"{dataset}_failed_ligands.csv", index=False)

        sns.histplot(failed_ligands_df["number_atoms"].values)
        plt.savefig(f"{dataset}_failed_ligands_number_atoms.png")

        sns.histplot(failed_ligands_df["weight"].values)
        plt.savefig(f"{dataset}_failed_ligands_weight.png")

        plt.close("all")

#### Find commonalities among the failure modes of each dataset

In [None]:
# plot functional keyword statistics of the failed ligands across all datasets
for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
    all_failed_ligands_df = []
    for dataset in datasets:
        failed_ligands_df = pd.DataFrame(globals()[f"{dataset}_failed_ligands_{repeat_index}"])
        failed_ligands_df["dataset"] = dataset_mapping[dataset]
        all_failed_ligands_df.append(failed_ligands_df)
    all_failed_ligands_df = pd.concat(all_failed_ligands_df, ignore_index=True)

    if all_failed_ligands_df.empty:
        print("No failed ligands for any dataset.")
        continue

    failed_ligand_function_annotations = []
    for pdb_id in all_failed_ligands_df["pdb_id"]:
        if pdb_id == "?":
            continue
        pdb_id_info = pypdb.get_all_info(pdb_id)
        failed_ligand_function_annotations.append(
            # NOTE: these represent functional keywords
            pdb_id_info["struct_keywords"]["pdbx_keywords"]
            .lower()
            .split(", ")[0]
        )

    failed_ligand_function_annotation_counts = Counter(failed_ligand_function_annotations)
    df = pd.DataFrame(
        failed_ligand_function_annotation_counts.items(),
        columns=["Keyword", "Frequency"],
    )
    df["Frequency"] = df["Frequency"].astype(int)
    df = df.sort_values(by="Frequency", ascending=False)

    # plot with Seaborn
    plt.figure(figsize=(10, 6))
    sns.barplot(data=df, x="Frequency", y="Keyword", palette="viridis")

    max_freq = df["Frequency"].max()
    plt.xticks(ticks=range(0, max_freq + 1), labels=range(0, max_freq + 1))

    plt.xlabel("Frequency")
    plt.ylabel("Protein Function")

    plt.tight_layout()
    plt.savefig(f"failed_ligands_functional_keywords_{repeat_index}.png")
    plt.show()

    plt.close("all")

#### Identify AlphaFold 3's failure modes

In [None]:
# find ligands that AlphaFold 3 failed to correctly predict
ccd_components_inchi_fp_cache = dict()
ccd_components_inchi_keys = list(CCD_COMPONENTS_INCHI)

for dataset in datasets:
    # NOTE: for DockGen, we consider centroid RMSD (cRMSD) ≤ 1 Å as a surrogate docking success criterion
    docking_success_column = "cRMSD ≤ 1 Å" if dataset == "dockgen" else "RMSD ≤ 2 Å & PB-Valid"

    for repeat_index in range(1, max_num_repeats_per_method + 1):
        dataset_results_table = globals()[f"results_table_{repeat_index}"].loc[
            (
                globals()[f"results_table_{repeat_index}"].loc[:, "dataset"]
                == dataset_mapping[dataset]
            )
            & (globals()[f"results_table_{repeat_index}"].loc[:, "method"] == "AF3")
        ]

        globals()[f"ligands_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                (dataset_results_table.loc[:, docking_success_column]).astype(bool),
                "inchi_crystal",
            ].unique()
        )
        globals()[f"ligands_not_docked_by_af3_{repeat_index}"] = set(
            dataset_results_table.loc[
                ~dataset_results_table.loc[:, "inchi_crystal"].isin(
                    globals()[f"ligands_docked_by_af3_{repeat_index}"]
                ),
                "inchi_crystal",
            ].unique()
        )

        # collect metadata of ligands for which the correct (e.g., RMSD ≤ 2 Å & PB-Valid) binding conformation was not found by all methods
        globals()[f"{dataset}_af3_failed_ligands_{repeat_index}"] = []
        for ligand in tqdm(
            globals()[f"ligands_not_docked_by_af3_{repeat_index}"],
            desc=f"Processing {dataset} AF3 failed ligands",
        ):
            ligand_results = dataset_results_table.loc[
                dataset_results_table.loc[:, "inchi_crystal"] == ligand
            ]
            row = ligand_results.iloc[0]
            row_inchi = CCD_COMPONENTS_INCHI.get(row.inchi_crystal)
            if not row_inchi:
                # (slowly) find the closest matching CCD component InChI string if necessary (e.g., for DockGen)
                closest_inchi, closest_inchi_similarity = find_closest_inchi(
                    row.inchi_crystal,
                    ccd_components_inchi_keys,
                    candidate_fp_cache=ccd_components_inchi_fp_cache,
                )
                row_inchi = CCD_COMPONENTS_INCHI[closest_inchi]
                print(
                    f"Found closest CCD component InChI string (similarity={closest_inchi_similarity}) for: {ligand}"
                )
            if row_inchi:
                # if row_inchi and row_inchi["modified_date"] > method_max_training_cutoff_date:
                print(f"Failed {dataset_mapping[dataset]} AF3 docking case for: {ligand}")
                print(f"CCD component InChI metadata: {row_inchi}\n")
                globals()[f"{dataset}_af3_failed_ligands_{repeat_index}"].append(row_inchi)

#### Record and plot failure mode metadata

In [None]:
# plot functional keyword statistics of AlphaFold 3's failed ligands across all datasets
all_failed_ligands_df = []
for dataset in datasets:
    for repeat_index in [1]:  # NOTE: for now, we only consider the first repeat
        failed_ligands_df = pd.DataFrame(globals()[f"{dataset}_af3_failed_ligands_{repeat_index}"])
        failed_ligands_df["dataset"] = dataset_mapping[dataset]
        failed_ligands_df["repeat_index"] = repeat_index
        all_failed_ligands_df.append(failed_ligands_df)
all_failed_ligands_df = pd.concat(all_failed_ligands_df, ignore_index=True)

failed_ligand_function_annotations = []
for pdb_id in all_failed_ligands_df["pdb_id"]:
    if pdb_id == "?":
        continue
    pdb_id_info = pypdb.get_all_info(pdb_id)
    if not pdb_id_info:
        continue
    failed_ligand_function_annotations.append(
        # NOTE: these represent functional keywords
        pdb_id_info["struct_keywords"]["pdbx_keywords"]
        .lower()
        .split(", ")[0]
    )

failed_ligand_function_annotation_counts = Counter(failed_ligand_function_annotations)
df = pd.DataFrame(
    failed_ligand_function_annotation_counts.items(),
    columns=["Keyword", "Frequency"],
)
df["Frequency"] = df["Frequency"].astype(int)
df = df.sort_values(by="Frequency", ascending=False)

# plot with Seaborn
plt.figure(figsize=(10, 6))
sns.barplot(data=df, x="Frequency", y="Keyword", palette="viridis")

max_freq = df["Frequency"].max()
plt.xticks(ticks=range(0, max_freq + 1), labels=range(0, max_freq + 1))

plt.xlabel("Frequency")
plt.ylabel("Protein Function")

plt.tight_layout()
plt.savefig("failed_af3_ligands_functional_keywords.png")
plt.show()

plt.close("all")

#### Study PDB statistics of transferases

In [None]:
# combine all CSV files from a custom transferase PDB report
transferase_pdb_report_dir = os.path.join("pdb_reports", "transferase")
transferase_pdb_report_files = [
    os.path.join(transferase_pdb_report_dir, f)
    for f in os.listdir(transferase_pdb_report_dir)
    if f.endswith(".csv")
]

transferase_pdb_report_dfs = []
for transferase_pdb_report_file in transferase_pdb_report_files:
    transferase_pdb_report_dfs.append(pd.read_csv(transferase_pdb_report_file, skiprows=1))
transferase_pdb_report_df = pd.concat(transferase_pdb_report_dfs, ignore_index=True)

# analyze and plot statistics of the custom transferase PDB report
transferase_pdb_report_df["Refinement Resolution (Å)"] = pd.to_numeric(
    transferase_pdb_report_df["Refinement Resolution (Å)"].str.replace(",", ""), errors="coerce"
)
transferase_pdb_report_df["Deposition Date"] = pd.to_datetime(
    transferase_pdb_report_df["Deposition Date"]
)

transferase_pdb_report_df.to_csv("transferase_pdb_report.csv", index=False)

sns.histplot(transferase_pdb_report_df["Refinement Resolution (Å)"].values)
plt.xlim(0, 10)
plt.savefig("transferase_pdb_report_resolution.png")

plt.close("all")

sns.histplot(transferase_pdb_report_df["Deposition Date"].values)
plt.savefig("transferase_pdb_report_deposition_date.png")

plt.close("all")