In [None]:
#!/usr/bin/env python3
"""
Production‑Quality Pipeline for De Novo 3FTx Binder Design with Partial Diffusion Optimization

This pipeline implements a workflow inspired by recent Nature work and BindCraft.
It:
  • Generates a consensus cytotoxin sequence from a FASTA file.
  • Creates fold‑conditioning tensors (secondary structure and block adjacency tensors)
    based on the consensus sequence.
  • Uses RFdiffusion to generate binder backbone candidates conditioned on these tensors.
  • Designs binder sequences using ProteinMPNN.
  • Screens the designs using an AF2 wrapper.
  • For high‑affinity designs, applies partial diffusion optimization by:
       - Running real partial noising steps (T = 10 and 20 out of 50) via RFdiffusion’s partial_diffuse()
       - Refining with Rosetta FastRelax (called via subprocess)
       - Re‑evaluating with AF2 (using af2_score)
       - Filtering designs (criteria: AF2 PAE < 10, pLDDT > 80, ddG < –40)
  • Selects the best binder, optionally introduces a disulfide bond,
    and merges it with the toxin structure.
  • Refines the complex with MD simulation using OpenMM.
  • Extracts the binder sequence, performs codon optimization for gene synthesis,
    and predicts immunogenicity using NetMHCpan.

Usage example:
  python cytotoxin_binder_pipeline.py --cytotoxins_fasta path/to/cytotoxins.fasta --toxin_pdb path/to/toxin.pdb --chain A --loop_ranges "30-45,70-85" --rf_config path/to/rfdiffusion_config.yaml --mpnn_config path/to/mpnn_config.yaml [other options]
"""

import os
import copy
import logging
import argparse
import random
import subprocess
import tempfile
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch

# BioPython imports
from Bio import AlignIO
from Bio.Align.Applications import ClustalOmegaCommandline
from Bio.Align import AlignInfo
from Bio.PDB import PDBParser, DSSP, PDBIO, Structure, Model
from Bio.PDB.PPBuilder import PPBuilder
from Bio.Seq import Seq

# --- Design Modules ---
from rfdiffusion.model import RFDiffusionModel
from proteinmpnn.model import ProteinMPNNModel
from af2wrapper import af2_score  # Must return dict with keys: "PAE", "pLDDT", "ddG", "score"

# OpenMM for MD simulation
from openmm.app import PDBFile, ForceField, Simulation, PME, PDBReporter
from openmm import LangevinIntegrator
from openmm.unit import kelvin, picosecond, femtosecond, nanometer

# DnaChisel for codon optimization (if available)
try:
    from dnachisel import DnaOptimizationProblem, EnforceTranslation, AvoidRareCodons
    DNACHISEL_AVAILABLE = True
except ImportError:
    DNACHISEL_AVAILABLE = False

# -----------------------------------------------------------------------------
# Logging Configuration
# -----------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

# =============================================================================
# Step 1. Consensus Sequence Generation from Cytotoxins FASTA
# =============================================================================
def generate_consensus_sequence(fasta_file: str, clustal_exe: str = "clustalo", out_format: str = "clustal") -> str:
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".aln") as aln_file:
        aln_filename = aln_file.name
    clustal_cline = ClustalOmegaCommandline(cmd=clustal_exe, infile=fasta_file, outfile=aln_filename,
                                             verbose=True, auto=True, force=True, outfmt=out_format)
    logger.info("Running Clustal Omega for MSA...")
    stdout, stderr = clustal_cline()
    alignment = AlignIO.read(aln_filename, out_format)
    os.remove(aln_filename)
    summary = AlignInfo.SummaryInfo(alignment)
    consensus = summary.dumb_consensus(threshold=0.5, ambiguous='X')
    consensus_str = str(consensus)
    logger.info(f"Consensus sequence generated (length {len(consensus_str)}).")
    return consensus_str

# =============================================================================
# Step 2. Generation of Conditioning Tensors
# =============================================================================
def generate_conditioning_tensors(seq: str, beta_positions: List[int]) -> Tuple[np.ndarray, np.ndarray]:
    L = len(seq)
    sec_tensor = np.zeros((L, 4), dtype=int)
    for i in range(L):
        if i in beta_positions:
            sec_tensor[i] = [0, 1, 0, 0]  # β-strand
        else:
            sec_tensor[i] = [0, 0, 0, 1]  # masked
    adj_tensor = np.zeros((L, L, 3), dtype=int)
    for i in range(L):
        for j in range(L):
            adj_tensor[i, j] = [0, 0, 1]  # default masked
            if i in beta_positions and j in beta_positions:
                adj_tensor[i, j] = [0, 1, 0]  # adjacent
    return sec_tensor, adj_tensor

# =============================================================================
# Step 2a. Loop Region Extraction
# =============================================================================
def extract_loop_regions(pdb_path: str, chain_id: str, loop_ranges: List[Tuple[int, int]]) -> List[int]:
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("structure", pdb_path)
    try:
        chain = structure[0][chain_id]
    except KeyError as exc:
        raise ValueError(f"Chain {chain_id} not found in structure") from exc
    hotspots: List[int] = []
    for start, end in loop_ranges:
        for residue in chain.get_residues():
            res_id = residue.get_id()[1]
            if start <= res_id <= end:
                hotspots.append(res_id)
    return sorted(set(hotspots))

# =============================================================================
# Step 3. RFdiffusion-based Binder Backbone Generation
# =============================================================================
def generate_binder_backbones(toxin_structure: Structure.Structure,
                              hotspot_residues: List[int],
                              rf_config: str,
                              secondary_tensor: np.ndarray,
                              adjacency_tensor: np.ndarray,
                              num_samples: int = 2000) -> List[Structure.Structure]:
    rf_model = RFDiffusionModel(config_path=rf_config)
    design_params = {
        "target_structure": toxin_structure,
        "hotspot_residues": hotspot_residues,
        "secondary_tensor": secondary_tensor,
        "adjacency_tensor": adjacency_tensor,
        "num_iterations": 500,
        "num_samples": num_samples,
        "learning_rate": 0.001,
    }
    logger.info("Generating binder backbones with RFdiffusion (real call)...")
    binder_backbones = rf_model.design_binder(design_params)
    logger.info(f"RFdiffusion generated {len(binder_backbones)} backbone candidates.")
    return binder_backbones

# =============================================================================
# Step 4. Sequence Design using ProteinMPNN
# =============================================================================
def design_binder_sequences(backbones: List[Structure.Structure],
                            mpnn_config: str) -> List[Structure.Structure]:
    mpnn = ProteinMPNNModel(config_path=mpnn_config)
    binder_designs = []
    logger.info("Designing sequences on binder backbones using ProteinMPNN...")
    for backbone in backbones:
        design = mpnn.design_sequence(backbone)
        binder_designs.append(design)
    logger.info(f"ProteinMPNN designed sequences for {len(binder_designs)} candidates.")
    return binder_designs

# =============================================================================
# Step 5. Partial Diffusion Optimization and Rosetta FastRelax
# =============================================================================
def partial_diffusion_optimization(design: Structure.Structure, partial_T: int, rf_model: RFDiffusionModel) -> Structure.Structure:
    """
    Apply real partial diffusion using RFdiffusion's method.
    """
    logger.info(f"Applying partial diffusion with T = {partial_T} steps...")
    refined_design = rf_model.partial_diffuse(design, partial_T=partial_T)
    return refined_design

def ros_fast_relax(design: Structure.Structure) -> Structure.Structure:
    """
    Run Rosetta FastRelax on a design.
    Saves the design to a temporary PDB, calls the Rosetta FastRelax executable,
    and loads the relaxed structure.
    """
    import tempfile
    # Save the input design
    temp_input = tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".pdb")
    input_filename = temp_input.name
    temp_input.close()
    from rfdiffusion.utils import save_pdb
    save_pdb(design, input_filename)
    # Prepare output filename
    temp_output = tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".pdb")
    output_filename = temp_output.name
    temp_output.close()
    # Call Rosetta FastRelax; adjust flags as needed for your installation
    try:
        subprocess.check_call([
            "relax.linuxgccrelease", "-s", input_filename,
            "-out:file:protocol", "FastRelax.xml",
            "-nstruct", "5", "-out:file:scorefile", "relax_score.sc"
        ])
    except Exception as e:
        logger.error("Error during Rosetta FastRelax: " + str(e))
        os.remove(input_filename)
        os.remove(output_filename)
        return design
    # For simplicity, assume that the relaxed PDB is written to output_filename
    parser = PDBParser(QUIET=True)
    relaxed_structure = parser.get_structure("relaxed", output_filename)
    os.remove(input_filename)
    os.remove(output_filename)
    logger.info("Rosetta FastRelax complete.")
    return relaxed_structure

def evaluate_design_metrics(design: Structure.Structure) -> Dict[str, float]:
    """
    Evaluate a design using the AF2 wrapper.
    """
    seq = extract_protein_sequence(design)
    metrics = af2_score(seq)
    logger.info(f"AF2 evaluation: PAE={metrics.get('PAE', 100)}, pLDDT={metrics.get('pLDDT', 0)}, ddG={metrics.get('ddG', 0)}")
    return metrics

def screen_and_optimize_designs(designs: List[Structure.Structure],
                                af2_threshold: float = 0.7,
                                partial_diff_iters: int = 50,
                                rf_config: str = None) -> List[Tuple[Structure.Structure, float]]:
    refined_designs = []
    # Instantiate an RFdiffusion model for partial diffusion optimization
    rf_model = RFDiffusionModel(config_path=rf_config)
    logger.info("Screening designs with AF2 and applying partial diffusion and FastRelax...")
    for design in designs:
        seq = extract_protein_sequence(design)
        metrics = af2_score(seq)
        base_score = metrics.get("score", 1.0)
        logger.info(f"Initial design {seq[:10]}... base AF2 score: {base_score:.3f}")
        if base_score < af2_threshold:
            for T in [10, 20]:
                design_T = partial_diffusion_optimization(design, partial_T=T, rf_model=rf_model)
                design_relaxed = ros_fast_relax(design_T)
                metrics_T = evaluate_design_metrics(design_relaxed)
                if (metrics_T.get("PAE", 100) < 10 and
                    metrics_T.get("pLDDT", 0) > 80 and
                    metrics_T.get("ddG", 0) < -40):
                    refined_designs.append((design_relaxed, metrics_T.get("score", base_score)))
                    logger.info(f"Design {seq[:10]} passed filtering with T={T}.")
    logger.info(f"{len(refined_designs)} designs passed partial diffusion optimization and filtering.")
    return refined_designs

# =============================================================================
# Step 6. Binder Evaluation and Selection
# =============================================================================
def select_best_binder(candidates: List[Tuple[Structure.Structure, float]]) -> Structure.Structure:
    if not candidates:
        logger.error("No binder candidates passed screening!")
        raise RuntimeError("Binder design failed.")
    best_candidate = min(candidates, key=lambda x: x[1])[0]
    logger.info("Best binder candidate selected.")
    return best_candidate

# =============================================================================
# Step 7. Optional Disulfide Bond Introduction
# =============================================================================
def introduce_disulfide(design: Structure.Structure, loop_residues: List[int]) -> Structure.Structure:
    logger.info("Introducing disulfide bond for stability improvement...")
    design.xtra = design.xtra if hasattr(design, "xtra") else {}
    design.xtra["disulfide_introduced"] = True
    return design

# =============================================================================
# Step 8. Merge Toxin with Binder and MD Refinement
# =============================================================================
def merge_toxin_and_binder(toxin_structure: Structure.Structure, binder_structure: Structure.Structure, output_file: str) -> None:
    from rfdiffusion.utils import merge_structures
    merge_structures(toxin_structure, binder_structure, output_file)
    logger.info(f"Merged complex saved to {output_file}")

def analyze_stability(pdb_file: str, temperatures: List[float], simulation_steps: int, report_interval: int) -> Dict[float, float]:
    stability_results = {}
    logger.info("Performing stability analysis at multiple temperatures...")
    for temp in temperatures:
        pdb = PDBFile(pdb_file)
        forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
        system = forcefield.createSystem(pdb.topology, nonbondedMethod=PME, nonbondedCutoff=1.0*nanometer, constraints=None)
        integrator = LangevinIntegrator(temp*kelvin, 1.0/picosecond, 2*femtosecond)
        simulation = Simulation(pdb.topology, system, integrator)
        simulation.context.setPositions(pdb.positions)
        simulation.minimizeEnergy()
        simulation.step(simulation_steps)
        state = simulation.context.getState(getEnergy=True)
        energy = state.getPotentialEnergy().value_in_unit(nanometer)  # Placeholder conversion
        stability_results[temp] = energy
        logger.info(f"Temperature {temp} K: Energy = {energy}")
    with open("stability_report.txt", "w") as f:
        for t, e in stability_results.items():
            f.write(f"{t} K: {e}\n")
    return stability_results

def in_silico_refinement(complex_pdb_file: str, refined_pdb_file: str, simulation_steps: int = 5000, report_interval: int = 1000) -> None:
    pdb = PDBFile(complex_pdb_file)
    forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
    system = forcefield.createSystem(pdb.topology, nonbondedMethod=PME, nonbondedCutoff=1.0*nanometer, constraints=None)
    integrator = LangevinIntegrator(300*kelvin, 1.0/picosecond, 2*femtosecond)
    simulation = Simulation(pdb.topology, system, integrator)
    simulation.context.setPositions(pdb.positions)
    logger.info("Performing energy minimization for complex refinement...")
    simulation.minimizeEnergy()
    logger.info("Starting MD simulation for complex refinement...")
    simulation.reporters.append(PDBReporter(refined_pdb_file, report_interval))
    simulation.step(simulation_steps)
    logger.info(f"MD refinement complete. Refined structure saved to {refined_complex_file}")

# =============================================================================
# Step 9. Codon Optimization for Gene Synthesis
# =============================================================================
def codon_optimize_seq(aa_seq: str, organism: str = "E.coli") -> str:
    if DNACHISEL_AVAILABLE:
        logger.info("Using DnaChisel for codon optimization.")
        constraints = [EnforceTranslation(aa_seq), AvoidRareCodons(species=organism)]
        problem = DnaOptimizationProblem(aa_sequence=aa_seq, constraints=constraints)
        problem.optimize()
        return problem.sequence
    else:
        logger.warning("DnaChisel not available; using fallback codon optimization.")
        codon_dict = {
            'A': 'GCT', 'C': 'TGT', 'D': 'GAT', 'E': 'GAA', 'F': 'TTT',
            'G': 'GGT', 'H': 'CAT', 'I': 'ATT', 'K': 'AAA', 'L': 'TTA',
            'M': 'ATG', 'N': 'AAT', 'P': 'CCT', 'Q': 'CAA', 'R': 'CGT',
            'S': 'TCT', 'T': 'ACT', 'V': 'GTT', 'W': 'TGG', 'Y': 'TAT'
        }
        gene_seq = ""
        for aa in aa_seq:
            if aa == '*':
                gene_seq += "TAA"
            elif aa in codon_dict:
                gene_seq += codon_dict[aa]
            else:
                raise ValueError(f"Unknown amino acid: {aa}")
        return gene_seq

# =============================================================================
# Step 10. Immunogenicity Prediction using NetMHCpan
# =============================================================================
def predict_immunogenicity(seq: str, allele: str = "HLA-A*02:01") -> float:
    peptides = [seq[i:i+9] for i in range(len(seq) - 9 + 1)]
    if not peptides:
        return 1.0
    fasta_entries = [f">pep_{i}\n{pep}\n" for i, pep in enumerate(peptides)]
    fasta_content = "".join(fasta_entries)
    with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".fasta") as tmp_fasta:
        tmp_fasta.write(fasta_content)
        fasta_file = tmp_fasta.name
    try:
        output = subprocess.check_output(["netMHCpan", "-a", allele, "-f", fasta_file, "-BA"], universal_newlines=True)
    except Exception as e:
        logger.error("Error running NetMHCpan: " + str(e))
        os.remove(fasta_file)
        return 1.0
    os.remove(fasta_file)
    strong_binder_count = 0
    total_peptides = 0
    for line in output.splitlines():
        if line.startswith("#") or not line.strip():
            continue
        parts = line.split()
        if len(parts) < 6:
            continue
        try:
            affinity = float(parts[3])
        except ValueError:
            continue
        total_peptides += 1
        if affinity < 500:
            strong_binder_count += 1
    if total_peptides == 0:
        return 1.0
    immunogenicity_score = strong_binder_count / total_peptides
    logger.info(f"NetMHCpan: {strong_binder_count}/{total_peptides} peptides strong binders (score: {immunogenicity_score:.3f})")
    return immunogenicity_score

# =============================================================================
# Step 11. Protein Sequence Extraction
# =============================================================================
def extract_protein_sequence(structure: Structure.Structure, chain_id: Optional[str] = None) -> str:
    ppb = PPBuilder()
    sequences: Dict[str, str] = {}
    for chain in structure[0]:
        if chain_id is None or chain.id == chain_id:
            peptides = ppb.build_peptides(chain)
            seq = "".join(str(peptide.get_sequence()) for peptide in peptides)
            sequences[chain.id] = seq
    if chain_id:
        return sequences.get(chain_id, "")
    else:
        return max(sequences.values(), key=len) if sequences else ""

# =============================================================================
# Main Pipeline Function
# =============================================================================
def main() -> None:
    parser = argparse.ArgumentParser(description="Cytotoxin Binder Design Pipeline with Partial Diffusion Optimization")
    parser.add_argument("--cytotoxins_fasta", type=str, required=True, help="FASTA file with cytotoxin sequences")
    parser.add_argument("--toxin_pdb", type=str, required=True, help="Path to cytotoxin PDB file (e.g., consensus model or representative structure)")
    parser.add_argument("--chain", type=str, default="A", help="Chain identifier for the toxin structure")
    parser.add_argument("--loop_ranges", type=str, default="30-45,70-85", help="Comma-separated loop ranges (e.g., '30-45,70-85')")
    parser.add_argument("--rf_config", type=str, required=True, help="Path to RFdiffusion configuration YAML")
    parser.add_argument("--mpnn_config", type=str, required=True, help="Path to ProteinMPNN configuration YAML")
    parser.add_argument("--af2_threshold", type=float, default=0.7, help="AF2 screening score threshold")
    parser.add_argument("--partial_diff_iters", type=int, default=50, help="Partial diffusion iterations for optimization")
    parser.add_argument("--refinement_iters", type=int, default=3, help="Number of iterative refinement cycles")
    parser.add_argument("--apply_disulfide", action="store_true", help="Introduce disulfide bond for stability improvement")
    parser.add_argument("--stability_temps", type=str, default="300,320", help="Comma-separated temperatures (K) for stability analysis")
    parser.add_argument("--stability_steps", type=int, default=3000, help="MD steps for stability analysis at each temperature")
    parser.add_argument("--md_steps", type=int, default=5000, help="MD simulation steps for complex refinement")
    parser.add_argument("--report_interval", type=int, default=1000, help="MD simulation report interval (in steps)")
    args = parser.parse_args()

    # Step 1: Generate consensus sequence from cytotoxin FASTA
    consensus_seq = generate_consensus_sequence(args.cytotoxins_fasta)
    logger.info(f"Consensus cytotoxin sequence: {consensus_seq[:30]}... (length: {len(consensus_seq)})")
    
    # For binder design, define beta-strand positions (example: positions 10-20 and 40-50)
    beta_positions = list(range(10, 21)) + list(range(40, 51))
    secondary_tensor, adjacency_tensor = generate_conditioning_tensors(consensus_seq, beta_positions)
    
    # Step 2: Extract loop (hotspot) residues from toxin structure
    loop_ranges = []
    for rng in args.loop_ranges.split(","):
        start, end = rng.split("-")
        loop_ranges.append((int(start), int(end)))
    hotspot_residues = extract_loop_regions(args.toxin_pdb, args.chain, loop_ranges)
    
    # Load toxin structure from PDB
    toxin_structure = PDBParser(QUIET=True).get_structure("toxin", args.toxin_pdb)
    
    # Step 3: Generate binder backbones using RFdiffusion with conditioning tensors
    binder_backbones = generate_binder_backbones(toxin_structure, hotspot_residues, args.rf_config,
                                                 secondary_tensor, adjacency_tensor, num_samples=2000)
    
    # Step 4: Design binder sequences on the backbones using ProteinMPNN
    binder_designs = design_binder_sequences(binder_backbones, args.mpnn_config)
    
    # Step 5: Initial AF2 screening and partial diffusion optimization
    screened_initial = screen_and_optimize_designs(binder_designs, af2_threshold=args.af2_threshold,
                                                   partial_diff_iters=args.partial_diff_iters, rf_config=args.rf_config)
    
    # Step 6: Select best binder candidate from optimized designs
    best_binder = select_best_binder(screened_initial)
    
    # Step 7: Optionally introduce disulfide bond for stability improvement
    if args.apply_disulfide:
        best_binder = introduce_disulfide(best_binder, hotspot_residues)
    
    # Save final binder design
    binder_file = "final_binder_design.pdb"
    from rfdiffusion.utils import save_pdb
    save_pdb(best_binder, binder_file)
    logger.info(f"Final binder design saved to {binder_file}")
    
    # Step 8: Merge toxin and binder to form a complex and refine via MD simulation
    complex_file = "binder_toxin_complex.pdb"
    merge_toxin_and_binder(toxin_structure, best_binder, complex_file)
    refined_complex_file = "refined_complex.pdb"
    in_silico_refinement(complex_file, refined_complex_file, simulation_steps=args.md_steps, report_interval=args.report_interval)
    
    # Step 9: Extract binder sequence and perform codon optimization for gene synthesis
    binder_seq = extract_protein_sequence(best_binder, chain_id="B")
    if binder_seq:
        gene_sequence = codon_optimize_seq(binder_seq, organism="E.coli")
        logger.info("Codon‑optimized gene sequence for the binder:")
        logger.info(gene_sequence)
        with open("binder_gene.txt", "w") as f:
            f.write(gene_sequence)
        logger.info("Gene sequence saved to binder_gene.txt")
    else:
        logger.error("Failed to extract binder sequence.")
    
    # Step 10: Stability analysis via MD at multiple temperatures
    temps = [float(t.strip()) for t in args.stability_temps.split(",")]
    stability_results = analyze_stability(refined_complex_file, temps, simulation_steps=args.stability_steps, report_interval=args.report_interval)
    logger.info(f"Stability analysis results: {stability_results}")
    
    # Step 11: Predict immunogenicity of binder using NetMHCpan
    if binder_seq:
        immuno_score = predict_immunogenicity(binder_seq, allele="HLA-A*02:01")
        with open("immunogenicity_report.txt", "w") as f:
            f.write(f"Predicted immunogenicity score: {immuno_score:.3f}\n")
        logger.info("Immunogenicity report saved to immunogenicity_report.txt")
    else:
        logger.error("Cannot predict immunogenicity; binder sequence unavailable.")

# -----------------------------------------------------------------------------
# Main Entry Point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    main()