In [1]:
# @markdown # Step 1: Install all neccessary packages
#Download pymol in colabsystem
from IPython.utils import io
import tqdm.notebook
import os
"""The PyMOL installation is done inside two nested context managers. This approach
was inspired by Dr. Christopher Schlicksup's (of the Phenix group at
Lawrence Berkeley National Laboratory) method for installing cctbx
in a Colab Notebook. He presented his work on September 1, 2021 at the IUCr
Crystallographic Computing School. I adapted Chris's approach here. It replaces my first approach
that requires seven steps. My approach was presentated at the SciPy2021 conference
in July 2021 and published in the
[proceedings](http://conference.scipy.org/proceedings/scipy2021/blaine_mooers.html).
The new approach is easier for beginners to use. The old approach is easier to debug
and could be used as a back-up approach.

"""
total = 100
with tqdm.notebook.tqdm(total=total) as pbar:
    with io.capture_output() as captured:

        !pip install -q condacolab
        import condacolab
        condacolab.install()
        pbar.update(10)

        import sys
        sys.path.append('/usr/local/lib/python3.7/site-packages/')
        pbar.update(20)

        # Install PyMOL
        %shell mamba install -c schrodinger pymol-bundle --yes

        pbar.update(90)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
# @markdown # Step 2: Run prescreening (Geometric parameters applied)
# pip install scipy

import os
import glob
import pandas as pd
import numpy as np
from Bio.PDB import PDBParser
import itertools
import requests
from concurrent.futures import ProcessPoolExecutor
import traceback
from scipy.spatial import KDTree # <-- Import KDTree

# --- Constants ---
# Input directory containing PDB files to be processed
Target_pdb_directory = "/content/drive/MyDrive/All/Mn_2His_1Glu" # @param {type:"string"}
# Output directory where processed Excel files will be saved
Prescreening result_directory = "/content/drive/MyDrive/250413_Final/Mn_2His_1Glu_1" # @param {type:"string"}
# Create the output directory if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# --- Download threshold configuration ---
# (This part remains the same as before)
base_url = "https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/Threshold"
Metal = 'Cu'  # @param ["Zn", "Mn", "Cu", "Fe"]
Combinations = '3His'  # @param ["3His", "2His_1Asp", "2His_1Glu", "2His_1Cys"]
Range = '3'  # @param ["1", "2", "3", "4","5"]
thresholds_url = f"{base_url}/{Metal}/{Combinations}/{Range}.xlsx"
thresholds_file = os.path.join(output_folder, "thresholds.xlsx")

print(f"⬇️ Downloading thresholds from: {thresholds_url}")
response = requests.get(thresholds_url)
if response.status_code == 200:
    with open(thresholds_file, "wb") as file:
        file.write(response.content)
    print(f"✅ Thresholds downloaded successfully to {thresholds_file}")
else:
    raise ValueError(f"❌ Failed to download thresholds from {thresholds_url}. Status code: {response.status_code}")

# --- Load thresholds ---
# (This part remains the same as before)
print("⚙️ Loading thresholds...")
thresholds_df = pd.read_excel(thresholds_file, sheet_name="Sheet1")
thresholds = {
    row["Parameter"]: (row["Min"], row["Max"])
    for _, row in thresholds_df.iterrows()
    if pd.notna(row["Min"]) and pd.notna(row["Max"])
}
alpha_distance_range = thresholds["alpha_distance_range"]
beta_distance_range = thresholds["beta_distance_range"]
ratio_threshold_range = thresholds["ratio_threshold_range"]
pie_threshold_range = thresholds["pie_threshold_range"]

print("📊 Thresholds loaded:")
for key, value in thresholds.items():
    print(f"  - {key}: Min={value[0]}, Max={value[1]}")

# --- Helper Functions ---
# (These functions remain the same: calculate_pie, standardize_residue_identity)
def calculate_pie(v1, v2):
    """Calculates the angle (in degrees) between two vectors."""
    dot = np.dot(v1, v2)
    norm = np.linalg.norm(v1) * np.linalg.norm(v2)
    if norm == 0: return np.nan
    angle_rad = np.arccos(np.clip(dot / norm, -1.0, 1.0))
    return np.degrees(angle_rad)

def standardize_residue_identity(row):
    """Creates a standardized, sorted tuple representing the triad's residues (by name and number)."""
    res_names = [row[f"Coord_residue_name{i+1}"] for i in range(3)]
    res_numbers = [row[f"Coord_residue_number{i+1}"] for i in range(3)]
    items = list(zip(res_names, res_numbers))
    items.sort()
    return tuple(items)

# --- Main Processing Function ---

def process_pdb_file(pdb_file):
    """
    Processes a single PDB file to find inter/intra triads using KDTree pre-filtering.
    """
    pdb_name = os.path.basename(pdb_file)
    # Adjusted output file naming if needed, or remove if using combined output
    output_file = os.path.join(output_folder, f"{os.path.splitext(pdb_name)[0]}_processed.xlsx")
    print(f"🔄 Processing: {pdb_name}")
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", pdb_file)
        model = structure[0]

        target_residues = {"HIS", "ASP", "GLU", "CYS"}
        print(f"  Target residue types for filtering (at least 2 required): {target_residues}")

        all_residues_full = [res for chain in model for res in chain if res.get_id()[0] == " "]
        print(f"  Found {len(all_residues_full)} standard residues initially.")

        # --- KDTree Pre-filtering Implementation ---

        # 1. Prepare data for KDTree (using CA atoms) and map indices back to residues
        # Ensure residues have CA atoms needed for the tree
        residues_for_tree = [res for res in all_residues_full if res.has_id("CA")]
        if len(residues_for_tree) < 3:
            print(f"  ⚠️ Skipping {pdb_name}: Not enough residues (<3) with CA atoms.")
            # Return empty DataFrame if using combined output, or just skip file writing
            return pd.DataFrame() # Important for combined output scenario

        coords_ca = np.array([res["CA"].coord for res in residues_for_tree])
        # residue_map allows getting the Residue object back from its index in coords_ca
        residue_map = residues_for_tree

        # 2. Build KDTree
        kdtree = KDTree(coords_ca)

        # 3. Define max distance for neighbor search
        # Use a value slightly larger than the max required by subsequent filters
        # Here, using the max alpha distance + 10% buffer
        max_dist = alpha_distance_range[1] * 1.1
        print(f"  KDTree Max search distance (CA-CA): {max_dist:.2f} Å")

        # 4. Find pairs within max_dist
        # query_pairs finds all pairs (i, j) where i < j and distance(i, j) <= max_dist
        pairs = kdtree.query_pairs(r=max_dist)
        print(f"  Found {len(pairs)} pairs within distance using KDTree.")

        # 5. Find potential third neighbors (k) for each pair (i, j)
        potential_triad_indices = set() # Use a set to store unique sorted index tuples (i, j, k)
        for i, j in pairs:
            # Find neighbors of point i
            indices_k_near_i = kdtree.query_ball_point(coords_ca[i], r=max_dist)
            # Find neighbors of point j
            indices_k_near_j = kdtree.query_ball_point(coords_ca[j], r=max_dist)
            # Find common neighbors (potential 'k' candidates)
            common_neighbors = set(indices_k_near_i).intersection(indices_k_near_j)

            for k in common_neighbors:
                # Ensure k is distinct from i and j
                if k != i and k != j:
                    # Add the sorted tuple of indices to the set to ensure uniqueness
                    triad_indices = tuple(sorted((i, j, k)))
                    potential_triad_indices.add(triad_indices)

        print(f"  Generated {len(potential_triad_indices)} unique potential spatial triads.")

        # 6. Map indices back to residue objects and apply the ">= 2 target residues" filter
        triads_to_process = []
        for idx_i, idx_j, idx_k in potential_triad_indices:
            # Retrieve the actual Bio.PDB Residue objects
            comb = (residue_map[idx_i], residue_map[idx_j], residue_map[idx_k])
            # Apply the target residue count filter
            if sum(res.get_resname() in target_residues for res in comb) >= 2:
                triads_to_process.append(comb)

        print(f"  Identified {len(triads_to_process)} triads meeting spatial and target residue criteria.")
        # --- End of KDTree Implementation ---


        # --- Start Detailed Geometric Filtering (on the reduced 'triads_to_process' list) ---
        def get_triad_type(comb):
            chains = [res.get_full_id()[2] for res in comb]
            return "intra" if len(set(chains)) == 1 else "inter"

        results = []
        # Now iterate through the much smaller, pre-filtered list
        for comb in triads_to_process:
            try:
                # Check for CA and CB atoms (important, as KDTree only used CA)
                if not all(res.has_id("CA") and res.has_id("CB") for res in comb):
                    continue # Skip if any residue misses CA or CB

                # Calculate pairwise distances (CA-CA and CB-CB)
                alpha_distances, beta_distances = [], []
                valid_distances = True
                for res1, res2 in itertools.combinations(comb, 2):
                    d_ca = np.linalg.norm(res1["CA"].coord - res2["CA"].coord)
                    d_cb = np.linalg.norm(res1["CB"].coord - res2["CB"].coord)

                    # Check against the *precise* thresholds loaded earlier
                    if not (alpha_distance_range[0] <= d_ca <= alpha_distance_range[1] and \
                            beta_distance_range[0] <= d_cb <= beta_distance_range[1]):
                        valid_distances = False
                        break
                    alpha_distances.append(d_ca)
                    beta_distances.append(d_cb)

                if valid_distances and len(alpha_distances) == 3: # Ensure all 3 pairs passed
                    row = {
                        "PDB_ID": pdb_name,
                        "Triad_Type": get_triad_type(comb),
                    }
                    for i, res in enumerate(comb):
                        full_id = res.get_full_id()
                        row[f"Coord_chain_id_number{i+1}"] = full_id[2]
                        row[f"Coord_residue_number{i+1}"] = full_id[3][1]
                        row[f"Coord_residue_name{i+1}"] = res.get_resname()
                    for i in range(3):
                        row[f"Alpha Distance {i+1}"] = alpha_distances[i]
                        row[f"Beta Distance {i+1}"] = beta_distances[i]
                    results.append(row)
            except KeyError as ke:
                # print(f"    Skipping triad due to missing atom: {ke} in {comb}")
                continue
            except Exception as e_inner:
                print(f"    Error processing triad {comb}: {e_inner}")
                continue

        print(f"  Found {len(results)} triads passing initial distance filters.")
        df = pd.DataFrame(results)

        # --- Apply Ratio Filter ---
        # (This part remains the same)
        if not df.empty:
            def pass_ratio(row):
                try:
                    for i in range(3):
                        if row[f"Beta Distance {i+1}"] == 0: return False
                        ratio = row[f"Alpha Distance {i+1}"] / row[f"Beta Distance {i+1}"]
                        if not (ratio_threshold_range[0] <= ratio <= ratio_threshold_range[1]):
                            return False
                    return True
                except: return False
            df_ratio = df[df.apply(pass_ratio, axis=1)].copy()
            print(f"  Found {len(df_ratio)} triads passing ratio filter.")
        else: df_ratio = pd.DataFrame()

        # --- Apply Pie Angle Filter ---
        # (This part remains the same, but needs all_residues_full for lookup)
        if not df_ratio.empty:
            res_lookup = {
                (res.get_full_id()[2], res.get_full_id()[3][1]): res
                for res in all_residues_full # Use the original full list for lookup
            }
            def compute_pie(row):
                try:
                    comb_ids = [(row[f"Coord_chain_id_number{i+1}"], row[f"Coord_residue_number{i+1}"]) for i in range(3)]
                    res_objs = [res_lookup[res_id] for res_id in comb_ids]
                    angles = []
                    for i, j in [(0,1), (0,2), (1,2)]:
                        # Ensure atoms exist before accessing coords (add extra check)
                        if not (res_objs[i].has_id("CA") and res_objs[i].has_id("CB") and \
                                res_objs[j].has_id("CA") and res_objs[j].has_id("CB")):
                                return pd.Series([np.nan, np.nan, np.nan], index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])
                        v_ca = res_objs[j]["CA"].coord - res_objs[i]["CA"].coord
                        v_cb = res_objs[j]["CB"].coord - res_objs[i]["CB"].coord
                        angles.append(calculate_pie(v_ca, v_cb))
                    return pd.Series(angles, index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])
                except Exception as e_pie:
                    # print(f"    Error calculating pie angle for row: {e_pie}")
                    return pd.Series([np.nan, np.nan, np.nan], index=["Pie_1_2", "Pie_1_3", "Pie_2_3"])

            pie_results = df_ratio.apply(compute_pie, axis=1)
            df_ratio[["Pie_1_2", "Pie_1_3", "Pie_2_3"]] = pie_results
            for col in ["Pie_1_2", "Pie_1_3", "Pie_2_3"]:
                df_ratio[f"{col}_Filter"] = df_ratio[col].apply(
                    lambda x: pie_threshold_range[0] < x < pie_threshold_range[1] if pd.notnull(x) else False)
            df_ratio['Pie_Filter'] = df_ratio[[f'{col}_Filter' for col in ['Pie_1_2', 'Pie_1_3', 'Pie_2_3']]].all(axis=1)
            df_final = df_ratio[df_ratio['Pie_Filter']].copy()
            print(f"  Found {len(df_final)} triads passing pie angle filter.")
        else: df_final = pd.DataFrame()

        # --- Redundancy Removal ---
        # (This part remains the same)
        if not df_final.empty:
            print("  Applying redundancy removal based on residue identity (name and number)...")
            df_final["Triad_Identity"] = df_final.apply(standardize_residue_identity, axis=1)
            df_deduplicated = df_final.drop_duplicates(subset="Triad_Identity").drop(columns=["Triad_Identity"])
            print(f"  Kept {len(df_deduplicated)} unique triads after deduplication.")
        else: df_deduplicated = pd.DataFrame()

        # --- Output ---
        # Decide whether to write individual files or return for combined output later
        # Option 1: Write individual Excel files (as before)
        with pd.ExcelWriter(output_file) as writer:
             # Optionally write intermediate steps for debugging
             # df.to_excel(writer, sheet_name="1_Initial_Distances_KD", index=False)
             # df_ratio.to_excel(writer, sheet_name="2_Ratio_Filtered_KD", index=False)
             # df_final.to_excel(writer, sheet_name="3_Pie_Filtered_KD", index=False) # Before deduplication
             df_deduplicated.to_excel(writer, sheet_name="4_Final_Deduplicated_KD", index=False) # Final unique per PDB
        print(f"✅ Finished: {pdb_name}. Results saved to {output_file} (Unique triads: {len(df_deduplicated)})")
        return None # Return None if writing file here

        # Option 2: Return DataFrame for combined output (Requires changes in __main__)
        # print(f"✅ Finished processing: {pdb_name} (Unique triads: {len(df_deduplicated)})")
        # return df_deduplicated

    except FileNotFoundError:
        print(f"❌ Error: Input PDB file not found at {pdb_file}")
        return None # Or return empty DataFrame for combined output
    except Exception as e:
        print(f"❌ An unexpected error occurred while processing {pdb_name}: {e}")
        traceback.print_exc()
        return None # Or return empty DataFrame for combined output


# --- Run Processing for All PDB Files ---
# (This part remains largely the same, but adjust based on whether process_pdb_file returns data)
if __name__ == "__main__":
    print("\n--- Starting Batch Processing with KDTree Pre-filtering ---")
    pdb_files = glob.glob(os.path.join(input_folder, "*.pdb"))
    print(f"Found {len(pdb_files)} PDB files in {input_folder}")

    if not pdb_files:
        print("⚠️ No PDB files found. Exiting.")
    else:
        # Determine number of workers
        num_workers = min(6, os.cpu_count() or 1) # Adjust '6' as needed
        print(f"🚀 Starting parallel processing with up to {num_workers} workers...")

        # --- Choose based on process_pdb_file return ---
        # If process_pdb_file writes its own files (returns None):
        with ProcessPoolExecutor(max_workers=num_workers) as executor:
             executor.map(process_pdb_file, pdb_files)

        # If process_pdb_file returns DataFrames for combined output:
        # all_results_list = []
        # with ProcessPoolExecutor(max_workers=num_workers) as executor:
        #     results_iterator = executor.map(process_pdb_file, pdb_files)
        #     all_results_list = [res for res in results_iterator if res is not None and not res.empty]
        # print("\n--- Combining Results ---")
        # if all_results_list:
        #     final_df = pd.concat(all_results_list, ignore_index=True)
        #     # Add optional global deduplication here if needed (see previous example)
        #     # ... (global deduplication code) ...
        #     output_path_csv = os.path.join(output_folder, "ALL_RESULTS_final_deduplicated_KDTree.csv")
        #     final_df.to_csv(output_path_csv, index=False) # Save combined results
        #     print(f"✅ Combined results saved to {output_path_csv}")
        # else:
        #     print("⚠️ No valid triads found in any PDB file.")
        # -------------------------------------------------

    print("\n🎉 All processing finished.")

In [None]:
# @markdown # Step 3: Extraction of coordinates of the prescreened target
import os
import pandas as pd
import numpy as np
from Bio.PDB import PDBParser

# 🔧 Folder paths (adjust as needed)
prescreening_directory = "/content/drive/MyDrive/250413_Final/Zn_2His_1Glu_3"        # @param {type:"string"}
pdb_directory  = "/content/drive/MyDrive/All/Zn_2His_1Glu"           # @param {type:"string"}
coordinate_directory = "/content/drive/MyDrive/250413_Final/Zn_2His_1Glu_3_coordinate"  # @param {type:"string"}

os.makedirs(output_folder, exist_ok=True)

# 🧠 Coordinate extraction helper
def extract_coordinates(chain, res_id, atom_name):
    try:
        residue = chain[res_id]
        return residue[atom_name].coord
    except Exception:
        return [None, None, None]

# 🔁 Loop through all Excel files
for file in os.listdir(excel_folder):
    if file.endswith("_processed.xlsx"):
        pdb_id = file.replace("_processed.xlsx", "")
        excel_path = os.path.join(excel_folder, file)
        pdb_path = os.path.join(pdb_folder, f"{pdb_id}.pdb")
        output_path = os.path.join(output_folder, f"{pdb_id}_with_coordinates.xlsx")

        if not os.path.isfile(pdb_path):
            print(f"❌ Skipping {pdb_id}: PDB file not found.")
            continue

        try:
            # Load Excel
            df_pie = pd.read_excel(excel_path, sheet_name="4_Final_Deduplicated_KD")

            # Load PDB
            parser = PDBParser(QUIET=True)
            structure = parser.get_structure("protein", pdb_path)
            chains = {chain.id: chain for chain in structure[0]}

            ca_coords, cb_coords = [], []

            for idx, row in df_pie.iterrows():
                chain1 = chains.get(row['Coord_chain_id_number1'])
                chain2 = chains.get(row['Coord_chain_id_number2'])
                chain3 = chains.get(row['Coord_chain_id_number3'])

                # Cα
                ca1 = extract_coordinates(chain1, row['Coord_residue_number1'], 'CA')
                ca2 = extract_coordinates(chain2, row['Coord_residue_number2'], 'CA')
                ca3 = extract_coordinates(chain3, row['Coord_residue_number3'], 'CA')
                ca_coords.append([*ca1, *ca2, *ca3])

                # Cβ
                cb1 = extract_coordinates(chain1, row['Coord_residue_number1'], 'CB')
                cb2 = extract_coordinates(chain2, row['Coord_residue_number2'], 'CB')
                cb3 = extract_coordinates(chain3, row['Coord_residue_number3'], 'CB')
                cb_coords.append([*cb1, *cb2, *cb3])

            # Convert to DataFrame
            ca_cols = ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z']
            cb_cols = ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z']
            df_ca = pd.DataFrame(ca_coords, columns=ca_cols)
            df_cb = pd.DataFrame(cb_coords, columns=cb_cols)

            # Combine and save
            df_pie = pd.concat([df_pie.reset_index(drop=True), df_ca, df_cb], axis=1)
            if 'PDB_ID' in df_pie.columns:
                df_pie['PDB_ID'] = df_pie['PDB_ID'].str.replace('.pdb', '', regex=False)

            df_pie.to_excel(output_path, index=False)
            print(f"✅ Saved: {output_path}")

        except Exception as e:
            print(f"❌ Error processing {pdb_id}: {e}")


In [7]:
# @markdown # Step 4: Probability density map
import numpy as np
import pandas as pd
import os
import glob
import requests
import traceback # Import traceback for better error printing
import multiprocessing # Import multiprocessing
from Bio.PDB import PDBParser # PDBParser is still needed for structure loading
from scipy.spatial import KDTree # Import KDTree

# --- Configuration and Setup ---
# Define base directories (MODIFY THESE PATHS IF NEEDED)
# Using the paths from your last provided code
excel_directory = "/content/drive/MyDrive/250413_Final/Zn_3His_1_coordinate" # @param {type:"string"}
pdb_directory_base = "/content/drive/MyDrive/All/Zn_3His"                    # @param {type:"string"}
output_result_directory = "/content/drive/MyDrive/Benchmark_Final/Zn_3His_1" # @param {type:"string"}

# Check if base directories exist
if not os.path.exists(excel_directory):
    raise FileNotFoundError(f"Excel directory not found: {excel_directory}")
if not os.path.exists(pdb_directory_base):
    raise FileNotFoundError(f"PDB directory not found: {pdb_directory_base}")
os.makedirs(output_result_directory, exist_ok=True)

# Define file paths for downloaded data
prob_map_file = '/content/map.xlsx'
thresholds_file = '/content/threshold.xlsx'

# --- Download Data from GitHub ---
# Using the parameters from your last provided code
base_url = "https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/probability/"
Metal = 'Zn'  # @param ["Zn", "Mn", "Cu", "Fe"]
Combinations = '3His'  # @param ["3His", "2His_1Asp", "2His_1Glu", "2His_1Cys"]
map_url = f"{base_url}/{Metal}/{Combinations}/map.xlsx"
thresholds_url = f"{base_url}/{Metal}/{Combinations}/threshold.xlsx"

print(f"Downloading probability map from: {map_url}")
response_map = requests.get(map_url)
if response_map.status_code == 200:
    with open(prob_map_file, 'wb') as file:
        file.write(response_map.content)
    print(f"Downloaded map data to {prob_map_file}")
else:
    raise ValueError(f"Failed to download map file from {map_url}. Status code: {response_map.status_code}")

print(f"Downloading thresholds from: {thresholds_url}")
response_thresh = requests.get(thresholds_url)
if response_thresh.status_code == 200:
    with open(thresholds_file, 'wb') as file:
        file.write(response_thresh.content)
    print(f"Downloaded thresholds data to {thresholds_file}")
else:
    raise ValueError(f"Failed to download thresholds file from {thresholds_url}. Status code: {response_thresh.status_code}")

# --- Load and Process Data --- (Load data ONCE in the main process)
print("Loading thresholds...")
thresholds_df = pd.read_excel(thresholds_file, sheet_name='Sheet1')
print("Loading probability map...")
df_precomputed_prob_map = pd.read_excel(prob_map_file)

print("Processing thresholds...")
thresholds = {}
for _, row in thresholds_df.iterrows():
    parameter = row['Parameter']
    min_value = row['Min']
    max_value = row['Max']
    if pd.notna(min_value) and pd.notna(max_value):
        thresholds[parameter] = (min_value, max_value)

required_keys = ['ca_distances_calc', 'cb_distances_calc', 'ratio', 'angle']
for key in required_keys:
    if key not in thresholds:
        raise KeyError(f"Missing key '{key}' in thresholds file.")

print("Processing probability map...")
ca_bins = np.sort(df_precomputed_prob_map['Calpha_Zn_Dist'].unique())
cb_bins = np.sort(df_precomputed_prob_map['Cbeta_Zn_Dist'].unique())
angle_bins = np.sort(df_precomputed_prob_map['CA-Zn-CB_Angle'].unique())
pivoted_prob_map = df_precomputed_prob_map.pivot_table(
    index='Calpha_Zn_Dist', columns=['Cbeta_Zn_Dist', 'CA-Zn-CB_Angle'], values='Probability', fill_value=0
)
expected_shape = (len(ca_bins), len(cb_bins) * len(angle_bins))
if pivoted_prob_map.shape == expected_shape:
     prob_map_3d = pivoted_prob_map.values.reshape((len(ca_bins), len(cb_bins), len(angle_bins)))
     print("Probability map processed into 3D array.")
else:
     raise ValueError(f"Pivoted map shape {pivoted_prob_map.shape} doesn't match expected shape {expected_shape} for reshaping.")


# --- Helper Function Definitions ---

def calculate_ratio(current_point, ca_xyz, cb_xyz):
    # ... (no changes needed) ...
    ca_distances = np.linalg.norm(ca_xyz - current_point, axis=1)
    cb_distances = np.linalg.norm(cb_xyz - current_point, axis=1)
    ratios = np.divide(ca_distances, cb_distances, out=np.full_like(ca_distances, np.inf), where=cb_distances!=0)
    return ratios

def load_pdb_structure(entry_id, pdb_directory):
    # ... (no changes needed) ...
    pdb_parser = PDBParser(QUIET=True)
    pdb_file_path = os.path.join(pdb_directory, f"{entry_id}.pdb")
    try:
        structure = pdb_parser.get_structure(entry_id, pdb_file_path)
        return structure
    except FileNotFoundError:
        print(f"❌ PDB file not found for loading: {pdb_file_path}")
        return None
    except Exception as e:
        print(f"❌ Error loading PDB file {pdb_file_path}: {e}")
        return None

def score_zn_predictions(ca_distances, cb_distances, angles, prob_map_3d, ca_bins, cb_bins, angle_bins):
    # ... (no changes needed) ...
    ca_bin_indices = np.clip(np.digitize(ca_distances, ca_bins[1:], right=True), 0, len(ca_bins)-1)
    cb_bin_indices = np.clip(np.digitize(cb_distances, cb_bins[1:], right=True), 0, len(cb_bins)-1)
    angle_bin_indices = np.clip(np.digitize(angles, angle_bins[1:], right=True), 0, len(angle_bins)-1)
    probabilities = []
    valid = True
    for cbin, bbin, abin in zip(ca_bin_indices, cb_bin_indices, angle_bin_indices):
        if 0 <= cbin < prob_map_3d.shape[0] and 0 <= bbin < prob_map_3d.shape[1] and 0 <= abin < prob_map_3d.shape[2]:
            prob_value = prob_map_3d[cbin, bbin, abin]
            if prob_value <= 0:
                valid = False
                break
            probabilities.append(prob_value)
        else:
            # print(f"⚠️ Warning: Invalid bin indices generated: CA({cbin}), CB({bbin}), Angle({abin})")
            valid = False
            break
    final_score = np.prod(probabilities) if valid and probabilities else 0.0
    return final_score

def calculate_angles(zn_coords, ca_coords_triplet, cb_coords_triplet):
    # ... (no changes needed) ...
    angles = []
    for i in range(3):
        v_ca = ca_coords_triplet[i] - zn_coords
        v_cb = cb_coords_triplet[i] - zn_coords
        norm_v_ca = np.linalg.norm(v_ca)
        norm_v_cb = np.linalg.norm(v_cb)
        if norm_v_ca == 0 or norm_v_cb == 0:
            angles.append(0.0)
            continue
        cos_theta = np.dot(v_ca, v_cb) / (norm_v_ca * norm_v_cb)
        angle_rad = np.arccos(np.clip(cos_theta, -1.0, 1.0))
        angles.append(np.degrees(angle_rad))
    return angles

def define_excluded_triads(triad_res_nums, structure):
    # ... (no changes needed) ...
    excluded_residues = set()
    if structure is None: return excluded_residues
    # Ensure triad_res_nums are integers for comparison
    try:
        res_nums_to_find = set(int(num) for num in triad_res_nums)
    except (ValueError, TypeError):
         print(f"⚠️ Warning: Could not convert all triad residue numbers {triad_res_nums} to integers.")
         return excluded_residues # Return empty set if conversion fails

    for model in structure:
        for chain in model:
            for residue in chain:
                res_seq_num = residue.id[1]
                if res_seq_num in res_nums_to_find:
                    excluded_residues.add((chain.id, res_seq_num))
    return excluded_residues

# --- NEW Proximity Filter using KDTree ---
def proximity_filter_kdtree(kdtree, zn_candidate, exclusion_radius=2.5):
    """
    Checks proximity using a pre-built SciPy KDTree.
    Returns True if valid (no atoms too close), False otherwise.
    """
    if kdtree is None:
        # If no tree was built (e.g., no non-excluded atoms), assume valid
        return True

    try:
        # Query the KDTree to find indices of points within the radius
        # query_ball_point is efficient for this "are there any?" check
        indices_nearby = kdtree.query_ball_point(zn_candidate, r=exclusion_radius, return_length=True)

        # If the length is > 0, points were found nearby
        if indices_nearby > 0:
            return False # Invalid: atoms are too close
        else:
            return True # Valid: no atoms found within the radius
    except Exception as e:
        # Handle potential errors during KDTree query phase
        print(f"❌ Error during KDTree query: {e}")
        return False # Treat query errors as failing the proximity check


# --- MODIFIED Main Prediction Function (Uses KDTree) ---
def estimate_zn_iterative(
    ca_coords_site_flat, # Coords for ONE site - FLAT array (9,) expected
    cb_coords_site_flat, # Coords for ONE site - FLAT array (9,) expected
    site_info,      # DataFrame row or dict with PDB_ID and residue numbers
    structure,      # Pass the loaded structure
    thresholds,     # Pass the thresholds dict
    prob_map_3d, ca_bins, cb_bins, angle_bins, # Pass map and bins
    grid_resolution=0.2
    ):
    """
    Estimates Zn coordinate for a SINGLE site using KDTree for proximity,
    returning the FIRST valid candidate.
    """
    entry_id = site_info['PDB_ID']
    site_index_name = site_info.name

    # --- Coordinate Validation and Reshape --- (Includes fix from before)
    try:
        ca_coords_numeric = pd.to_numeric(np.asarray(ca_coords_site_flat), errors='coerce')
        cb_coords_numeric = pd.to_numeric(np.asarray(cb_coords_site_flat), errors='coerce')
        if np.isnan(ca_coords_numeric).any() or np.isnan(cb_coords_numeric).any():
             return "no metal", 0, [None, None, None]
        if ca_coords_numeric.shape != (9,) or cb_coords_numeric.shape != (9,):
             return "no metal", 0, [None, None, None]
        ca_xyz = ca_coords_numeric.astype(np.float64).reshape(3, 3)
        cb_xyz = cb_coords_numeric.astype(np.float64).reshape(3, 3)
    except (ValueError, TypeError) as e:
        print(f"❌ Error validating/reshaping coordinates for site index {site_index_name} in {entry_id}: {e}")
        return "no metal", 0, [None, None, None]

    if structure is None: return "no metal", 0, [None, None, None]

    # --- Define Excluded Residues ---
    triad_res_nums = [
        site_info['Coord_residue_number1'],
        site_info['Coord_residue_number2'],
        site_info['Coord_residue_number3']
    ]
    excluded_residues = define_excluded_triads(triad_res_nums, structure) # Get set of (chain, resnum)

    # --- Build KD-Tree for Proximity Check (for this specific site's excluded residues) ---
    non_excluded_coords_list = []
    for atom in structure.get_atoms():
        residue = atom.get_parent()
        chain = residue.get_parent()
        res_info = (chain.id, residue.id[1])
        if res_info not in excluded_residues:
            # Optional: Skip Hydrogens if needed
            # if atom.element == 'H': continue
            non_excluded_coords_list.append(atom.coord)

    kdtree = None # Initialize kdtree
    if non_excluded_coords_list:
        try:
             non_excluded_coords = np.array(non_excluded_coords_list, dtype=np.float64)
             # Check if array is valid before building tree
             if non_excluded_coords.ndim == 2 and non_excluded_coords.shape[1] == 3:
                  kdtree = KDTree(non_excluded_coords)
             # else: print(f"⚠️ Warning: Invalid shape {non_excluded_coords.shape} for KDTree points in {entry_id}, site {site_index_name}") # Less verbose
        except Exception as kdtree_error:
             print(f"❌ Error building KDTree for {entry_id}, site {site_index_name}: {kdtree_error}")
             # kdtree remains None, proximity_filter_kdtree will handle this

    # --- Define Search Space (Outer Box Intersection) ---
    # ... (calculation remains the same) ...
    shared_x_min, shared_x_max = -np.inf, np.inf
    shared_y_min, shared_y_max = -np.inf, np.inf
    shared_z_min, shared_z_max = -np.inf, np.inf
    for j in range(3):
        x_min_outer = min(ca_xyz[j, 0], cb_xyz[j, 0]) - thresholds['ca_distances_calc'][1]
        x_max_outer = max(ca_xyz[j, 0], cb_xyz[j, 0]) + thresholds['ca_distances_calc'][1]
        y_min_outer = min(ca_xyz[j, 1], cb_xyz[j, 1]) - thresholds['cb_distances_calc'][1]
        y_max_outer = max(ca_xyz[j, 1], cb_xyz[j, 1]) + thresholds['cb_distances_calc'][1]
        z_min_outer = min(ca_xyz[j, 2], cb_xyz[j, 2]) - thresholds['ca_distances_calc'][1]
        z_max_outer = max(ca_xyz[j, 2], cb_xyz[j, 2]) + thresholds['ca_distances_calc'][1]
        buffer = grid_resolution * 2
        shared_x_min = max(shared_x_min, x_min_outer - buffer)
        shared_x_max = min(shared_x_max, x_max_outer + buffer)
        shared_y_min = max(shared_y_min, y_min_outer - buffer)
        shared_y_max = min(shared_y_max, y_max_outer + buffer)
        shared_z_min = max(shared_z_min, z_min_outer - buffer)
        shared_z_max = min(shared_z_max, z_max_outer + buffer)

    if shared_x_min >= shared_x_max or shared_y_min >= shared_y_max or shared_z_min >= shared_z_max:
        return "no metal", 0, [None, None, None]

    # --- Refined Grid Search (Find First Valid Candidate) ---
    found_candidate_for_entry = False
    candidate_coords = "no metal"
    candidate_score = 0
    candidate_angles = [None, None, None]

    x_range = np.arange(shared_x_min, shared_x_max + 1e-9, grid_resolution)
    y_range = np.arange(shared_y_min, shared_y_max + 1e-9, grid_resolution)
    z_range = np.arange(shared_z_min, shared_z_max + 1e-9, grid_resolution)

    if not (x_range.size > 0 and y_range.size > 0 and z_range.size > 0):
         return "no metal", 0, [None, None, None]

    # Grid search loops...
    for x in x_range:
        if found_candidate_for_entry: break
        for y in y_range:
            if found_candidate_for_entry: break
            for z in z_range:
                if found_candidate_for_entry: break
                corner_point = np.array([x, y, z])
                center_point = corner_point + grid_resolution / 2.0
                points_to_check = [corner_point]
                if np.all(center_point < [shared_x_max, shared_y_max, shared_z_max]):
                    points_to_check.append(center_point)

                for point in points_to_check:
                    if found_candidate_for_entry: break
                    # --- Filtering Cascade ---
                    distances_ca = np.linalg.norm(ca_xyz - point, axis=1)
                    distances_cb = np.linalg.norm(cb_xyz - point, axis=1)
                    distance_ok = (np.all((thresholds['ca_distances_calc'][0] <= distances_ca) & (distances_ca <= thresholds['ca_distances_calc'][1])) and
                                   np.all((thresholds['cb_distances_calc'][0] <= distances_cb) & (distances_cb <= thresholds['cb_distances_calc'][1])))
                    if not distance_ok: continue

                    angles = calculate_angles(point, ca_xyz, cb_xyz)
                    angle_ok = all(thresholds['angle'][0] <= angle <= thresholds['angle'][1] for angle in angles)
                    if not angle_ok: continue

                    ratios = calculate_ratio(point, ca_xyz, cb_xyz)
                    ratio_ok = np.all((thresholds['ratio'][0] <= ratios) & (ratios <= thresholds['ratio'][1]))
                    if not ratio_ok: continue

                    score = score_zn_predictions(distances_ca, distances_cb, angles, prob_map_3d, ca_bins, cb_bins, angle_bins)
                    if score is None or score <= 0: continue

                    # !!! Use the KDTree proximity filter !!!
                    is_prox_valid = proximity_filter_kdtree(kdtree, point, exclusion_radius=2.5)
                    if not is_prox_valid: continue

                    # --- Candidate Found! ---
                    candidate_coords = point
                    candidate_score = score
                    candidate_angles = angles
                    found_candidate_for_entry = True
                    break # Exit points_to_check loop

    return candidate_coords, candidate_score, candidate_angles


# --- Worker Function for Multiprocessing --- (No changes needed here)
def process_pdb_entry(args):
    # ... (This function remains the same as the previous multiprocessing version) ...
    # ... (It unpacks args, loads structure, reads excel, loops through sites...) ...
    # ... (Inside the loop, it calls the NEW estimate_zn_iterative) ...
    # ... (It collects results, adds back to df, saves output) ...
    # Unpack arguments
    excel_path, pdb_directory_base, base_name, output_result_directory, \
    thresholds_local, prob_map_3d_local, ca_bins_local, cb_bins_local, angle_bins_local = args

    process_id = os.getpid()
    print(f"[PID {process_id}] Processing PDB ID: {base_name}")

    pdb_path = os.path.join(pdb_directory_base, f"{base_name}.pdb")
    if not os.path.exists(pdb_path):
        print(f"[PID {process_id}] ❌ Skipping {base_name}, PDB file not found at {pdb_path}.")
        return base_name, False, 0 # Return PDB ID, status, count

    # Load structure ONCE for this PDB
    structure = load_pdb_structure(base_name, pdb_directory_base)
    if structure is None:
        print(f"[PID {process_id}] ❌ Failed to load structure for {base_name}, skipping.")
        return base_name, False, 0

    try:
        df_alanine = pd.read_excel(excel_path)
        if df_alanine.empty:
            print(f"[PID {process_id}] ⚠️ Input Excel file is empty for {base_name}, skipping.")
            return base_name, False, 0

        df_alanine['PDB_ID'] = base_name
        df_alanine = df_alanine.reset_index(drop=True)

        # Check required columns...
        ca_cols = ['CA1_X', 'CA1_Y', 'CA1_Z', 'CA2_X', 'CA2_Y', 'CA2_Z', 'CA3_X', 'CA3_Y', 'CA3_Z']
        cb_cols = ['CB1_X', 'CB1_Y', 'CB1_Z', 'CB2_X', 'CB2_Y', 'CB2_Z', 'CB3_X', 'CB3_Y', 'CB3_Z']
        res_num_cols = ['Coord_residue_number1', 'Coord_residue_number2', 'Coord_residue_number3']
        required_input_cols = ca_cols + cb_cols + res_num_cols
        if not all(col in df_alanine.columns for col in required_input_cols):
            missing_cols = [col for col in required_input_cols if col not in df_alanine.columns]
            print(f"[PID {process_id}] ❌ Missing required columns in {os.path.basename(excel_path)}: {missing_cols}, skipping {base_name}.")
            return base_name, False, 0

        # Initialize result columns
        df_alanine['Zn_X_Grid'] = None
        df_alanine['Zn_Y_Grid'] = None
        df_alanine['Zn_Z_Grid'] = None
        df_alanine['Zn_Score'] = 0.0
        df_alanine['Angle_1'] = None
        df_alanine['Angle_2'] = None
        df_alanine['Angle_3'] = None

        # Iterate through each site (row) in the dataframe for this PDB
        for index, site_info in df_alanine.iterrows():
            # Extract coordinates for this specific site (flat arrays)
            ca_coords_site_flat = site_info[ca_cols].values
            cb_coords_site_flat = site_info[cb_cols].values

            # Call the modified estimate_zn_iterative for this site
            zn_coords, zn_score, zn_angles = estimate_zn_iterative(
                ca_coords_site_flat, cb_coords_site_flat, site_info, structure, # Pass site info and loaded structure
                thresholds_local, prob_map_3d_local, ca_bins_local, cb_bins_local, angle_bins_local, # Pass data
                grid_resolution=0.2
            )

            # Store results directly back into the DataFrame for this index
            df_alanine.loc[index, 'Zn_Score'] = zn_score
            if isinstance(zn_coords, np.ndarray):
                df_alanine.loc[index, 'Zn_X_Grid'] = zn_coords[0]
                df_alanine.loc[index, 'Zn_Y_Grid'] = zn_coords[1]
                df_alanine.loc[index, 'Zn_Z_Grid'] = zn_coords[2]
            if isinstance(zn_angles, (list, np.ndarray)) and len(zn_angles) == 3:
                 df_alanine.loc[index, 'Angle_1'] = zn_angles[0]
                 df_alanine.loc[index, 'Angle_2'] = zn_angles[1]
                 df_alanine.loc[index, 'Angle_3'] = zn_angles[2]

        # --- Post-processing and Saving Results ---
        # Filter results after processing all sites for this PDB
        df_output = df_alanine[df_alanine['Zn_Score'] > 0].copy()

        if not df_output.empty:
            output_file_path = os.path.join(output_result_directory, f"{base_name}_result.xlsx")
            df_output.to_excel(output_file_path, index=False)
            print(f"[PID {process_id}] ✅ Saved results for {len(df_output)} site(s) from {base_name} to: {os.path.basename(output_file_path)}")
            return base_name, True, len(df_output) # Return PDB ID, status, count
        else:
            print(f"[PID {process_id}] ⚠️ No valid Zn predictions passed filters for {base_name}.")
            return base_name, True, 0 # Return PDB ID, status, count

    except Exception as e:
        print(f"[PID {process_id}] ❌ Error processing {base_name}: {e}")
        print(traceback.format_exc())
        return base_name, False, 0 # Return PDB ID, status, count


# --- Main Execution Guard --- (No changes needed here)
if __name__ == "__main__":
    print("\n--- Starting Main Process ---")
    # Data is loaded once here: thresholds, prob_map_3d, ca_bins, cb_bins, angle_bins

    # Find input coordinate files using glob
    print(f"\nSearching for input Excel files in: {excel_directory}")
    excel_files = glob.glob(os.path.join(excel_directory, '*_with_coordinates.xlsx'))
    print(f"Found {len(excel_files)} potential input files.")

    if not excel_files:
        print("❌ No input Excel files found matching pattern '*_with_coordinates.xlsx'. Exiting.")
        exit()

    # Prepare list of arguments for worker processes
    tasks = []
    for excel_path in excel_files:
        base_name = os.path.basename(excel_path).replace('_with_coordinates.xlsx', '')
        base_name = base_name.replace('processed_', '')
        tasks.append((
            excel_path, pdb_directory_base, base_name, output_result_directory,
            thresholds, prob_map_3d, ca_bins, cb_bins, angle_bins
        ))

    # Determine number of processes
    num_processes = 8 # Example: Manually set if needed
    print(f"\nInitializing multiprocessing pool with {num_processes} workers...")

    successful_files = 0
    failed_files = 0
    total_sites_saved = 0

    # Create and run the pool
    with multiprocessing.Pool(processes=num_processes) as pool:
        print("Starting parallel processing...")
        results_iterator = pool.imap_unordered(process_pdb_entry, tasks)
        processed_count = 0
        for result in results_iterator:
            processed_count += 1
            pdb_id, status, site_count = result
            if status:
                successful_files += 1
                total_sites_saved += site_count
                print(f"  ({processed_count}/{len(tasks)}) Completed: {pdb_id} ({site_count} sites saved)")
            else:
                failed_files += 1
                print(f"  ({processed_count}/{len(tasks)}) Failed/Skipped: {pdb_id}")

    print("\n--- Multiprocessing Pool Finished ---")
    print(f"\n🏁 Batch processing summary:")
    print(f"  Total input files found: {len(tasks)}")
    print(f"  Successfully processed files: {successful_files}")
    print(f"  Failed/Skipped files: {failed_files}")
    print(f"  Total result sites saved: {total_sites_saved}")

Downloading probability map from: https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/probability//Zn/3His/map.xlsx
Downloaded map data to /content/map.xlsx
Downloading thresholds from: https://raw.githubusercontent.com/SNU-Songlab/Metal-Installer-code/main/probability//Zn/3His/threshold.xlsx
Downloaded thresholds data to /content/threshold.xlsx
Loading thresholds...
Loading probability map...
Processing thresholds...
Processing probability map...
Probability map processed into 3D array.

--- Starting Main Process ---

Searching for input Excel files in: /content/drive/MyDrive/250413_Final/Zn_3His_1_coordinate/
Found 117 potential input files.

Initializing multiprocessing pool with 8 workers...
[PID 6353] Processing PDB ID: 1fr2_chainA[PID 6354] Processing PDB ID: 1a85_chainA[PID 6355] Processing PDB ID: 1bkc_chainA[PID 6356] Processing PDB ID: 1atl_chainA[PID 6357] Processing PDB ID: 1bmc_chainA


[PID 6359] Processing PDB ID: 2esl_chainA
[PID 6358] Processing PDB

KeyboardInterrupt: 

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
from Bio.PDB import PDBParser
import requests

# Markdown documentation for file pathways

# @markdown # Step 5: Analysis the result (Apply to the PDB file)

# Load input file
input_file_path = "/content/3ttis_coordinates_1_result.xlsx" # @param {type:"string"}
df_new = pd.read_excel(input_file_path)

# Generate PyMOL script file
pymol_script_commands = []
df_new['Combination_Number'] = range(1, len(df_new) + 1)

# Generate the PyMOL script for both valid and invalid Zn binding forms
for index, row in df_new.iterrows():
    # Retrieve chain and residue information
    chain1, res1 = row['Coord_chain_id_number1'], row['Coord_residue_number1']
    chain2, res2 = row['Coord_chain_id_number2'], row['Coord_residue_number2']
    chain3, res3 = row['Coord_chain_id_number3'], row['Coord_residue_number3']

    # Retrieve Zn coordinates
    zn_x, zn_y, zn_z = row['Zn_X_Grid'], row['Zn_Y_Grid'], row['Zn_Z_Grid']

    selection_name = f"obj{row['Combination_Number']:02d}"

    # Select the residues
    pymol_script_commands.append(f"select {selection_name}, (chain {chain1} and resi {res1}) or (chain {chain2} and resi {res2}) or (chain {chain3} and resi {res3})")

    # Create the objects for the residues
    pymol_script_commands.append(f"create {selection_name}_residue1, /{row['PDB_ID']}//{chain1}/{res1}")
    pymol_script_commands.append(f"create {selection_name}_residue2, /{row['PDB_ID']}//{chain2}/{res2}")
    pymol_script_commands.append(f"create {selection_name}_residue3, /{row['PDB_ID']}//{chain3}/{res3}")

    # Check if Zn coordinates are available
    if not pd.isna(zn_x) and not pd.isna(zn_y) and not pd.isna(zn_z):
        # Zn coordinates are present, add the Zn pseudoatom
        zn_name = f"{selection_name}_Metal"
        pymol_script_commands.append(f"pseudoatom {zn_name}, pos=[{zn_x}, {zn_y}, {zn_z}], elem=Metal, name={zn_name}")
        pymol_script_commands.append(f"show sphere, {zn_name}")
    else:
        # Zn coordinates are missing, mark this combination as non-binding
        pymol_script_commands.append(f"# {selection_name} does not bind Zn")

# Save the commands into a PyMOL script
pymol_script_file = "/content/3tis_final.pml" # @param {type:"string"}
with open(pymol_script_file, 'w') as f:
    f.write("# PyMOL script for visualizing both Zn-binding and non-binding residue combinations\n\n")
    for command in pymol_script_commands:
        f.write(command + '\n')

print(f"PyMOL script saved to {pymol_script_file}")