## PoseBusters Benchmark Method Interaction Analysis Plotting

#### Import packages

In [None]:
import glob
import os
import shutil
import subprocess
import tempfile

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from beartype import beartype
from beartype.typing import Any, Literal
from Bio.PDB import PDBIO, PDBParser, Select
from posecheck import PoseCheck
from rdkit import Chem
from tqdm import tqdm

from posebench import resolve_method_ligand_dir, resolve_method_protein_dir
from posebench.utils.data_utils import count_num_residues_in_pdb_file

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "diffdock",
    "diffdockv1",
    "fabind",
    "dynamicbind",
    "neuralplexer",
    "rfaa",
    "chai-lab",
    "tulip",
    "vina_diffdock",
    "vina_p2rank",
    "consensus_ensemble",
]
max_num_repeats_per_method = (
    1  # NOTE: Here, to simplify the analysis, we only consider the first run of each method
)

pb_set_dir = os.path.join(
    "..",
    "data",
    "posebusters_benchmark_set",
)
assert os.path.exists(
    pb_set_dir
), "Please download the PoseBusters Benchmark set from `https://zenodo.org/records/13858866` before proceeding."

# Mappings
method_mapping = {
    "diffdock": "DiffDock-L",
    "diffdockv1": "-> w/o SCT",
    "fabind": "FABind",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "rfaa": "RoseTTAFold-AA",
    "chai-lab": "Chai-1",
    "tulip": "TULIP",
    "vina_diffdock": "DiffDock-L-Vina",
    "vina_p2rank": "P2Rank-Vina",
    "consensus_ensemble": "Ensemble (Con)",
}

MAX_POSEBUSTERS_BENCHMARK_ANALYSIS_PROTEIN_SEQUENCE_LENGTH = 2000  # Only PoseBusters Benchmark targets with protein sequences below this threshold can be analyzed

#### Define utility functions

In [None]:
class ProteinSelect(Select):
    """A class to select only protein residues from a PDB file."""

    def accept_residue(self, residue: Any):
        """
        Only accept residues that are part of a protein (e.g., standard amino acids).

        :param residue: The residue to check.
        :return: True if the residue is part of a protein, False otherwise.
        """
        return residue.id[0] == " "  # NOTE: `HETATM` flag must be a blank for protein residues


class LigandSelect(Select):
    """A class to select only ligand residues from a PDB file."""

    def accept_residue(self, residue: Any):
        """
        Only accept residues that are part of a ligand.

        :param residue: The residue to check.
        :return: True if the residue is part of a ligand, False otherwise.
        """
        return residue.id[0] != " "  # NOTE: `HETATM` flag must be a filled for ligand residues


@beartype
def create_temp_pdb_with_only_molecule_type_residues(
    input_pdb_filepath: str,
    molecule_type: Literal["protein", "ligand"],
    add_element_types: bool = False,
) -> str:
    """
    Create a temporary PDB file with only residues of a chosen molecule type.

    :param input_pdb_filepath: The input PDB file path.
    :param molecule_type: The molecule type to keep (either "protein" or "ligand").
    :param add_element_types: Whether to add element types to the atoms.
    :return: The temporary PDB file path.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(molecule_type, input_pdb_filepath)

    io = PDBIO()
    io.set_structure(structure)

    # create a temporary PDB filepdb_name
    temp_pdb_filepath = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
    io.save(
        temp_pdb_filepath.name, ProteinSelect() if molecule_type == "protein" else LigandSelect()
    )

    if add_element_types:
        with open(temp_pdb_filepath.name.replace(".pdb", "_elem.pdb"), "w") as f:
            subprocess.run(  # nosec
                f"pdb_element {temp_pdb_filepath.name}",
                shell=True,
                check=True,
                stdout=f,
            )
        shutil.move(temp_pdb_filepath.name.replace(".pdb", "_elem.pdb"), temp_pdb_filepath.name)

    return temp_pdb_filepath.name

#### Compute interaction fingerprints

##### Analyze `PoseBusters Benchmark` set interactions as a baseline

In [None]:
if not os.path.exists("posebusters_benchmark_interaction_dataframes.h5"):
    pb_protein_ligand_filepath_pairs = []
    for item in os.listdir(pb_set_dir):
        ligand_item_path = os.path.join(pb_set_dir, item)
        if os.path.isdir(ligand_item_path):
            protein_filepath = os.path.join(ligand_item_path, f"{item}_protein.pdb")
            ligand_filepath = os.path.join(ligand_item_path, f"{item}_ligand.sdf")
            if os.path.exists(protein_filepath) and os.path.exists(ligand_filepath):
                pb_protein_ligand_filepath_pairs.append((protein_filepath, ligand_filepath))

    pc = (
        PoseCheck()
    )  # NOTE: despite what `PoseCheck` might say, `reduce` should be available in the `PoseBench` environment
    pb_protein_ligand_interaction_dfs = []
    for protein_filepath, ligand_filepath in tqdm(
        pb_protein_ligand_filepath_pairs, desc="Processing PoseBusters Benchmark set"
    ):
        temp_protein_filepath = create_temp_pdb_with_only_molecule_type_residues(
            protein_filepath, molecule_type="protein"
        )
        pc.load_protein_from_pdb(temp_protein_filepath)
        pc.load_ligands_from_sdf(ligand_filepath)
        pb_protein_ligand_interaction_dfs.append(pc.calculate_interactions())

        # NOTE: we iteratively save the interaction dataframes to an HDF5 file
        with pd.HDFStore("posebusters_benchmark_interaction_dataframes.h5") as store:
            for i, df in enumerate(pb_protein_ligand_interaction_dfs):
                store.put(f"df_{i}", df)

##### Analyze interactions of each method

In [None]:
# calculate and cache PoseBusters Benchmark interaction statistics for each baseline method
dataset = "posebusters_benchmark"
for method in baseline_methods:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        method_title = method_mapping[method]

        if not os.path.exists(f"{method}_{dataset}_interaction_dataframes_{repeat_index}.h5"):
            v1_baseline = method == "diffdockv1"
            base_method = (
                method
                if method == "consensus_ensemble"
                else method.replace("v1", "").split("_")[0]
            )
            vina_binding_site_method = method.split("_")[-1]

            method_protein_dir = os.path.join(
                resolve_method_protein_dir(
                    base_method,
                    dataset,
                    repeat_index,
                    pocket_only_baseline=False,
                ),
            )
            method_ligand_dir = os.path.join(
                "..",
                resolve_method_ligand_dir(
                    base_method,
                    dataset,
                    vina_binding_site_method,
                    repeat_index,
                    pocket_only_baseline=False,
                    v1_baseline=v1_baseline,
                ),
            )

            if not os.path.isdir(method_protein_dir):
                method_protein_dir = os.path.join("..", method_protein_dir)

            if not os.path.isdir(method_ligand_dir):
                # NOTE: this handles DynamicBind's output formats
                ligand_dirs = [
                    p
                    for p in glob.glob(f"{method_ligand_dir}_*_1/index0_idx_0")
                    if "pocket_only" not in p
                ]
            else:
                ligand_dirs = [
                    p
                    for p in os.listdir(method_ligand_dir)
                    if "_relaxed" not in p and "bust_results" not in p
                ]

            posebusters_protein_ligand_complex_filepaths = []
            for item in ligand_dirs:
                protein_item_path = os.path.join(method_protein_dir, f"{item}*")
                ligand_item_path = os.path.join(method_ligand_dir, item)

                complex_filepaths = []
                if method == "dynamicbind":
                    protein_item_path = glob.glob(os.path.join(item, "rank1_receptor*.pdb"))
                    ligand_item_path = glob.glob(os.path.join(item, "rank1_ligand*.sdf"))
                    if not protein_item_path or not ligand_item_path:
                        continue
                    complex_filepaths = [protein_item_path[0], ligand_item_path[0]]
                elif os.path.isfile(ligand_item_path) and "_relaxed" not in item:
                    # NOTE: this handles FABind's output formats
                    complex_filepaths = glob.glob(
                        os.path.join(
                            os.path.dirname(protein_item_path),
                            "_".join(item.split("_")[:2]) + ".pdb",
                        )
                    ) + glob.glob(ligand_item_path)
                elif os.path.isdir(ligand_item_path) and "_relaxed" not in item:
                    protein_pdb_filepath, ligand_sdf_filepath = None, None
                    complex_filepaths = glob.glob(
                        os.path.join(protein_item_path, "*rank1*.pdb")
                    ) + glob.glob(os.path.join(ligand_item_path, "*rank1.sdf"))
                    if not len(complex_filepaths) == 2:
                        # NOTE: this handles DiffDock and TULIP's output formats
                        complex_filepaths = glob.glob(f"{protein_item_path}.pdb") + glob.glob(
                            os.path.join(ligand_item_path, "rank1.sdf")
                        )
                    if not len(complex_filepaths) == 2:
                        # NOTE: this handles RFAA's output formats
                        complex_filepaths = glob.glob(
                            os.path.join(protein_item_path, "*_protein.pdb")
                        ) + glob.glob(os.path.join(ligand_item_path, "*_ligand.sdf"))
                    if not len(complex_filepaths) == 2:
                        # NOTE: this handles Vina's output formats
                        complex_filepaths = glob.glob(f"{protein_item_path}.pdb") + glob.glob(
                            os.path.join(ligand_item_path, f"{item}.sdf")
                        )

                    if method == "neuralplexer":
                        # NOTE: this handles NeuralPlexer's output formats
                        complex_filepaths = [
                            p
                            for p in glob.glob(
                                os.path.join(protein_item_path.removesuffix("*"), "*rank1_*.pdb")
                            )
                            + glob.glob(
                                os.path.join(ligand_item_path.removesuffix("*"), "*rank1_*.sdf")
                            )
                            if "relaxed" not in p and "aligned" not in p
                        ]
                    elif method == "chai-lab":
                        # NOTE: this handles Chai-1's output formats
                        complex_filepaths = [
                            p
                            for p in glob.glob(
                                os.path.join(
                                    protein_item_path.removesuffix("*"),
                                    "pred.model_idx_0_protein.pdb",
                                )
                            )
                            + glob.glob(
                                os.path.join(
                                    ligand_item_path.removesuffix("*"),
                                    "pred.model_idx_0_ligand.sdf",
                                )
                            )
                        ]
                    elif method == "consensus_ensemble":
                        # NOTE: this handles the Consensus Ensemble's output formats
                        complex_filepaths = glob.glob(
                            os.path.join(protein_item_path.removesuffix("*"), "*.pdb")
                        ) + glob.glob(os.path.join(ligand_item_path.removesuffix("*"), "*.sdf"))

                if not len(complex_filepaths) == 2:
                    continue

                for file in complex_filepaths:
                    if file.endswith(".pdb"):
                        protein_pdb_filepath = file
                    elif file.endswith(".sdf"):
                        ligand_sdf_filepath = file
                if protein_pdb_filepath is not None and ligand_sdf_filepath is not None:
                    posebusters_protein_ligand_complex_filepaths.append(
                        (protein_pdb_filepath, ligand_sdf_filepath)
                    )
                else:
                    print(
                        f"Warning: Could not find `rank1` protein-ligand complex files for {item}. Skipping..."
                    )

            pc = (
                PoseCheck()
            )  # NOTE: despite what `PoseCheck` might say, `reduce` should be available in the `PoseBench` environment
            posebusters_protein_ligand_interaction_dfs = []
            for protein_ligand_complex_filepath in tqdm(
                posebusters_protein_ligand_complex_filepaths,
                desc=f"Processing interactions for {method_title}",
            ):
                try:
                    protein_filepath, ligand_filepath = protein_ligand_complex_filepath
                    num_residues_in_target_protein = count_num_residues_in_pdb_file(
                        protein_filepath
                    )
                    if (
                        num_residues_in_target_protein
                        > MAX_POSEBUSTERS_BENCHMARK_ANALYSIS_PROTEIN_SEQUENCE_LENGTH
                    ):
                        print(
                            f"{method_title} target {protein_ligand_complex_filepath} has too many protein residues ({num_residues_in_target_protein} > {MAX_POSEBUSTERS_BENCHMARK_ANALYSIS_PROTEIN_SEQUENCE_LENGTH}) for `MDAnalysis` to fit into CPU memory. Skipping..."
                        )
                        continue
                    ligand_mol = Chem.MolFromMolFile(ligand_filepath)
                    pc.load_protein_from_pdb(protein_filepath)
                    pc.load_ligands_from_mols(
                        Chem.GetMolFrags(ligand_mol, asMols=True, sanitizeFrags=False)
                    )
                    posebusters_protein_ligand_interaction_dfs.append(pc.calculate_interactions())
                except Exception as e:
                    print(
                        f"Error processing {method_title} target {protein_ligand_complex_filepath} due to: {e}. Skipping..."
                    )
                    continue

                # NOTE: we iteratively save the interaction dataframes to an HDF5 file
                with pd.HDFStore(
                    f"{method}_posebusters_benchmark_interaction_dataframes_{repeat_index}.h5"
                ) as store:
                    for i, df in enumerate(posebusters_protein_ligand_interaction_dfs):
                        store.put(f"df_{i}", df)

#### Plot interaction statistics for each method

In [None]:
dfs = []


# define a function to process each method
def process_method(file_path, category):
    interactions = []
    with pd.HDFStore(file_path) as store:
        for key in store.keys():
            for row_index in range(len(store[key])):
                interaction_types = [
                    interaction[2] for interaction in store[key].iloc[row_index].keys().tolist()
                ]
                num_hb_acceptors = interaction_types.count("HBAcceptor")
                num_hb_donors = interaction_types.count("HBDonor")
                num_vdw_contacts = interaction_types.count("VdWContact")
                num_hydrophobic = interaction_types.count("Hydrophobic")
                interactions.append(
                    {
                        "Hydrogen Bond Acceptors": num_hb_acceptors,
                        "Hydrogen Bond Donors": num_hb_donors,
                        "Van der Waals Contacts": num_vdw_contacts,
                        "Hydrophobic Interactions": num_hydrophobic,
                    }
                )
    df_rows = []
    for interaction in interactions:
        for interaction_type, num_interactions in interaction.items():
            df_rows.append(
                {
                    "Category": category,
                    "InteractionType": interaction_type,
                    "NumInteractions": num_interactions,
                }
            )
    return pd.DataFrame(df_rows)


# load data from files
for method in baseline_methods:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        method_title = method_mapping[method]
        file_path = f"{method}_posebusters_benchmark_interaction_dataframes_{repeat_index}.h5"
        if os.path.exists(file_path):
            dfs.append(process_method(file_path, method_title))

if os.path.exists("posebusters_benchmark_interaction_dataframes.h5"):
    dfs.append(process_method("posebusters_benchmark_interaction_dataframes.h5", "Reference"))

# combine statistics
assert len(dfs) > 0, "No interaction dataframes found."
df = pd.concat(dfs)

# define font properties
plt.rcParams["font.size"] = 14
plt.rcParams["axes.labelsize"] = 16

# plot statistics
fig, axes = plt.subplots(2, 2, figsize=(34, 14), sharey=False)

interaction_types = [
    "Hydrogen Bond Acceptors",
    "Hydrogen Bond Donors",
    "Van der Waals Contacts",
    "Hydrophobic Interactions",
]
plot_types = ["box", "box", "violin", "violin"]

for ax, interaction, plot_type in zip(axes.flatten(), interaction_types, plot_types):
    data = df[df["InteractionType"] == interaction]

    if plot_type == "box":
        sns.boxplot(data=data, x="Category", y="NumInteractions", ax=ax, showfliers=True)
        sns.stripplot(
            data=data,
            x="Category",
            y="NumInteractions",
            ax=ax,
            color="black",
            alpha=0.3,
            jitter=True,
        )
    elif plot_type == "violin":
        sns.violinplot(data=data, x="Category", y="NumInteractions", ax=ax)
        sns.stripplot(
            data=data,
            x="Category",
            y="NumInteractions",
            ax=ax,
            color="black",
            alpha=0.3,
            jitter=True,
        )

    ax.set_title(interaction)
    ax.set_ylabel("No. Interactions")
    ax.set_xlabel("")
    ax.grid(True)

plt.tight_layout()
plt.savefig("posebusters_benchmark_method_interaction_analysis.png", dpi=300)
plt.show()