## CASP15 Method Interaction Analysis Plotting

#### Import packages

In [None]:
import glob
import os
import shutil
import subprocess
import tempfile

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from beartype import beartype
from beartype.typing import Any, Literal
from Bio.PDB import PDBIO, PDBParser, Select
from posecheck import PoseCheck
from rdkit import Chem
from tqdm import tqdm

from posebench.utils.data_utils import count_num_residues_in_pdb_file

#### Configure packages

In [None]:
pd.options.mode.copy_on_write = True

#### Define constants

In [None]:
# General variables
baseline_methods = [
    "diffdock",
    "diffdockv1",
    "dynamicbind",
    "neuralplexer",
    "neuralplexer_no_ilcl",
    "rfaa",
    "chai-lab",
    "tulip",
    "vina_diffdock",
    "vina_p2rank",
    "consensus_ensemble",
]
max_num_repeats_per_method = (
    1  # NOTE: Here, to simplify the analysis, we only consider the first run of each method
)

casp15_set_dir = os.path.join(
    "..",
    "data",
    "casp15_set",
    "targets",
)
assert os.path.exists(
    casp15_set_dir
), "Please download the (public) CASP15 set from `https://zenodo.org/records/13858866` before proceeding."

# Mappings
method_mapping = {
    "diffdock": "DiffDock-L",
    "diffdockv1": "-> w/o SCT",
    "dynamicbind": "DynamicBind",
    "neuralplexer": "NeuralPLexer",
    "neuralplexer_no_ilcl": "-> w/o ILCL",
    "rfaa": "RoseTTAFold-AA",
    "chai-lab": "Chai-1",
    "tulip": "TULIP",
    "vina_diffdock": "DiffDock-L-Vina",
    "vina_p2rank": "P2Rank-Vina",
    "consensus_ensemble": "Ensemble (Con)",
}

CASP15_ANALYSIS_TARGETS_TO_SKIP = [
    "T1170"
]  # NOTE: these will be skipped since they were not scoreable
MAX_CASP15_ANALYSIS_PROTEIN_SEQUENCE_LENGTH = (
    2000  # Only CASP15 targets with protein sequences below this threshold can be analyzed
)

#### Define utility functions

In [None]:
class ProteinSelect(Select):
    """A class to select only protein residues from a PDB file."""

    def accept_residue(self, residue: Any):
        """
        Only accept residues that are part of a protein (e.g., standard amino acids).

        :param residue: The residue to check.
        :return: True if the residue is part of a protein, False otherwise.
        """
        return residue.id[0] == " "  # NOTE: `HETATM` flag must be a blank for protein residues


class LigandSelect(Select):
    """A class to select only ligand residues from a PDB file."""

    def accept_residue(self, residue: Any):
        """
        Only accept residues that are part of a ligand.

        :param residue: The residue to check.
        :return: True if the residue is part of a ligand, False otherwise.
        """
        return residue.id[0] != " "  # NOTE: `HETATM` flag must be a filled for ligand residues


@beartype
def create_temp_pdb_with_only_molecule_type_residues(
    input_pdb_filepath: str,
    molecule_type: Literal["protein", "ligand"],
    add_element_types: bool = False,
) -> str:
    """
    Create a temporary PDB file with only residues of a chosen molecule type.

    :param input_pdb_filepath: The input PDB file path.
    :param molecule_type: The molecule type to keep (either "protein" or "ligand").
    :param add_element_types: Whether to add element types to the atoms.
    :return: The temporary PDB file path.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(molecule_type, input_pdb_filepath)

    io = PDBIO()
    io.set_structure(structure)

    # create a temporary PDB filepdb_name
    temp_pdb_filepath = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
    io.save(
        temp_pdb_filepath.name, ProteinSelect() if molecule_type == "protein" else LigandSelect()
    )

    if add_element_types:
        with open(temp_pdb_filepath.name.replace(".pdb", "_elem.pdb"), "w") as f:
            subprocess.run(  # nosec
                f"pdb_element {temp_pdb_filepath.name}",
                shell=True,
                check=True,
                stdout=f,
            )
        shutil.move(temp_pdb_filepath.name.replace(".pdb", "_elem.pdb"), temp_pdb_filepath.name)

    return temp_pdb_filepath.name

#### Compute interaction fingerprints

##### Analyze `CASP15` set interactions as a baseline

In [None]:
if not os.path.exists("casp15_interaction_dataframes.h5"):
    casp15_protein_ligand_complex_filepaths = []
    for item in os.listdir(casp15_set_dir):
        item_path = os.path.join(casp15_set_dir, item)
        if item.endswith("_lig.pdb") and item.split("_")[0] not in CASP15_ANALYSIS_TARGETS_TO_SKIP:
            casp15_protein_ligand_complex_filepaths.append(item_path)

    pc = (
        PoseCheck()
    )  # NOTE: despite what `PoseCheck` might say, `reduce` should be available in the `PoseBench` environment
    casp15_protein_ligand_interaction_dfs = []
    for protein_ligand_complex_filepath in tqdm(
        casp15_protein_ligand_complex_filepaths, desc="Processing CASP15 set"
    ):
        try:
            temp_protein_filepath = create_temp_pdb_with_only_molecule_type_residues(
                protein_ligand_complex_filepath, molecule_type="protein"
            )
            num_residues_in_target_protein = count_num_residues_in_pdb_file(temp_protein_filepath)
            if num_residues_in_target_protein > MAX_CASP15_ANALYSIS_PROTEIN_SEQUENCE_LENGTH:
                print(
                    f"CASP15 target {protein_ligand_complex_filepath} has too many protein residues ({num_residues_in_target_protein} > {MAX_CASP15_ANALYSIS_PROTEIN_SEQUENCE_LENGTH}) for `MDAnalysis` to fit into CPU memory. Skipping..."
                )
                continue
            temp_ligand_filepath = create_temp_pdb_with_only_molecule_type_residues(
                protein_ligand_complex_filepath, molecule_type="ligand"
            )
            ligand_mol = Chem.MolFromPDBFile(temp_ligand_filepath)
            pc.load_protein_from_pdb(temp_protein_filepath)
            pc.load_ligands_from_mols(
                Chem.GetMolFrags(ligand_mol, asMols=True, sanitizeFrags=False)
            )
            casp15_protein_ligand_interaction_dfs.append(pc.calculate_interactions())
        except Exception as e:
            print(
                f"Error processing CASP15 target {protein_ligand_complex_filepath} due to: {e}. Skipping..."
            )
            continue

        # NOTE: we iteratively save the interaction dataframes to an HDF5 file
        with pd.HDFStore("casp15_interaction_dataframes.h5") as store:
            for i, df in enumerate(casp15_protein_ligand_interaction_dfs):
                store.put(f"df_{i}", df)

##### Analyze interactions of each method

In [None]:
# calculate and cache CASP15 interaction statistics for each baseline method
for method in baseline_methods:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        method_title = method_mapping[method]

        if not os.path.exists(f"{method}_casp15_interaction_dataframes_{repeat_index}.h5"):
            method_casp15_set_dir = os.path.join(
                "..",
                "data",
                "test_cases",
                "casp15",
                f"top_{method}{'' if 'ensemble' in method else '_ensemble'}_predictions_{repeat_index}",
            )

            casp15_protein_ligand_complex_filepaths = []
            for item in os.listdir(method_casp15_set_dir):
                item_path = os.path.join(method_casp15_set_dir, item)
                if (
                    item.split("_")[0] not in CASP15_ANALYSIS_TARGETS_TO_SKIP
                    and os.path.isdir(item_path)
                    and "_relaxed" not in item
                ):
                    protein_pdb_filepath, ligand_sdf_filepath = None, None
                    complex_filepaths = glob.glob(
                        os.path.join(item_path, "*rank1*.pdb")
                    ) + glob.glob(os.path.join(item_path, "*rank1*.sdf"))
                    for file in complex_filepaths:
                        if file.endswith(".pdb"):
                            protein_pdb_filepath = file
                        elif file.endswith(".sdf"):
                            ligand_sdf_filepath = file
                    if protein_pdb_filepath is not None and ligand_sdf_filepath is not None:
                        casp15_protein_ligand_complex_filepaths.append(
                            (protein_pdb_filepath, ligand_sdf_filepath)
                        )
                    else:
                        raise FileNotFoundError(
                            f"Could not find `rank1` protein-ligand complex files for {item}"
                        )

            pc = (
                PoseCheck()
            )  # NOTE: despite what `PoseCheck` might say, `reduce` should be available in the `PoseBench` environment
            casp15_protein_ligand_interaction_dfs = []
            for protein_ligand_complex_filepath in tqdm(
                casp15_protein_ligand_complex_filepaths,
                desc=f"Processing interactions for {method_title}",
            ):
                try:
                    protein_filepath, ligand_filepath = protein_ligand_complex_filepath
                    temp_protein_filepath = create_temp_pdb_with_only_molecule_type_residues(
                        protein_filepath, molecule_type="protein", add_element_types=True
                    )
                    num_residues_in_target_protein = count_num_residues_in_pdb_file(
                        temp_protein_filepath
                    )
                    if (
                        num_residues_in_target_protein
                        > MAX_CASP15_ANALYSIS_PROTEIN_SEQUENCE_LENGTH
                    ):
                        print(
                            f"{method_title} target {protein_ligand_complex_filepath} has too many protein residues ({num_residues_in_target_protein} > {MAX_CASP15_ANALYSIS_PROTEIN_SEQUENCE_LENGTH}) for `MDAnalysis` to fit into CPU memory. Skipping..."
                        )
                        continue
                    ligand_mol = Chem.MolFromMolFile(ligand_filepath)
                    pc.load_protein_from_pdb(temp_protein_filepath)
                    pc.load_ligands_from_mols(
                        Chem.GetMolFrags(ligand_mol, asMols=True, sanitizeFrags=False)
                    )
                    casp15_protein_ligand_interaction_dfs.append(pc.calculate_interactions())
                except Exception as e:
                    print(
                        f"Error processing {method_title} target {protein_ligand_complex_filepath} due to: {e}. Skipping..."
                    )
                    continue

                # NOTE: we iteratively save the interaction dataframes to an HDF5 file
                with pd.HDFStore(
                    f"{method}_casp15_interaction_dataframes_{repeat_index}.h5"
                ) as store:
                    for i, df in enumerate(casp15_protein_ligand_interaction_dfs):
                        store.put(f"df_{i}", df)

#### Plot interaction statistics for each method

In [None]:
dfs = []


# define a function to process each method
def process_method(file_path, category):
    interactions = []
    with pd.HDFStore(file_path) as store:
        for key in store.keys():
            for row_index in range(len(store[key])):
                interaction_types = [
                    interaction[2] for interaction in store[key].iloc[row_index].keys().tolist()
                ]
                num_hb_acceptors = interaction_types.count("HBAcceptor")
                num_hb_donors = interaction_types.count("HBDonor")
                num_vdw_contacts = interaction_types.count("VdWContact")
                num_hydrophobic = interaction_types.count("Hydrophobic")
                interactions.append(
                    {
                        "Hydrogen Bond Acceptors": num_hb_acceptors,
                        "Hydrogen Bond Donors": num_hb_donors,
                        "Van der Waals Contacts": num_vdw_contacts,
                        "Hydrophobic Interactions": num_hydrophobic,
                    }
                )
    df_rows = []
    for interaction in interactions:
        for interaction_type, num_interactions in interaction.items():
            df_rows.append(
                {
                    "Category": category,
                    "InteractionType": interaction_type,
                    "NumInteractions": num_interactions,
                }
            )
    return pd.DataFrame(df_rows)


# load data from files
for method in baseline_methods:
    for repeat_index in range(1, max_num_repeats_per_method + 1):
        method_title = method_mapping[method]
        file_path = f"{method}_casp15_interaction_dataframes_{repeat_index}.h5"
        if os.path.exists(file_path):
            dfs.append(process_method(file_path, method_title))

if os.path.exists("casp15_interaction_dataframes.h5"):
    dfs.append(process_method("casp15_interaction_dataframes.h5", "Reference"))

# combine statistics
assert len(dfs) > 0, "No interaction dataframes found."
df = pd.concat(dfs)

# define font properties
plt.rcParams["font.size"] = 14
plt.rcParams["axes.labelsize"] = 16

# plot statistics
fig, axes = plt.subplots(2, 2, figsize=(34, 14), sharey=False)

interaction_types = [
    "Hydrogen Bond Acceptors",
    "Hydrogen Bond Donors",
    "Van der Waals Contacts",
    "Hydrophobic Interactions",
]
plot_types = ["box", "box", "violin", "violin"]

for ax, interaction, plot_type in zip(axes.flatten(), interaction_types, plot_types):
    data = df[df["InteractionType"] == interaction]

    if plot_type == "box":
        sns.boxplot(data=data, x="Category", y="NumInteractions", ax=ax, showfliers=True)
        sns.stripplot(
            data=data,
            x="Category",
            y="NumInteractions",
            ax=ax,
            color="black",
            alpha=0.3,
            jitter=True,
        )
    elif plot_type == "violin":
        sns.violinplot(data=data, x="Category", y="NumInteractions", ax=ax)
        sns.stripplot(
            data=data,
            x="Category",
            y="NumInteractions",
            ax=ax,
            color="black",
            alpha=0.3,
            jitter=True,
        )

    ax.set_title(interaction)
    ax.set_ylabel("No. Interactions")
    ax.set_xlabel("")
    ax.grid(True)

plt.tight_layout()
plt.savefig("casp15_method_interaction_analysis.png", dpi=300)
plt.show()