In [None]:
import subprocess
import requests
import os
import shutil
import logging
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import Pool
from pathlib import Path
# --- add (or ensure) these imports near the top of your file ---
import re, shlex
from concurrent.futures import ProcessPoolExecutor, as_completed
from collections import defaultdict
from Bio.PDB import PDBParser, PPBuilder
from Bio import pairwise2
import warnings



In [None]:
#this part deal with downloading and preparing the pdb file
def fetch_pdb(pdb_code, output_dir):
    """
    Fetches a PDB file from RCSB and saves it locally.
    """
    pdb_url = f"https://files.rcsb.org/download/{pdb_code}.pdb"
    os.makedirs(output_dir, exist_ok=True)
    pdb_file = os.path.join(output_dir, f"{pdb_code}.pdb")
    
    response = requests.get(pdb_url)
    if response.status_code == 200:
        with open(pdb_file, "w") as file:
            file.write(response.text)
        print(f"Downloaded PDB file for {pdb_code} to {pdb_file}")
        return pdb_file
    else:
        raise FileNotFoundError(f"Could not download PDB file for {pdb_code}.")

def download_prep(pdb_code, predicted_cif, output_dir, pdb_dir):
    """
    Fetch the ground-truth PDB, then prep both that PDB and the predicted CIF
    with your Schrodinger scripts. Returns paths to the two *_preped.pdb files.
    """

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(pdb_dir, exist_ok=True)

    # 1) Fetch ground-truth PDB
    ground_truth_pdb = fetch_pdb(pdb_code, pdb_dir)

    # 2) Build the prepped-output paths
    truth_out = os.path.join(output_dir, f"{pdb_code}_preped.pdb")
    pred_base = os.path.basename(predicted_cif).replace(".cif", "_preped.pdb")
    pred_out = os.path.join(output_dir, pred_base)

    # 3) Locate your shell‐wrapper scripts
    script_pdb = "./run_prep_single_pdb.sh"
    script_cif = "./run_prep_single_cif.sh"
    for script in (script_pdb, script_cif):
        if not os.path.isfile(script) or not os.access(script, os.X_OK):
            raise FileNotFoundError(f"Cannot execute wrapper script: {script}")

    # 4) Prepare ground-truth PDB
    cmd1 = [script_pdb, ground_truth_pdb, output_dir]
    cmd2 = [script_cif, predicted_cif, output_dir]
    
    try:
        res1 = subprocess.run(
            cmd1,
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True
        )
        print(f"[prep pdb] {res1.stdout}")
    except subprocess.CalledProcessError as e:
        print(f"❌ prep failed for ground truth ({ground_truth_pdb}), exit {e.returncode}")
        print(e.stderr)
        return None, None

    # 5) Prepare predicted CIF
    cmd2 = [script_cif, predicted_cif, output_dir]
    try:
        res2 = subprocess.run(
            cmd2,
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True
        )
        print(f"[prep cif] {res2.stdout}")
    except subprocess.CalledProcessError as e:
        print(f"❌ prep failed for predicted ({predicted_cif}), exit {e.returncode}")
        print(e.stderr)
        return None, None

    # 6) Return the two resulting PDB paths
    return truth_out, pred_out

def _worker_download_prep(args):
    """
    Unpack arguments tuple and call download_prep.
    Returns (pdb_code, truth_pdb, pred_pdb) or (pdb_code, None, None) on error.
    """
    pdb_code, cif_path, output_dir, pdb_dir = args
    try:
        truth_pdb, pred_pdb = download_prep(pdb_code, cif_path, output_dir, pdb_dir)
        return (pdb_code, truth_pdb, pred_pdb)
    except Exception as e:
        print(f"[ERROR] {pdb_code}: {e}")
        return (pdb_code, None, None)

def multiprocess_directory_prep(
    directory: str,
    output_dir: str,
    pdb_dir: str,
    processes: int = 40
) -> pd.DataFrame:
    """
    For each *_model.cif in `directory`, in parallel:
      1. download ground-truth PDB
      2. prep both ground-truth and predicted via download_prep()
    Returns a pandas DataFrame with columns:
      ['pdb_code', 'prepped_truth_pdb', 'prepped_predicted_pdb']
    """
    

    # build list of jobs
    jobs = []
    suffix = '_model.cif'
    for fn in os.listdir(directory):
        if fn.lower().endswith(suffix):
            pdb_code = fn[:-len(suffix)]
            cif_path = os.path.join(directory, fn)
            jobs.append((pdb_code, cif_path, output_dir, pdb_dir))

    # launch Pool
    with Pool(processes) as pool:
        results = pool.map(_worker_download_prep, jobs)

    # filter out failures and build DataFrame
    successful = [
        (code, truth, pred)
        for code, truth, pred in results
        if truth is not None and pred is not None
    ]
    df = pd.DataFrame(successful,
                      columns=['pdb_code', 'prepped_truth_pdb', 'prepped_predicted_pdb'])
    return df

In [3]:
def rename_complex_prep_files(directory):
    """
    In `directory`, find all files matching '*_model_prep.pdb' that contain '__'
    and rename them so everything after the first '__' is dropped, e.g.:
      1cbs__1__1.a__1.b_model_prep.pdb → 1cbs_model_prep.pdb
    """
    base = Path(directory)
    for f in base.glob("*_model_prep.pdb"):
        if "__" not in f.name:
            continue
        # f.stem is e.g. "1cbs__1__1.a__1.b_model_prep"
        prefix = f.stem.split("__", 1)[0]
        new_name = f"{prefix}_model_prep.pdb"
        target = f.with_name(new_name)
        f.rename(target)
        print(f"Renamed: {f.name} → {new_name}")

In [None]:


# optional: quiet the pairwise2 deprecation warning noise
warnings.filterwarnings("ignore", category=UserWarning, module="Bio.pairwise2")

# ======================
# poseviewer + analysis
# ======================

def _poseviewer_txt_name(pdb_path: str) -> str:
    base, _ = os.path.splitext(os.path.basename(pdb_path))
    return f"{base}_pv_interactions.txt"


def _run_poseviewer(
    pdb_path: str,
    cwd: str,
    force: bool = False,
    sbgrid_rc: str = "/programs/sbgrid.shrc",
    schrodinger_root: str | None = None,
) -> str:
    """
    Run poseviewer_interactions.py on a PDB and return the *_pv_interactions.txt path.
    If `schrodinger_root` is provided, use <root>/run directly.
    Otherwise, source SBGrid (`sbgrid_rc`) to populate $SCHRODINGER, then run "$SCHRODINGER/run".
    Output and logs are written in `cwd`.
    """
    out_txt = os.path.join(cwd, _poseviewer_txt_name(pdb_path))
    if os.path.exists(out_txt) and not force:
        return out_txt

    log_path = os.path.join(cwd, f"{os.path.splitext(os.path.basename(pdb_path))[0]}_poseviewer.log")

    if schrodinger_root:
        runner = os.path.join(schrodinger_root, "run")
        if not os.path.exists(runner):
            raise FileNotFoundError(f"Schrödinger runner not found: {runner}")
        cmd = [runner, "poseviewer_interactions.py", pdb_path]
        with open(log_path, "w") as log:
            proc = subprocess.run(cmd, cwd=cwd, stdout=log, stderr=subprocess.STDOUT, check=False)
    else:
        # Use bash -lc so we can source the environment and then execute.
        quoted_pdb = shlex.quote(pdb_path)
        quoted_rc  = shlex.quote(sbgrid_rc)
        bash_cmd = f'source {quoted_rc} >/dev/null 2>&1 && "$SCHRODINGER/run" poseviewer_interactions.py {quoted_pdb}'
        with open(log_path, "w") as log:
            proc = subprocess.run(["bash", "-lc", bash_cmd], cwd=cwd, stdout=log, stderr=subprocess.STDOUT, check=False)

    if proc.returncode != 0:
        raise RuntimeError(f"poseviewer_interactions.py failed for {pdb_path} (see log: {log_path})")
    if not os.path.exists(out_txt):
        raise FileNotFoundError(f"Expected interactions file not found: {out_txt}")
    return out_txt


# ---------------------------
# YOUR interaction extraction
# ---------------------------

def refined_extract_interaction_data_with_chain(file_path):
    with open(file_path, 'r') as file:
        content = file.read()
    interaction_sections = re.findall(
        r"Interactions grouped by receptor residue:\n(.*?)Total \w+ Contacts interactions:",
        content, re.S
    )
    interaction_type_list = []
    residue_chain_list = []
    residue_type_list = []
    residue_number_list = []
    contact_count_list = []
    for section in interaction_sections:
        lines = section.splitlines()
        for line in lines:
            match = re.match(r"(\w+)\s+([A-Za-z0-9:]+)\((\w{3})\)\s+\d+\s+(\d+)", line.strip())
            if match:
                interaction_type, residue, amino_acid, contacts = match.groups()
                chain_id = residue.split(":")[0]
                residue_number = re.search(r"(\d+)", residue).group()
                interaction_type_list.append(interaction_type)
                residue_chain_list.append(chain_id)
                residue_type_list.append(amino_acid)
                residue_number_list.append(int(residue_number))
                contact_count_list.append(int(contacts))
    df = pd.DataFrame({
        "Interaction Type": interaction_type_list,
        "Residue Chain": residue_chain_list,
        "Residue Type": residue_type_list,
        "Residue Number": residue_number_list,
        "# of Contacts": contact_count_list
    })
    # drop "Ugly" and "Bad" if present
    if not df.empty:
        df = df[~df["Interaction Type"].str.lower().isin(["ugly", "bad"])].reset_index(drop=True)
    return df


# ---------------------------
# YOUR sequence mapping funcs
# ---------------------------

def extract_sequence_and_residues(pdb_file, chain_id=None):
    """
    Extracts the sequence and residue information from a PDB file,
    handling breaks in chains by combining all peptides.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("structure", pdb_file)
    model = structure[0]
    ppb = PPBuilder()

    chain_data = {}
    for chain in model:
        if chain_id and chain.id != chain_id:
            continue

        peptides = ppb.build_peptides(chain)
        full_sequence = ""
        full_residues = []

        for peptide in peptides:
            seq = str(peptide.get_sequence())
            residues = [
                (res.get_resname(), chain.id, res.id[1], res.id[2])  # (resname, chain, resnum, icode)
                for res in peptide
            ]
            full_sequence += seq
            full_residues.extend(residues)

        if full_sequence:
            chain_data[chain.id] = (full_sequence, full_residues)

    return chain_data


def align_sequences(seq1, seq2):
    alignment = pairwise2.align.globalxx(seq1, seq2)[0]
    return alignment


def map_residues(alignment, residues1, residues2):
    seq1, seq2 = alignment[0], alignment[1]
    mapping = defaultdict(list)
    idx1, idx2 = 0, 0

    for char1, char2 in zip(seq1, seq2):
        if char1 != "-":
            idx1 += 1
        if char2 != "-":
            idx2 += 1
        if char1 != "-" and char2 != "-":
            pred_res = residues1[idx1 - 1]  # (resname, chain, resnum, icode)
            true_res = residues2[idx2 - 1]
            # key by (resnum, chain) on the predicted side, map to (resnum, chain) on the truth side
            mapping[(pred_res[2], pred_res[1])].append((true_res[2], true_res[1]))
    return mapping


def expand_predicted_interactions(predicted_interactions, all_mappings):
    expanded_rows = []
    for _, row in predicted_interactions.iterrows():
        resnum = row["Residue Number"]
        chain = row["Residue Chain"]
        # Get all mappings for this residue
        mapped_residues = all_mappings.get((resnum, chain), [])
        for mapped_resnum, mapped_chain in mapped_residues:
            expanded_row = row.copy()
            expanded_row["Mapped Residue Number"] = mapped_resnum
            expanded_row["Mapped Chain"] = mapped_chain
            expanded_rows.append(expanded_row)

    expanded_df = pd.DataFrame(expanded_rows)
    return expanded_df


def analyze_interaction_recovery_with_expansion(predicted_file, true_file, predicted_pdb, true_pdb):
    predicted_interactions = refined_extract_interaction_data_with_chain(predicted_file)
    true_interactions = refined_extract_interaction_data_with_chain(true_file)

    pred_seq_data = extract_sequence_and_residues(predicted_pdb)
    true_data = extract_sequence_and_residues(true_pdb)

    all_mappings = defaultdict(list)  # one-to-many mapping across chains
    if isinstance(true_data, dict):
        for _, (true_seq, true_reslist) in true_data.items():
            for _, (pred_seq, pred_reslist) in pred_seq_data.items():
                alignment = align_sequences(pred_seq, true_seq)
                mapping = map_residues(alignment, pred_reslist, true_reslist)
                for k, v in mapping.items():
                    all_mappings[k].extend(v)
    elif isinstance(true_data, tuple):
        true_seq, true_reslist = true_data
        for _, (pred_seq, pred_reslist) in pred_seq_data.items():
            alignment = align_sequences(pred_seq, true_seq)
            mapping = map_residues(alignment, pred_reslist, true_reslist)
            for k, v in mapping.items():
                all_mappings[k].extend(v)

    # Expand predicted interactions with mappings
    predicted_interactions = expand_predicted_interactions(predicted_interactions, all_mappings)

    # Create Mapped Key for comparison
    if predicted_interactions.empty:
        recovery_percentage = 0.0
        return recovery_percentage, true_interactions

    predicted_interactions["Mapped Key"] = (
        predicted_interactions["Interaction Type"].astype(str) + "_" +
        predicted_interactions["Residue Type"].astype(str) + "_" +
        predicted_interactions["Mapped Residue Number"].astype(str)
    )

    true_interactions["Key"] = (
        true_interactions["Interaction Type"].astype(str) + "_" +
        true_interactions["Residue Type"].astype(str) + "_" +
        true_interactions["Residue Number"].astype(str)
    )

    # Recovery check (require chain match too)
    true_interactions["Recovered"] = true_interactions.apply(
        lambda row: any(
            (row["Key"] == mk and row["Residue Chain"] == mc)
            for mk, mc in zip(predicted_interactions["Mapped Key"], predicted_interactions["Mapped Chain"])
        ),
        axis=1
    )

    recovery_percentage = float(true_interactions["Recovered"].mean() * 100.0) if not true_interactions.empty else 0.0
    print(f"Recovery Percentage: {recovery_percentage:.2f}%")
    return recovery_percentage, true_interactions


# (optional) visualization helper you provided — left unchanged
def visualize_mapping_details(mapping, predicted_reslist, true_reslist):
    mapping_table = []
    for pred_res, true_res_list in mapping.items():
        for true_res in true_res_list:
            mapping_table.append({
                "Predicted Residue Name": pred_res[0],
                "Predicted Chain": pred_res[1],
                "Predicted Residue Number": pred_res[2],
                "True Residue Name": true_res[0],
                "True Chain": true_res[1],
                "True Residue Number": true_res[2]
            })
    mapping_df = pd.DataFrame(mapping_table)
    print("Residue Mapping Details:")
    print(mapping_df)
    return mapping_df


# ---------- pair + batch (now using YOUR analyzer) ----------

def _analyze_pair_interactions(
    pdb_id: str,
    pair_dir: str,
    force: bool = False,
    sbgrid_rc: str = "/programs/sbgrid.shrc",
    schrodinger_root: str | None = None,
) -> dict:
    """
    - runs poseviewer on aligned_ref_*.pdb and aligned_mob_*.pdb (with SBGrid env)
    - computes interaction recovery via your analyze_interaction_recovery_with_expansion(...)
    - writes <pdb_id>_true_interactions.csv (the parsed truth table)
    """
    ref_pdb = os.path.join(pair_dir, f"aligned_ref_{pdb_id}.pdb")
    mob_pdb = os.path.join(pair_dir, f"aligned_mob_{pdb_id}.pdb")

    if not os.path.exists(ref_pdb):
        return {"pdb_id": pdb_id, "interaction_recovery_percent": "N/A", "status": f"missing {os.path.basename(ref_pdb)}"}
    if not os.path.exists(mob_pdb):
        return {"pdb_id": pdb_id, "interaction_recovery_percent": "N/A", "status": f"missing {os.path.basename(mob_pdb)}"}

    try:
        true_txt = _run_poseviewer(ref_pdb, cwd=pair_dir, force=force, sbgrid_rc=sbgrid_rc, schrodinger_root=schrodinger_root)
        pred_txt = _run_poseviewer(mob_pdb,  cwd=pair_dir, force=force, sbgrid_rc=sbgrid_rc, schrodinger_root=schrodinger_root)

        # Use YOUR analyzer (sequence alignment -> residue mapping)
        recovery_pct, true_interactions_df = analyze_interaction_recovery_with_expansion(
            pred_txt, true_txt, mob_pdb, ref_pdb
        )

        out_csv = os.path.join(pair_dir, f"{pdb_id}_true_interactions.csv")
        true_interactions_df.to_csv(out_csv, index=False)

        return {"pdb_id": pdb_id, "interaction_recovery_percent": round(float(recovery_pct), 2), "status": "ok"}
    except Exception as e:
        return {"pdb_id": pdb_id, "interaction_recovery_percent": "N/A", "status": f"error: {e}"}


def analyze_interactions_batch(
    work_root: str,
    out_csv: str,
    workers: int = 4,
    force: bool = False,
    sbgrid_rc: str = "/programs/sbgrid.shrc",
    schrodinger_root: str | None = None,
) -> pd.DataFrame:
    """
    Finds work_root/<pdb_id>/aligned_ref_<pdb_id>.pdb and aligned_mob_<pdb_id>.pdb,
    runs poseviewer + interaction recovery, writes CSV: pdb_id,interaction_recovery_percent,status
    """
    candidates = []
    for name in os.listdir(work_root):
        d = os.path.join(work_root, name)
        if not os.path.isdir(d):
            continue
        if (os.path.exists(os.path.join(d, f"aligned_ref_{name}.pdb")) and
            os.path.exists(os.path.join(d, f"aligned_mob_{name}.pdb"))):
            candidates.append((name, d))
    if not candidates:
        raise RuntimeError(f"No aligned pairs found under {work_root}")

    rows = []
    with ProcessPoolExecutor(max_workers=workers) as ex:
        futs = [
            ex.submit(_analyze_pair_interactions, pid, d, force, sbgrid_rc, schrodinger_root)
            for (pid, d) in candidates
        ]
        for fut in as_completed(futs):
            rows.append(fut.result())

    df = pd.DataFrame(rows, columns=["pdb_id", "interaction_recovery_percent", "status"])
    df.to_csv(out_csv, index=False)
    print(f"Saved interaction recovery to {out_csv}")
    return df


In [4]:
def collect_files_and_scores(source_dir, destination_dir):
    """
    Collect `_model.cif` files and extract the highest ranking score from `ranking_scores.csv`.
    
    Args:
        source_dir (str): Path to the directory containing subdirectories with `_model.cif` files and `ranking_scores.csv`.
        destination_dir (str): Path to the directory where `_model.cif` files will be copied.
    
    Returns:
        pd.DataFrame: A DataFrame with folder names and highest ranking scores.
    """
    # Ensure the destination directory exists
    os.makedirs(destination_dir, exist_ok=True)
    
    scores_data = []  # To store folder names and highest scores

    # Walk through all subdirectories
    for root, dirs, files in os.walk(source_dir):
        # Check for `_model.cif` and `ranking_scores.csv` in each subdirectory
        cif_files = [f for f in files if f.endswith("_model.cif")]
        csv_files = [f for f in files if f == "ranking_scores.csv"]

        if cif_files and csv_files:
            # Copy `_model.cif` file
            cif_file_path = os.path.join(root, cif_files[0])
            destination_path = os.path.join(destination_dir, cif_files[0])
            shutil.copy(cif_file_path, destination_path)

            # Read the `ranking_scores.csv` file
            csv_file_path = os.path.join(root, "ranking_scores.csv")
            df = pd.read_csv(csv_file_path)
            
            # Extract the highest ranking score
            if "ranking_score" in df.columns:
                highest_score = df["ranking_score"].max()
                scores_data.append({"id": os.path.basename(root)[0:4], "Highest Ranking Score": highest_score})
    
    # Create a DataFrame of scores
    scores_df = pd.DataFrame(scores_data)
    return scores_df

In [None]:
# Here we set up all the folder and copy some files
source_dir = "/path/to/AF3_Prediction_directory"
analysis_dir = "/path/to/your/analysis/directory"
dest_dir = "/path/to/predicted_files_directory"
work_dir = "/path/to/work_dir"
scoring_path = "/output/path/to/ranking_scores.csv"
pdb_dir = "path/to/real/pdb_files"
analysis_result_csv_path = "/output_path/analysis_results.csv"
interaction_csv_path = "/output_path/interaction_recovery.csv"


# Create necessary directories
os.makedirs(analysis_dir, exist_ok=True)
os.makedirs(dest_dir, exist_ok=True)
os.makedirs(work_dir, exist_ok=True)
os.makedirs(pdb_dir, exist_ok=True)

score_df = collect_files_and_scores(source_dir, dest_dir)
score_df.to_csv(scoring_path, index=False)
print("Collected files and scores, saved to:", scoring_path)




In [None]:
score_df = pd.read_csv(scoring_path)
for index, row in score_df.iterrows():
    pdb_code = row['id']
    try:
        fetch_pdb(pdb_code, pdb_dir)
    except Exception as e:
        print(f"Failed to fetch PDB for {pdb_code}: {e}")
        continue

In [None]:
# Prepare the PDB and CIF files using your Schrodinger scripts
cmd1 = [
        "./run_prep_dir_cif.sh",
        dest_dir, work_dir
    ]

cmd2 = [
         "./run_prep_dir_pdb.sh",
        pdb_dir, work_dir
    ]


In [10]:
try:
    proc = subprocess.run(
        cmd1,
        check=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
except subprocess.CalledProcessError as e:
    logging.error(f"script failed: {e.stderr.strip()}")
    print(e.stderr.strip())



In [11]:
try:
    proc = subprocess.run(
        cmd2,
        check=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
except subprocess.CalledProcessError as e:
    logging.error(f"script failed: {e.stderr.strip()}")
    print(e.stderr.strip())

In [None]:
#rename the prepared files to only keep  pdb id
rename_complex_prep_files(work_dir)

In [None]:
# Run the analysis on the prepared files
cmd = f"source /programs/sbgrid.shrc && \
  $SCHRODINGER/run python3 align-batch-schrodinger-similarity.py \
    --real_dir {work_dir} \
    --pred_dir {work_dir} \
    --work_dir {work_dir} \
    --out_csv  {analysis_result_csv_path} \
    --workers  20"
subprocess.run(cmd, shell=True, check=True, executable='/bin/bash')


In [None]:
#Interaction analysis using Schrodinger poseviwewer
int_df = analyze_interactions_batch(
    work_root=work_dir,
    out_csv=interaction_csv_path,
    workers=10,
    force=False,
    sbgrid_rc="/programs/sbgrid.shrc",   # <<< this is the key
)

In [None]:
#in this section, we do the final processing of the results
final_result_csv_path = "/output_path/final_analysis_results.csv"
int_df = pd.read_csv(interaction_csv_path)
analysis_result = pd.read_csv(analysis_result_csv_path)


# Merge on pdb_code
merged = analysis_result.merge(int_df, on="pdb_id")

# Keep only the required columns
merged = merged[["pdb_id", "protein_ca_rmsd", "ligand_rmsd", "interaction_recovery_percent"]]

merged.dropna(inplace=True)  # Drop rows with NaN values


# Save each metric to its own CSV in the same folder
merged.to_csv(final_result_csv_path, index=False)

