In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import re
import json
from pathlib import Path
from typing import Any, Dict, Tuple, Union

from IPython.display import clear_output, display
from ipywidgets import Layout, widgets
from openff.toolkit import Molecule
from QligFEP.chemIO import MoleculeIO
from QligFEP.CLI.utils import get_avail_restraint_methods
from QligFEP.logger import logger, setup_logger
from QligFEP.restraints.restraint_setter import RestraintSetter
from QligFEP.visualization import render_ligand_restraints
from rdkit import Chem


def apply_and_render_restraint_v2(
    ligand1: Union[Molecule, Chem.Mol, str, Path],
    ligand2: Union[Molecule, Chem.Mol, str, Path],
    restraint_method: str = "hybridization_p",
    show_atom_idxs: bool = True,
    size: tuple[int, int] = (900, 500),
    sphere_palette: str = "hsv",
    sphere_radius: float = 0.6,
    sphere_alpha: float = 0.8,
) -> Tuple[dict[int, int], Any]:  # Return both dictionary and view
    """Modified version of apply_and_render_restraint that returns both the restraint dict
    and the rendered view.

    Returns:
        Tuple[dict[int, int], Any]: Tuple containing:
            - Dictionary of atom indexes containing the atoms to be restrained
            - The rendered view object
    """
    # Your existing code until render_ligand_restraints
    ligand1 = RestraintSetter.input_to_small_molecule_component(ligand1)
    ligand2 = RestraintSetter.input_to_small_molecule_component(ligand2)

    pattern = r"_(\d+\.?\d*)"  # check for the optional atom max distance
    match = re.search(pattern, restraint_method)
    if match:
        atom_max_distance = float(match.group(1))
        restraint_method = re.sub(pattern, "", restraint_method)
    else:
        atom_max_distance = 0.95

    if restraint_method not in get_avail_restraint_methods():
        raise ValueError(
            f"restraint_method should be one of {get_avail_restraint_methods()}, got {restraint_method}"
        )
    elif restraint_method == "overlap":
        raise ValueError(
            "Overlap method is not supported by this method yet, use `kartograf` instead."
        )

    rsetter = RestraintSetter(
        ligand1, ligand2, kartograf_max_atom_distance=atom_max_distance
    )

    if restraint_method == "kartograf":
        restraint_dict = rsetter.set_restraints(
            kartograf_native=True,
        )
    else:
        atom_compare_method, permissiveness_lvl = restraint_method.split("_")
        if permissiveness_lvl == "p":
            params = {"strict_surround": False}
        elif permissiveness_lvl == "ls":
            params = {"strict_surround": True, "ignore_surround_atom_type": True}
        elif permissiveness_lvl == "strict":
            params = {"strict_surround": True, "ignore_surround_atom_type": False}
        restraint_dict = rsetter.set_restraints(
            atom_compare_method=atom_compare_method, **params
        )
        logger.debug(
            f"Restraints set using {restraint_method} method. Parameters: {params}"
        )

    # Modify render_ligand_restraints to return the view
    view = render_ligand_restraints(
        ligand1,
        ligand2,
        restraint_dict,
        show_atom_idxs=show_atom_idxs,
        size=size,
        sphere_palette=sphere_palette,
        sphere_radius=sphere_radius,
        sphere_alpha=sphere_alpha,
        render=False,
    )

    return restraint_dict, view


class ViewWrapper:
    def __init__(self, view):
        self.view = view

    def _repr_html_(self):
        return self.view._repr_html_()

    def __repr__(self):
        return ""  # Return empty string to suppress the repr output


class RestraintsViewer:
    def __init__(self):
        self.current_idx = 0
        self.views = []
        self.restraint_dicts = []
        self.descriptions = []
        self.molecule_names = []

        # Create widgets
        self.prev_button = widgets.Button(
            description="Previous",
            disabled=True,
            button_style="",
            tooltip="Go to previous pair",
            icon="chevron-left",
            layout=Layout(width="100px"),
        )

        self.next_button = widgets.Button(
            description="Next",
            disabled=True,
            button_style="",
            tooltip="Go to next pair",
            icon="chevron-right",
            layout=Layout(width="100px"),
        )

        # Change Label to HTML for better styling
        self.index_label = widgets.HTML(value="0/0", layout=Layout(width="100px"))

        # Add new label for molecule names with centered text
        self.molecule_names_label = widgets.HTML(
            value="", layout=Layout(width="100%")  # Full width for centering
        )

        # Set up callbacks
        self.prev_button.on_click(self._on_prev_clicked)
        self.next_button.on_click(self._on_next_clicked)

        # Create output widget for the viewer
        self.output = widgets.Output()

        # Create navigation box with centered elements
        self.nav_box = widgets.HBox(
            [self.prev_button, self.index_label, self.next_button],
            layout=Layout(
                display="flex",
                flex_flow="row",
                align_items="center",
                justify_content="center",
                width="100%",
            ),
        )

        # Create main widget with proper layout
        self.widget = widgets.VBox(
            [self.output, self.molecule_names_label, self.nav_box],
            layout=Layout(
                display="flex", flex_flow="column", align_items="center", width="100%"
            ),
        )

    def add_view(
        self,
        restraint_dict: Dict,
        view: Any,
        description: str = "",
        molecule_names: tuple = None,
    ):
        """Add a new view to the viewer."""
        self.restraint_dicts.append(restraint_dict)
        self.views.append(ViewWrapper(view))  # Wrap the view
        self.descriptions.append(description)
        self.molecule_names.append(molecule_names)

        # Update navigation buttons
        self._update_buttons()

        # If this is the first view, display it
        if len(self.views) == 1:
            self._display_current()

    def _update_buttons(self):
        """Update the state of navigation buttons and labels."""
        self.prev_button.disabled = self.current_idx <= 0
        self.next_button.disabled = self.current_idx >= len(self.views) - 1
        self.index_label.value = f"<div style='text-align: center'>Pair {self.current_idx + 1} of {len(self.views)}</div>"

        # Update molecule names label with centered text
        if self.molecule_names[self.current_idx]:
            from_mol, to_mol = self.molecule_names[self.current_idx]
            self.molecule_names_label.value = (
                f'<div style="text-align: center">'
                f"Comparing: {from_mol} â†’ {to_mol}"
                f"</div>"
            )

    def _display_current(self):
        """Display the current view."""
        with self.output:
            clear_output(wait=True)
            display(self.views[self.current_idx])

    def _on_prev_clicked(self, b):
        """Handle previous button click."""
        if self.current_idx > 0:
            self.current_idx -= 1
            self._display_current()
            self._update_buttons()

    def _on_next_clicked(self, b):
        """Handle next button click."""
        if self.current_idx < len(self.views) - 1:
            self.current_idx += 1
            self._display_current()
            self._update_buttons()

    def display(self):
        """Display the widget"""
        display(self.widget)


# Example usage:
def process_multiple_restraints(
    ligand_pairs, molecule_names, method="hybridization_p", size=(900, 500)
):
    """
    Process multiple ligand pairs with different restraint methods.
    """
    viewer = RestraintsViewer()

    for idx, ((lig1, lig2), (from_name, to_name)) in enumerate(
        zip(ligand_pairs, molecule_names)
    ):
        restraint_dict, view = apply_and_render_restraint_v2(
            lig1, lig2, restraint_method=method, size=size
        )
        description = f"Pair {idx + 1} - Method: {method}"
        viewer.add_view(
            restraint_dict, view, description, molecule_names=(from_name, to_name)
        )

    viewer.display()  # Add this line
    return viewer

# JACS dataset

## bace

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/bace/ligands.sdf")
network_dict = json.loads(Path("../perturbations/bace/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="hybridization_p", size=(1200, 500)
)

## cdk2

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/cdk2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/cdk2/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="hybridization_p", size=(1200, 500)
)

## jnk1

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/jnk1/ligands.sdf")
network_dict = json.loads(Path("../perturbations/jnk1/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="hybridization_p", size=(1200, 500)
)

## mcl1

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/mcl1/ligands.sdf")
network_dict = json.loads(Path("../perturbations/mcl1/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

## p38

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/p38/ligands.sdf")
network_dict = json.loads(Path("../perturbations/p38/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

## ptp1b

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/ptp1b/ligands.sdf")
network_dict = json.loads(Path("../perturbations/ptp1b/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

## thrombin

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/thrombin/ligands.sdf")
network_dict = json.loads(Path("../perturbations/thrombin/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="hybridization_p", size=(1200, 500)
)

## tyk2

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/tyk2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/tyk2/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="hybridization_p", size=(1200, 500)
)

# Merck dataset

## cdk8

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/cdk8/ligands.sdf")
network_dict = json.loads(Path("../perturbations/cdk8/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="hybridization_p", size=(1200, 500)
)

## cmet

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/cmet/ligands.sdf")
network_dict = json.loads(Path("../perturbations/cmet/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_ls", size=(1200, 500)
)

## eg5

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/eg5/ligands.sdf")
network_dict = json.loads(Path("../perturbations/eg5/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

## hif2a

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/hif2a/ligands.sdf")
network_dict = json.loads(Path("../perturbations/hif2a/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

## shp2

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/shp2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/shp2/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p_1.2", size=(1200, 500)
)

## syk

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/syk/ligands.sdf")
network_dict = json.loads(Path("../perturbations/syk/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

## pfkb3

In [1]:
# TODO

## tnks2

In [None]:
# Create pairs of molecule names along with the ligand pairs
molecule_name_pairs = []
ligand_pairs = []

chemio = MoleculeIO("../perturbations/tnks2/ligands.sdf")
network_dict = json.loads(Path("../perturbations/tnks2/mapping.json").read_text())

for edge in network_dict["edges"]:
    _from = edge["from"]
    _to = edge["to"]
    ligand1 = chemio.molecules[chemio.lig_names.index(_from)].to_rdkit()
    ligand2 = chemio.molecules[chemio.lig_names.index(_to)].to_rdkit()

    # Store both the ligand pairs and the name pairs
    ligand_pairs.append((ligand1, ligand2))
    molecule_name_pairs.append((_from, _to))

# Pass both ligand pairs and molecule names to the function
viewer = process_multiple_restraints(
    ligand_pairs, molecule_name_pairs, method="heavyatom_p", size=(1200, 500)
)

# Visualize systems

## eg5

In [None]:
from pathlib import Path
from QligFEP.visualization import render_system
from QligFEP.chemIO import MoleculeIO

fileroot = Path("eg5/")
ligands = fileroot / "ligands/ligands.sdf"
protein = fileroot / "protein/protein.pdb"

chemio = MoleculeIO(ligands)

render_system(molecules=chemio.molecules, protein_path=protein)

## bace

In [None]:
from pathlib import Path
from QligFEP.visualization import render_system
from QligFEP.chemIO import MoleculeIO

fileroot = Path("bace/")
ligands = fileroot / "ligands/ligands.sdf"
protein = fileroot / "protein/protein.pdb"

chemio = MoleculeIO(ligands)

render_system(molecules=chemio.molecules, protein_path=protein)

## cdk2

In [None]:
from pathlib import Path
from QligFEP.visualization import render_system
from QligFEP.chemIO import MoleculeIO

fileroot = Path("cdk2/")
ligands = fileroot / "ligands/ligands.sdf"
protein = fileroot / "protein/protein.pdb"

chemio = MoleculeIO(ligands)

render_system(molecules=chemio.molecules, protein_path=protein)