# Calculate the final grid box

- According to grid search results (See the hierachical clustering gri search notebook) Use the median clustering method with following parameters:

     - Standard factor (for the median filter): 100

     - pind_weight: 2.5

     - Clustering with Kmean

In [3]:
# Define data path
data_path = "/Users/nicha/dev/Protein-preparation-pipeline/data/toy_examples_clustering"

# System and OS utilities
import os
import sys

# Numerical and Data Processing
import numpy as np
import pandas as pd

# Scientific and Bioinformatics Tools
from math import e
from pymol import cmd, stored
import numpy as np
from Bio import PDB
from scipy.spatial import distance_matrix, KDTree
from scipy.stats import shapiro
from joblib import Parallel, delayed
from tqdm import tqdm
from itertools import combinations, product


# Machine Learning & Clustering
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import euclidean_distances

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import py3Dmol

# Color Mapping for Visualization
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
from mpl_toolkits.mplot3d import Axes3D

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.metrics import silhouette_score
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from scipy.linalg import eigh
from scipy.sparse.csgraph import laplacian
from sklearn.neighbors import kneighbors_graph


# Add project-specific source path
sys.path.append('/Users/nicha/dev/Protein-preparation-pipeline/src/')

# Custom Modules from Your Project
from pdb_retrival.data_retriever import PDBDataRetriever

In [4]:
# Function for the data preparation

# 1️⃣ Extract residue coordinates from PDB
def extract_residue_coordinates(pdb_file, residue_number, chain_id="A"):
    """
    Extracts atomic coordinates for a given residue from a PDB file.

    Args:
        pdb_file (str): Path to the PDB file.
        residue_number (int): Residue number to extract.
        chain_id (str): Chain ID of the residue.

    Returns:
        list: List of tuples (atom_type, x, y, z).
    """
    coordinates = []
    try:
        with open(pdb_file, "r") as file:
            for line in file:
                if line.startswith(("ATOM", "HETATM")) and line[21] == chain_id:
                    resi = int(line[22:26].strip())
                    if resi == residue_number:
                        atom_type = line[76:78].strip()
                        x = float(line[30:38].strip())
                        y = float(line[38:46].strip())
                        z = float(line[46:54].strip())
                        coordinates.append((atom_type, x, y, z))
    except FileNotFoundError:
        print(f"Error: File {pdb_file} not found.")
    except Exception as e:
        print(f"Unexpected error: {e}")
    return coordinates


# 2️⃣ Compute Weighted Center of Mass
def calculate_weighted_center_of_mass(coordinates):
    """
    Calculates the weighted center of mass for a given set of atomic coordinates.

    Args:
        coordinates (list): List of tuples (atom_type, x, y, z).

    Returns:
        list: [x, y, z] coordinates of the weighted center of mass.
    """
    total_weight = 0
    weighted_coords = np.zeros(3)
    atomic_weights = {"H": 1.008, "C": 12.011, "N": 14.007, "O": 15.999, "S": 32.06}

    for atom_type, x, y, z in coordinates:
        weight = atomic_weights.get(atom_type.upper(), 1.0)
        weighted_coords += np.array([x, y, z]) * weight
        total_weight += weight

    return np.round(weighted_coords / total_weight, 3).tolist()


# 3️⃣ Process a Single PDB File and Update the DataFrame
def process_pdb_file(data_path, pdb_code, df):
    """
    Reads a PDB file and updates a DataFrame with residue center of mass.

    Args:
        data_path (str): Path to the directory containing the PDB and CSV files.
        pdb_code (str): PDB code of the protein.
        df (pd.DataFrame): DataFrame containing residue data.

    Returns:
        pd.DataFrame: Updated DataFrame with center of mass information.
    """
    pdb_file = os.path.join(data_path, f"{pdb_code}.pdb")

    try:
        for index, row in df.iterrows():
            residue_number = row["resi"]
            chain_id = row["chain"]
            coordinates = extract_residue_coordinates(pdb_file, residue_number, chain_id)
            if coordinates:
                center_of_mass = calculate_weighted_center_of_mass(coordinates)
                df.loc[index, "resn_coordinates"] = str(coordinates)
                df.loc[index, ["center_of_mass_x", "center_of_mass_y", "center_of_mass_z"]] = center_of_mass
            else:
                print(f"Residue {residue_number} in chain {chain_id} not found in {pdb_file}.")
    except FileNotFoundError:
        print(f"Error: {pdb_file} not found.")
    except Exception as e:
        print(f"Unexpected error: {e}")

    return df


# 4️⃣ Process and Update Data for a Single PDB Structure
def process_and_update_pdb_data(data_path, pdb_code):
    """
    Reads and processes a PDB file and corresponding CSV file to update residue information.

    Args:
        data_path (str): Path to the directory containing the PDB and CSV files.
        pdb_code (str): PDB code of the protein.

    Returns:
        pd.DataFrame: Updated DataFrame with residue information.
    """
    csv_file_path = f"{data_path}/results_{pdb_code}.csv"

    try:
        df = pd.read_csv(csv_file_path)
        df_updated = process_pdb_file(data_path, pdb_code, df)
        df_updated.to_csv(f"{data_path}/results_{pdb_code}_updated.csv", index=False)
        return df_updated
    except Exception as e:
        print(f"Error processing {pdb_code}: {e}")
        return None


# 5️⃣ Process All PDB Files in a Directory
def process_all_pdb_files(data_path):
    """
    Process all PDB-related files in the given directory and combine results.

    Args:
        data_path (str): Path to the directory containing the PDB and CSV files.

    Returns:
        pd.DataFrame: Combined DataFrame with all processed data.
    """
    combined_df = pd.DataFrame()

    for file in os.listdir(data_path):
        if file.startswith("results_") and file.endswith(".csv"):
            pdb_code = file.split("_")[1].split(".")[0]
            updated_df = process_and_update_pdb_data(data_path, pdb_code)
            if updated_df is not None:
                updated_df["PDBcode"] = pdb_code
                combined_df = pd.concat([combined_df, updated_df], ignore_index=True)

    combined_df.drop_duplicates(inplace=True)
    combined_df.to_csv(f"{data_path}/combined_results.csv", index=False)
    return combined_df


# 6️⃣ Compute Center of Protein Using PyMOL
def calculate_center_of_mass_pymol(pdb_file):
    """
    Calculates the center of mass of a protein using PyMOL.

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        tuple: (x, y, z) coordinates of the protein's center of mass.
    """
    cmd.load(pdb_file, "protein")
    center_of_mass = cmd.centerofmass("protein")
    cmd.delete("all")
    return round(center_of_mass[0], 3), round(center_of_mass[1], 3), round(center_of_mass[2], 3)


# 7️⃣ Compute Protein Diameter
def calculate_protein_diameter(df):
    """
    Computes the diameter of a protein, defined as the maximum distance between residues.

    Args:
        df (pd.DataFrame): DataFrame containing residue data with center of mass coordinates.

    Returns:
        float: Maximum distance between residues.
    """
    coordinates = df[['center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']].values
    distances = np.linalg.norm(coordinates[:, np.newaxis] - coordinates, axis=2)
    return round(np.max(distances), 3)


# 8️⃣ Process Protein Data for a Single PDB
def process_protein_data(pdb_code, dfs, data_path):
    """
    Processes protein data by computing center of mass and diameter.

    Args:
        pdb_code (str): The PDB code of the protein.
        dfs (dict): Dictionary containing DataFrames indexed by PDB code.
        data_path (str): Path to the directory containing the PDB files.

    Returns:
        pd.DataFrame: Updated DataFrame with center of protein and diameter.
    """
    pdb_file_path = f"{data_path}/{pdb_code}.pdb"
    df = dfs[pdb_code]
    center_of_mass = calculate_center_of_mass_pymol(pdb_file_path)
    protein_diameter = calculate_protein_diameter(df)

    df["protein_x"], df["protein_y"], df["protein_z"] = center_of_mass
    df["protein_diameter"] = protein_diameter

    return df



def normalize_features(df, feature_columns):
    """
    Normalize or standardize features based on their distribution.

    - **Standardizes (Z-score) if normally distributed**.
    - **Normalizes (Min-Max) if not normally distributed**.
    - **Adds new columns instead of replacing existing ones**.
    - **Stores scalers for consistent transformation in clustering**.

    Args:
        df (pd.DataFrame): DataFrame containing features.
        feature_columns (list): List of feature column names (e.g., 3D coordinates).

    Returns:
        pd.DataFrame: DataFrame with added transformed feature columns.
        dict: Dictionary of fitted scalers (for later reuse).
    """
    processed_df = df.copy()
    scalers = {}

    for col in feature_columns:
        scaler = MinMaxScaler()  # Normalization
        new_col_name = f"{col}_normalized"

        processed_df[new_col_name] = scaler.fit_transform(df[[col]])  # Apply transformation
        scalers[col] = scaler  # Store scaler for reuse

    return processed_df, scalers


def weighted_pbind(df, pbind_column='p(bind)', weights=[1.0]):
    """
    Normalize and apply multiple weight factors to p(bind).

    - **Scales p(bind) between 0 and 1**.
    - **Applies multiple weight factors** to create additional columns.
    - **Adds new columns instead of replacing p(bind)**.
    - **Stores scaler** for reuse in clustering.

    Args:
        df (pd.DataFrame): DataFrame containing p(bind).
        pbind_column (str): Column name for p(bind).
        weights (list): List of weight multipliers (e.g., [1.0, 5.0, 10.0, 100.0]).

    Returns:
        pd.DataFrame: DataFrame with weighted p(bind) columns.
        dict: Dictionary of fitted scalers.
    """
    processed_df = df.copy()
    scalers = {}

    if pbind_column in df.columns:
        scaler = MinMaxScaler()
        normalized_pbind = scaler.fit_transform(df[[pbind_column]])  # Scale to 0-1
        scalers[pbind_column] = scaler  # Store scaler

        for weight in weights:
            weighted_col_name = f"{pbind_column}_weight_{weight}"
            processed_df[weighted_col_name] = normalized_pbind * weight  # Apply weight

    return processed_df, scalers
    """
    Normalize and apply multiple weight factors to p(bind).

    - **Scales p(bind) between 0 and 1**.
    - **Applies multiple weight factors** to create additional columns.
    - **Stores scaler** for reuse in clustering.

    Args:
        df (pd.DataFrame): DataFrame containing p(bind).
        pbind_column (str): Column name for p(bind).
        weights (list): List of weight multipliers (e.g., [1.0, 1.5, 2.0]).

    Returns:
        pd.DataFrame: DataFrame with weighted p(bind) columns.
        dict: Dictionary of fitted scalers.
    """
    processed_df = df.copy()
    scalers = {}

    if pbind_column in df.columns:
        scaler = MinMaxScaler()
        normalized_pbind = scaler.fit_transform(df[[pbind_column]])  # Scale to 0-1
        scalers[pbind_column] = scaler  # Store scaler

        for weight in weights:
            weighted_col_name = f"{pbind_column}_weight_{weight}"
            processed_df[weighted_col_name] = normalized_pbind * weight  # Apply weight

    return processed_df, scalers


In [82]:
# Function to evaluate the clustering performance
from numpy import size


def calculate_protein_center_of_mass(pdb_file):
    """
    Calculate the center of mass of the entire protein using PyMOL.

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        tuple: Center of mass coordinates (x, y, z).
    """
    cmd.load(pdb_file, "protein")
    center_of_mass = cmd.centerofmass("protein")
    cmd.delete("all")  # Clear the loaded structure
    return tuple(round(coord, 3) for coord in center_of_mass)

def calculate_cluster_centers(df, cluster_column, pdb_file):
    """
    Calculate the center of mass for each cluster using PyMOL.

    Args:
        df (pd.DataFrame): DataFrame containing residue and cluster information.
        cluster_column (str): Column name containing cluster IDs.
        pdb_file (str): Path to the PDB file.

    Returns:
        dict: Cluster IDs as keys and their center of mass coordinates as values.
    """
    cmd.load(pdb_file, "protein")
    cluster_centers = {}
    clusters = df[cluster_column].unique()

    for cluster_id in clusters:
        cluster_data = df[df[cluster_column] == cluster_id]
        resi_selection = "+".join(map(str, cluster_data['resi']))
        cmd.select(f"cluster_{cluster_id}", f"resi {resi_selection}")
        center_of_mass = cmd.centerofmass(f"cluster_{cluster_id}")
        cluster_centers[cluster_id] = tuple(round(coord, 3) for coord in center_of_mass)
        cmd.delete(f"cluster_{cluster_id}")

    cmd.delete("all")  # Clear the loaded structure
    return cluster_centers

def calculate_ligand_centers_slow(pdb_file):
    """
    Calculate the center of mass for each ligand in the structure using PyMOL.

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        dict: Ligand identifiers as keys and their center of mass coordinates as values.
    """
    cmd.load(pdb_file, "protein")
    cmd.select("ligands", "organic")
    ligand_centers = {}

    for atom in cmd.get_model("ligands").atom:
        ligand_id = f"{atom.chain}_{atom.resn}_{atom.resi}"
        cmd.select("ligand", f"chain {atom.chain} and resn {atom.resn} and resi {atom.resi}")
        center_of_mass = cmd.centerofmass("ligand")
        ligand_centers[ligand_id] = tuple(round(coord, 3) for coord in center_of_mass)
        cmd.delete("ligand")  # Clear the selection

    cmd.delete("all")  # Clear the loaded structure
    return ligand_centers


def calculate_ligand_centers(pdb_file):
    """
    Calculate the center of mass for each ligand in the structure using PyMOL.

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        dict: Ligand identifiers as keys and their center of mass coordinates as values.
    """
    cmd.load(pdb_file, "protein")
    cmd.select("ligands", "organic")
    
    ligand_centers = {}
    stored.stored_atoms = []  # Corrected: Define `stored_atoms` in the PyMOL namespace

    # Iterate through ligands and collect unique identifiers
    cmd.iterate("ligands", "stored.stored_atoms.append((chain, resn, resi))")

    unique_ligands = set(stored.stored_atoms)  # Get unique ligand identifiers

    for chain, resn, resi in unique_ligands:
        ligand_id = f"{chain}_{resn}_{resi}"
        cmd.select("ligand", f"chain {chain} and resn {resn} and resi {resi}")
        center_of_mass = cmd.centerofmass("ligand")
        ligand_centers[ligand_id] = tuple(round(coord, 3) for coord in center_of_mass)
        cmd.delete("ligand")  # Clear selection

    cmd.delete("all")  # Clear everything from PyMOL session
    return ligand_centers

def calculate_ligand_diameter(pdb_file):
    """
    Calculate the diameter of each ligand (maximum pairwise distance between its atoms).

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        dict: Ligand identifiers as keys and their diameters as values.
    """
    cmd.load(pdb_file, "protein")
    cmd.select("ligands", "organic")

    ligand_diameters = {}

    # Store ligand atom coordinates
    stored.ligand_atoms = []
    cmd.iterate_state(1, "ligands", "stored.ligand_atoms.append((chain, resn, resi, x, y, z))")

    unique_ligands = {}
    for chain, resn, resi, x, y, z in stored.ligand_atoms:
        ligand_id = f"{chain}_{resn}_{resi}"
        if ligand_id not in unique_ligands:
            unique_ligands[ligand_id] = []
        unique_ligands[ligand_id].append((x, y, z))

    # Compute ligand diameters
    for ligand_id, coordinates in unique_ligands.items():
        if len(coordinates) > 1:
            max_distance = max(
                np.linalg.norm(np.array(p1) - np.array(p2))
                for p1, p2 in combinations(coordinates, 2)
            )
            ligand_diameters[ligand_id] = round(max_distance, 3)
        else:
            ligand_diameters[ligand_id] = 0.0  # Single atom ligand, diameter is zero

    cmd.delete("all")  # Clear the loaded structure
    return ligand_diameters


def calculate_grid_size_ligand(ligand_diameter):
    """
    Compute the grid size based on the ligand diameter.

    Args:
        ligand_diameter (float): Diameter of the ligand.

    Returns:
        float: Computed grid size.
    """
    return round(16 + (0.8 * ligand_diameter), 3)


def define_bounding_box_ligand(ligand_centers, ligand_diameters):
    """
    Define a bounding box around the ligands based on their diameters.

    Args:
        ligand_centers (dict): Ligand identifiers and their center of mass coordinates.
        ligand_diameters (dict): Ligand identifiers and their diameters.

    Returns:
        dict: Bounding box coordinates for each ligand.
    """
    bounding_boxes = {}

    for ligand_id, center in ligand_centers.items():
        if ligand_id not in ligand_diameters:
            print(f"Warning: Missing diameter for ligand {ligand_id}")
            continue

        grid_size = calculate_grid_size_ligand(ligand_diameters[ligand_id])

        min_coords = [center[i] - grid_size / 2 for i in range(3)]
        max_coords = [center[i] + grid_size / 2 for i in range(3)]

        bounding_boxes[ligand_id] = {"min": min_coords, "max": max_coords}

        # Add pseudoatoms in PyMOL for visualization
        cmd.pseudoatom(f"box_{ligand_id}_min", pos=min_coords, color="blue")
        cmd.pseudoatom(f"box_{ligand_id}_max", pos=max_coords, color="red")

        print(f"Bounding box for {ligand_id}: Min={min_coords}, Max={max_coords}")

    return bounding_boxes

def select_cluster_define_bounding_box(df, cluster_id, cluster_column, selection_name="cluster_selection"):
    """
    Select residues based on the clustering method and create a bounding box in PyMOL.

    Args:
        df (pd.DataFrame): DataFrame containing residue data.
        cluster_id (int): The cluster ID to select residues for.
        cluster_column (str): Column name that stores cluster assignments.
        selection_name (str): Name of the selection in PyMOL.

    Returns:
        dict: Bounding box coordinates { "min": np.array, "max": np.array } for the selected cluster.
    """
    # Ensure required columns exist
    required_columns = {"chain", "resi", "center_of_mass_x", "center_of_mass_y", "center_of_mass_z"}
    missing_columns = required_columns - set(df.columns)
    if missing_columns:
        raise ValueError(f"Missing columns in DataFrame: {missing_columns}")

    # Filter residues based on the selected cluster
    selected_residues = df[df[cluster_column] == cluster_id]
    
    if selected_residues.empty:
        print(f"⚠️ No residues found for cluster {cluster_id} in column '{cluster_column}'")
        return None

    # Construct a PyMOL selection string
    selection_string = " or ".join(
        [f"(chain {row['chain']} and resi {row['resi']})" for _, row in selected_residues.iterrows()]
    )

    # Check if PyMOL is available
    try:
        cmd.select(selection_name, selection_string)
    except Exception as e:
        print(f"⚠️ PyMOL selection failed: {e}")
        return None

    # Compute the bounding box
    min_coords = selected_residues[['center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']].min().to_numpy()
    max_coords = selected_residues[['center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']].max().to_numpy()

    # Convert coordinates to float for compatibility
    min_coords = np.array([float(coord) for coord in min_coords])
    max_coords = np.array([float(coord) for coord in max_coords])

    # Add pseudoatoms for visualization
    try:
        cmd.pseudoatom(f"{selection_name}_box_min", pos=min_coords.tolist(), color="blue")
        cmd.pseudoatom(f"{selection_name}_box_max", pos=max_coords.tolist(), color="red")
    except Exception as e:
        print(f"⚠️ Failed to create pseudoatoms in PyMOL: {e}")

    print(f"✅ Selection '{selection_name}' created for cluster {cluster_id}")
    print(f"📦 Bounding box: Min={min_coords.tolist()}, Max={max_coords.tolist()}")
    
    #euclidean distance
    size = np.linalg.norm(max_coords - min_coords)

    return {"min": min_coords, "max": max_coords, "size": size}


def calculate_dice_score(cluster_selection, ligand_selection):
    """
    Compute the Dice similarity coefficient between a cluster and a ligand.

    Args:
        cluster_selection (str): PyMOL selection name for the cluster.
        ligand_selection (str): PyMOL selection name for the ligand.

    Returns:
        float: Dice score indicating spatial overlap (higher means better overlap).
    """
    # Get atoms in selections
    cluster_atoms = set(cmd.index(cluster_selection))
    ligand_atoms = set(cmd.index(ligand_selection))

    if not cluster_atoms or not ligand_atoms:
        print(f"Error: One or both selections are empty: {cluster_selection}, {ligand_selection}")
        return 0.0

    # Compute Dice coefficient
    intersection = len(cluster_atoms & ligand_atoms)
    dice_score = (2 * intersection) / (len(cluster_atoms) + len(ligand_atoms))

    print(f"Dice Score between {cluster_selection} and {ligand_selection}: {round(dice_score, 3)}")
    
    return round(dice_score, 3)

def calculate_distances_between_clusters_and_ligands(cluster_centers, ligand_centers):
    """
    Calculate the minimum distance between each cluster and the ligands.

    Args:
        cluster_centers (dict): Cluster IDs and their center of mass coordinates.
        ligand_centers (dict): Ligand identifiers and their center of mass coordinates.

    Returns:
        dict: Cluster IDs as keys and tuples of (closest ligand ID, distance) as values.
    """
    distances = {}

    for cluster_id, cluster_center in cluster_centers.items():
        min_distance = float('inf')
        closest_ligand = None

        for ligand_id, ligand_center in ligand_centers.items():
            distance = np.linalg.norm(np.array(cluster_center) - np.array(ligand_center))
            if distance < min_distance:
                min_distance = round(distance, 3)
                closest_ligand = ligand_id

        distances[cluster_id] = (closest_ligand, min_distance)

    return distances

def evaluate_clustering(pdb_file, df, cluster_column):
    """
    Evaluate the clustering by calculating distances between cluster centers and ligand centers.

    Args:
        pdb_file (str): Path to the PDB file.
        df (pd.DataFrame): DataFrame containing residue and cluster information.
        cluster_column (str): Column name containing cluster IDs.

    Returns:
        dict: Distances between each cluster and the closest ligand.
    """
    print("Calculating cluster centers...")
    cluster_centers = calculate_cluster_centers(df, cluster_column, pdb_file)

    print("Calculating ligand centers...")
    ligand_centers = calculate_ligand_centers(pdb_file)

    print("Calculating distances between clusters and ligands...")
    distances = calculate_distances_between_clusters_and_ligands(cluster_centers, ligand_centers)

    for cluster_id, (ligand_id, distance) in distances.items():
        print(f"Cluster {cluster_id} -> Closest Ligand: {ligand_id}, Distance: {distance} Å")

    return distances

# Pipeline Function
def run_pipeline(data_path, pdb_file, clustering_method, cluster_params, pbind_column='p(bind)', cutoff_method='median_std', std_factor=0.5):
    """
    Run the clustering pipeline, including clustering, normalization, filtering, and evaluation.

    Args:
        data_path (str): Path to the data directory.
        pdb_file (str): Path to the PDB file.
        clustering_method (str): Clustering method to use ('kmeans' or 'spectral').
        cluster_params (dict): Parameters for the clustering method.
        pbind_column (str): Column name for binding probabilities.
        cutoff_method (str): Method for calculating cutoff ('median_std' or 'percentile').
        std_factor (float): Factor for standard deviation when calculating cutoff.

    Returns:
        dict: Evaluation results including distances between clusters and ligands.
    """
    print("Loading data...")
    combined_df = process_all_pdb_files(data_path)

    print("Applying clustering...")
    if clustering_method == 'kmeans':
        n_clusters = cluster_params.get('n_clusters', 3)
        combined_df, _ = kmeans_clustering(combined_df, feature_columns=cluster_params['feature_columns'], n_clusters=n_clusters)
    elif clustering_method == 'spectral':
        n_clusters = cluster_params.get('n_clusters', 3)
        affinity = cluster_params.get('affinity', 'nearest_neighbors')
        combined_df, _ = spectral_clustering(combined_df, feature_columns=cluster_params['feature_columns'], n_clusters=n_clusters, affinity=affinity)
    else:
        raise ValueError("Invalid clustering method. Choose 'kmeans' or 'spectral'.")

    print("Normalizing p(bind) within clusters...")
    combined_df = normalize_within_clusters(combined_df, cluster_column=clustering_method+'_cluster', pbind_column=pbind_column)

    print("Filtering residues based on p(bind) cutoff...")
    filtered_df = filter_by_pbind_cutoff(combined_df, pbind_column=pbind_column, cluster_column=clustering_method+'_cluster', cutoff_method=cutoff_method, std_factor=std_factor)

    print("Evaluating clustering...")
    evaluation_results = evaluate_clustering(pdb_file, filtered_df, cluster_column=clustering_method+'_cluster')

    return {
        'clustered_df': combined_df,
        'filtered_df': filtered_df,
        'evaluation_results': evaluation_results
    }


In [6]:
# Execute the data processing pipeline

# Process all files and get the combined DataFrame
combined_df = process_all_pdb_files(data_path)

# Ensure 'PDBcode' is present in combined_df
if 'PDBcode' not in combined_df.columns:
    raise ValueError("Error: 'PDBcode' column is missing in combined_df. Check preprocessing.")

# Create a dictionary of dataframes grouped by PDBcode
dfs = {pdb: df for pdb, df in combined_df.groupby('PDBcode')}

# Check if extracted dataframes match the original CSV files
for pdb in combined_df['PDBcode'].unique():
    original_csv_path = os.path.join(data_path, f"results_{pdb}.csv")
    
    if not os.path.exists(original_csv_path):
        print(f"Warning: Original CSV file missing for {pdb}. Skipping...")
        continue
    
    original_df = pd.read_csv(original_csv_path)
    
    if pdb in dfs:
        extracted_df = dfs[pdb]
    else:
        print(f"Warning: {pdb} missing in processed data. Reprocessing...")
        extracted_df = process_and_update_pdb_data(data_path, pdb)

    # Check if the number of entries match
    if len(original_df) == len(extracted_df):
        print(f"{pdb}: ✅ Match - {len(original_df)} entries")
    else:
        print(f"{pdb}: ❌ Mismatch - Original: {len(original_df)} entries, Extracted: {len(extracted_df)} entries")
        
        # Attempt to fix by loading updated version
        updated_csv_path = os.path.join(data_path, f"results_{pdb}_updated.csv")
        
        if os.path.exists(updated_csv_path):
            print(f"Fixing process by using updated data for {pdb}...")
            extracted_df = pd.read_csv(updated_csv_path)
            dfs[pdb] = extracted_df  # Update dictionary
            
            if len(original_df) == len(extracted_df):
                print(f"{pdb}: ✅ Fixed - Match after update: {len(original_df)} entries")
            else:
                print(f"{pdb}: ❌ Still Mismatch after update: {len(original_df)} vs {len(extracted_df)} entries")
        else:
            print(f"Error: No updated file found for {pdb}. Manual check required.")

# Define feature and p(bind) column
feature_columns = ['center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']
pbind_column = 'p(bind)'

# Process features for each PDB entry
for pdb_code in dfs.keys():
    # Normalize spatial features
    dfs[pdb_code], feature_scalers = normalize_features(dfs[pdb_code], feature_columns)
    
    # Apply multiple p(bind) weight factors
    dfs[pdb_code], pbind_scalers = weighted_pbind(dfs[pdb_code], pbind_column, weights=[1, 2, 5, 10, 20, 50, 100])

    # Process protein metadata (diameter, center of mass)
    dfs[pdb_code] = process_protein_data(pdb_code, dfs, data_path)

    # Rename and clean up dataframe
    dfs[pdb_code].rename(columns=lambda x: x.strip(), inplace=True)
    dfs[pdb_code].drop(columns="Unnamed: 0", inplace=True, errors="ignore")  # Avoid errors if column missing

# Print an example output for verification
print("\n✅ Processed Data Example:")
print(dfs[list(dfs.keys())[0]].head())  # Print first PDB's processed dataframe


1qcf: ✅ Match - 450 entries
1ubq: ✅ Match - 76 entries
3g5d: ✅ Match - 513 entries
3zln: ✅ Match - 144 entries
4f9w: ✅ Match - 336 entries
1lyz: ✅ Match - 129 entries
1hvy: ✅ Match - 1152 entries
3hvc: ✅ Match - 327 entries
3cpa: ✅ Match - 307 entries
1pw6: ✅ Match - 250 entries
1ema: ✅ Match - 225 entries
1kv1: ✅ Match - 331 entries
1be9: ✅ Match - 120 entries
3ptb: ✅ Match - 220 entries
6o0k: ✅ Match - 141 entries
2bal: ✅ Match - 338 entries
4pti: ✅ Match - 58 entries
1yer: ✅ Match - 207 entries
1cz2: ✅ Match - 90 entries
3qkd: ✅ Match - 282 entries
1ao6: ✅ Match - 1156 entries
1h61: ✅ Match - 364 entries
1ny3: ✅ Match - 277 entries

✅ Processed Data Example:
     chain  resi resn   p(bind)  \
8133     A     5    S  0.008225   
8134     A     6    E  0.013932   
8135     A     7    V  0.044013   
8136     A     8    A  0.005330   
8137     A     9    H  0.007117   

                                       resn_coordinates  center_of_mass_x  \
8133  [('N', 56.653, 51.017, 34.141), ('C'

In [119]:
# Median clustering function

import re

from sqlalchemy import true


def silhouette_analysis(data, max_clusters=10):
    """
    Perform silhouette analysis for a range of clusters and suggest the optimal n_clusters.

    Args:
        data (array-like): Dataset to cluster.
        max_clusters (int): Maximum number of clusters to evaluate.

    Returns:
        tuple: Suggested n_clusters, silhouette scores for each k.
    """
    silhouette_scores = []
    
    for n_clusters in range(2, max_clusters + 1):  # Silhouette starts from k=2
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        labels = kmeans.fit_predict(data)
        silhouette_scores.append(silhouette_score(data, labels))

    suggested_k = np.argmax(silhouette_scores) + 2  # +2 because we start at k=2
    print(f"Suggested number of clusters (Silhouette): {suggested_k}")
    print("Silhouette Scores:", silhouette_scores)
    
    return suggested_k, silhouette_scores


def kmeans_clustering(df, feature_columns, n_clusters):
    """
    Perform KMeans clustering and add cluster labels to the DataFrame.

    Args:
        df (pd.DataFrame): Input DataFrame.
        feature_columns (list): List of feature column names.
        n_clusters (int): Number of clusters.

    Returns:
        pd.DataFrame: Updated DataFrame with cluster labels.
        KMeans: Fitted KMeans model.
    """
    features = df[feature_columns].values
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    df['kmeans_cluster'] = kmeans.fit_predict(features)
    return df, kmeans


def normalize_within_clusters(df, cluster_column, pbind_column):
    """
    Normalize p(bind) values within each cluster.

    Args:
        df (pd.DataFrame): Input DataFrame.
        cluster_column (str): Column containing cluster IDs.
        pbind_column (str): Column containing p(bind) values.

    Returns:
        pd.DataFrame: Updated DataFrame with normalized p(bind) values.
    """
    df[f"normalized_{pbind_column}"] = df.groupby(cluster_column)[pbind_column].transform(
        lambda x: (x - x.min()) / (x.max() - x.min())
    )
    return df


def filter_by_pbind_cutoff(df, pbind_column, cluster_column, cutoff_method, std_factor):
    """
    Filter residues based on p(bind) cutoff.

    Args:
        df (pd.DataFrame): Input DataFrame.
        pbind_column (str): Column containing p(bind) values.
        cluster_column (str): Column containing cluster IDs.
        cutoff_method (str): Method for calculating cutoff ('median_std' or 'percentile').
        std_factor (float): Factor for standard deviation in cutoff calculation.

    Returns:
        pd.DataFrame: Filtered DataFrame.
    """
    filtered_df = pd.DataFrame()

    for cluster_id in df[cluster_column].unique():
        cluster_data = df[df[cluster_column] == cluster_id]
        if cutoff_method == 'median_std':
            median_value = cluster_data[pbind_column].median()
            std_value = cluster_data[pbind_column].std()
            cutoff = median_value + std_factor * std_value
        elif cutoff_method == 'percentile':
            cutoff = cluster_data[pbind_column].quantile(0.95)
        else:
            raise ValueError("Invalid cutoff method. Use 'median_std' or 'percentile'.")

        filtered_cluster = cluster_data[cluster_data[pbind_column] > cutoff]
        filtered_df = pd.concat([filtered_df, filtered_cluster])

    return filtered_df.reset_index(drop=True)


def compute_centroids(df, cluster_column):
                # Compute the cluster centers and round to 3 decimal places
                centroids = df.groupby(cluster_column)[['p(bind)', 'center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']].mean().round(3)
                centroids.sort_values(by='p(bind)')
                
                return centroids


def med_clustering(df, std_factors, feature_columns, pbind_column, pbind_weight, ligand_centers, output_csv="results_med_clustering.csv"):
    """ Perform clustering to identify residues with high binding probabilities. """
    
    # Initialize list to store all results
    results = []
    
    # 🔹 Step 1:Compute median and std for filtering
    median_value = df[pbind_column].median()
    std_value = df[pbind_column].std()
    
    cutoff = median_value + std_factors * std_value
    df_filtered = df.loc[df[pbind_column] > cutoff]
    
    if df_filtered.shape[0] < 3:
            print(f"⚠️ Skipping std_factor={std_filter} as only {df_filtered.shape[0]} entries remain. Setting distance to inf.")
            
    feature_columns = feature_columns + [pbind_column]*pbind_weight
    
    # 🔹 Step 2: Determine optimal number of clusters safely
    optimal_kmeans, _ = silhouette_analysis(df_filtered[feature_columns].values, max_clusters=min(10, len(df_filtered) - 1))
    
    # 🔹 Step 3: Run Clustering with Optimal Clusters
    df_kmeans, _ = kmeans_clustering(df_filtered, feature_columns, n_clusters=optimal_kmeans)
    
    # Store cluster labels
    #df_filtered[f'kmeans_pbind'] = df_kmeans['kmeans_cluster']
    
    # 🔹 Step 4: Compute Cluster Center
    centroids = compute_centroids(df_kmeans, 'kmeans_cluster')
    
    # Compute residue IDs for each cluster: pls 
    df_kmeans['residue_id'] = df_kmeans['chain'] + '_' + df_kmeans['resi'].astype(str)
    centroids["residue_ids"] = df_kmeans.groupby('kmeans_cluster')['residue_id'].apply(list)
    
    # make in this format w/o , and []: A_105 A_106 A_107
    centroids["residue_ids"] = centroids["residue_ids"].apply(lambda x: ' '.join(x))
    
    # Rename columns
    centroids['rank'] = centroids['p(bind)'].rank(ascending=False).astype(int)
    centroids['name'] = "pocket" + centroids['rank'].astype(str)
    centroids['name'] = centroids['name'].apply(lambda x: re.sub(r'\.0', '', x))
    
    # Select and rename columns
    centroids = centroids[['name', 'rank', 'p(bind)', 'center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z', 'residue_ids']]
    centroids.rename(columns={'p(bind)': 'probability', 'center_of_mass_x': 'center_x', 'center_of_mass_y': 'center_y', 'center_of_mass_z': 'center_z'}, inplace=True)
    
    # Sort by binding probability
    centroids.sort_values(by='probability', ascending=False, inplace=True)
    
    # Save to CSV
    centroids.to_csv(output_csv, index=False)
    
    return df_kmeans, centroids


    
    

In [120]:
df = dfs['3hvc']
pdb_file =f"{data_path}/3hvc.pdb"
ligand_center = calculate_ligand_centers(pdb_file)
feature_columns = ['center_of_mass_x_normalized', 'center_of_mass_y_normalized', 'center_of_mass_z_normalized']
pbind_column = 'p(bind)_weight_2'


df_kmean, centroid = med_clustering(df, 2.5, feature_columns, pbind_column, 100, ligand_center, output_csv="results_med_clustering.csv")


Suggested number of clusters (Silhouette): 2
Silhouette Scores: [0.6515612511949008, 0.5638726791131474, 0.5174007395983826, 0.6065039468262533, 0.6279202847821913, 0.5535207186224813, 0.4798482430849183, 0.42068758246781646, 0.34881737959423226]


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['kmeans_cluster'] = kmeans.fit_predict(features)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_kmeans['residue_id'] = df_kmeans['chain'] + '_' + df_kmeans['resi'].astype(str)


In [121]:
centroid

Unnamed: 0_level_0,name,rank,probability,center_x,center_y,center_z,residue_ids
kmeans_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,pocket1,1,0.84,-3.75,35.536,20.913,A_36 A_38 A_51 A_53 A_71 A_84 A_106 A_167 A_168
0,pocket2,2,0.577,0.726,32.606,20.853,A_31 A_67 A_104 A_109 A_149 A_150 A_155


In [73]:
# Compute ligand centers and diameters
ligand_centers = calculate_ligand_centers(pdb_file)
ligand_diameters = calculate_ligand_diameter(pdb_file)

# Compute grid sizes
grid_sizes = {lig_id: calculate_grid_size_ligand(diameter) for lig_id, diameter in ligand_diameters.items()}

# Print results
print("Ligand Centers:", ligand_centers)
print("Ligand Diameters:", ligand_diameters)
print("Computed Grid Sizes:", grid_sizes)

Ligand Centers: {'A_GG5_361': (-3.732, 37.54, 19.149), 'A_GG5_362': (24.619, 22.289, 26.946)}
Ligand Diameters: {'A_GG5_361': 8.652, 'A_GG5_362': 8.632}
Computed Grid Sizes: {'A_GG5_361': 22.922, 'A_GG5_362': 22.906}


In [74]:
define_bounding_box_ligand(ligand_centers, ligand_diameters)

Bounding box for A_GG5_361: Min=[-15.193000000000001, 26.079, 7.688000000000001], Max=[7.729, 49.001, 30.61]
Bounding box for A_GG5_362: Min=[13.166, 10.836000000000002, 15.493000000000002], Max=[36.072, 33.742000000000004, 38.399]


{'A_GG5_361': {'min': [-15.193000000000001, 26.079, 7.688000000000001],
  'max': [7.729, 49.001, 30.61]},
 'A_GG5_362': {'min': [13.166, 10.836000000000002, 15.493000000000002],
  'max': [36.072, 33.742000000000004, 38.399]}}

In [75]:
df_kmean

Unnamed: 0,chain,resi,resn,p(bind),resn_coordinates,center_of_mass_x,center_of_mass_y,center_of_mass_z,PDBcode,center_of_mass_x_normalized,...,p(bind)_weight_10,p(bind)_weight_20,p(bind)_weight_50,p(bind)_weight_100,protein_x,protein_y,protein_z,protein_diameter,kmeans_cluster,residue_id
2826,A,31,G,0.55723,"[('N', -5.827, 34.477, 11.096), ('C', -4.911, ...",-5.722,33.118,11.88,3hvc,0.32318,...,5.733093,11.466186,28.665466,57.330931,6.706,34.061,23.333,66.946,0,A_31
2828,A,36,G,0.766697,"[('N', -5.954, 27.711, 18.58), ('C', -7.293, 2...",-6.642,28.671,17.527,3hvc,0.308704,...,7.890504,15.781008,39.452519,78.905038,6.706,34.061,23.333,66.946,1,A_36
2830,A,38,V,0.881284,"[('N', -8.068, 33.133, 15.253), ('C', -8.043, ...",-7.559,34.871,15.162,3hvc,0.294276,...,9.070701,18.141403,45.353507,90.707014,6.706,34.061,23.333,66.946,1,A_38
2843,A,51,A,0.852709,"[('N', -8.065, 41.966, 18.286), ('C', -7.896, ...",-7.906,40.462,18.581,3hvc,0.288816,...,8.776386,17.552772,43.88193,87.763861,6.706,34.061,23.333,66.946,1,A_51
2845,A,53,K,0.971511,"[('N', -9.336, 35.82, 19.8), ('C', -8.89, 34.4...",-7.546,33.385,20.729,3hvc,0.29448,...,10.0,20.0,50.0,100.0,6.706,34.061,23.333,66.946,1,A_53
2859,A,67,R,0.642798,"[('N', -3.827, 23.684, 30.012), ('C', -3.212, ...",-3.542,24.102,26.564,3hvc,0.35748,...,6.614405,13.228809,33.072023,66.144045,6.706,34.061,23.333,66.946,0,A_67
2863,A,71,E,0.76848,"[('N', -2.792, 29.609, 29.857), ('C', -2.793, ...",-3.364,31.209,27.314,3hvc,0.360281,...,7.908875,15.817749,39.544373,79.088745,6.706,34.061,23.333,66.946,1,A_71
2876,A,84,I,0.838801,"[('N', 1.789, 40.957, 24.754), ('C', 0.362, 40...",0.202,40.264,24.265,3hvc,0.416389,...,8.633139,17.266278,43.165696,86.331392,6.706,34.061,23.333,66.946,1,A_84
2896,A,104,L,0.496336,"[('N', -10.454, 35.194, 23.855), ('C', -9.481,...",-8.818,35.524,24.353,3hvc,0.274467,...,5.105914,10.211828,25.529571,51.059142,6.706,34.061,23.333,66.946,0,A_104
2898,A,106,T,0.792738,"[('N', -7.422, 41.132, 22.49), ('C', -6.258, 4...",-5.757,41.891,22.039,3hvc,0.322629,...,8.158719,16.317437,40.793593,81.587185,6.706,34.061,23.333,66.946,1,A_106


In [83]:

select_cluster_define_bounding_box(df_kmean, 1, 'kmeans_cluster', selection_name="cluster_selection")


✅ Selection 'cluster_selection' created for cluster 1
📦 Bounding box: Min=[-7.906, 28.671, 15.162], Max=[2.606, 41.891, 27.314]


{'min': array([-7.906, 28.671, 15.162]),
 'max': array([ 2.606, 41.891, 27.314]),
 'size': 20.807249890362733}

In [25]:
df_kmean.head()

Unnamed: 0,chain,resi,resn,p(bind),resn_coordinates,center_of_mass_x,center_of_mass_y,center_of_mass_z,PDBcode,center_of_mass_x_normalized,...,p(bind)_weight_5,p(bind)_weight_10,p(bind)_weight_20,p(bind)_weight_50,p(bind)_weight_100,protein_x,protein_y,protein_z,protein_diameter,kmeans_cluster
2826,A,31,G,0.55723,"[('N', -5.827, 34.477, 11.096), ('C', -4.911, ...",-5.722,33.118,11.88,3hvc,0.32318,...,2.866547,5.733093,11.466186,28.665466,57.330931,6.706,34.061,23.333,66.946,0
2828,A,36,G,0.766697,"[('N', -5.954, 27.711, 18.58), ('C', -7.293, 2...",-6.642,28.671,17.527,3hvc,0.308704,...,3.945252,7.890504,15.781008,39.452519,78.905038,6.706,34.061,23.333,66.946,1
2830,A,38,V,0.881284,"[('N', -8.068, 33.133, 15.253), ('C', -8.043, ...",-7.559,34.871,15.162,3hvc,0.294276,...,4.535351,9.070701,18.141403,45.353507,90.707014,6.706,34.061,23.333,66.946,1
2843,A,51,A,0.852709,"[('N', -8.065, 41.966, 18.286), ('C', -7.896, ...",-7.906,40.462,18.581,3hvc,0.288816,...,4.388193,8.776386,17.552772,43.88193,87.763861,6.706,34.061,23.333,66.946,1
2845,A,53,K,0.971511,"[('N', -9.336, 35.82, 19.8), ('C', -8.89, 34.4...",-7.546,33.385,20.729,3hvc,0.29448,...,5.0,10.0,20.0,50.0,100.0,6.706,34.061,23.333,66.946,1


In [87]:
centroid

Unnamed: 0_level_0,name,binding_probability,center_x,center_y,center_z,residue_ids
kmeans_cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,pocket1,0.84,-3.75,35.536,20.913,A_36 A_38 A_51 A_53 A_71 A_84 A_106 A_167 A_168
0,pocket2,0.577,0.726,32.606,20.853,A_31 A_67 A_104 A_109 A_149 A_150 A_155


In [None]:
def select_cluster_define_bounding_box(df, cluster_id, cluster_column, selection_name="cluster_selection"):
    """
    Select residues based on the clustering method and create a bounding box in PyMOL.

    Args:
        df (pd.DataFrame): DataFrame containing residue data.
        cluster_id (int): The cluster ID to select residues for.
        cluster_column (str): Column name that stores cluster assignments.
        selection_name (str): Name of the selection in PyMOL.

    Returns:
        dict: Bounding box coordinates { "min": np.array, "max": np.array } for the selected cluster.
    """
    # Ensure required columns exist
    required_columns = {"chain", "resi", "center_of_mass_x", "center_of_mass_y", "center_of_mass_z"}
    missing_columns = required_columns - set(df.columns)
    if missing_columns:
        raise ValueError(f"Missing columns in DataFrame: {missing_columns}")

    # Filter residues based on the selected cluster
    selected_residues = df[df[cluster_column] == cluster_id]
    
    if selected_residues.empty:
        print(f"⚠️ No residues found for cluster {cluster_id} in column '{cluster_column}'")
        return None

    # Construct a PyMOL selection string
    selection_string = " or ".join(
        [f"(chain {row['chain']} and resi {row['resi']})" for _, row in selected_residues.iterrows()]
    )
    print(f"string: {selection_string}")

    # Check if PyMOL is available
    try:
        cmd.select(selection_name, selection_string)
    except Exception as e:
        print(f"⚠️ PyMOL selection failed: {e}")
        return None

    # Compute the bounding box
    min_coords = selected_residues[['center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']].min().to_numpy()
    max_coords = selected_residues[['center_of_mass_x', 'center_of_mass_y', 'center_of_mass_z']].max().to_numpy()

    # Convert coordinates to float for compatibility
    min_coords = np.array([float(coord) for coord in min_coords])
    max_coords = np.array([float(coord) for coord in max_coords])

    # Add pseudoatoms for visualization
    try:
        cmd.pseudoatom(f"{selection_name}_box_min", pos=min_coords.tolist(), color="blue")
        cmd.pseudoatom(f"{selection_name}_box_max", pos=max_coords.tolist(), color="red")
    except Exception as e:
        print(f"⚠️ Failed to create pseudoatoms in PyMOL: {e}")

    print(f"✅ Selection '{selection_name}' created for cluster {cluster_id}")
    print(f"📦 Bounding box: Min={min_coords.tolist()}, Max={max_coords.tolist()}")
    
    #euclidean distance
    size = np.linalg.norm(max_coords - min_coords)

    return {"min": min_coords, "max": max_coords, "size": size}

def calculate_size_selected_residue(df, cluster_name, selection_name="cluster_selection"):
    """
    Calculate the size of the selected residues based on the clustering method.

    Args:
        df (pd.DataFrame): DataFrame containing residue data.
        cluster_name (str): The cluster name to select residues for.
        selection_name (str): Name of the selection in PyMOL.

    Returns:
        dict: Bounding box coordinates { "min": np.array, "max": np.array, "size": float } for the selected cluster.
    """
    # Ensure required columns exist
    required_columns = {"name", "residue_ids", "center_x", "center_y", "center_z"}
    missing_columns = required_columns - set(df.columns)
    if missing_columns:
        raise ValueError(f"Missing columns in DataFrame: {missing_columns}")

    # Filter residues based on the selected cluster
    selected_residues = df[df['name'] == cluster_name]

    if selected_residues.empty:
        print(f"⚠️ No residues found for cluster {cluster_name}")
        return None

    # Construct a PyMOL selection string
    selection_string = " or ".join(
        [f"(chain {resi.split('_')[0]} and resi {resi.split('_')[1]})" for resi in selected_residues['residue_ids']]
    )
    print(f"Selection string: {selection_string}")

   # Check if PyMOL is available
    try:
        cmd.select(selection_name, selection_string)
    except Exception as e:
        print(f"⚠️ PyMOL selection failed: {e}")
        return None

    # Compute the bounding box
    min_coords = selected_residues[['center_x', 'center_y', 'center_z']].min().to_numpy()
    max_coords = selected_residues[['center_x', 'center_y', 'center_z']].max().to_numpy()

    # Convert coordinates to float for compatibility
    min_coords = np.array([float(coord) for coord in min_coords])
    max_coords = np.array([float(coord) for coord in max_coords])

    # Add pseudoatoms for visualization
    try:
        cmd.pseudoatom(f"{selection_name}_box_min", pos=min_coords.tolist(), color="blue")
        cmd.pseudoatom(f"{selection_name}_box_max", pos=max_coords.tolist(), color="red")
    except Exception as e:
        print(f"⚠️ Failed to create pseudoatoms in PyMOL: {e}")

    print(f"✅ Selection '{selection_name}' created for cluster {cluster_id}")
    print(f"📦 Bounding box: Min={min_coords.tolist()}, Max={max_coords.tolist()}")
    
    #euclidean distance
    size = np.linalg.norm(max_coords - min_coords)

    return {"min": min_coords, "max": max_coords, "size": size}


select_cluster_define_bounding_box(df_kmean, 1, 'kmeans_cluster', selection_name="cluster_selection")


string: (chain A and resi 36) or (chain A and resi 38) or (chain A and resi 51) or (chain A and resi 53) or (chain A and resi 71) or (chain A and resi 84) or (chain A and resi 106) or (chain A and resi 167) or (chain A and resi 168)
✅ Selection 'cluster_selection' created for cluster 1
📦 Bounding box: Min=[-7.906, 28.671, 15.162], Max=[2.606, 41.891, 27.314]


{'min': array([-7.906, 28.671, 15.162]),
 'max': array([ 2.606, 41.891, 27.314]),
 'size': 20.807249890362733}

In [None]:
def calculate_size_selected_residue(df, cluster_name, selection_name="cluster_selection"):
    """
    Calculate the size of the selected residues based on the clustering method.

    Args:
        df (pd.DataFrame): DataFrame containing residue data.
        cluster_name (str): The cluster name to select residues for.
        selection_name (str): Name of the selection in PyMOL.

    Returns:
        dict: Bounding box coordinates { "min": np.array, "max": np.array, "size": float } for the selected cluster.
    """
    # Ensure required columns exist
    required_columns = {"name", "residue_ids", "center_x", "center_y", "center_z"}
    missing_columns = required_columns - set(df.columns)
    if missing_columns:
        raise ValueError(f"Missing columns in DataFrame: {missing_columns}")

    # Filter residues based on the selected cluster
    selected_residues = df[df['name'] == cluster_name]

    if selected_residues.empty:
        print(f"⚠️ No residues found for cluster {cluster_name}")
        return None

    # Flatten 'residue_ids' column and construct a PyMOL selection string
    residue_list = selected_residues['residue_ids'].tolist()
    
    # Split each entry to extract chain and residue number
    selection_string = " or ".join(
        [f"(chain {res.split('_')[0]} and resi {res.split('_')[1]})" for res_list in residue_list for res in res_list.split()]
    )
    
    print(f"Selection string: {selection_string}")

   # Check if PyMOL is available
    try:
        cmd.select(selection_name, selection_string)
    except Exception as e:
        print(f"⚠️ PyMOL selection failed: {e}")
        return None

    # Compute the bounding box
    min_coords = selection_string[['center_x', 'center_y', 'center_z']].min().to_numpy()
    max_coords = selection_string[['center_x', 'center_y', 'center_z']].max().to_numpy()

    # Convert coordinates to float for compatibility
    min_coords = np.array([float(coord) for coord in min_coords])
    max_coords = np.array([float(coord) for coord in max_coords])

    # Add pseudoatoms for visualization
    try:
        cmd.pseudoatom(f"{selection_name}_box_min", pos=min_coords.tolist(), color="blue")
        cmd.pseudoatom(f"{selection_name}_box_max", pos=max_coords.tolist(), color="red")
    except Exception as e:
        print(f"⚠️ Failed to create pseudoatoms in PyMOL: {e}")

    #print(f"✅ Selection '{selection_name}' created for cluster {cluster_id}")
    print(f"📦 Bounding box: Min={min_coords.tolist()}, Max={max_coords.tolist()}")
    
    #euclidean distance
    size = np.linalg.norm(max_coords - min_coords)

    return {"min": min_coords, "max": max_coords, "size": size}


name     ,  rank,   score, probability, sas_points, surf_atoms,   center_x,   center_y,   center_z, residue_ids, surf_atom_ids
pocket1  ,     1,    8.80,       0.471,         87,         46,    -2.8817,    37.6926,    18.5196, A_104 A_105 A_106 A_107 A_108 A_109 A_110 A_111 A_112 A_154 A_157 A_167 A_168 A_169 A_30 A_32 A_38 A_51 A_53 A_71 A_75 A_84 A_86, 231 239 256 257 258 351 363 365 366 367 524 557 558 559 634 636 637 649 794 795 796 802 807 811 812 817 830 832 835 836 837 843 846 847 853 856 1170 1193 1272 1273 1274 1275 1280 1281 1291 1292
pocket2  ,     2,    2.70,       0.081,         37,         31,    24.1331,    21.1997,    28.0418, A_192 A_195 A_196 A_197 A_246 A_249 A_250 A_251 A_252 A_255 A_291 A_293,