In [1]:
import os
import glob
import argparse
import tempfile
import os
from pathlib import Path
from typing import Tuple, Dict, List

import pandas as pd
import MDAnalysis as mda
import prolif as plf
from rdkit import DataStructs
from rdkit import Chem
from prolif.molecule import Molecule

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
BASE_DIR = '/Users/aoxu/projects/DrugDiscovery/PoseBench'
df = pd.read_csv(f"{BASE_DIR}/notebooks/posebusters_results_filtered_with_descriptors_predicted_ligands.csv")

for col in ["protein_pdb", "predicted_ligand", "true_ligand"]:
    df[col] = df[col].str.replace("/Users/aoxu/projects/DrugDiscovery/PoseBench", BASE_DIR)
assert not df['predicted_ligand'].isnull().any()
df.head()[['protein', 'method', 'true_ligand', 'protein_pdb', 'predicted_ligand']]

  df = pd.read_csv(f"{BASE_DIR}/notebooks/posebusters_results_filtered_with_descriptors_predicted_ligands.csv")


Unnamed: 0,protein,method,true_ligand,protein_pdb,predicted_ligand
0,7ZDY_6MJ,icm,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/f...
1,7ZDY_6MJ,icm,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/f...
2,7ZDY_6MJ,icm,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/f...
3,7ZDY_6MJ,icm,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/f...
4,7ZDY_6MJ,icm,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/d...,/Users/aoxu/projects/DrugDiscovery/PoseBench/f...


In [9]:
class ConformationAnalyzer:
    """
    Analyzes protein-ligand interactions for multiple conformations using ProLIF.
    This class handles loading structures and generating interaction fingerprints
    for multiple poses of the same protein-ligand system.
    """
    def __init__(self, df: pd.DataFrame):
        """
        Initialize with a DataFrame containing paths to structures.

        Parameters:
            df: DataFrame with columns for predicted_ligand, protein_pdb, method, and rank
        """
        self.df = df
        self.fp_calculator = plf.Fingerprint()
        self.results = {}  # Stores method -> fingerprint DataFrame

    def combine_conformations(self, method_name: str) -> Tuple[str, str]:
        """
        Combines multiple conformations into a single SDF file for a given method.
        Returns the combined ligand SDF path and protein PDB path.
        """
        method_data = self.df[self.df['method'] == method_name].sort_values('rank')

        temp_sdf = tempfile.NamedTemporaryFile(suffix='.sdf', delete=False)
        writer = Chem.SDWriter(temp_sdf.name)

        for _, row in method_data.iterrows():
            mol = Chem.SDMolSupplier(str(row['predicted_ligand']))[0]
            if mol is not None:
                mol.SetProp('_Name', f"Pose_{row['rank']}")
                writer.write(mol)
        writer.close()

        protein_path = str(method_data.iloc[0]['protein_pdb'])
        return temp_sdf.name, protein_path

    def analyze_method(self, method_name: str) -> pd.DataFrame:
        """
        Analyzes protein-ligand interactions for all conformations of a method
        and returns a ProLIF fingerprint DataFrame.
        """
        try:
            ligand_file, protein_path = self.combine_conformations(method_name)

            ligands = plf.sdf_supplier(ligand_file)
            rdkit_prot = Chem.MolFromPDBFile(protein_path, removeHs=False)
            protein = plf.Molecule(rdkit_prot)

            self.fp_calculator.run_from_iterable(ligands, protein)
            fp_df = self.fp_calculator.to_dataframe(index_col="Pose")
            self.results[method_name] = fp_df

            # Cleanup temp file
            os.unlink(ligand_file)

            return fp_df
        except Exception as e:
            print(f"Error analyzing method {method_name}: {str(e)}")
            if 'ligand_file' in locals():
                os.unlink(ligand_file)
            raise

    def analyze_all_methods(self) -> dict:
        """
        Runs analyze_method() for all unique methods in self.df.
        Returns a dict of {method: fingerprint_df}.
        """
        methods = self.df['method'].unique()
        for method in methods:
            self.analyze_method(method)
        return self.results

    def summarize_interactions(self, method_name: str) -> pd.DataFrame:
        """
        Summarizes how often each interaction appears across all poses of a method.
        """
        if method_name not in self.results:
            self.analyze_method(method_name)
        fp_df = self.results[method_name]

        summary = pd.DataFrame({
            'occurrence_rate': fp_df.mean(),
            'always_present': fp_df.all(),
            'never_present': ~fp_df.any(),
            'variable': fp_df.any() & ~fp_df.all()
        })
        return summary

    def compare_to_reference(self, method_name: str, reference_df: pd.DataFrame) -> pd.DataFrame:
        """
        Compare docked poses' interactions against a reference ligand fingerprint.
        
        Parameters:
            method_name: Name of the docking method to analyze
            reference_df: DataFrame containing reference interactions (from ref_fp.csv)
            
        Returns:
            DataFrame with comparison results for each pose
        """
        if method_name not in self.results:
            self.analyze_method(method_name)
            
        # Get docked poses results
        docked_df = self.results[method_name]
        
        # Preprocess reference dataframe
        reference = self._process_reference(reference_df)
        
        # Align columns between reference and docked poses
        docked_df, reference = docked_df.align(reference, axis=1, fill_value=False)
        
        # Compare interactions
        comparison = pd.DataFrame()
        for pose_name, pose in docked_df.iterrows():
            comparison = pd.concat([
                comparison,
                self._compare_single_pose(pose, reference, pose_name)
            ])
            
        return comparison
    
    def _process_reference(self, reference_df: pd.DataFrame) -> pd.Series:
        """Process reference dataframe into a boolean Series"""
        # Convert multiindex columns to tuple (protein, interaction)
        reference_df.columns = pd.MultiIndex.from_tuples(
            zip(reference_df.columns.get_level_values('protein'),
                reference_df.columns.get_level_values('interaction'))
        )
        
        # Get reference interactions (assuming first row is reference)
        return reference_df.iloc[0].astype(bool)
    
    def _compare_single_pose(self, pose: pd.Series, reference: pd.Series, pose_name: str) -> pd.DataFrame:
        """Compare a single pose against reference"""
        return pd.DataFrame({
            'pose': pose_name,
            'interaction': pose.index,
            'in_reference': reference.values,
            'in_pose': pose.values,
            'match': pose.values == reference.values,
            'missing': reference.values & ~pose.values,
            'extra': ~reference.values & pose.values
        })
    
    def plot_comparison(self, comparison_df: pd.DataFrame, top_n: int = 10):
        """Visualize comparison results"""
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        # Aggregate results
        agg_df = comparison_df.groupby('interaction').agg({
            'missing': 'sum',
            'extra': 'sum'
        })
        
        # Split interaction into residue and type
        agg_df[['residue', 'interaction_type']] = agg_df.index.to_series()\
            .apply(lambda x: (x[0], x[1])).tolist()
        
        # Plot missing interactions
        plt.figure(figsize=(12, 6))
        missing = agg_df.groupby(['residue', 'interaction_type'])['missing'].sum().nlargest(top_n)
        missing.unstack().plot(kind='bar', stacked=True, title='Most Common Missing Interactions')
        plt.ylabel('Frequency')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        
        # Plot extra interactions
        plt.figure(figsize=(12, 6))
        extra = agg_df.groupby(['residue', 'interaction_type'])['extra'].sum().nlargest(top_n)
        extra.unstack().plot(kind='bar', stacked=True, title='Most Common Extra Interactions')
        plt.ylabel('Frequency')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()

In [11]:
protein = '8B8H' 
method = 'icm'
# Load reference data
reference_df = pd.read_csv("ref_fp.csv", header=[0, 1, 2], skiprows=[3], index_col=0)

# Initialize analyzer
analyzer = ConformationAnalyzer(df[df['protein'] == protein])

# Compare to reference for a specific method
comparison = analyzer.compare_to_reference(method, reference_df)

# Generate visualizations
analyzer.plot_comparison(comparison)

# Show problematic poses
problematic_poses = comparison.groupby('pose').agg({
    'missing': 'sum', 
    'extra': 'sum'
}).sort_values(['missing', 'extra'], ascending=False)

print("Poses with most discrepancies:")
print(problematic_poses.head())

Error analyzing method icm: single positional indexer is out-of-bounds


IndexError: single positional indexer is out-of-bounds