In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re
import json
from pathlib import Path
from typing import Any, Dict, Tuple, Union

from IPython.display import clear_output, display
from ipywidgets import Layout, widgets
from openff.toolkit import Molecule
from QligFEP.chemIO import MoleculeIO
from QligFEP.CLI.utils import get_avail_restraint_methods
from QligFEP.logger import logger, setup_logger
from QligFEP.restraints.restraint_setter import RestraintSetter
from QligFEP.visualization import render_ligand_restraints
from rdkit import Chem


def apply_and_render_restraint_v2(
    ligand1: Union[Molecule, Chem.Mol, str, Path],
    ligand2: Union[Molecule, Chem.Mol, str, Path],
    restraint_method: str = "hybridization_p",
    show_atom_idxs: bool = True,
    size: tuple[int, int] = (900, 500),
    sphere_palette: str = "hsv",
    sphere_radius: float = 0.6,
    sphere_alpha: float = 0.8,
) -> Tuple[dict[int, int], Any]:  # Return both dictionary and view
    """Modified version of apply_and_render_restraint that returns both the restraint dict
    and the rendered view.

    Returns:
        Tuple[dict[int, int], Any]: Tuple containing:
            - Dictionary of atom indexes containing the atoms to be restrained
            - The rendered view object
    """
    # Your existing code until render_ligand_restraints
    ligand1 = RestraintSetter.input_to_small_molecule_component(ligand1)
    ligand2 = RestraintSetter.input_to_small_molecule_component(ligand2)

    pattern = r"_(\d+\.?\d*)"  # check for the optional atom max distance
    match = re.search(pattern, restraint_method)
    if match:
        atom_max_distance = float(match.group(1))
        restraint_method = re.sub(pattern, "", restraint_method)
    else:
        atom_max_distance = 0.95

    if restraint_method not in get_avail_restraint_methods():
        raise ValueError(
            f"restraint_method should be one of {get_avail_restraint_methods()}, got {restraint_method}"
        )
    elif restraint_method == "overlap":
        raise ValueError(
            "Overlap method is not supported by this method yet, use `kartograf` instead."
        )

    rsetter = RestraintSetter(
        ligand1, ligand2, kartograf_max_atom_distance=atom_max_distance
    )

    if restraint_method == "kartograf":
        restraint_dict = rsetter.set_restraints(
            kartograf_native=True,
        )
    else:
        atom_compare_method, permissiveness_lvl = restraint_method.split("_")
        if permissiveness_lvl == "p":
            params = {"strict_surround": False}
        elif permissiveness_lvl == "ls":
            params = {"strict_surround": True, "ignore_surround_atom_type": True}
        elif permissiveness_lvl == "strict":
            params = {"strict_surround": True, "ignore_surround_atom_type": False}
        restraint_dict = rsetter.set_restraints(
            atom_compare_method=atom_compare_method, **params
        )
        logger.debug(
            f"Restraints set using {restraint_method} method. Parameters: {params}"
        )

    # Modify render_ligand_restraints to return the view
    view = render_ligand_restraints(
        ligand1,
        ligand2,
        restraint_dict,
        show_atom_idxs=show_atom_idxs,
        size=size,
        sphere_palette=sphere_palette,
        sphere_radius=sphere_radius,
        sphere_alpha=sphere_alpha,
        render=False,
    )

    return restraint_dict, view


class RestraintsViewer:
    def __init__(self):
        self.current_idx = 0
        self.views = []
        self.restraint_dicts = []
        self.descriptions = []

        # Create widgets
        self.prev_button = widgets.Button(
            description="Previous", disabled=True, layout=Layout(width="100px")
        )
        self.next_button = widgets.Button(
            description="Next", disabled=True, layout=Layout(width="100px")
        )
        self.index_label = widgets.Label(value="0/0", layout=Layout(width="100px"))
        self.description_label = widgets.Label(value="", layout=Layout(width="400px"))

        # Set up callbacks
        self.prev_button.on_click(self._on_prev_clicked)
        self.next_button.on_click(self._on_next_clicked)

        # Create layout
        self.button_box = widgets.HBox(
            [
                self.prev_button,
                self.index_label,
                self.next_button,
                self.description_label,
            ]
        )

        self.output = widgets.Output()
        display(self.button_box, self.output)

    def add_view(self, restraint_dict: Dict, view: Any, description: str = ""):
        """Add a new view to the viewer."""
        self.restraint_dicts.append(restraint_dict)
        self.views.append(view)
        self.descriptions.append(description)

        # Update navigation buttons
        self._update_buttons()

        # If this is the first view, display it
        if len(self.views) == 1:
            self._display_current()

    def _update_buttons(self):
        """Update the state of navigation buttons."""
        self.prev_button.disabled = self.current_idx <= 0
        self.next_button.disabled = self.current_idx >= len(self.views) - 1
        self.index_label.value = f"{self.current_idx + 1}/{len(self.views)}"
        self.description_label.value = self.descriptions[self.current_idx]

    def _display_current(self):
        """Display the current view."""
        with self.output:
            clear_output(wait=True)
            display(self.views[self.current_idx])
        self._update_buttons()

    def _on_prev_clicked(self, b):
        """Handle previous button click."""
        if self.current_idx > 0:
            self.current_idx -= 1
            self._display_current()

    def _on_next_clicked(self, b):
        """Handle next button click."""
        if self.current_idx < len(self.views) - 1:
            self.current_idx += 1
            self._display_current()


# Example usage:
def process_multiple_restraints(
    ligand_pairs, method="hybridization_p", size=(900, 500)
):
    """
    Process multiple ligand pairs with different restraint methods.

    Args:
        ligand_pairs: List of tuples containing (ligand1, ligand2)
        methods: List of restraint methods to try. If None, uses default method.
    """
    viewer = RestraintsViewer()

    for idx, (lig1, lig2) in enumerate(ligand_pairs):
        restraint_dict, view = apply_and_render_restraint_v2(
            lig1, lig2, restraint_method=method, size=size
        )
        description = f"Pair {idx + 1} - Method: {method}"
        viewer.add_view(restraint_dict, view, description)

    return viewer

# JACS dataset

## bace

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/bace/ligands.sdf")
network_dict = json.loads(Path("../perturbations/bace/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="hybridization_p", size=(1200, 500))

## cdk2

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/cdk2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/cdk2/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## jnk1

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/jnk1/ligands.sdf")
network_dict = json.loads(Path("../perturbations/jnk1/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="hybridization_p", size=(1200, 500))

## mcl1

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/mcl1/ligands.sdf")
network_dict = json.loads(Path("../perturbations/mcl1/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## p38

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/p38/ligands.sdf")
network_dict = json.loads(Path("../perturbations/p38/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## ptp1b

In [None]:
chemio = MoleculeIO("../perturbations/ptp1b/ligands.sdf")
network_dict = json.loads(Path("../perturbations/ptp1b/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## thrombin

In [None]:
chemio = MoleculeIO("../perturbations/thrombin/ligands.sdf")
network_dict = json.loads(Path("../perturbations/thrombin/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="hybridization_p", size=(1200, 500))

## tyk2

In [None]:
chemio = MoleculeIO("../perturbations/tyk2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/tyk2/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="hybridization_p", size=(1200, 500))

# Merck dataset

## cdk8

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/cdk8/ligands.sdf")
network_dict = json.loads(Path("../perturbations/cdk8/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="hybridization_p", size=(1200, 500))

## cmet

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/cmet/ligands.sdf")
network_dict = json.loads(Path("../perturbations/cmet/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_ls", size=(1200, 500))

## eg5

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/eg5/ligands.sdf")
network_dict = json.loads(Path("../perturbations/eg5/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit() 
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## hif2a

In [None]:
setup_logger("info")

chemio = MoleculeIO("../perturbations/hif2a/ligands.sdf")
network_dict = json.loads(Path("../perturbations/hif2a/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## shp2

In [None]:
setup_logger("info")

chemio = MoleculeIO("../perturbations/shp2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/shp2/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p_1.2", size=(1200, 500))

## syk

In [None]:
setup_logger("INFO")

chemio = MoleculeIO("../perturbations/syk/ligands.sdf")
network_dict = json.loads(Path("../perturbations/syk/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

## pfkb3

In [1]:
# TODO

## tnks2

In [None]:
setup_logger("info")

chemio = MoleculeIO("../perturbations/tnks2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/tnks2/mapping.json").read_text())
ligand_pairs = []
for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()
    ligand_pairs.append((ligand1, ligand2))

process_multiple_restraints(ligand_pairs, method="heavyatom_p", size=(1200, 500))

# Visualize systems

## eg5

In [None]:
from pathlib import Path
from QligFEP.visualization import render_system
from QligFEP.chemIO import MoleculeIO

fileroot = Path("eg5/")
ligands = fileroot / "ligands/ligands.sdf"
protein = fileroot / "protein/protein.pdb"

chemio = MoleculeIO(ligands)

render_system(molecules=chemio.molecules, protein_path=protein)

## bace

In [None]:
from pathlib import Path
from QligFEP.visualization import render_system
from QligFEP.chemIO import MoleculeIO

fileroot = Path("bace/")
ligands = fileroot / "ligands/ligands.sdf"
protein = fileroot / "protein/protein.pdb"

chemio = MoleculeIO(ligands)

render_system(molecules=chemio.molecules, protein_path=protein)

## cdk2

In [None]:
from pathlib import Path
from QligFEP.visualization import render_system
from QligFEP.chemIO import MoleculeIO

fileroot = Path("cdk2/")
ligands = fileroot / "ligands/ligands.sdf"
protein = fileroot / "protein/protein.pdb"

chemio = MoleculeIO(ligands)

render_system(molecules=chemio.molecules, protein_path=protein)