In [2]:
# @markdown # Step 1: Install all neccessary packages
#Download pymol in colabsystem
from IPython.utils import io
import tqdm.notebook
import os
"""The PyMOL installation is done inside two nested context managers. This approach
was inspired by Dr. Christopher Schlicksup's (of the Phenix group at
Lawrence Berkeley National Laboratory) method for installing cctbx
in a Colab Notebook. He presented his work on September 1, 2021 at the IUCr
Crystallographic Computing School. I adapted Chris's approach here. It replaces my first approach
that requires seven steps. My approach was presentated at the SciPy2021 conference
in July 2021 and published in the
[proceedings](http://conference.scipy.org/proceedings/scipy2021/blaine_mooers.html).
The new approach is easier for beginners to use. The old approach is easier to debug
and could be used as a back-up approach.

"""
total = 100
with tqdm.notebook.tqdm(total=total) as pbar:
    with io.capture_output() as captured:

        !pip install -q condacolab
        import condacolab
        condacolab.install()
        pbar.update(10)

        import sys
        sys.path.append('/usr/local/lib/python3.7/site-packages/')
        pbar.update(20)

        # Install PyMOL
        %shell mamba install -c schrodinger pymol-bundle --yes

        pbar.update(90)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
# @markdown # Run prescreening & Extract Coordinates
# pip install scipy pytz pandas openpyxl BioPython
import os
import pandas as pd
import numpy as np
from Bio.PDB import PDBParser # For PDB parsing
from Bio.PDB.Residue import Residue # For type hinting if needed
from Bio.PDB.Chain import Chain # For type hinting if needed
import itertools
import requests
import traceback
from scipy.spatial import KDTree
import datetime
import pytz
from typing import List, Optional, Tuple # For type hinting

# --- Configuration ---
# Specify the SINGLE PDB file path
Target_pdb_file = "/content/1EP0_alanine_dimer.pdb" # @param {type:"string"}

# Specify Specific Residue Number
# @markdown ### Specify a residue number (filters for triads containing AT LEAST ONE residue with this number). Use 0 for no specific number filtering.
specific_residue_number = 0  # @param {type:"integer"}

# --- Specify Exact Output Excel File Path ---
Output_excel_file_path = "/content/1EP0_full_Coords.xlsx"# @param {type:"string"}

# --- Determine Output Directory from Output File Path ---
output_directory = os.path.dirname(Output_excel_file_path)
if output_directory:
    os.makedirs(output_directory, exist_ok=True)
    print(f"➡️ Output directory: {output_directory}")
else:
    output_directory = "."
    print(f"➡️ Output directory: Current directory")

print(f"➡️ Output Excel file with coordinates will be saved as: {Output_excel_file_path}")

if specific_residue_number != 0:
    print(f"⚠️ WARNING: Filtering by residue number '{specific_residue_number}'.")


# --- Download threshold configuration ---
# Thresholds file will be saved in the same directory as the output Excel file
base_url = "https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/Threshold"
Metal = 'Cu'  # @param ["Zn", "Mn", "Cu", "Fe"]
Threshold_Download_Set = '3His'  # @param ["3His", "2His_1Asp", "2His_1Glu", "2His_1Cys"]
Range = '3'  # @param ["1", "2", "3", "4","5"]
thresholds_url = f"{base_url}/{Metal}/{Threshold_Download_Set}/{Range}.xlsx"
thresholds_file = os.path.join(output_directory, f"thresholds_{Metal}_{Threshold_Download_Set}_R{Range}.xlsx")

print(f"⬇️ Downloading threshold set '{Threshold_Download_Set}' for '{Metal}' (Range {Range}) from: {thresholds_url}")
print(f"⚠️ NOTE: These thresholds will be applied to any found triad (regardless of residue type).")

response = requests.get(thresholds_url)
if response.status_code == 200:
    with open(thresholds_file, "wb") as file:
        file.write(response.content)
    print(f"✅ Thresholds downloaded successfully to {thresholds_file}")
else:
    raise ValueError(f"❌ Failed to download thresholds from {thresholds_url}. Status code: {response.status_code}")

# --- Load thresholds ---
print("⚙️ Loading thresholds...")
try:
    thresholds_df = pd.read_excel(thresholds_file, sheet_name="Sheet1")
except FileNotFoundError:
     print(f"❌ Error: Thresholds file not found at {thresholds_file}. Cannot load thresholds.")
     raise # Re-raise the error to stop execution
thresholds = {
    row["Parameter"]: (row["Min"], row["Max"])
    for _, row in thresholds_df.iterrows()
    if pd.notna(row["Min"]) and pd.notna(row["Max"])
}
required_thresholds = ["alpha_distance_range", "beta_distance_range", "ratio_threshold_range", "pie_threshold_range"]
if not all(key in thresholds for key in required_thresholds):
    missing = [key for key in required_thresholds if key not in thresholds]
    # Allow script to continue but warn if only geometric coords are needed?
    # For now, raise error as geometry checks depend on them.
    raise ValueError(f"❌ Missing required thresholds in the downloaded file: {missing}")

# Only define if present, otherwise geometry checks later will fail if needed
alpha_distance_range = thresholds.get("alpha_distance_range")
beta_distance_range = thresholds.get("beta_distance_range")
ratio_threshold_range = thresholds.get("ratio_threshold_range")
pie_threshold_range = thresholds.get("pie_threshold_range")


print("📊 Thresholds loaded:")
# Check if thresholds were actually loaded before printing
if all(t is not None for t in [alpha_distance_range, beta_distance_range, ratio_threshold_range, pie_threshold_range]):
    for key, value in thresholds.items():
        print(f"   - {key}: Min={value[0]}, Max={value[1]}")
else:
    print("   ⚠️ Warning: Could not load all required thresholds.")


# --- Helper Functions ---
def calculate_pie(v1, v2):
    """Calculates the angle (in degrees) between two vectors."""
    dot = np.dot(v1, v2)
    norm = np.linalg.norm(v1) * np.linalg.norm(v2)
    if norm == 0: return np.nan
    # Clip to avoid domain errors with arccos due to floating point inaccuracies
    angle_rad = np.arccos(np.clip(dot / norm, -1.0, 1.0))
    return np.degrees(angle_rad)

# --- NEW Coordinate Extraction Helper ---
def extract_coordinates(chain: Optional[Chain], res_id: int, atom_name: str) -> List[Optional[float]]:
    """Safely extracts coordinates for a given atom in a residue."""
    # Return None values immediately if chain object is None
    if chain is None:
        # print(f"Debug: Chain object is None for res_id {res_id}, atom {atom_name}") # Optional Debug
        return [np.nan, np.nan, np.nan]
    try:
        # Use the standard way to access residues in BioPython: tuple key
        # Assumes standard residues (no HETATM flag) and no insertion code (' ')
        residue_key = (' ', res_id, ' ')
        if residue_key not in chain:
             # Try accessing just by number for simplicity if the tuple key fails
             if res_id not in chain:
                  # print(f"Debug: Residue {res_id} not found in chain {chain.id}") # Optional Debug
                  return [np.nan, np.nan, np.nan]
             else:
                  residue = chain[res_id] # Fallback access
        else:
             residue = chain[residue_key] # Preferred access

        # Check if the atom exists
        if atom_name not in residue:
            # print(f"Debug: Atom {atom_name} not found in residue {res_id} of chain {chain.id}") # Optional Debug
            return [np.nan, np.nan, np.nan]

        # Extract coordinates
        coord = residue[atom_name].coord
        # Ensure coord is a numpy array of size 3, convert Nones to NaN
        if coord is None or len(coord) != 3:
             return [np.nan, np.nan, np.nan]
        # Return list of floats
        return [float(c) for c in coord]

    except Exception as e:
        # Catch any other unexpected errors during lookup/extraction
        # print(f"❌ Error extracting {atom_name} for Res {res_id} in Chain {chain.id if chain else 'N/A'}: {e}") # Optional Debug
        return [np.nan, np.nan, np.nan] # Return list of NaNs on error


# --- Main Processing Function ---
def process_pdb_file(pdb_file_path, output_excel_path):
    """
    Processes a single PDB file. Finds triads meeting spatial, optional number,
    and geometric criteria. Extracts CA/CB coordinates. Saves combined results.
    REDUNDANCY REMOVAL IS DISABLED.
    """
    pdb_name = os.path.basename(pdb_file_path)
    print(f"🔄 Processing: {pdb_name} (Number Filter: {specific_residue_number if specific_residue_number != 0 else 'OFF'}, No Redundancy Removal)")

    # --- Check if thresholds loaded ---
    if not all(t is not None for t in [alpha_distance_range, beta_distance_range, ratio_threshold_range, pie_threshold_range]):
         print(f"   ❌ Cannot proceed: Required geometric thresholds were not loaded successfully.")
         return # Stop processing this file if thresholds are missing

    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", pdb_file_path)
        # Ensure only one model is processed if multiple exist
        if len(structure) > 1:
            print(f"   ⚠️ Warning: Multiple models found in {pdb_name}. Using only the first model (ID: {structure[0].id}).")
        model = structure[0]

        # --- Create Chain Lookup ---
        # Store chains in a dictionary for faster access
        chains_dict = {chain.id: chain for chain in model}
        if not chains_dict:
             print(f"   ❌ Error: No chains found in model {model.id} of {pdb_name}.")
             return

        all_residues_full = [res for chain in chains_dict.values() for res in chain if res.get_id()[0] == " "]
        print(f"   Found {len(all_residues_full)} standard residues.")

        # --- KDTree Pre-filtering ---
        residues_for_tree = [res for res in all_residues_full if res.has_id("CA")]
        if len(residues_for_tree) < 3:
            print(f"   ⚠️ Skipping {pdb_name}: Not enough residues (<3) with CA atoms.")
            return
        coords_ca = np.array([res["CA"].coord for res in residues_for_tree])
        residue_map = residues_for_tree # Map KDTree index back to Residue object
        kdtree = KDTree(coords_ca)
        max_dist = alpha_distance_range[1] * 1.1
        pairs = kdtree.query_pairs(r=max_dist)
        potential_triad_indices = set()
        for i, j in pairs:
            indices_k_near_i = kdtree.query_ball_point(coords_ca[i], r=max_dist)
            indices_k_near_j = kdtree.query_ball_point(coords_ca[j], r=max_dist)
            common_neighbors = set(indices_k_near_i).intersection(indices_k_near_j)
            for k in common_neighbors:
                if k != i and k != j:
                    triad_indices = tuple(sorted((i, j, k)))
                    potential_triad_indices.add(triad_indices)
        print(f"   Generated {len(potential_triad_indices)} unique potential spatial triads.")

        # --- Map Indices to Residues ---
        all_spatial_triads = []
        for idx_i, idx_j, idx_k in potential_triad_indices:
             # Ensure indices are within bounds of residue_map
             if all(idx < len(residue_map) for idx in [idx_i, idx_j, idx_k]):
                 comb = (residue_map[idx_i], residue_map[idx_j], residue_map[idx_k])
                 all_spatial_triads.append(comb)
             else:
                  print(f"   ⚠️ Warning: KDTree index out of bounds. Skipping triad indices {(idx_i, idx_j, idx_k)}.")
        print(f"   Mapped {len(all_spatial_triads)} spatial triads.")


        # --- Apply Residue NUMBER Filter ---
        triads_meeting_number_criteria = []
        if specific_residue_number == 0:
            print(f"   Specific Residue Number Filter: OFF")
            triads_meeting_number_criteria = all_spatial_triads
        else:
            print(f"   Filtering {len(all_spatial_triads)} triads: At least one residue must have number {specific_residue_number}")
            for comb in all_spatial_triads:
                if any(res.get_id()[1] == specific_residue_number for res in comb):
                    triads_meeting_number_criteria.append(comb)
            print(f"   {len(triads_meeting_number_criteria)} triads have at least one residue with number {specific_residue_number}.")

        # --- Final list for geometric checks ---
        final_triads_to_process = triads_meeting_number_criteria
        print(f"   Proceeding to detailed geometric checks for {len(final_triads_to_process)} triads.")

        # --- Detailed Geometric Filtering & Initial DataFrame Creation ---
        def get_triad_type(comb):
            chains = [res.get_full_id()[2] for res in comb]
            return "intra" if len(set(chains)) == 1 else "inter"

        results = []
        for comb in final_triads_to_process:
            try:
                # Ensure CA and CB atoms exist for distance calculations
                if not all(res.has_id("CA") and res.has_id("CB") for res in comb): continue

                alpha_distances, beta_distances = [], []
                valid_distances = True
                for res1, res2 in itertools.combinations(comb, 2):
                     d_ca = np.linalg.norm(res1["CA"].coord - res2["CA"].coord)
                     d_cb = np.linalg.norm(res1["CB"].coord - res2["CB"].coord)
                     # Apply distance thresholds
                     if not (alpha_distance_range[0] <= d_ca <= alpha_distance_range[1] and \
                             beta_distance_range[0] <= d_cb <= beta_distance_range[1]):
                         valid_distances = False; break
                     alpha_distances.append(d_ca); beta_distances.append(d_cb)

                if valid_distances and len(alpha_distances) == 3:
                     row = {"PDB_ID": pdb_name, "Triad_Type": get_triad_type(comb)}
                     for i, res in enumerate(comb):
                         full_id = res.get_full_id()
                         row[f"Coord_chain_id_number{i+1}"] = full_id[2]
                         row[f"Coord_residue_number{i+1}"] = res.get_id()[1]
                         row[f"Coord_residue_name{i+1}"] = res.get_resname()
                     for i in range(3):
                         row[f"Alpha Distance {i+1}"] = alpha_distances[i]
                         row[f"Beta Distance {i+1}"] = beta_distances[i]
                     results.append(row) # Add row if distances are valid

            except KeyError: continue # Skip if atoms are missing during coord access
            except Exception as e_inner: print(f"     Error processing triad {comb} during geometry checks: {e_inner}"); continue

        print(f"   Found {len(results)} triads passing initial distance filters.")
        df = pd.DataFrame(results) # DataFrame of triads passing distance filter

        # --- Apply Ratio Filter ---
        if not df.empty:
             def pass_ratio(row):
                 try:
                     for i in range(3):
                         # Check for zero beta distance before division
                         if row[f"Beta Distance {i+1}"] is None or row[f"Beta Distance {i+1}"] == 0: return False
                         ratio = row[f"Alpha Distance {i+1}"] / row[f"Beta Distance {i+1}"]
                         if not (ratio_threshold_range[0] <= ratio <= ratio_threshold_range[1]): return False
                     return True
                 except (TypeError, KeyError): return False # Handle potential None values or missing keys
             df_ratio = df[df.apply(pass_ratio, axis=1)].copy()
             print(f"   Found {len(df_ratio)} triads passing ratio filter.")
        else: df_ratio = pd.DataFrame()

        # --- Apply Pie Angle Filter ---
        if not df_ratio.empty:
            def compute_pie(row):
                try:
                    # Use chains_dict (created earlier from the model) for lookup
                    chain1 = chains_dict.get(row['Coord_chain_id_number1'])
                    chain2 = chains_dict.get(row['Coord_chain_id_number2'])
                    chain3 = chains_dict.get(row['Coord_chain_id_number3'])
                    if not all([chain1, chain2, chain3]): # Check if all chains were found
                         return pd.Series([np.nan]*3, index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])

                    # Retrieve residue objects (handle potential errors if res number not in chain)
                    res1 = chain1[(' ', row['Coord_residue_number1'], ' ')] if (' ', row['Coord_residue_number1'], ' ') in chain1 else None
                    res2 = chain2[(' ', row['Coord_residue_number2'], ' ')] if (' ', row['Coord_residue_number2'], ' ') in chain2 else None
                    res3 = chain3[(' ', row['Coord_residue_number3'], ' ')] if (' ', row['Coord_residue_number3'], ' ') in chain3 else None
                    if not all([res1, res2, res3]):
                        return pd.Series([np.nan]*3, index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])

                    res_objs = [res1, res2, res3]
                    angles = []
                    for i, j in [(0,1), (0,2), (1,2)]:
                         # Check atoms exist before calculating vectors
                         if not (res_objs[i].has_id("CA") and res_objs[i].has_id("CB") and \
                                 res_objs[j].has_id("CA") and res_objs[j].has_id("CB")):
                             angles.append(np.nan) # Append NaN if atoms missing
                             continue
                         v_ca = res_objs[j]["CA"].coord - res_objs[i]["CA"].coord
                         v_cb = res_objs[j]["CB"].coord - res_objs[i]["CB"].coord
                         angles.append(calculate_pie(v_ca, v_cb))
                    # Ensure we always return a Series of length 3
                    while len(angles) < 3: angles.append(np.nan)
                    return pd.Series(angles, index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])

                except Exception: # Catch any other error during residue lookup or calculation
                    return pd.Series([np.nan]*3, index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])


            pie_results = df_ratio.apply(compute_pie, axis=1)
            df_ratio[["Pie_1_2", "Pie_1_3", "Pie_2_3"]] = pie_results
            # Filter based on calculated pie angles
            for col in ["Pie_1_2", "Pie_1_3", "Pie_2_3"]:
                 # Apply filter only if pie angle is not NaN
                 df_ratio[f"{col}_Filter"] = df_ratio[col].apply(
                     lambda x: pie_threshold_range[0] < x < pie_threshold_range[1] if pd.notnull(x) else False)
            df_ratio['Pie_Filter'] = df_ratio[[f'{col}_Filter' for col in ['Pie_1_2', 'Pie_1_3', 'Pie_2_3']]].all(axis=1)
            df_final = df_ratio[df_ratio['Pie_Filter']].copy() # This is the final DataFrame before coordinate extraction
            print(f"   Found {len(df_final)} triads passing pie angle filter.")
        else: df_final = pd.DataFrame() # df_final is empty if df_ratio was empty


        # --- Coordinate Extraction (Integrated Step 3) ---
        df_with_coords = pd.DataFrame() # Initialize empty DataFrame for results
        if not df_final.empty:
            print(f"   Extracting CA and CB coordinates for {len(df_final)} final triads...")
            ca_coords, cb_coords = [], []

            # Use the chains_dict created earlier
            for idx, row in df_final.iterrows():
                # Get chain objects using the dictionary (safer than re-parsing)
                # Use .get() to handle potential missing chain IDs gracefully
                chain1 = chains_dict.get(row['Coord_chain_id_number1'])
                chain2 = chains_dict.get(row['Coord_chain_id_number2'])
                chain3 = chains_dict.get(row['Coord_chain_id_number3'])

                # Extract Cα coordinates
                ca1 = extract_coordinates(chain1, row['Coord_residue_number1'], 'CA')
                ca2 = extract_coordinates(chain2, row['Coord_residue_number2'], 'CA')
                ca3 = extract_coordinates(chain3, row['Coord_residue_number3'], 'CA')
                ca_coords.append([*ca1, *ca2, *ca3]) # Flatten coordinates [x1,y1,z1, x2,y2,z2, x3,y3,z3]

                # Extract Cβ coordinates
                cb1 = extract_coordinates(chain1, row['Coord_residue_number1'], 'CB')
                cb2 = extract_coordinates(chain2, row['Coord_residue_number2'], 'CB')
                cb3 = extract_coordinates(chain3, row['Coord_residue_number3'], 'CB')
                cb_coords.append([*cb1, *cb2, *cb3]) # Flatten coordinates

            # Define column names for coordinate DataFrames
            ca_cols = ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z']
            cb_cols = ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z']

            # Create DataFrames from the extracted coordinates
            df_ca = pd.DataFrame(ca_coords, columns=ca_cols)
            df_cb = pd.DataFrame(cb_coords, columns=cb_cols)

            # Combine the original filtered data with the new coordinate data
            # Use reset_index to ensure indices align correctly during concatenation
            df_with_coords = pd.concat([df_final.reset_index(drop=True), df_ca, df_cb], axis=1)
            print(f"   Coordinate extraction complete.")

        else:
            print("   No triads passed all filters, skipping coordinate extraction.")
            df_with_coords = df_final # Assign empty df_final if no triads


        # --- Redundancy Removal DISABLED ---

        # --- Output ---
        # Save the DataFrame that includes coordinates (df_with_coords)
        if not df_with_coords.empty:
             print(f"   Saving final results with coordinates...")
             # Use the exact output path provided
             with pd.ExcelWriter(output_excel_path) as writer:
                  num_suffix = f"_Num{specific_residue_number}" if specific_residue_number != 0 else "_ANY_Num"
                  # Update sheet name to reflect content
                  sheet_name = f"Final_Coords{num_suffix}"
                  # Save the DataFrame with coordinates
                  df_with_coords.to_excel(writer, sheet_name=sheet_name, index=False)
             print(f"✅ Finished: {pdb_name}. Results with coordinates saved to {output_excel_path} (Total triads found: {len(df_with_coords)})")
        else:
             print(f"✅ Finished: {pdb_name}. No triads passed all filters. No output file created.")

    except FileNotFoundError: print(f"❌ Error: Input PDB file not found at {pdb_file_path}")
    except Exception as e: print(f"❌ An unexpected error occurred while processing {pdb_name}: {e}"); traceback.print_exc()


# --- Run Processing for the Single PDB File ---
if __name__ == "__main__":
    try:
        kst = pytz.timezone('Asia/Seoul')
        current_time_kst = datetime.datetime.now(kst)
        print(f"\n--- Starting Single File Processing & Coordinate Extraction ---") # Updated title
        print(f"Current Time (KST): {current_time_kst.strftime('%Y-%m-%d %H:%M:%S %Z%z')}")
    except ImportError:
        print("\n--- Starting Single File Processing & Coordinate Extraction ---")
        print("Note: Could not determine KST time (pytz not installed?). Run 'pip install pytz' if needed.")

    print(f"Input PDB File: {Target_pdb_file}")
    print(f"Specific Number Filter: {specific_residue_number if specific_residue_number != 0 else 'OFF'}")
    print(f"Threshold Set Used (for download): {Metal} / {Threshold_Download_Set} / Range {Range}")
    print(f"Outputting results with coordinates to file: {Output_excel_file_path}") # Updated description
    print(f"(Thresholds file saved in: {output_directory})")
    print("Redundancy Removal: DISABLED")

    if not os.path.isfile(Target_pdb_file):
        print(f"⚠️ Error: Input PDB file not found at '{Target_pdb_file}'. Please check the path.")
    else:
        # Call the combined processing function
        process_pdb_file(Target_pdb_file, Output_excel_file_path)

    print("\n🎉 Processing finished.")

✅ Finished: 1EP0_alanine_dimer.pdb. Results with coordinates saved to /content/1EP0_full_Coords.xlsx (Total triads found: 4314)

🎉 Processing finished.


In [None]:
# @markdown # Step 4: Probability density map (Parallel Processing WITHIN a Single PDB)
# pip install scipy pandas openpyxl requests BioPython numpy pytz

import numpy as np
import pandas as pd
import os
import requests
import traceback # Import traceback for better error printing
import multiprocessing # Import multiprocessing
from Bio.PDB import PDBParser # PDBParser is still needed for structure loading
from scipy.spatial import KDTree # Import KDTree
import time # For timing if desired
import datetime
import pytz
from typing import List, Optional, Tuple # For type hinting
from Bio.PDB.Chain import Chain # For type hinting if needed


# --- Configuration and Setup ---
# Define paths for the SINGLE input coordinate Excel file and the corresponding PDB file
input_coords_file = '/content/1EP0_full_Coords.xlsx' # @param {type:"string"}
Input_pdb_file = '/content/1EP0_alanine_dimer.pdb' # @param {type:"string"}
Output_result_excel_file = '/content/Full_result.xlsx' # @param {type:"string"}
# Create output directory if it doesn't exist
output_result_directory = os.path.dirname(Output_result_excel_file)
if output_result_directory:
    os.makedirs(output_result_directory, exist_ok=True)
    print(f"➡️ Output directory: {output_result_directory}")
else:
    # If no directory specified, output files go in the current directory
    output_result_directory = "."
    print(f"➡️ Output directory: Current directory")


# Check if input files exist
if not os.path.isfile(input_coords_file):
    raise FileNotFoundError(f"Input coordinate Excel file not found: {Input_coordinate_excel_file}")
if not os.path.isfile(Input_pdb_file):
    raise FileNotFoundError(f"Input PDB file not found: {Input_pdb_file}")

# Define local file paths for downloaded data
prob_map_file = os.path.join(output_result_directory, 'map.xlsx') # Save downloads in output dir
thresholds_file = os.path.join(output_result_directory, 'threshold.xlsx')

# --- Download Data from GitHub ---
# Using the parameters from your previous examples (adjust if needed)
base_url = "https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/probability/"
Metal = 'Cu'  # @param ["Zn", "Mn", "Cu", "Fe"]
Combinations = '3His'  # @param ["3His", "2His_1Asp", "2His_1Glu", "2His_1Cys"]
map_url = f"{base_url}/{Metal}/{Combinations}/map.xlsx"
thresholds_url = f"{base_url}/{Metal}/{Combinations}/threshold.xlsx"

print(f"Downloading probability map from: {map_url}")
response_map = requests.get(map_url)
if response_map.status_code == 200:
    with open(prob_map_file, 'wb') as file:
        file.write(response_map.content)
    print(f"Downloaded map data to {prob_map_file}")
else:
    raise ValueError(f"Failed to download map file from {map_url}. Status code: {response_map.status_code}")

print(f"Downloading thresholds from: {thresholds_url}")
response_thresh = requests.get(thresholds_url)
if response_thresh.status_code == 200:
    with open(thresholds_file, 'wb') as file:
        file.write(response_thresh.content)
    print(f"Downloaded thresholds data to {thresholds_file}")
else:
    raise ValueError(f"Failed to download thresholds file from {thresholds_url}. Status code: {response_thresh.status_code}")


# --- Load and Process Data (Load ONCE in the main process) ---
print("\n--- Loading Shared Data ---")
# Load PDB Structure
pdb_id = os.path.splitext(os.path.basename(Input_pdb_file))[0]
print(f"Loading PDB structure: {pdb_id}...")
pdb_parser = PDBParser(QUIET=True)
try:
    structure = pdb_parser.get_structure(pdb_id, Input_pdb_file)
    print(f"Loaded structure.")
except Exception as e:
    print(f"❌ Error loading PDB file {Input_pdb_file}: {e}")
    raise # Stop if structure cannot be loaded

# Load Thresholds
print("Loading thresholds...")
try:
    thresholds_df = pd.read_excel(thresholds_file, sheet_name='Sheet1')
except FileNotFoundError:
    print(f"❌ Error: Thresholds file not found at {thresholds_file}")
    raise
thresholds = {}
for _, row in thresholds_df.iterrows():
    parameter = row['Parameter']
    min_value = row['Min']
    max_value = row['Max']
    if pd.notna(min_value) and pd.notna(max_value):
        thresholds[parameter] = (min_value, max_value)

required_keys = ['ca_distances_calc', 'cb_distances_calc', 'ratio', 'angle']
if not all(key in thresholds for key in required_keys):
    missing_keys = [key for key in required_keys if key not in thresholds]
    raise KeyError(f"Missing key(s) {missing_keys} in thresholds file.")
print("Thresholds loaded.")

# Load and Process Probability Map
print("Loading probability map...")
try:
    df_precomputed_prob_map = pd.read_excel(prob_map_file)
except FileNotFoundError:
    print(f"❌ Error: Probability map file not found at {prob_map_file}")
    raise

print("Processing probability map...")
map_req_cols = ['Calpha_Zn_Dist', 'Cbeta_Zn_Dist', 'CA-Zn-CB_Angle', 'Probability']
if not all(col in df_precomputed_prob_map.columns for col in map_req_cols):
     missing_map_cols = [col for col in map_req_cols if col not in df_precomputed_prob_map.columns]
     raise ValueError(f"Missing required columns in map file: {missing_map_cols}")

ca_bins = np.sort(df_precomputed_prob_map['Calpha_Zn_Dist'].unique())
cb_bins = np.sort(df_precomputed_prob_map['Cbeta_Zn_Dist'].unique())
angle_bins = np.sort(df_precomputed_prob_map['CA-Zn-CB_Angle'].unique())
# --- Check Indentation Carefully Here ---
try:
    pivoted_prob_map = df_precomputed_prob_map.pivot_table(
        index='Calpha_Zn_Dist', columns=['Cbeta_Zn_Dist', 'CA-Zn-CB_Angle'], values='Probability', fill_value=0
    ) # Line ~119
    expected_shape = (len(ca_bins), len(cb_bins) * len(angle_bins))
    if pivoted_prob_map.shape == expected_shape:
        # Correct indentation (4 spaces relative to 'if')
        prob_map_3d = pivoted_prob_map.values.reshape((len(ca_bins), len(cb_bins), len(angle_bins)))
        print("Probability map processed into 3D array.")
    else:
        # Correct indentation (4 spaces relative to 'else')
        raise ValueError(f"Pivoted map shape {pivoted_prob_map.shape} doesn't match expected shape {expected_shape} for reshaping.") # Likely Line 126 area
# Correct indentation (aligned with 'try')
except Exception as e:
    # Correct indentation (4 spaces relative to 'except')
    print(f"❌ Error processing probability map: {e}")
    raise # Re-raise the caught exception

# Load Input Coordinate Data
print(f"Loading input coordinate data from: {input_coords_file}...")
try:
    df_sites = pd.read_excel(input_coords_file)
except FileNotFoundError:
     print(f"❌ Error: Input coordinate file not found at {input_coords_file}")
     raise
if df_sites.empty:
    print("⚠️ Input coordinate file is empty. Nothing to process.")
    exit()
# Ensure PDB_ID column exists or add it based on filename
if 'PDB_ID' not in df_sites.columns:
     df_sites['PDB_ID'] = pdb_id
print(f"Loaded {len(df_sites)} candidate sites.")

# --- Helper Function Definitions ---
# (calculate_ratio, calculate_angles, score_zn_predictions, define_excluded_triads,
#  proximity_filter_kdtree, estimate_zn_iterative remain unchanged from previous version)
def calculate_ratio(current_point, ca_xyz, cb_xyz):
    ca_distances = np.linalg.norm(ca_xyz - current_point, axis=1)
    cb_distances = np.linalg.norm(cb_xyz - current_point, axis=1)
    # Handle potential division by zero
    ratios = np.divide(ca_distances, cb_distances, out=np.full_like(ca_distances, np.inf), where=cb_distances!=0)
    return ratios

def calculate_angles(zn_coords, ca_coords_triplet, cb_coords_triplet):
    angles = []
    for i in range(3):
        v_ca = ca_coords_triplet[i] - zn_coords
        v_cb = cb_coords_triplet[i] - zn_coords
        norm_v_ca = np.linalg.norm(v_ca)
        norm_v_cb = np.linalg.norm(v_cb)
        if norm_v_ca == 0 or norm_v_cb == 0:
            angles.append(np.nan) # Use NaN for undefined angles
            continue
        # Clip argument to avoid domain errors due to floating point inaccuracies
        cos_theta = np.clip(np.dot(v_ca, v_cb) / (norm_v_ca * norm_v_cb), -1.0, 1.0)
        angle_rad = np.arccos(cos_theta)
        angles.append(np.degrees(angle_rad))
    # Ensure list always has 3 elements, padding with NaN if necessary
    while len(angles) < 3: angles.append(np.nan)
    return angles

def score_zn_predictions(ca_distances, cb_distances, angles, prob_map_3d, ca_bins, cb_bins, angle_bins):
    # Check for NaN inputs
    if np.isnan(ca_distances).any() or np.isnan(cb_distances).any() or np.isnan(angles).any():
        return 0.0

    # Digitize finds the index of the bin each value belongs to.
    # right=True means bins are [left, right)
    ca_bin_indices = np.clip(np.digitize(ca_distances, ca_bins[1:], right=True), 0, len(ca_bins)-1)
    cb_bin_indices = np.clip(np.digitize(cb_distances, cb_bins[1:], right=True), 0, len(cb_bins)-1)
    angle_bin_indices = np.clip(np.digitize(angles, angle_bins[1:], right=True), 0, len(angle_bins)-1)

    probabilities = []
    valid = True
    try:
        # Use advanced indexing to get probabilities for all 3 residues at once
        probs = prob_map_3d[ca_bin_indices, cb_bin_indices, angle_bin_indices]
        # Check if any probability is non-positive
        if np.any(probs <= 0):
            valid = False
        else:
            probabilities = probs
    except IndexError:
        valid = False # Indices out of bounds
    except Exception:
        valid = False # Other potential errors

    # Calculate final score only if all 3 probabilities were valid (positive)
    final_score = np.prod(probabilities) if valid and len(probabilities) == 3 else 0.0
    return final_score

def define_excluded_triads(triad_res_nums, structure):
    excluded_residues = set()
    if structure is None: return excluded_residues
    try:
        # Ensure input numbers are valid integers
        res_nums_to_find = set(int(num) for num in triad_res_nums if pd.notna(num))
    except (ValueError, TypeError):
        print(f"⚠️ Warning: Could not convert all triad residue numbers {triad_res_nums} to integers for exclusion.")
        return excluded_residues

    if not res_nums_to_find: return excluded_residues # No valid numbers provided

    for model in structure:
        for chain in model:
            for residue in chain:
                try:
                    # residue.id format is (hetfield, sequence_identifier, insertion_code)
                    res_seq_num = residue.id[1]
                    if res_seq_num in res_nums_to_find:
                        excluded_residues.add((chain.id, res_seq_num))
                except (TypeError, IndexError):
                    continue # Skip if residue ID format is unexpected
    return excluded_residues

def proximity_filter_kdtree(kdtree, zn_candidate, exclusion_radius=2.5):
    if kdtree is None: return True # Assume valid if no tree was built
    try:
        # Find indices of points within the exclusion radius
        indices_nearby = kdtree.query_ball_point(zn_candidate, r=exclusion_radius, return_length=True)
        # Return True if no points are found nearby (length is 0)
        return indices_nearby == 0
    except Exception as e:
        # Log error and return False (fail safe) if KDTree query fails
        # print(f"❌ Error during KDTree query: {e}") # Optional Debug
        return False

def estimate_zn_iterative(
    ca_coords_site_flat,
    cb_coords_site_flat,
    site_info,
    structure_local,
    thresholds_local,
    prob_map_3d_local, ca_bins_local, cb_bins_local, angle_bins_local,
    grid_resolution=0.2
    ):
    """Estimates Zn coordinate for a SINGLE site."""
    # --- Coordinate Validation ---
    try:
        ca_coords_numeric = pd.to_numeric(np.asarray(ca_coords_site_flat), errors='coerce')
        cb_coords_numeric = pd.to_numeric(np.asarray(cb_coords_site_flat), errors='coerce')
        if np.isnan(ca_coords_numeric).any() or np.isnan(cb_coords_numeric).any():
            return "no metal", 0, [np.nan, np.nan, np.nan]
        if ca_coords_numeric.shape != (9,) or cb_coords_numeric.shape != (9,):
             return "no metal", 0, [np.nan, np.nan, np.nan]
        ca_xyz = ca_coords_numeric.astype(np.float64).reshape(3, 3)
        cb_xyz = cb_coords_numeric.astype(np.float64).reshape(3, 3)
    except (ValueError, TypeError):
        return "no metal", 0, [np.nan, np.nan, np.nan]

    if structure_local is None: return "no metal", 0, [np.nan, np.nan, np.nan]

    # --- Excluded Residues & KDTree ---
    triad_res_nums = [
        site_info.get('Coord_residue_number1'), # Use .get() for safety
        site_info.get('Coord_residue_number2'),
        site_info.get('Coord_residue_number3')
    ]
    excluded_residues_set = define_excluded_triads(triad_res_nums, structure_local)

    non_excluded_coords_list = []
    try:
        for atom in structure_local.get_atoms():
            residue = atom.get_parent()
            chain = residue.get_parent()
            if residue is None or chain is None: continue
            res_info = (chain.id, residue.id[1])
            if res_info not in excluded_residues_set:
                 # Check if coord is valid ndarray
                 if isinstance(atom.coord, np.ndarray) and atom.coord.shape == (3,):
                      non_excluded_coords_list.append(atom.coord)
    except Exception as atom_iter_err:
        print(f"Warning: Error iterating atoms for KDTree build: {atom_iter_err}")

    kdtree = None
    if non_excluded_coords_list:
        try:
            non_excluded_coords = np.array(non_excluded_coords_list, dtype=np.float64)
            if non_excluded_coords.ndim == 2 and non_excluded_coords.shape[1] == 3 and non_excluded_coords.shape[0] > 0:
                 kdtree = KDTree(non_excluded_coords)
        except Exception:
             pass # kdtree remains None

    # --- Search Space ---
    shared_x_min, shared_x_max = -np.inf, np.inf
    shared_y_min, shared_y_max = -np.inf, np.inf
    shared_z_min, shared_z_max = -np.inf, np.inf
    buffer_dist = max(thresholds_local['ca_distances_calc'][1], thresholds_local['cb_distances_calc'][1])
    for j in range(3):
        x_min_j, x_max_j = ca_xyz[j, 0] - buffer_dist, ca_xyz[j, 0] + buffer_dist
        y_min_j, y_max_j = ca_xyz[j, 1] - buffer_dist, ca_xyz[j, 1] + buffer_dist
        z_min_j, z_max_j = ca_xyz[j, 2] - buffer_dist, ca_xyz[j, 2] + buffer_dist
        shared_x_min, shared_x_max = max(shared_x_min, x_min_j), min(shared_x_max, x_max_j)
        shared_y_min, shared_y_max = max(shared_y_min, y_min_j), min(shared_y_max, y_max_j)
        shared_z_min, shared_z_max = max(shared_z_min, z_min_j), min(shared_z_max, z_max_j)

    buffer_grid = grid_resolution * 2
    shared_x_min, shared_x_max = shared_x_min - buffer_grid, shared_x_max + buffer_grid
    shared_y_min, shared_y_max = shared_y_min - buffer_grid, shared_y_max + buffer_grid
    shared_z_min, shared_z_max = shared_z_min - buffer_grid, shared_z_max + buffer_grid

    if not (shared_x_min < shared_x_max and shared_y_min < shared_y_max and shared_z_min < shared_z_max):
        return "no metal", 0, [np.nan, np.nan, np.nan]

    # --- Grid Search ---
    x_range = np.arange(shared_x_min, shared_x_max, grid_resolution)
    y_range = np.arange(shared_y_min, shared_y_max, grid_resolution)
    z_range = np.arange(shared_z_min, shared_z_max, grid_resolution)
    if not (x_range.size > 0 and y_range.size > 0 and z_range.size > 0):
        return "no metal", 0, [np.nan, np.nan, np.nan]

    for x in x_range:
        for y in y_range:
            z_coords = z_range
            num_z = len(z_coords)
            points = np.column_stack([np.full(num_z, x), np.full(num_z, y), z_coords])

            dist_ca = np.linalg.norm(ca_xyz[np.newaxis, :, :] - points[:, np.newaxis, :], axis=2)
            dist_cb = np.linalg.norm(cb_xyz[np.newaxis, :, :] - points[:, np.newaxis, :], axis=2)
            dist_ca_ok = np.all((thresholds_local['ca_distances_calc'][0] <= dist_ca) & (dist_ca <= thresholds_local['ca_distances_calc'][1]), axis=1)
            dist_cb_ok = np.all((thresholds_local['cb_distances_calc'][0] <= dist_cb) & (dist_cb <= thresholds_local['cb_distances_calc'][1]), axis=1)
            dist_ok_mask = dist_ca_ok & dist_cb_ok
            if not np.any(dist_ok_mask): continue

            points_dist_ok = points[dist_ok_mask]
            dist_ca_filt = dist_ca[dist_ok_mask]
            dist_cb_filt = dist_cb[dist_ok_mask]

            for i, point in enumerate(points_dist_ok):
                current_dist_ca = dist_ca_filt[i]; current_dist_cb = dist_cb_filt[i]
                angles = calculate_angles(point, ca_xyz, cb_xyz)
                if np.isnan(angles).any() or not all(thresholds_local['angle'][0] <= ang <= thresholds_local['angle'][1] for ang in angles if pd.notna(ang)): continue

                ratios = calculate_ratio(point, ca_xyz, cb_xyz)
                if np.isinf(ratios).any() or not np.all((thresholds_local['ratio'][0] <= ratios) & (ratios <= thresholds_local['ratio'][1])): continue

                score = score_zn_predictions(current_dist_ca, current_dist_cb, angles, prob_map_3d_local, ca_bins_local, cb_bins_local, angle_bins_local)
                if score <= 0: continue

                if not proximity_filter_kdtree(kdtree, point, exclusion_radius=2.0): continue

                return point, score, angles # Return first valid point

    return "no metal", 0, [np.nan, np.nan, np.nan] # No point found

# --- Worker Function for Multiprocessing ---
def process_single_site(args):
    """Worker function to process a single candidate site (row)."""
    site_index, site_data_dict, structure_shared, thresholds_shared, \
    prob_map_3d_shared, ca_bins_shared, cb_bins_shared, angle_bins_shared = args

    ca_cols = ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z']
    cb_cols = ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z']

    try:
        # Extract data directly from the dictionary
        ca_coords_flat = np.array([site_data_dict[col] for col in ca_cols], dtype=np.float64)
        cb_coords_flat = np.array([site_data_dict[col] for col in cb_cols], dtype=np.float64)

        zn_coords, zn_score, zn_angles = estimate_zn_iterative(
            ca_coords_flat, cb_coords_flat, site_data_dict, structure_shared,
            thresholds_shared, prob_map_3d_shared, ca_bins_shared,
            cb_bins_shared, angle_bins_shared,
            grid_resolution=0.2
        )
        return site_index, zn_coords, zn_score, zn_angles

    except Exception as e:
        # print(f"❌ Error in worker processing site index {site_index}: {e}") # Optional Debug
        return site_index, "error", 0, [np.nan, np.nan, np.nan]

# --- Main Execution Guard ---
if __name__ == "__main__":
    try:
        kst = pytz.timezone('Asia/Seoul')
        current_time_kst = datetime.datetime.now(kst)
        print(f"\n--- Starting Main Process for Single PDB Site Parallelization ---")
        print(f"Current Time (KST): {current_time_kst.strftime('%Y-%m-%d %H:%M:%S %Z%z')}")
    except ImportError:
        print("\n--- Starting Main Process for Single PDB Site Parallelization ---")
        print("Note: Could not determine KST time (pytz not installed?). Run 'pip install pytz' if needed.")

    start_time = time.time()

    # --- Prepare Tasks for Multiprocessing ---
    tasks = []
    required_input_cols = (
        ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z'] +
        ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z'] +
        ['Coord_residue_number1', 'Coord_residue_number2', 'Coord_residue_number3', 'PDB_ID']
    )
    if not all(col in df_sites.columns for col in required_input_cols):
         missing_cols = [col for col in required_input_cols if col not in df_sites.columns]
         raise ValueError(f"Input coordinate Excel file is missing required columns: {missing_cols}")

    # Convert relevant parts of DataFrame to list of dicts for potentially better pickling
    sites_data_list = df_sites[required_input_cols].to_dict('records')

    for index, site_dict in enumerate(sites_data_list):
        tasks.append((
            df_sites.index[index], # Original DataFrame index
            site_dict,            # Dictionary for this site
            structure,
            thresholds,
            prob_map_3d,
            ca_bins, cb_bins, angle_bins
        ))

    print(f"Prepared {len(tasks)} tasks for parallel processing.")

    # --- Initialize Results Columns ---
    df_sites['Zn_X_Grid'] = np.nan
    df_sites['Zn_Y_Grid'] = np.nan
    df_sites['Zn_Z_Grid'] = np.nan
    df_sites['Zn_Score'] = 0.0
    df_sites['Angle_1'] = np.nan
    df_sites['Angle_2'] = np.nan
    df_sites['Angle_3'] = np.nan

    num_processes = min(8, os.cpu_count())
    print(f"\nInitializing multiprocessing pool with {num_processes} workers...")

    processed_count = 0
    # --- Run Multiprocessing ---
    # Using try/finally to ensure pool cleanup
    pool = None # Initialize pool variable
    try:
        pool = multiprocessing.Pool(processes=num_processes)
        print("Starting parallel processing of sites...")
        results_iterator = pool.imap_unordered(process_single_site, tasks)

        for result in results_iterator:
            processed_count += 1
            try:
                site_idx, zn_coords_res, zn_score_res, zn_angles_res = result

                df_sites.loc[site_idx, 'Zn_Score'] = zn_score_res
                if isinstance(zn_coords_res, np.ndarray) and zn_coords_res.shape == (3,):
                     df_sites.loc[site_idx, 'Zn_X_Grid'] = zn_coords_res[0]
                     df_sites.loc[site_idx, 'Zn_Y_Grid'] = zn_coords_res[1]
                     df_sites.loc[site_idx, 'Zn_Z_Grid'] = zn_coords_res[2]
                if isinstance(zn_angles_res, (list, np.ndarray)) and len(zn_angles_res) == 3:
                    # Convert potential numpy floats to python floats before assigning if needed
                    df_sites.loc[site_idx, 'Angle_1'] = float(zn_angles_res[0]) if pd.notna(zn_angles_res[0]) else np.nan
                    df_sites.loc[site_idx, 'Angle_2'] = float(zn_angles_res[1]) if pd.notna(zn_angles_res[1]) else np.nan
                    df_sites.loc[site_idx, 'Angle_3'] = float(zn_angles_res[2]) if pd.notna(zn_angles_res[2]) else np.nan

                if processed_count % 50 == 0 or processed_count == len(tasks):
                     print(f"  Processed {processed_count}/{len(tasks)} sites...")

            except Exception as result_error:
                 print(f"❌ Error processing result for one site: {result_error}")
                 if isinstance(result, tuple) and len(result) > 0: print(f"   Problem occurred for site index: {result[0]}")
                 continue

        print(f"Finished processing all {processed_count} assigned tasks.")

    except Exception as pool_error:
        print(f"❌ An error occurred during multiprocessing: {pool_error}")
        print(traceback.format_exc())
    finally:
        # Ensure pool resources are released
        if pool:
            pool.close()
            pool.join()

    print("\n--- Multiprocessing Pool Finished ---")

    # --- Post-processing and Saving Results ---
    print("Filtering results (Zn_Score > 0)...")
    # Ensure Zn_Score is numeric before filtering
    df_sites['Zn_Score'] = pd.to_numeric(df_sites['Zn_Score'], errors='coerce').fillna(0.0)
    df_output = df_sites[df_sites['Zn_Score'] > 0].copy()

    total_sites_saved = len(df_output)

    if not df_output.empty:
        print(f"Saving {total_sites_saved} site(s) with positive scores...")
        try:
            # Convert coordinate columns to float before saving if they are object type
            for col in ['Zn_X_Grid', 'Zn_Y_Grid', 'Zn_Z_Grid', 'Angle_1', 'Angle_2', 'Angle_3']:
                 if col in df_output.columns:
                      df_output[col] = pd.to_numeric(df_output[col], errors='coerce')
            df_output.to_excel(Output_result_excel_file, index=False)
            print(f"✅ Successfully saved results to: {Output_result_excel_file}")
        except Exception as save_error:
            print(f"❌ Error saving results to Excel file: {save_error}")
            print(traceback.format_exc()) # Print traceback for saving errors
    else:
        print("⚠️ No valid Zn predictions passed filters (Score > 0). No output file created.")

    end_time = time.time()
    print(f"\n🏁 Processing summary:")
    print(f"  Total input sites: {len(df_sites)}")
    print(f"  Total sites processed by pool: {processed_count}") # Note: Might be less than total if errors occur early
    print(f"  Total result sites saved (Score > 0): {total_sites_saved}")
    print(f"  Total execution time: {end_time - start_time:.2f} seconds")


In [None]:
# @markdown # Step 4: Probability density map (Parallel Processing WITHIN a Single PDB - Find BEST Score)
# pip install scipy pandas openpyxl requests BioPython numpy pytz

import numpy as np
import pandas as pd
import os
import requests
import traceback # Import traceback for better error printing
import multiprocessing # Import multiprocessing
from Bio.PDB import PDBParser # PDBParser is still needed for structure loading
from scipy.spatial import KDTree # Import KDTree
import time # For timing if desired
import datetime
import pytz
from typing import List, Optional, Tuple # For type hinting
from Bio.PDB.Chain import Chain # For type hinting if needed


# --- Configuration and Setup ---
# Define paths for the SINGLE input coordinate Excel file and the corresponding PDB file
input_coords_file = '/content/1EP0_full_Coords.xlsx' # @param {type:"string"}
Input_pdb_file = '/content/1EP0_alanine_dimer.pdb' # @param {type:"string"}
Output_result_excel_file = '/content/1EP0_full_best_score.xlsx' # @param {type:"string"}
# Create output directory if it doesn't exist
output_result_directory = os.path.dirname(Output_result_excel_file)
if output_result_directory:
    os.makedirs(output_result_directory, exist_ok=True)
    print(f"➡️ Output directory: {output_result_directory}")
else:
    output_result_directory = "."
    print(f"➡️ Output directory: Current directory")


# Check if input files exist
if not os.path.isfile(input_coords_file):
    # Corrected variable name in error message
    raise FileNotFoundError(f"Input coordinate Excel file not found: {input_coords_file}")
if not os.path.isfile(Input_pdb_file):
    raise FileNotFoundError(f"Input PDB file not found: {Input_pdb_file}")

# Define local file paths for downloaded data
prob_map_file = os.path.join(output_result_directory, 'map.xlsx')
thresholds_file = os.path.join(output_result_directory, 'threshold.xlsx')

# --- Download Data from GitHub ---
base_url = "https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/probability/"
Metal = 'Cu'  # @param ["Zn", "Mn", "Cu", "Fe"]
Combinations = '3His'  # @param ["3His", "2His_1Asp", "2His_1Glu", "2His_1Cys"]
map_url = f"{base_url}/{Metal}/{Combinations}/map.xlsx"
thresholds_url = f"{base_url}/{Metal}/{Combinations}/threshold.xlsx"

print(f"Downloading probability map from: {map_url}")
response_map = requests.get(map_url)
if response_map.status_code == 200:
    with open(prob_map_file, 'wb') as file:
        file.write(response_map.content)
    print(f"Downloaded map data to {prob_map_file}")
else:
    raise ValueError(f"Failed to download map file from {map_url}. Status code: {response_map.status_code}")

print(f"Downloading thresholds from: {thresholds_url}")
response_thresh = requests.get(thresholds_url)
if response_thresh.status_code == 200:
    with open(thresholds_file, 'wb') as file:
        file.write(response_thresh.content)
    print(f"Downloaded thresholds data to {thresholds_file}")
else:
    raise ValueError(f"Failed to download thresholds file from {thresholds_url}. Status code: {response_thresh.status_code}")


# --- Load and Process Data (Load ONCE in the main process) ---
print("\n--- Loading Shared Data ---")
# Load PDB Structure
pdb_id = os.path.splitext(os.path.basename(Input_pdb_file))[0]
print(f"Loading PDB structure: {pdb_id}...")
pdb_parser = PDBParser(QUIET=True)
try:
    structure = pdb_parser.get_structure(pdb_id, Input_pdb_file)
    print(f"Loaded structure.")
except Exception as e:
    print(f"❌ Error loading PDB file {Input_pdb_file}: {e}")
    raise

# Load Thresholds
print("Loading thresholds...")
try:
    thresholds_df = pd.read_excel(thresholds_file, sheet_name='Sheet1')
except FileNotFoundError:
    print(f"❌ Error: Thresholds file not found at {thresholds_file}")
    raise
thresholds = {}
for _, row in thresholds_df.iterrows():
    parameter = row['Parameter']
    min_value = row['Min']
    max_value = row['Max']
    if pd.notna(min_value) and pd.notna(max_value):
        thresholds[parameter] = (min_value, max_value)

required_keys = ['ca_distances_calc', 'cb_distances_calc', 'ratio', 'angle']
if not all(key in thresholds for key in required_keys):
    missing_keys = [key for key in required_keys if key not in thresholds]
    raise KeyError(f"Missing key(s) {missing_keys} in thresholds file.")
print("Thresholds loaded.")

# Load and Process Probability Map
print("Loading probability map...")
try:
    df_precomputed_prob_map = pd.read_excel(prob_map_file)
except FileNotFoundError:
    print(f"❌ Error: Probability map file not found at {prob_map_file}")
    raise

print("Processing probability map...")
map_req_cols = ['Calpha_Zn_Dist', 'Cbeta_Zn_Dist', 'CA-Zn-CB_Angle', 'Probability']
if not all(col in df_precomputed_prob_map.columns for col in map_req_cols):
     missing_map_cols = [col for col in map_req_cols if col not in df_precomputed_prob_map.columns]
     raise ValueError(f"Missing required columns in map file: {missing_map_cols}")

ca_bins = np.sort(df_precomputed_prob_map['Calpha_Zn_Dist'].unique())
cb_bins = np.sort(df_precomputed_prob_map['Cbeta_Zn_Dist'].unique())
angle_bins = np.sort(df_precomputed_prob_map['CA-Zn-CB_Angle'].unique())

try:
    pivoted_prob_map = df_precomputed_prob_map.pivot_table(
        index='Calpha_Zn_Dist', columns=['Cbeta_Zn_Dist', 'CA-Zn-CB_Angle'], values='Probability', fill_value=0
    )
    expected_shape = (len(ca_bins), len(cb_bins) * len(angle_bins))
    if pivoted_prob_map.shape == expected_shape:
        prob_map_3d = pivoted_prob_map.values.reshape((len(ca_bins), len(cb_bins), len(angle_bins)))
        print("Probability map processed into 3D array.")
    else:
        raise ValueError(f"Pivoted map shape {pivoted_prob_map.shape} doesn't match expected shape {expected_shape} for reshaping.")
except Exception as e:
    print(f"❌ Error processing probability map: {e}")
    raise

# Load Input Coordinate Data
print(f"Loading input coordinate data from: {input_coords_file}...")
try:
    # Make sure to use the correct variable name here
    df_sites = pd.read_excel(input_coords_file)
except FileNotFoundError:
     print(f"❌ Error: Input coordinate file not found at {input_coords_file}")
     raise
if df_sites.empty:
    print("⚠️ Input coordinate file is empty. Nothing to process.")
    exit()
# Ensure PDB_ID column exists or add it based on filename
if 'PDB_ID' not in df_sites.columns:
     df_sites['PDB_ID'] = pdb_id
print(f"Loaded {len(df_sites)} candidate sites.")

# --- Helper Function Definitions ---
# (calculate_ratio, calculate_angles, score_zn_predictions, define_excluded_triads,
#  proximity_filter_kdtree remain unchanged)
def calculate_ratio(current_point, ca_xyz, cb_xyz):
    ca_distances = np.linalg.norm(ca_xyz - current_point, axis=1)
    cb_distances = np.linalg.norm(cb_xyz - current_point, axis=1)
    ratios = np.divide(ca_distances, cb_distances, out=np.full_like(ca_distances, np.inf), where=cb_distances!=0)
    return ratios

def calculate_angles(zn_coords, ca_coords_triplet, cb_coords_triplet):
    angles = []
    for i in range(3):
        v_ca = ca_coords_triplet[i] - zn_coords
        v_cb = cb_coords_triplet[i] - zn_coords
        norm_v_ca = np.linalg.norm(v_ca)
        norm_v_cb = np.linalg.norm(v_cb)
        if norm_v_ca == 0 or norm_v_cb == 0:
            angles.append(np.nan); continue
        cos_theta = np.clip(np.dot(v_ca, v_cb) / (norm_v_ca * norm_v_cb), -1.0, 1.0)
        angle_rad = np.arccos(cos_theta)
        angles.append(np.degrees(angle_rad))
    while len(angles) < 3: angles.append(np.nan)
    return angles

def score_zn_predictions(ca_distances, cb_distances, angles, prob_map_3d, ca_bins, cb_bins, angle_bins):
    if np.isnan(ca_distances).any() or np.isnan(cb_distances).any() or np.isnan(angles).any(): return 0.0
    ca_bin_indices = np.clip(np.digitize(ca_distances, ca_bins[1:], right=True), 0, len(ca_bins)-1)
    cb_bin_indices = np.clip(np.digitize(cb_distances, cb_bins[1:], right=True), 0, len(cb_bins)-1)
    angle_bin_indices = np.clip(np.digitize(angles, angle_bins[1:], right=True), 0, len(angle_bins)-1)
    probabilities = []
    valid = True
    try:
        probs = prob_map_3d[ca_bin_indices, cb_bin_indices, angle_bin_indices]
        if np.any(probs <= 0): valid = False
        else: probabilities = probs
    except IndexError: valid = False
    except Exception: valid = False
    final_score = np.prod(probabilities) if valid and len(probabilities) == 3 else 0.0
    return final_score

def define_excluded_triads(triad_res_nums, structure):
    excluded_residues = set()
    if structure is None: return excluded_residues
    try:
        res_nums_to_find = set(int(num) for num in triad_res_nums if pd.notna(num))
    except (ValueError, TypeError): return excluded_residues
    if not res_nums_to_find: return excluded_residues
    for model in structure:
        for chain in model:
            for residue in chain:
                try:
                    res_seq_num = residue.id[1]
                    if res_seq_num in res_nums_to_find: excluded_residues.add((chain.id, res_seq_num))
                except (TypeError, IndexError): continue
    return excluded_residues

def proximity_filter_kdtree(kdtree, zn_candidate, exclusion_radius=2.5):
    if kdtree is None: return True
    try:
        indices_nearby = kdtree.query_ball_point(zn_candidate, r=exclusion_radius, return_length=True)
        return indices_nearby == 0
    except Exception: return False

# --- MODIFIED Zn Estimation Function (Finds BEST Score) ---
def estimate_zn_iterative(
    ca_coords_site_flat,
    cb_coords_site_flat,
    site_info,
    structure_local,
    thresholds_local,
    prob_map_3d_local, ca_bins_local, cb_bins_local, angle_bins_local,
    grid_resolution=0.2
    ):
    """
    Estimates Zn coordinate for a SINGLE site by searching the entire grid
    and returning the position with the HIGHEST score.
    """
    # --- Coordinate Validation ---
    try:
        ca_coords_numeric = pd.to_numeric(np.asarray(ca_coords_site_flat), errors='coerce')
        cb_coords_numeric = pd.to_numeric(np.asarray(cb_coords_site_flat), errors='coerce')
        if np.isnan(ca_coords_numeric).any() or np.isnan(cb_coords_numeric).any():
            return "no metal", 0, [np.nan, np.nan, np.nan]
        if ca_coords_numeric.shape != (9,) or cb_coords_numeric.shape != (9,):
             return "no metal", 0, [np.nan, np.nan, np.nan]
        ca_xyz = ca_coords_numeric.astype(np.float64).reshape(3, 3)
        cb_xyz = cb_coords_numeric.astype(np.float64).reshape(3, 3)
    except (ValueError, TypeError):
        return "no metal", 0, [np.nan, np.nan, np.nan]

    if structure_local is None: return "no metal", 0, [np.nan, np.nan, np.nan]

    # --- Excluded Residues & KDTree ---
    triad_res_nums = [site_info.get(f'Coord_residue_number{i+1}') for i in range(3)]
    excluded_residues_set = define_excluded_triads(triad_res_nums, structure_local)
    non_excluded_coords_list = []
    try:
        for atom in structure_local.get_atoms():
            residue = atom.get_parent(); chain = residue.get_parent()
            if residue is None or chain is None: continue
            res_info = (chain.id, residue.id[1])
            if res_info not in excluded_residues_set:
                 if isinstance(atom.coord, np.ndarray) and atom.coord.shape == (3,):
                      non_excluded_coords_list.append(atom.coord)
    except Exception as atom_iter_err:
        print(f"Warning: Error iterating atoms for KDTree build: {atom_iter_err}")
    kdtree = None
    if non_excluded_coords_list:
        try:
            non_excluded_coords = np.array(non_excluded_coords_list, dtype=np.float64)
            if non_excluded_coords.ndim == 2 and non_excluded_coords.shape[1] == 3 and non_excluded_coords.shape[0] > 0:
                 kdtree = KDTree(non_excluded_coords)
        except Exception: pass

    # --- Search Space ---
    shared_x_min, shared_x_max = -np.inf, np.inf; shared_y_min, shared_y_max = -np.inf, np.inf; shared_z_min, shared_z_max = -np.inf, np.inf
    buffer_dist = max(thresholds_local['ca_distances_calc'][1], thresholds_local['cb_distances_calc'][1])
    for j in range(3):
        x_min_j, x_max_j = ca_xyz[j, 0] - buffer_dist, ca_xyz[j, 0] + buffer_dist
        y_min_j, y_max_j = ca_xyz[j, 1] - buffer_dist, ca_xyz[j, 1] + buffer_dist
        z_min_j, z_max_j = ca_xyz[j, 2] - buffer_dist, ca_xyz[j, 2] + buffer_dist
        shared_x_min, shared_x_max = max(shared_x_min, x_min_j), min(shared_x_max, x_max_j)
        shared_y_min, shared_y_max = max(shared_y_min, y_min_j), min(shared_y_max, y_max_j)
        shared_z_min, shared_z_max = max(shared_z_min, z_min_j), min(shared_z_max, z_max_j)
    buffer_grid = grid_resolution * 2
    shared_x_min, shared_x_max = shared_x_min - buffer_grid, shared_x_max + buffer_grid
    shared_y_min, shared_y_max = shared_y_min - buffer_grid, shared_y_max + buffer_grid
    shared_z_min, shared_z_max = shared_z_min - buffer_grid, shared_z_max + buffer_grid
    if not (shared_x_min < shared_x_max and shared_y_min < shared_y_max and shared_z_min < shared_z_max): return "no metal", 0, [np.nan, np.nan, np.nan]

    # --- Grid Search Initialization ---
    x_range = np.arange(shared_x_min, shared_x_max, grid_resolution)
    y_range = np.arange(shared_y_min, shared_y_max, grid_resolution)
    z_range = np.arange(shared_z_min, shared_z_max, grid_resolution)
    if not (x_range.size > 0 and y_range.size > 0 and z_range.size > 0): return "no metal", 0, [np.nan, np.nan, np.nan]

    # ***** MODIFICATION START *****
    best_score = 0.0
    best_coords = "no metal"
    best_angles = [np.nan, np.nan, np.nan]
    # ***** MODIFICATION END *****

    # --- Grid Search Loop ---
    for x in x_range:
        for y in y_range:
            z_coords = z_range
            num_z = len(z_coords)
            points = np.column_stack([np.full(num_z, x), np.full(num_z, y), z_coords])

            # Vectorized distance check
            dist_ca = np.linalg.norm(ca_xyz[np.newaxis, :, :] - points[:, np.newaxis, :], axis=2)
            dist_cb = np.linalg.norm(cb_xyz[np.newaxis, :, :] - points[:, np.newaxis, :], axis=2)
            dist_ca_ok = np.all((thresholds_local['ca_distances_calc'][0] <= dist_ca) & (dist_ca <= thresholds_local['ca_distances_calc'][1]), axis=1)
            dist_cb_ok = np.all((thresholds_local['cb_distances_calc'][0] <= dist_cb) & (dist_cb <= thresholds_local['cb_distances_calc'][1]), axis=1)
            dist_ok_mask = dist_ca_ok & dist_cb_ok
            if not np.any(dist_ok_mask): continue

            # Filter points that passed distance check
            points_dist_ok = points[dist_ok_mask]
            dist_ca_filt = dist_ca[dist_ok_mask]
            dist_cb_filt = dist_cb[dist_ok_mask]

            # Iterate through potentially valid points
            for i, point in enumerate(points_dist_ok):
                current_dist_ca = dist_ca_filt[i]; current_dist_cb = dist_cb_filt[i]

                # Angle Filter
                angles = calculate_angles(point, ca_xyz, cb_xyz)
                # Check if angles are valid and within range
                if np.isnan(angles).any() or not all(thresholds_local['angle'][0] <= ang <= thresholds_local['angle'][1] for ang in angles if pd.notna(ang)): continue

                # Ratio Filter
                ratios = calculate_ratio(point, ca_xyz, cb_xyz)
                if np.isinf(ratios).any() or not np.all((thresholds_local['ratio'][0] <= ratios) & (ratios <= thresholds_local['ratio'][1])): continue

                # Probability Score Filter
                score = score_zn_predictions(current_dist_ca, current_dist_cb, angles, prob_map_3d_local, ca_bins_local, cb_bins_local, angle_bins_local)
                if score <= 0: continue # Only consider positive scores

                # Proximity Filter
                if not proximity_filter_kdtree(kdtree, point, exclusion_radius=2.5): continue

                # ***** MODIFICATION START *****
                # Candidate passed all filters, check if it's the best score so far
                if score > best_score:
                    best_score = score
                    best_coords = point # Store the numpy array
                    best_angles = angles
                # ***** MODIFICATION END *****
                # Do NOT return here, continue searching the rest of the grid

    # ***** MODIFICATION START *****
    # After checking all points, return the best one found (or defaults if none found)
    return best_coords, best_score, best_angles
    # ***** MODIFICATION END *****


# --- Worker Function for Multiprocessing ---
# (process_single_site remains unchanged, it calls the modified estimate_zn_iterative)
def process_single_site(args):
    """Worker function to process a single candidate site (row)."""
    site_index, site_data_dict, structure_shared, thresholds_shared, \
    prob_map_3d_shared, ca_bins_shared, cb_bins_shared, angle_bins_shared = args
    ca_cols = ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z']
    cb_cols = ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z']
    try:
        ca_coords_flat = np.array([site_data_dict[col] for col in ca_cols], dtype=np.float64)
        cb_coords_flat = np.array([site_data_dict[col] for col in cb_cols], dtype=np.float64)
        # Call the modified estimate_zn_iterative which now finds the BEST score
        zn_coords, zn_score, zn_angles = estimate_zn_iterative(
            ca_coords_flat, cb_coords_flat, site_data_dict, structure_shared,
            thresholds_shared, prob_map_3d_shared, ca_bins_shared,
            cb_bins_shared, angle_bins_shared,
            grid_resolution=0.2
        )
        return site_index, zn_coords, zn_score, zn_angles
    except Exception as e:
        return site_index, "error", 0, [np.nan, np.nan, np.nan]


# --- Main Execution Guard ---
if __name__ == "__main__":
    try:
        kst = pytz.timezone('Asia/Seoul')
        current_time_kst = datetime.datetime.now(kst)
        print(f"\n--- Starting Main Process for Single PDB Site Parallelization (Find Best Score) ---") # Updated Title
        print(f"Current Time (KST): {current_time_kst.strftime('%Y-%m-%d %H:%M:%S %Z%z')}")
    except ImportError:
        print("\n--- Starting Main Process for Single PDB Site Parallelization (Find Best Score) ---")
        print("Note: Could not determine KST time (pytz not installed?). Run 'pip install pytz' if needed.")

    start_time = time.time()

    # --- Prepare Tasks for Multiprocessing ---
    tasks = []
    required_input_cols = (
        ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z'] +
        ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z'] +
        ['Coord_residue_number1', 'Coord_residue_number2', 'Coord_residue_number3', 'PDB_ID']
    )
    if not all(col in df_sites.columns for col in required_input_cols):
         missing_cols = [col for col in required_input_cols if col not in df_sites.columns]
         raise ValueError(f"Input coordinate Excel file is missing required columns: {missing_cols}")
    sites_data_list = df_sites[required_input_cols].to_dict('records')
    for index, site_dict in enumerate(sites_data_list):
        tasks.append((df_sites.index[index], site_dict, structure, thresholds, prob_map_3d, ca_bins, cb_bins, angle_bins))
    print(f"Prepared {len(tasks)} tasks for parallel processing.")

    # --- Initialize Results Columns ---
    df_sites['Zn_X_Grid'] = np.nan; df_sites['Zn_Y_Grid'] = np.nan; df_sites['Zn_Z_Grid'] = np.nan
    df_sites['Zn_Score'] = 0.0
    df_sites['Angle_1'] = np.nan; df_sites['Angle_2'] = np.nan; df_sites['Angle_3'] = np.nan

    num_processes = min(8, os.cpu_count()) # Adjust as needed
    print(f"\nInitializing multiprocessing pool with {num_processes} workers...")

    processed_count = 0
    pool = None
    try:
        pool = multiprocessing.Pool(processes=num_processes)
        print("Starting parallel processing of sites...")
        results_iterator = pool.imap_unordered(process_single_site, tasks)
        for result in results_iterator:
            processed_count += 1
            try:
                site_idx, zn_coords_res, zn_score_res, zn_angles_res = result
                df_sites.loc[site_idx, 'Zn_Score'] = zn_score_res
                if isinstance(zn_coords_res, np.ndarray) and zn_coords_res.shape == (3,):
                     df_sites.loc[site_idx, 'Zn_X_Grid'] = zn_coords_res[0]
                     df_sites.loc[site_idx, 'Zn_Y_Grid'] = zn_coords_res[1]
                     df_sites.loc[site_idx, 'Zn_Z_Grid'] = zn_coords_res[2]
                if isinstance(zn_angles_res, (list, np.ndarray)) and len(zn_angles_res) == 3:
                    df_sites.loc[site_idx, 'Angle_1'] = float(zn_angles_res[0]) if pd.notna(zn_angles_res[0]) else np.nan
                    df_sites.loc[site_idx, 'Angle_2'] = float(zn_angles_res[1]) if pd.notna(zn_angles_res[1]) else np.nan
                    df_sites.loc[site_idx, 'Angle_3'] = float(zn_angles_res[2]) if pd.notna(zn_angles_res[2]) else np.nan
                if processed_count % 50 == 0 or processed_count == len(tasks): print(f"  Processed {processed_count}/{len(tasks)} sites...")
            except Exception as result_error:
                 print(f"❌ Error processing result for one site: {result_error}")
                 if isinstance(result, tuple) and len(result) > 0: print(f"   Problem occurred for site index: {result[0]}")
                 continue
        print(f"Finished processing all {processed_count} assigned tasks.")
    except Exception as pool_error:
        print(f"❌ An error occurred during multiprocessing: {pool_error}")
        print(traceback.format_exc())
    finally:
        if pool: pool.close(); pool.join()
    print("\n--- Multiprocessing Pool Finished ---")

    # --- Post-processing and Saving Results ---
    print("Filtering results (Zn_Score > 0)...")
    df_sites['Zn_Score'] = pd.to_numeric(df_sites['Zn_Score'], errors='coerce').fillna(0.0)
    df_output = df_sites[df_sites['Zn_Score'] > 0].copy()
    total_sites_saved = len(df_output)
    if not df_output.empty:
        print(f"Saving {total_sites_saved} site(s) with positive scores...")
        try:
            for col in ['Zn_X_Grid', 'Zn_Y_Grid', 'Zn_Z_Grid', 'Angle_1', 'Angle_2', 'Angle_3']:
                 if col in df_output.columns: df_output[col] = pd.to_numeric(df_output[col], errors='coerce')
            df_output.to_excel(Output_result_excel_file, index=False)
            print(f"✅ Successfully saved results to: {Output_result_excel_file}")
        except Exception as save_error:
            print(f"❌ Error saving results to Excel file: {save_error}")
            print(traceback.format_exc())
    else:
        print("⚠️ No valid Zn predictions passed filters (Score > 0). No output file created.")

    end_time = time.time()
    print(f"\n🏁 Processing summary:")
    print(f"  Total input sites: {len(df_sites)}")
    print(f"  Total sites processed by pool: {processed_count}")
    print(f"  Total result sites saved (Score > 0): {total_sites_saved}")
    print(f"  Total execution time: {end_time - start_time:.2f} seconds")

In [55]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
from Bio.PDB import PDBParser
import requests

# Markdown documentation for file pathways

# @markdown # Step 5: Analysis the result (Apply to the PDB file)

# Load input file
input_file_path = "/content/1EP0_full_best_score.xlsx" # @param {type:"string"}
df_new = pd.read_excel(input_file_path)

# Generate PyMOL script file
pymol_script_commands = []
df_new['Combination_Number'] = range(1, len(df_new) + 1)

# Generate the PyMOL script for both valid and invalid Zn binding forms
for index, row in df_new.iterrows():
    # Retrieve chain and residue information
    chain1, res1 = row['Coord_chain_id_number1'], row['Coord_residue_number1']
    chain2, res2 = row['Coord_chain_id_number2'], row['Coord_residue_number2']
    chain3, res3 = row['Coord_chain_id_number3'], row['Coord_residue_number3']

    # Retrieve Zn coordinates
    zn_x, zn_y, zn_z = row['Zn_X_Grid'], row['Zn_Y_Grid'], row['Zn_Z_Grid']

    selection_name = f"obj{row['Combination_Number']:02d}"

    # Select the residues
    pymol_script_commands.append(f"select {selection_name}, (chain {chain1} and resi {res1}) or (chain {chain2} and resi {res2}) or (chain {chain3} and resi {res3})")

    # Create the objects for the residues
    pymol_script_commands.append(f"create {selection_name}_residue1, /{row['PDB_ID']}//{chain1}/{res1}")
    pymol_script_commands.append(f"create {selection_name}_residue2, /{row['PDB_ID']}//{chain2}/{res2}")
    pymol_script_commands.append(f"create {selection_name}_residue3, /{row['PDB_ID']}//{chain3}/{res3}")

    # Check if Zn coordinates are available
    if not pd.isna(zn_x) and not pd.isna(zn_y) and not pd.isna(zn_z):
        # Zn coordinates are present, add the Zn pseudoatom
        zn_name = f"{selection_name}_Metal"
        pymol_script_commands.append(f"pseudoatom {zn_name}, pos=[{zn_x}, {zn_y}, {zn_z}], elem=Metal, name={zn_name}")
        pymol_script_commands.append(f"show sphere, {zn_name}")
    else:
        # Zn coordinates are missing, mark this combination as non-binding
        pymol_script_commands.append(f"# {selection_name} does not bind Zn")

# Save the commands into a PyMOL script
pymol_script_file = "/content/1EP0_full_best_score.pml" # @param {type:"string"}
with open(pymol_script_file, 'w') as f:
    f.write("# PyMOL script for visualizing both Zn-binding and non-binding residue combinations\n\n")
    for command in pymol_script_commands:
        f.write(command + '\n')

print(f"PyMOL script saved to {pymol_script_file}")

PyMOL script saved to /content/1EP0_full_best_score.pml
