In [None]:
from IPython.display import Javascript, display
import ipywidgets as widgets
from ipywidgets import (
    Button,
    Dropdown,
    GridspecLayout,
    Layout,
    Output,
    RadioButtons,
    Text,
    VBox,
)
from utils import io, visualize

from pathlib import Path
import re


# --- Utility: Copy to clipboard ---
def copy_to_clipboard(value):
    display(Javascript(f"navigator.clipboard.writeText('{value}')"))


# --- Widget Factory Functions ---
def create_dropdown(options, description="", **kwargs):
    return Dropdown(
        options=options,
        description=description,
        layout=Layout(height="auto", width="auto"),
        **kwargs,
    )


def create_button(description, style="", **kwargs):
    return Button(
        description=description,
        button_style=style,
        layout=Layout(height="auto", width="auto"),
        **kwargs,
    )


def create_radio(description, options):
    return RadioButtons(
        options=options,
        layout={"width": "max-content"},
        description=description,
        disabled=False,
    )


# --- SMILES Input ---
smiles_inp = Text(
    placeholder="Enter SMILES...",
    description="SMILES:",
    value="[CH]1CO1",
    layout=Layout(height="auto", width="auto"),
)

# --- Optimization Tab Layout ---
opt_layout = GridspecLayout(5, 3, height="500px")
methods_dropdown = create_dropdown(["XTB", "REVDSD", "CCSDT"])
xyzs_radio = create_radio(
    "XYZ Files:", io.get_xyz(smiles=smiles_inp.value, method=methods_dropdown.value)
)
opt_layout[:, 0] = VBox([methods_dropdown, xyzs_radio])
opt_layout[:4, 1:] = Output()


def update_opt_tab(update):
    xyzs_radio.options = io.get_xyz(
        smiles=smiles_inp.value, method=methods_dropdown.value
    )
    opt_layout[:4, 1:].clear_output()
    xyz_path = xyzs_radio.value
    with opt_layout[:4, 1:3]:
        if xyz_path is not None:
            opt_layout[:, 0] = VBox([methods_dropdown, xyzs_radio])
            if "trj" in str(xyz_path) or "allxyz" in str(xyz_path):
                visualize.visualize_traj(xyz_path)
            else:
                visualize.visualize_xyz(xyz_path)
        else:
            try:
                visualize.visualize_smiles(smiles_inp.value)
            except Exception as e:
                print("Enter a valid SMILES string.")
                print(e)
            # --- Optimization method/type selectors ---
            method_type_selectors = []
            for i in range(1, 5):
                method_dd = create_dropdown(
                    [
                        ("NONE", None),
                        ("XTB", "XTB"),
                        ("REVDSD", "REVDSD-PBEP86-D4/2021 def2-TZVPP def2-TZVPP/c"),
                        (
                            "CCSD(T)",
                            "CCSD(T)-F12/RI cc-pVDZ-F12 cc-pVDZ-F12-CABS cc-pVTZ/c",
                        ),
                    ],
                    description=f"Method {i})",
                )
                type_dd = create_dropdown(
                    [
                        ("NONE", None),
                        ("GOAT", "goat"),
                        ("OPT", "opt"),
                        ("FREQUENCY", "opt_freq"),
                        ("SPC", "elec"),
                    ],
                    description=f"Type {i})",
                )
                method_type_selectors.extend([method_dd, type_dd])
            multiplicity_inp = widgets.IntSlider(
                value=2,
                min=0,
                max=10,
                step=1,
                description="Multiplicity:",
            )
            opt_layout[0, 0] = multiplicity_inp
            opt_layout[1:, 0] = VBox(method_type_selectors)

    # --- Energies or Submit Button ---
    if xyz_path is not None:
        zpv_text, spc_text = io.get_energies(xyz_path)
        zpv_btn = create_button(zpv_text, "info")
        spc_btn = create_button(spc_text, "info")
        zpv_value = re.findall(r"[-+]?\d*\.\d+|\d+", zpv_text)
        spc_value = re.findall(r"[-+]?\d*\.\d+|\d+", spc_text)
        zpv_value = zpv_value[0] if zpv_value else ""
        spc_value = spc_value[0] if spc_value else ""
        zpv_btn.on_click(lambda b: copy_to_clipboard(zpv_value))
        spc_btn.on_click(lambda b: copy_to_clipboard(spc_value))
        opt_layout[4, 1] = zpv_btn
        opt_layout[4, 2] = spc_btn
    else:
        submit_btn = create_button("Run optimization", "info")

        def on_submit_click(b):
            methods = [
                method_type_selectors[2 * i].value
                for i in range(len(method_type_selectors) // 2)
            ]
            types = [
                method_type_selectors[2 * i + 1].value
                for i in range(len(method_type_selectors) // 2)
            ]
            guess = True
            sh_script = str(
                "#!/bin/bash\n\n#SBATCH --partition=batch\n#SBATCH --ntasks=1\n#SBATCH --time=12:00:00\n#SBATCH --mem-per-cpu=50\n\n"
            )
            for i in range(len(methods)):
                if methods[i] is not None:
                    output_dir, label, identifier = io.orca_optimization(
                        method=methods[i],
                        smiles=smiles_inp.value,
                        job_type=types[i],
                        guess=guess,
                        multiplicity=multiplicity_inp.value,
                    )
                    if types[i] == "opt_goat":
                        label += ".globalminimum"
                    if guess is True:
                        sh_script += (
                            f"cd {str(output_dir)}\n"
                            f"slurm_id=$(sbatch submit_{label}.sh | awk '{{print $NF}}')\n"
                            f"sacct -j $slurm_id --format=State --noheader -n | grep -q 'COMPLETED' || "
                            f"while [[ $(sacct -j $slurm_id --format=State --noheader -n | grep -c 'COMPLETED') -eq 0 ]]; do sleep 5; done\n"
                            f"cp $slurm_id/{label}.xyz"
                        )
                    else:
                        sh_script += (
                            f" {output_dir}/init.xyz\n\n"
                            f"cd {str(output_dir)}\n"
                            f"slurm_id=$(sbatch --dependency=afterok:$slurm_id submit_{label}.sh | awk '{{print $NF}}')\n"
                            f"sacct -j $slurm_id --format=State --noheader -n | grep -q 'COMPLETED' || "
                            f"while [[ $(sacct -j $slurm_id --format=State --noheader -n | grep -c 'COMPLETED') -eq 0 ]]; do sleep 5; done\n"
                            f"cp $slurm_id/{label}.xyz"
                        )
                    guess = False
            par_dir = Path.home() / f"C5O-Kinetics/calc/{identifier}/Optimization/run"
            path_out = Path(par_dir) / "submit.sh"
            path_out.write_text(sh_script)
            copy_to_clipboard(f"sbatch {path_out}")

        submit_btn.on_click(on_submit_click)
        opt_layout[4, 1:] = submit_btn


# --- Transition Tab Layout ---
trans_layout = GridspecLayout(5, 3, height="500px")
trans_opts_dropdown = create_dropdown(
    ["Show scan", "Show vibrational mode", "New transition"]
)
scans_dropdown = create_dropdown(io.get_transitions(smiles=smiles_inp.value))
allxyzs_dropdown = create_dropdown([])
xyzs_radio.options = io.get_xyz(smiles=smiles_inp.value, method=methods_dropdown.value)
trans_layout[:, 0] = VBox([trans_opts_dropdown, scans_dropdown])
trans_layout[:4, 1] = Output()
trans_layout[:4, 2] = Output()


def update_trans_tab(update):
    scans_dropdown.options = io.get_transitions(smiles=smiles_inp.value)
    if trans_opts_dropdown.value == "Show scan":
        allxyzs_dropdown.options = io.get_scan_xyzs(trans_dir=scans_dropdown.value)
        trans_layout[:, 0] = VBox(
            [trans_opts_dropdown, scans_dropdown, allxyzs_dropdown]
        )
        trans_layout[:4, 1].clear_output()
        trans_layout[:4, 2].clear_output()
        xyz_path = allxyzs_dropdown.value
        with trans_layout[:4, 1]:
            if xyz_path is not None:
                if "trj" in str(xyz_path) or "allxyz" in str(xyz_path):
                    visualize.visualize_traj(xyz_path, width=350)
                else:
                    visualize.visualize_xyz(xyz_path, width=350)
        with trans_layout[:4, 2]:
            if xyz_path is not None:
                visualize.visualize_scan_energy(xyz_path)
    if trans_opts_dropdown.value == "Show vibrational mode":
        return
    if trans_opts_dropdown.value == "New transition":
        return


# --- Tabs and Display ---
tab = widgets.Tab()
tab.children = [opt_layout, trans_layout]
tab.titles = ["Optimization", "Transition"]

display(smiles_inp, tab)

# --- Observers ---
methods_dropdown.observe(update_opt_tab, "value")
smiles_inp.observe(update_opt_tab, "value")
xyzs_radio.observe(update_opt_tab, "value")
smiles_inp.observe(update_trans_tab, "value")
trans_opts_dropdown.observe(update_trans_tab, "value")
scans_dropdown.observe(update_trans_tab, "value")
allxyzs_dropdown.observe(update_trans_tab, "value")

Text(value='[CH]1CO1', description='SMILES:', layout=Layout(height='auto', width='auto'), placeholder='Enter S…

Tab(children=(GridspecLayout(children=(VBox(children=(Dropdown(layout=Layout(height='auto', width='auto'), opt…