In [1]:
"""
visualize_complexes.py

Notebook helper for interactive visualization of protein‚Äìligand complexes
with py3Dmol and PLIP interaction mapping.

Provides:
  1. list_complexes(): discover prepared complex PDB files.
  2. parse_plip_xml(): parse PLIP XML to extract interaction tuples.
  3. add_plip(): add interaction glyphs to a py3Dmol view.
  4. visualize_pdb(): render a single complex with configurable styling.
  5. ComplexNavigator: ipywidget navigator with live styling controls and GIF export.

Requires:
  - config.yaml in working directory with `paths.visuals` set.
  - gemmi, py3Dmol, ipywidgets, RDKit, PLIP XML reports for complexes.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, Any
import xml.etree.ElementTree as ET

import yaml
import gemmi
import py3Dmol
from IPython.display import display, HTML
import ipywidgets as widgets

cfg = yaml.safe_load(Path('config.yaml').read_text())
VIS_ROOT = Path(cfg['paths']['visuals']).resolve()

VIEW_WIDTH = 1200
VIEW_HEIGHT = 800

PDB_TEXT_CACHE: Dict[str, str] = {}
STRUCT_CACHE: Dict[str, gemmi.Structure] = {}
INTERACTION_CACHE: Dict[str, list[tuple[str, str, str]]] = {}

if not globals().get('_css_injected', False):
    display(HTML("""
<style>
.custom-btn {
    display:inline-flex;align-items:center;justify-content:center;
    background:linear-gradient(180deg,#8cdede 0%,#6bbcbc 100%);
    border:1px solid #4a9c9c;color:#fff!important;padding:.45em 1.3em;
    font-weight:600;font-size:14px;border-radius:8px;cursor:pointer;
    box-shadow:0 3px 0 #4a9c9c,0 3px 10px rgba(0,0,0,.12);
    line-height:1;transition:transform .12s,box-shadow .12s,filter .2s;
}
.custom-btn:hover {
    transform:translateY(-2px);
    box-shadow:0 5px 0 #4a9c9c,0 6px 16px rgba(0,0,0,.2);
    filter:brightness(1.05);
}
.custom-btn:active {
    transform:translateY(0);
    box-shadow:0 2px 0 #4a9c9c inset,0 2px 8px rgba(0,0,0,.15) inset;
}
.custom-btn:disabled {
    opacity:.6;cursor:not-allowed;
}
.status-msg {font-size:13px;color:#3b3b3b;margin:4px 0 0 0;text-align:center;}
.spin-btn button {background:linear-gradient(90deg,#ffafbd,#ffc3a0);color:#2c3e50;font-weight:700;border:1px solid #ff9a9e;}
.section-box {
    border:1px solid #dfe6e9;border-radius:6px;padding:0.5em;margin:0.3em;
    min-width:220px;max-width:260px;background:#fbfcfd;box-shadow:0 1px 2px rgba(0,0,0,0.05);
    display:flex;flex-direction:column;align-items:stretch;gap:0.35em;
}
.section-box h4 {margin:0 0 0.2em 0;padding:0;font-size:15px;color:#2d3436;text-align:center;font-weight:700;}
.section-box .widget-button, .section-box button {align-self:center;}
.section-box .widget-text, .section-box .widget-select, .section-box .widget-dropdown, .section-box .widget-slider {align-self:stretch;}
.section-box .widget-text input, .section-box .widget-dropdown select {text-align:center;font-size:13px;padding:0.28em;border-radius:4px;}
.slider-label {font-size:13px;color:#555;margin-bottom:2px;display:block;text-align:center;}
.slim-input input {text-align:center;}
</style>
<script>
window.saveCanvasSnapshot = function(filename, scale){
  const canvas = document.querySelector('canvas');
  if(!canvas){ return; }
  const factor = Number(scale || 4);
  const scaled = document.createElement('canvas');
  scaled.width = canvas.width * factor;
  scaled.height = canvas.height * factor;
  const ctx = scaled.getContext('2d');
  ctx.imageSmoothingEnabled = true;
  ctx.imageSmoothingQuality = 'high';
  ctx.scale(factor, factor);
  ctx.drawImage(canvas, 0, 0);
  scaled.toBlob(blob => {
    const a = document.createElement('a');
    a.download = filename;
    a.href = URL.createObjectURL(blob);
    a.click();
    setTimeout(() => URL.revokeObjectURL(a.href), 2000);
  }, 'image/png');
}
</script>
"""))
    globals()['_css_injected'] = True

def _color_value(color: str) -> str:
    color = color.strip()
    if color.startswith('0x'):
        return color
    if color.startswith('#'):
        return '0x' + color[1:]
    return color

def list_complexes() -> list[Path]:
    complexes: list[Path] = []
    for d in sorted(VIS_ROOT.iterdir()):
        if d.is_dir() and d.name.count('__') == 2:
            tag, rec, lig = d.name.split('__')
            fp = d / f'complex__{rec}__{lig}__{tag}.pdb'
            if fp.exists():
                complexes.append(fp)
    if not complexes:
        raise RuntimeError(f'‚ùå No PDB files found in {VIS_ROOT}')
    return complexes

PROTEIN_COLORS: Dict[str, str] = {
    'Ice Blue': 'lightblue',
    'Sky Blue': '0x74b9ff',
    'Steel': '0x7f8c8d',
    'Stone': '0xbdc3c7',
    'Midnight': '0x2c3e50',
    'Mint': '0x1abc9c',
    'Emerald': '0x2ecc71',
    'Forest': '0x27ae60',
    'Sunrise': '0xf1c40f',
    'Amber': '0xf39c12',
    'Tangerine': '0xe67e22',
    'Crimson': '0xc0392b',
    'Coral': '0xe74c3c',
    'Plum': '0x9b59b6',
    'Lavender': '0xbb8fce',
    'Royal': '0x5b2c6f',
    'Slate': '0x34495e',
    'Graphite': '0x4b4b4b',
    'Pearl': '0xfafafa',
    'Coffee': '0x6f4e37',
}

PROTEIN_STYLES = (
    'Cartoon',
    'Cartoon + Surface',
    'Surface only',
)

LIGAND_COLORS: Dict[str, str] = {
    'Tangerine Orange': '0xffa500',
    'Fire Red': '0xe74c3c',
    'Ruby': '0xc0392b',
    'Cherry': '0xd63031',
    'Goldenrod': '0xf1c40f',
    'Sunflower': '0xf39c12',
    'Olive': '0x808000',
    'Forest Green': '0x2ecc71',
    'Teal': '0x1abc9c',
    'Turquoise': '0x00cec9',
    'Cyan': '0x00a8ff',
    'Azure': '0x0984e3',
    'Deep Blue': '0x30336b',
    'Royal Purple': '0x8e44ad',
    'Magenta': '0xff00ff',
    'Plum': '0x9b59b6',
    'Rose': '0xff6b81',
    'Graphite': '0x2d3436',
    'Silver': '0xbdc3c7',
    'Pearl': '0xf8f9fa',
}

INTERACTION_COLORS = {
    'hb': '0xff6b6b',
    'hp': '0xffbe76',
    'pc': '0x686de0',
    'ps': '0xf0932b',
    'sb': '0x45aaf2',
    'wb': '0x7ed6df',
    'ha': '0xbadc58',
    'me': '0x574b90',
}

DEFAULT_OPTIONS: Dict[str, Any] = {
    'surface_opacity': 0.55,
    'protein_color': 'Ice Blue',
    'protein_style': 'Cartoon + Surface',
    'ligand_color': 'Tangerine Orange',
    'ligand_radius': 0.3,
    'show_interactions': True,
    'show_labels': False,
    'background': 'White',
    'projection': 'perspective',
    'outline': True,
    'spin': False,
    'spin_speed': 1.0,
    'zoom_factor': 0.6,
    'snapshot_scale': 4,
}

def parse_plip_xml(xml_path: Path) -> list[tuple[str, str, str]]:
    tree = ET.parse(xml_path)
    root = tree.getroot()
    out: list[tuple[str, str, str]] = []

    def txt(node, *tags):
        return next((node.find(t).text.strip() for t in tags if node.find(t) is not None), '')

    def salt_pairs(node):
        lig = [idx.text for idx in node.findall('./lig_idx_list/idx') if idx.text]
        prot = [idx.text for idx in node.findall('./prot_idx_list/idx') if idx.text]
        for l in lig:
            for p in prot:
                out.append(('sb', l, p))

    MAP = {
        './/hydrogen_bonds/hydrogen_bond': lambda n: out.append(('hb', txt(n, 'donoridx'), txt(n, 'acceptoridx'))),
        './/hydrophobic_interactions/hydrophobic_interaction': lambda n: out.append(('hp', n.findtext('ligcarbonidx'), n.findtext('protcarbonidx'))),
        './/pi_cation_interactions/pi_cation_interaction': lambda n: out.append(('pc', txt(n, 'lig_idx_list/idx'), txt(n, 'prot_idx_list/idx'))),
        './/pi_stacks/pi_stack': lambda n: out.append(('ps', txt(n, 'ligcentroididx'), txt(n, 'protcentroididx'))),
        './/water_bridges/water_bridge': lambda n: out.append(('wb', txt(n, 'a_idx'), txt(n, 'd_idx'))),
        './/halogen_bonds/halogen_bond': lambda n: out.append(('ha', n.findtext('halogenidx'), txt(n, 'acceptoridx'))),
        './/metal_complexes/metal_complex': lambda n: out.append(('me', n.findtext('metalidx'), txt(n, 'ligidx'))),
        './/salt_bridges/salt_bridge': salt_pairs,
    }

    for xp, fn in MAP.items():
        for node in root.findall(xp):
            try:
                fn(node)
            except Exception:
                pass

    return [(k, a, b) for k, a, b in out if a and b]

def _vec_from_pos(pos) -> Dict[str, float]:
    return {'x': pos.x, 'y': pos.y, 'z': pos.z}

def add_plip(view: py3Dmol.view, st: gemmi.Structure, interactions: list[tuple[str, str, str]]) -> None:
    s2pos = {str(a.serial): a.pos for m in st for c in m for r in c for a in r}
    for kind, a1, a2 in interactions:
        p1 = s2pos.get(a1)
        p2 = s2pos.get(a2)
        if not p1 or not p2:
            continue
        color = _color_value(INTERACTION_COLORS.get(kind, '0x7f8c8d'))
        start = _vec_from_pos(p1)
        end = _vec_from_pos(p2)
        view.addCylinder({
            'start': start,
            'end': end,
            'radius': 0.14,
            'color': color,
            'dashed': False,
            'opacity': 0.8,
            'fromCap': 1,
            'toCap': 1,
        })
        view.addSphere({'center': start, 'radius': 0.2, 'color': color, 'opacity': 0.7})
        view.addSphere({'center': end, 'radius': 0.2, 'color': color, 'opacity': 0.7})

def _apply_environment(view: py3Dmol.view, opts: Dict[str, Any]) -> None:
    view.setBackgroundColor(opts['background'].lower())
    if opts['outline'] and hasattr(view, 'setViewStyle'):
        try:
            view.setViewStyle({'style': 'outline', 'color': 'black', 'width': 0.03})
        except Exception:
            pass
    elif hasattr(view, 'setViewStyle'):
        try:
            view.setViewStyle({})
        except Exception:
            pass
    try:
        view.setProjection('orthographic' if opts['projection'] == 'orthographic' else 'perspective')
    except Exception:
        pass

def _apply_protein_style(view: py3Dmol.view, opts: Dict[str, Any]) -> None:
    selection = {'hetflag': False}
    color = _color_value(PROTEIN_COLORS[opts['protein_color']])
    style = opts.get('protein_style', 'Cartoon + Surface')

    view.setStyle(selection, {})
    view.removeAllSurfaces()

    if style in ('Cartoon', 'Cartoon + Surface'):
        view.setStyle(selection, {'cartoon': {'color': color}})
    if style in ('Surface only', 'Cartoon + Surface'):
        view.addSurface(py3Dmol.VDW, {'opacity': opts['surface_opacity'], 'color': color}, selection)

def _apply_ligand_style(view: py3Dmol.view, opts: Dict[str, Any]) -> None:
    selection = {'hetflag': True}
    ligand_color = _color_value(LIGAND_COLORS[opts['ligand_color']])
    radius = opts['ligand_radius']
    view.setStyle(selection, {'stick': {'radius': radius, 'color': ligand_color}})

def _add_ligand_labels(view: py3Dmol.view, pdb_text: str, max_labels: int = 25) -> None:
    count = 0
    for line in pdb_text.splitlines():
        if not line.startswith('HETATM'):
            continue
        element = line[76:78].strip().upper()
        if element == 'H':
            continue
        serial = line[6:11].strip()
        if not serial:
            continue
        try:
            serial_int = int(serial)
        except ValueError:
            continue
        atom = line[12:16].strip()
        resn = line[17:20].strip()
        chain = line[21].strip()
        resid = line[22:26].strip()
        label = f"{atom} ({resn}{resid}{chain})".strip()
        style = {
            'fontColor': 'black',
            'backgroundColor': 'white',
            'backgroundOpacity': 0.7,
            'fontSize': 14,
            'padding': 2,
            'outline': 'black',
        }
        try:
            view.addLabel(label, style, {'serial': serial_int})
        except Exception:
            pass
        count += 1
        if count >= max_labels:
            break

def _load_pdb_text(pdb_path: Path) -> str:
    key = str(pdb_path.resolve())
    if key not in PDB_TEXT_CACHE:
        PDB_TEXT_CACHE[key] = pdb_path.read_text()
    return PDB_TEXT_CACHE[key]

def _load_structure(pdb_path: Path) -> gemmi.Structure:
    key = str(pdb_path.resolve())
    if key not in STRUCT_CACHE:
        STRUCT_CACHE[key] = gemmi.read_structure(str(pdb_path))
    return STRUCT_CACHE[key]

def _load_interactions(plip_xml: Path) -> list[tuple[str, str, str]]:
    key = str(plip_xml.resolve())
    if key not in INTERACTION_CACHE:
        INTERACTION_CACHE[key] = parse_plip_xml(plip_xml)
    return INTERACTION_CACHE[key]

def visualize_pdb(pdb_path: Path, options: Dict[str, Any] | None = None, size: tuple[int, int] | None = None) -> py3Dmol.view:
    opts = {**DEFAULT_OPTIONS, **(options or {})}
    if not pdb_path.exists():
        raise FileNotFoundError(pdb_path)

    pdb_text = _load_pdb_text(pdb_path)
    if not pdb_text.strip():
        raise ValueError(f'{pdb_path} is empty')

    width, height = size if size else (VIEW_WIDTH, VIEW_HEIGHT)
    view = py3Dmol.view(width=width, height=height)
    view.addModel(pdb_text, 'pdb')

    _apply_environment(view, opts)
    _apply_protein_style(view, opts)
    _apply_ligand_style(view, opts)

    plip_xml = pdb_path.parent / 'plip' / 'report.xml'
    view.removeAllShapes()
    if opts['show_interactions'] and plip_xml.exists():
        try:
            st_map = _load_structure(pdb_path)
            interactions = _load_interactions(plip_xml)
            add_plip(view, st_map, interactions)
        except Exception:
            pass

    if opts['show_labels']:
        _add_ligand_labels(view, pdb_text)

    center_sel = {'or': [{'hetflag': True}, {'cartoon': True}]}
    try:
        view.zoomTo(center_sel)
    except Exception:
        view.zoomTo()
    try:
        view.center(center_sel)
    except Exception:
        pass

    if opts['spin']:
        try:
            view.spin('y', opts['spin_speed'])
        except Exception:
            pass

    return view

def _color_css(color: str) -> str:
    color = color.strip()
    if color.startswith('0x'):
        return '#' + color[2:]
    if color.startswith('#'):
        return color
    return color

def _format_label(text: str) -> widgets.HTML:
    return widgets.HTML(f'<span class="slider-label">{text}</span>')

class ComplexNavigator:
    def __init__(self, complexes: list[Path]):
        self.complexes = complexes
        self.idx = 0
        self.out = widgets.Output()
        self.curr_view: py3Dmol.view | None = None

        self.prev_btn = widgets.Button(description='‚¨ÖÔ∏è Previous')
        self.prev_btn.add_class('custom-btn')
        self.next_btn = widgets.Button(description='Next ‚û°Ô∏è')
        self.next_btn.add_class('custom-btn')
        self.save_btn = widgets.HTML()
        self.status = widgets.HTML(value='')

        self.spin_toggle = widgets.ToggleButton(value=DEFAULT_OPTIONS['spin'], description='Spin')
        self.spin_toggle.add_class('custom-btn')
        self.spin_toggle.add_class('spin-btn')

        self.protein_color = widgets.Dropdown(options=list(PROTEIN_COLORS.keys()), value=DEFAULT_OPTIONS['protein_color'])
        self.protein_style = widgets.ToggleButtons(options=PROTEIN_STYLES, value=DEFAULT_OPTIONS['protein_style'], layout=widgets.Layout(button_width='120px'))
        self.surface_opacity = widgets.FloatSlider(value=DEFAULT_OPTIONS['surface_opacity'], min=0.0, max=1.0, step=0.05, readout_format='.2f')

        self.ligand_color = widgets.Dropdown(options=list(LIGAND_COLORS.keys()), value=DEFAULT_OPTIONS['ligand_color'])
        self.ligand_radius = widgets.FloatSlider(value=DEFAULT_OPTIONS['ligand_radius'], min=0.1, max=0.6, step=0.02, readout_format='.2f')
        self.interactions_toggle = widgets.ToggleButton(value=DEFAULT_OPTIONS['show_interactions'], description='Interactions')
        self.interactions_toggle.add_class('custom-btn')
        self.labels_toggle = widgets.ToggleButton(value=DEFAULT_OPTIONS['show_labels'], description='Labels')
        self.labels_toggle.add_class('custom-btn')
        self.prev_btn.layout = widgets.Layout(width='130px', height='42px')
        self.next_btn.layout = widgets.Layout(width='130px', height='42px')
        self.spin_toggle.layout = widgets.Layout(width='160px', height='42px')
        self.interactions_toggle.layout = widgets.Layout(width='160px', height='42px')
        self.labels_toggle.layout = widgets.Layout(width='160px', height='42px')

        self.background_select = widgets.Dropdown(options=['White', 'Black', 'Slate', 'Soft Gray', 'Transparent'], value=DEFAULT_OPTIONS['background'])
        self.projection_toggle = widgets.ToggleButtons(options=[('Perspective', 'perspective'), ('Ortho', 'orthographic')], value=DEFAULT_OPTIONS['projection'], layout=widgets.Layout(button_width='120px'))
        self.outline_toggle = widgets.ToggleButton(value=DEFAULT_OPTIONS['outline'], description='Outline')
        self.outline_toggle.add_class('custom-btn')
        self.outline_toggle.layout = widgets.Layout(width='160px', height='42px')
        self.snapshot_scale = widgets.SelectionSlider(options=[2, 3, 4, 5, 6, 8], value=DEFAULT_OPTIONS['snapshot_scale'], description='Snapshot √ó', readout=True)
        self.snapshot_help = widgets.HTML('<span style="font-size:12px;color:#555;">Higher multiplier = higher PNG resolution</span>')

        self.protein_color.description = ''
        self.protein_style.description = ''
        self.surface_opacity.description = ''
        self.ligand_color.description = ''
        self.ligand_radius.description = ''
        self.background_select.description = ''
        self.projection_toggle.description = ''
        self.snapshot_scale.description = ''

        nav_row = widgets.HBox([self.prev_btn, self.next_btn], layout=widgets.Layout(justify_content='center', gap='0.6em', padding='0.2em'))
        save_row = widgets.HBox([self.save_btn], layout=widgets.Layout(justify_content='center', align_items='center', padding='0.2em'))
        spin_row = widgets.HBox([self.spin_toggle], layout=widgets.Layout(justify_content='center', padding='0.2em'))

        protein_box = widgets.VBox([
            widgets.HTML('<h4>Protein</h4>'),
            _format_label('Color'),
            self.protein_color,
            _format_label('Style'),
            self.protein_style,
            _format_label('Surface opacity'),
            self.surface_opacity,
        ], layout=widgets.Layout(align_items='stretch', gap='0.35em'))
        protein_box.add_class('section-box')

        ligand_box = widgets.VBox([
            widgets.HTML('<h4>Ligand</h4>'),
            _format_label('Color'),
            self.ligand_color,
            _format_label('Stick radius'),
            self.ligand_radius,
            self.interactions_toggle,
            self.labels_toggle,
        ], layout=widgets.Layout(align_items='stretch', gap='0.35em'))
        ligand_box.add_class('section-box')

        background_box = widgets.VBox([
            widgets.HTML('<h4>Background</h4>'),
            _format_label('Color'),
            self.background_select,
            _format_label('Projection'),
            self.projection_toggle,
            _format_label('Snapshot resolution √ó'),
            self.snapshot_scale,
            self.snapshot_help,
            self.outline_toggle,
        ], layout=widgets.Layout(align_items='stretch', gap='0.35em'))
        background_box.add_class('section-box')

        config_row = widgets.HBox([protein_box, ligand_box, background_box], layout=widgets.Layout(justify_content='center', flex_wrap='wrap', gap='0.5em'))

        self.controls = widgets.VBox([nav_row, save_row, spin_row, config_row, self.status])

        observers = [
            self.spin_toggle,
            self.interactions_toggle,
            self.labels_toggle,
            self.outline_toggle,
            self.protein_color,
            self.protein_style,
            self.surface_opacity,
            self.ligand_color,
            self.ligand_radius,
            self.background_select,
            self.projection_toggle,
            self.snapshot_scale,
        ]
        for w in observers:
            w.observe(self._refresh, names='value')

        self.prev_btn.on_click(lambda _: self._step(-1))
        self.next_btn.on_click(lambda _: self._step(1))

        display(self.out)
        display(self.controls)
        self._show()

    def _current_options(self) -> Dict[str, Any]:
        return {
            'surface_opacity': self.surface_opacity.value,
            'protein_color': self.protein_color.value,
            'protein_style': self.protein_style.value,
            'ligand_color': self.ligand_color.value,
            'ligand_radius': self.ligand_radius.value,
            'show_interactions': self.interactions_toggle.value,
            'show_labels': self.labels_toggle.value,
            'background': self.background_select.value,
            'projection': self.projection_toggle.value,
            'outline': self.outline_toggle.value,
            'spin': self.spin_toggle.value,
            'spin_speed': DEFAULT_OPTIONS['spin_speed'],
            'zoom_factor': DEFAULT_OPTIONS['zoom_factor'],
            'snapshot_scale': self.snapshot_scale.value,
        }

    def _step(self, delta: int):
        self.idx = (self.idx + delta) % len(self.complexes)
        self._show()

    def _refresh(self, change):
        self._show()

    def _show(self):
        with self.out:
            self.out.clear_output(wait=True)
            pdb_fp = self.complexes[self.idx]
            print(f'[{self.idx + 1}/{len(self.complexes)}] {pdb_fp.name}')
            opts = self._current_options()
            self.curr_view = visualize_pdb(pdb_fp, opts)
            display(self.curr_view)
            fname = pdb_fp.stem.replace('complex__', '') + '_snapshot.png'
            scale = opts['snapshot_scale']
            self.save_btn.value = f"""
            <button onclick="saveCanvasSnapshot('{fname}', {scale})"
                    class='custom-btn' style='margin:6px auto;width:150px;height:44px;'>üíæ Save PNG</button>
            """
            self.status.value = ''

def _main():
    ComplexNavigator(list_complexes())

if __name__ == '__main__':
    _main()


Output()

VBox(children=(HBox(children=(Button(description='‚¨ÖÔ∏è Previous', layout=Layout(height='42px', width='130px'), s‚Ä¶