In [None]:
"""
visualize_complexes.py

Jupyter‑style notebook converted to script for interactive visualization of
protein–ligand complexes with py3Dmol and PLIP interaction mapping.

This module provides:
  1. list_complexes(): discover prepared complex PDB files.
  2. parse_plip_xml(): parse PLIP XML to extract interaction tuples.
  3. add_plip(): add interaction cylinders to a py3Dmol view.
  4. visualize_pdb(): render a single complex with cartoon, surface, ligands, and interactions.
  5. ComplexNavigator: an ipywidget navigator to browse multiple complexes and save snapshots.

Requires:
  - config.yaml in working directory with `paths.visuals` set.
  - gemmi, py3Dmol, ipywidgets, tqdm, RDKit, PLIP XML reports under each complex folder.
"""

import yaml
import base64
from pathlib import Path
import gemmi
import py3Dmol
import xml.etree.ElementTree as ET
from IPython.display import display, HTML
import ipywidgets as widgets

# ─── 0) Load configuration and set visuals root ──────────────────────────
cfg = yaml.safe_load(Path("config.yaml").read_text())
VIS_ROOT = Path(cfg["paths"]["visuals"]).resolve()

# ─── 1) Inject custom CSS & JavaScript for snapshot button ───────────────
if not globals().get("_css_injected", False):
    # Define button styles and snapshot function
    display(HTML("""
<style>
.custom-btn {
    display:inline-flex;align-items:center;justify-content:center;
    background:linear-gradient(180deg,#8cdede 0%,#6bbcbc 100%);
    border:1px solid #4a9c9c;color:#fff!important;padding:.55em 1.6em;
    font-weight:600;font-size:15px;border-radius:8px;cursor:pointer;
    box-shadow:0 4px 0 #4a9c9c,0 4px 12px rgba(0,0,0,.15);
    line-height:1;transition:transform .1s,box-shadow .1s,filter .2s;
}
.custom-btn:hover {
    transform:translateY(-2px);
    box-shadow:0 6px 0 #4a9c9c,0 6px 16px rgba(0,0,0,.2);
    filter:brightness(1.05);
}
.custom-btn:active {
    transform:translateY(0);
    box-shadow:0 2px 0 #4a9c9c inset,0 2px 8px rgba(0,0,0,.15) inset;
}
.custom-btn:disabled {
    opacity:.6;cursor:not-allowed;
}
</style>
<script>
function saveCanvasSnapshot(filename) {
  const canvas = document.querySelector('canvas');
  const scaled = document.createElement('canvas');
  scaled.width = canvas.width * 8;
  scaled.height = canvas.height * 8;
  const ctx = scaled.getContext('2d');
  ctx.imageSmoothingEnabled = false;
  ctx.imageSmoothingQuality = "high";
  ctx.scale(8, 8);
  ctx.drawImage(canvas, 0, 0);
  scaled.toBlob(blob => {
    const a = document.createElement('a');
    a.download = filename;
    a.href = URL.createObjectURL(blob);
    a.click();
  }, 'image/png');
}
</script>
"""))
    globals()["_css_injected"] = True

# ─── 2) Locate available complex PDB files ────────────────────────────────
def list_complexes() -> list[Path]:
    """
    Scan VIS_ROOT for directories named <tag>__<receptor>__<ligand> and
    return paths to the corresponding complex PDB files.

    Returns
    -------
    List[Path]
        List of Path objects pointing to complex__<rec>__<lig>__<tag>.pdb files.

    Raises
    ------
    RuntimeError
        If no valid complex files are found under VIS_ROOT.
    """
    comps: list[Path] = []
    for d in sorted(VIS_ROOT.iterdir()):
        # Directory names should contain exactly two '__' separators
        if d.is_dir() and d.name.count("__") == 2:
            tag, rec, lig = d.name.split("__")
            fp = d / f"complex__{rec}__{lig}__{tag}.pdb"
            if fp.exists():
                comps.append(fp)
    if not comps:
        raise RuntimeError(f"❌ No PDB files found in {VIS_ROOT}")
    return comps

# ─── 3) Define PLIP interaction color map ────────────────────────────────
COLOR = {
    "hb": "red",     # hydrogen bond
    "hp": "yellow",  # hydrophobic contact
    "pc": "purple",  # pi–cation
    "ps": "orange",  # pi–stacking
    "sb": "blue",    # salt bridge
    "wb": "cyan",    # water bridge
    "ha": "lime",    # halogen bond
    "me": "orange",  # metal coordination
}

def parse_plip_xml(xml_path: Path) -> list[tuple[str,str,str]]:
    """
    Parse a PLIP report.xml to extract interaction tuples.

    Parameters
    ----------
    xml_path : Path
        Path to PLIP XML report.

    Returns
    -------
    List of (kind, ligand_idx, protein_idx) tuples.
    """
    tree = ET.parse(xml_path)
    root = tree.getroot()
    out: list[tuple[str,str,str]] = []

    # Helper to extract text from possible tags
    def txt(node, *tags):
        return next((node.find(t).text.strip() for t in tags if node.find(t) is not None), "")

    # Special handler for salt bridges (multiple pairs)
    def salt_pairs(node):
        lig = [idx.text for idx in node.findall("./lig_idx_list/idx") if idx.text]
        prot = [idx.text for idx in node.findall("./prot_idx_list/idx") if idx.text]
        for l in lig:
            for p in prot:
                out.append(("sb", l, p))

    # Map XML paths to parsing lambdas
    MAP = {
        ".//hydrogen_bonds/hydrogen_bond": lambda n: out.append(("hb", txt(n,"donoridx"), txt(n,"acceptoridx"))),
        ".//hydrophobic_interactions/hydrophobic_interaction":
            lambda n: out.append(("hp", n.findtext("ligcarbonidx"), n.findtext("protcarbonidx"))),
        ".//pi_cation_interactions/pi_cation_interaction":
            lambda n: out.append(("pc", txt(n,"lig_idx_list/idx"), txt(n,"prot_idx_list/idx"))),
        ".//pi_stacks/pi_stack":
            lambda n: out.append(("ps", txt(n,"ligcentroididx"), txt(n,"protcentroididx"))),
        ".//water_bridges/water_bridge":
            lambda n: out.append(("wb", txt(n,"a_idx"), txt(n,"d_idx"))),
        ".//halogen_bonds/halogen_bond":
            lambda n: out.append(("ha", n.findtext("halogenidx"), txt(n,"acceptoridx"))),
        ".//metal_complexes/metal_complex":
            lambda n: out.append(("me", n.findtext("metalidx"), txt(n,"ligidx"))),
        ".//salt_bridges/salt_bridge": salt_pairs,
    }

    # Iterate all interaction nodes
    for xp, fn in MAP.items():
        for node in root.findall(xp):
            try:
                fn(node)
            except Exception:
                pass

    # Filter out incomplete entries
    return [(k,a,b) for k,a,b in out if a and b]

def add_plip(view, st: gemmi.Structure, interactions: list[tuple[str,str,str]]) -> None:
    """
    Add PLIP interaction cylinders to a py3Dmol view.

    Parameters
    ----------
    view : py3Dmol.view
        The 3D view object.
    st : gemmi.Structure
        Structure used for coordinate lookup (can be PLIP PDB or main PDB).
    interactions : list of tuples
        List of (kind, ligand_atom_serial, protein_atom_serial).
    """
    # Build serial→coordinate map
    s2pos = {str(a.serial): a.pos for m in st for c in m for r in c for a in r}
    for kind, a1, a2 in interactions:
        p1 = s2pos.get(a1)
        p2 = s2pos.get(a2)
        if not p1 or not p2:
            continue
        # Draw dashed cylinder between the two atoms
        view.addCylinder({
            "start": {"x": p1.x, "y": p1.y, "z": p1.z},
            "end":   {"x": p2.x, "y": p2.y, "z": p2.z},
            "radius": 0.15,
            "color":  COLOR.get(kind, "gray"),
            "dashed": True,
            "fromCap": 1, "toCap": 1
        })

# ─── 4) Visualize a single PDB complex ────────────────────────────────
def visualize_pdb(pdb_path: Path):
    """
    Render a complex PDB in py3Dmol with cartoon, surface, ligand sticks,
    and PLIP interactions if available.

    Parameters
    ----------
    pdb_path : Path
        Path to complex__<rec>__<lig>__<tag>.pdb file.

    Returns
    -------
    py3Dmol.view
        The configured 3D view object.
    """
    cdir = pdb_path.parent
    # Locate the PLIP concatenated PDB for atom mapping
    plip_file = next(cdir.glob("*_plip.pdb"), None)

    # Read structure for visualization
    st_view = gemmi.read_structure(str(pdb_path))
    v = py3Dmol.view(width=1200, height=800)
    v.addModel(st_view.make_pdb_string(), "pdb")

    # Render protein chains as cartoon + semi‑transparent surface
    chains = {c.name for m in st_view for c in m if any(r.het_flag == "A" for r in c)}
    for ch in chains:
        v.setStyle({"chain": ch}, {"cartoon": {"color": "lightblue"}})
        v.addSurface(py3Dmol.VDW,
                     {"opacity": 0.6, "color": "lightblue"},
                     {"chain": ch, "hetflag": False})

    # Render heteroatoms (ligand + cofactors) as orange sticks
    v.setStyle({"hetflag": True},
               {"stick": {"radius": 0.25, "color": "orange"}})

    # Parse and add PLIP interactions if report exists
    plip_xml = cdir / "plip" / "report.xml"
    if plip_xml.exists():
        st_map = gemmi.read_structure(str(plip_file)) if plip_file else st_view
        interactions = parse_plip_xml(plip_xml)
        add_plip(v, st_map, interactions)

    # Auto‑zoom to ligand region
    v.zoomTo({"hetflag": True})
    v.zoom(0.5)
    return v

# ─── 5) Interactive navigator for multiple complexes ────────────────────
class ComplexNavigator:
    """
    Interactive Jupyter widget to browse through multiple complexes.

    Displays previous/next buttons, current PDB name, 3D view, and
    a Save PNG snapshot button.
    """
    def __init__(self, complexes: list[Path]):
        self.complexes = complexes
        self.idx = 0
        self.out = widgets.Output()
        self.curr_view = None

        # Navigation buttons
        self.prev_btn = widgets.Button(description="⬅️ Previous")
        self.prev_btn.add_class("custom-btn")
        self.next_btn = widgets.Button(description="Next ➡️")
        self.next_btn.add_class("custom-btn")
        self.save_btn = widgets.HTML()

        # Bind button callbacks
        self.prev_btn.on_click(lambda _: self._step(-1))
        self.next_btn.on_click(lambda _: self._step(1))

        # Display output area and controls
        display(self.out)
        display(widgets.HBox(
            [self.prev_btn, self.save_btn, self.next_btn],
            layout=widgets.Layout(justify_content="center", padding="0.6em 0")
        ))
        self._show()

    def _step(self, d: int):
        """
        Advance the index by d (±1) and update the view.
        """
        self.idx = (self.idx + d) % len(self.complexes)
        self._show()

    def _show(self):
        """
        Render the current complex in the output area and update buttons.
        """
        with self.out:
            self.out.clear_output(wait=True)
            self.pdb_fp = self.complexes[self.idx]
            print(f"[{self.idx+1}/{len(self.complexes)}] {self.pdb_fp.name}")
            self.curr_view = visualize_pdb(self.pdb_fp)
            display(self.curr_view)

            # Configure the Save PNG button
            fname = self.pdb_fp.stem.replace("complex__", "") + "_snapshot.png"
            self.save_btn.value = f"""
            <button onclick="saveCanvasSnapshot('{fname}')"
                    class='custom-btn' style='margin:0 10px;'>
              💾 Save PNG
            </button>
            """

# ─── 6) Launch navigator when run as script ────────────────────────────
if __name__ == "__main__":
    complexes = list_complexes()
    ComplexNavigator(complexes)



Output()

HBox(children=(Button(description='⬅️ Previous', style=ButtonStyle(), _dom_classes=('custom-btn',)), HTML(valu…