# VLab4Mic: a universal validation tool for microscopy.

## VLab4Mic is a modular package where you can:
- Model labelling strategies of macromolecular complexes
- Simulate image acquisitions under diverse microscopy modalities
- Find optimal parameters for feature recovery

# Import dependencies

In [None]:
from ipywidgets import GridspecLayout
from IPython.utils import io
from supramolsim import experiments
import ipywidgets as widgets
import matplotlib.pyplot as plt
from supramolsim.generate.labels import construct_label
from supramolsim.workflows import probe_model
import copy
import os
import numpy as np
from supramolsim.utils import data_format
from supramolsim.workflows import create_imaging_system
from supramolsim.jupyter_widgets.widget_generator import widgen
from ezinput import EZInput
from pathlib import Path
import os
from ipyfilechooser import FileChooser


# Virtual sample model: Choose a structure and a probe

In [None]:
#@title Virtual Sample

structures=["9I0K", "1XI5"]
#structures=["9I0K", "1XI5", "7R5K", "8GMO"]
my_experiment = experiments.ExperimentParametrisation()
config_probes_per_structure = copy.copy(my_experiment.config_probe_per_structure_names)
config_probe_parameters = copy.copy(my_experiment.config_probe_params)
config_vlab_probes = copy.copy(my_experiment.config_global_probes_names)


main_widget = GridspecLayout(7, 10)
params_section = 2
list_of_experiments = dict()
structure_target_suggestion = dict()
with io.capture_output() as captured:
    for struct in structures:
        list_of_experiments[struct] = experiments.ExperimentParametrisation()
        list_of_experiments[struct].structure_id = struct
        list_of_experiments[struct]._build_structure()
        protein_name = None
        sequence = None
        protein_name, _1, site, sequence = (
            list_of_experiments[struct].structure.get_peptide_motif(position="cterminal")
        )
        structure_target_suggestion[struct] = {}
        structure_target_suggestion[struct]["probe_target_type"] = "Sequence"
        structure_target_suggestion[struct]["probe_target_value"] = sequence
structure_name = widgets.Dropdown(options=structures)
n_atoms = widgets.IntSlider(value=1e4, min=0, max=1e5, steps = 100, description="Atoms to display", style = {'description_width': 'initial'}, continuous_update=False)
h_rotaiton = widgets.IntSlider(value=0, min=-90, max=90, description="Horizontal rotation",  style = {'description_width': 'initial'}, continuous_update=False)
v_rotation = widgets.IntSlider(value=0, min=-90, max=90, description="Vertical rotation",  style = {'description_width': 'initial'}, continuous_update=False)
structure_params = [structure_name, n_atoms, h_rotaiton, v_rotation]
structure_output = widgets.Output()
def on_changes(change):
    structure_output.clear_output()
    with structure_output:
        p1 = structure_name.value
        p2 = n_atoms.value
        p3 = h_rotaiton.value
        p4 = v_rotation.value
        total = list_of_experiments[p1].structure.num_assembly_atoms
        if total > p2:
            fraction = p2/total
        else:
            fraction = 1.0
        with io.capture_output() as captured:
            figure = list_of_experiments[p1].structure.show_assembly_atoms(
                assembly_fraction=fraction,
                view_init = [p4,p3,0]
            )
        plt.close()
        display(figure)
structure_name.observe(on_changes, names="value")
n_atoms.observe(on_changes, names="value")
h_rotaiton.observe(on_changes, names="value")
v_rotation.observe(on_changes, names="value")
main_widget[:params_section, :3]  = widgets.VBox(structure_params)
main_widget[params_section:, :3] = structure_output
# probes
list_of_probe_objects = {}
with io.capture_output() as captured:
    vsample, experiment = experiments.generate_virtual_sample(
        clear_probes=True,
        )
    for probe_name in config_probe_parameters.keys():
        print(probe_name)
        label_config_path = os.path.join(experiment.configuration_path, "probes", probe_name + ".yaml")
        probe_obj, probe_parameters = construct_label(label_config_path)
        if probe_obj.model:
            (
                probe_structure_obj,
                probe_emitter_sites,
                anchor_point,
                direction_point,
                probe_epitope,
            ) = probe_model(
                model=probe_obj.model,
                binding=probe_obj.binding,
                conjugation_sites=probe_obj.conjugation,
                epitope=probe_obj.epitope,
                config_dir=experiment.configuration_path,
            )
            if anchor_point.shape == (3,):
                    print("setting new axis")
                    probe_obj.set_axis(pivot=anchor_point, direction=direction_point)
            if (
                probe_epitope["coordinates"] is not None
                and probe_parameters["as_linker"]
            ):
                print("Generating linker from epitope site")
                # TODO: this decision needs to take into account if there is a
                # secondary label for this specific probe
                probe_obj.set_emitters(probe_epitope["coordinates"])
            else:
                probe_obj.set_emitters(probe_emitter_sites)
            probe_parameters["coordinates"] = probe_obj.gen_labeling_entity()
            list_of_probe_objects[probe_name] = {}
            list_of_probe_objects[probe_name]["probe_object"] = probe_obj
            list_of_probe_objects[probe_name]["probe_structure"] = probe_structure_obj
        else:
            list_of_probe_objects[probe_name] = {}
            list_of_probe_objects[probe_name]["probe_structure"] = None
            list_of_probe_objects[probe_name]["probe_object"] = probe_obj
            if probe_parameters["target"]["type"] is not None:
                if probe_parameters["target"]["type"] == "Sequence":
                    text = "This probe targets the sequence: "
                    text = text + probe_parameters["target"]["value"]
                else:
                    text = "This probe targets a residue: "
                    text = text + probe_parameters["target"]["value"]["residues"]
                list_of_probe_objects[probe_name]["probe_info_text"] = text

def show_probe(probe, n_atoms, h_rotation=0, v_rotation=0):
    if probe in list_of_probe_objects.keys():
        if probe == "Linker":
            plot = list_of_probe_objects[probe]["probe_object"].plot_emitters(return_plot = True)
            return plot
        else:
            if list_of_probe_objects[probe]["probe_structure"] is not None:
                total = list_of_probe_objects[probe]["probe_structure"].num_assembly_atoms
                if total > n_atoms:
                    fraction = n_atoms/total
                else:
                    fraction = 1.0
                list_of_probe_objects[probe]["probe_structure"].plotting_params["assemblyatoms"]["plotalpha"] = 0.3
                with io.capture_output() as captured:
                    plot = list_of_probe_objects[probe]["probe_structure"].show_target_labels(
                        with_assembly_atoms = True,
                        assembly_fraction=fraction,
                        view_init = [v_rotation, h_rotation,0],
                        show_axis = False,
                        return_plot=True
                    )
                plt.close()
                return plot
            else:
                fig, ax = plt.subplots()
                ax.text(0.5, 0.5, list_of_probe_objects[probe]["probe_info_text"], fontsize=14, ha='center')
                ax.set_axis_off()  # This hides the axes
                plt.close()
                return fig

probes2show = []
current_structure = structure_name.value
if current_structure in config_probes_per_structure.keys():
    probes2show.extend(
        copy.copy(config_probes_per_structure[current_structure])
    )
probes2show.extend(copy.copy(config_vlab_probes))

w_probe_name = widgets.Dropdown(options=probes2show)
w_probe_n_atoms = widgets.IntSlider(value=1e2, min=0, max=1e3, steps = 10, description="Atoms to display", style = {'description_width': 'initial'}, continuous_update=False)
w_probe_h_rotaiton = widgets.IntSlider(value=0, min=-90, max=90, description="Horizontal rotation",  style = {'description_width': 'initial'}, continuous_update=False)
w_probe_v_rotation = widgets.IntSlider(value=0, min=-90, max=90, description="Vertical rotation",  style = {'description_width': 'initial'}, continuous_update=False)
w_probe_params = [w_probe_name, w_probe_n_atoms, w_probe_h_rotaiton, w_probe_v_rotation]
w_probe_model_output = widgets.Output()

def probe_on_change(change):
    w_probe_model_output.clear_output()
    with w_probe_model_output:
        plot = show_probe(
            probe = w_probe_name.value,
            n_atoms = w_probe_n_atoms.value,
            h_rotation = w_probe_h_rotaiton.value,
            v_rotation = w_probe_v_rotation.value)
        plt.close()
        display(plot)

w_probe_name.observe(probe_on_change, names="value")
w_probe_n_atoms.observe(probe_on_change, names="value")
w_probe_h_rotaiton.observe(probe_on_change, names="value")
w_probe_v_rotation.observe(probe_on_change, names="value")

def my_update(change):
    probes2show = []
    if change.new in config_probes_per_structure:
        probe_list = config_probes_per_structure[change.new]
        probes2show.extend(
            copy.copy(probe_list)
        )
    probes2show.extend(copy.copy(config_vlab_probes))
    w_probe_name.options = probes2show

structure_name.observe(my_update, names="value")

main_widget[:params_section, 3:6] = widgets.VBox(w_probe_params)
main_widget[params_section:, 3:6] = w_probe_model_output

## show particle
def calculate_labelled_particle(b):
    set_button.disabled = False
    emitter_plotsize.disabled = False
    epitope_plotsize.disabled = False
    particle_h_rotation.disabled = False
    particle_v_rotation.disabled = False
    struct = structure_name.value
    probe_name = w_probe_name.value
    probe_target_type=None
    probe_target_value=None
    psize1 = emitter_plotsize.value
    psize2 = epitope_plotsize.value
    hview = particle_h_rotation.value
    vview = particle_v_rotation.value
    if config_probe_parameters[probe_name]["target"]["type"] is None:
        probe_target_type = structure_target_suggestion[struct]["probe_target_type"]
        probe_target_value = structure_target_suggestion[struct]["probe_target_value"]
    particle_output.clear_output()
    with particle_output:
        with io.capture_output() as captured:
            #vsample, experiment = .generate_virtual_sample(
            my_experiment.structure_id = struct
            my_experiment.remove_probes()
            my_experiment.add_probe(probe_name,
                probe_target_type=probe_target_type,
                probe_target_value=probe_target_value
                )
            list_of_experiments[struct].remove_probes()
            list_of_experiments[struct].add_probe(probe_name,
                probe_target_type=probe_target_type,
                probe_target_value=probe_target_value
                )
            #list_of_experiments[struct].add_probe(probe_name, **target_probe_params)
            list_of_experiments[struct].build(modules=["particle",])
            figure = show_particle(
                struct = struct,
                emitter_plotsize=psize1,
                source_plotsize=psize2,
                hview=hview,
                vview=vview
                )
            plt.close()
        display(figure)

def update_plot(change):
    struct = structure_name.value
    particle_output.clear_output()
    with particle_output:
        psize1 = emitter_plotsize.value
        psize2 = epitope_plotsize.value
        hview = particle_h_rotation.value
        vview = particle_v_rotation.value
        figure = show_particle(
            struct = struct,
            emitter_plotsize=psize1,
            source_plotsize=psize2,
            hview=hview,
            vview=vview
            )
        plt.close()
        display(figure)

def show_particle(struct= None,
                  emitter_plotsize = 1,
                  source_plotsize = 1,
                  hview=0,
                  vview=0):
    particle_output.clear_output()
    with io.capture_output() as captured:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection="3d")
        list_of_experiments[struct].particle.gen_axis_plot(
                    axis_object=ax,
                    with_sources=True,
                    axesoff=True,
                    emitter_plotsize=emitter_plotsize,
                    source_plotsize=source_plotsize,
                    view_init=[vview, hview, 0]
                    )
        plt.close()
        return fig


def select_model_action(b):
    with io.capture_output() as captured:
        set_button.disabled = True
        preview_button.disabled = True
        feedback_text.value = "<b>Updating virtual sample model...</b>"
        my_experiment.structure = copy.deepcopy(list_of_experiments[my_experiment.structure_id].structure)
        my_experiment.objects_created["structure"] = True
        my_experiment.build(modules=["particle",])
        structname = my_experiment.structure_id
        probe_name = list(my_experiment.probe_parameters.keys())[0]
        feedback_text.value = "<b>Selected model for virutal sample: " + "<br>" + structname + " with probe " + probe_name + "</b>"
        set_button.disabled = False
        preview_button.disabled = False

emitter_plotsize = widgets.IntSlider(value=1, min=0, max=24, description="Emitters size",  style = {'description_width': 'initial'}, continuous_update=False, disabled=True)
epitope_plotsize = widgets.IntSlider(value=1, min=0, max=24, description="Epitope size",  style = {'description_width': 'initial'}, continuous_update=False, disabled=True)
particle_h_rotation = widgets.IntSlider(value=0, min=-90, max=90, description="Horizontal rotation",  style = {'description_width': 'initial'}, continuous_update=True, disabled=True)
particle_v_rotation = widgets.IntSlider(value=0, min=-90, max=90, description="Vertical rotation",  style = {'description_width': 'initial'}, continuous_update=True, disabled=True)
particle_output = widgets.Output()
feedback_text = widgets.HTML("<b>No model has been selected</b>", style = dict(font_size= "15px", font_weight='bold'))
preview_button = widgets.Button(description = "Update labelling", layout=widgets.Layout(width='auto'))
set_button = widgets.Button(description = "Set this model for virtual sample", layout=widgets.Layout(width='auto'), disabled = True)
preview_button.on_click(calculate_labelled_particle)
set_button.on_click(select_model_action)
#
emitter_plotsize.observe(update_plot, names="value")
epitope_plotsize.observe(update_plot, names="value")
particle_h_rotation.observe(update_plot, names="value")
particle_v_rotation.observe(update_plot, names="value")
buttons_widget = widgets.HBox([preview_button, set_button])
size_widgets = widgets.VBox([emitter_plotsize, epitope_plotsize])
visualisation_widgets  = widgets.VBox([particle_h_rotation, particle_v_rotation])
sliders_box = widgets.HBox([size_widgets, visualisation_widgets])
main_widget[:params_section, 6:]  = widgets.VBox([buttons_widget, sliders_box, feedback_text], layout=widgets.Layout(margin='0px'))
main_widget[params_section:, 6:] = particle_output
structure_name.value = structures[1]
main_widget

# Image virtual sample: customise the virtual sample and choose one or more imaging modalities

In [None]:

mode = "default"
modalities_default = ["Widefield", "Confocal", "STED", "SMLM"]
local_config_dir = my_experiment.configuration_path
wgen = widgen()


grid2 = GridspecLayout(7, 3)
preview_exp = copy.deepcopy(my_experiment)
with io.capture_output() as captured:
    my_experiment._build_coordinate_field(
                keep=True,
                nparticles=1
            )
nparticles = widgets.IntSlider(value=1, min=1, max=20, description="Number of particles",  style = {'description_width': 'initial'},continuous_update=False)
angle_view = widgets.IntSlider(value=20, min=-90, max=90, description="Angle view",  style = {'description_width': 'initial'},continuous_update=False)
random_orientations = widgets.Checkbox(description = "Randomise orientations", value=True)
vsample_params = [nparticles,angle_view,random_orientations]
vsample_output = widgets.Output()
def on_changes(change):
    vsample_output.clear_output()
    with vsample_output:
        index = vsample_params.index(change.owner)
        if index == 1:
            with io.capture_output() as captured:
                plot = my_experiment.coordinate_field.show_field(
                    return_fig=True,
                    view_init=[change.new,0,0]
                    )
                plt.close()
            display(plot)
        elif index == 0:
            with io.capture_output() as captured:
                my_experiment._build_coordinate_field(
                    keep=True,
                    nparticles=nparticles.value,
                    random_orientations=random_orientations.value
                )
                plot = my_experiment.coordinate_field.show_field(
                    return_fig=True,
                    view_init=[angle_view.value,0,0]
                )
                plt.close()
            display(plot)
        elif index == 2:
            with io.capture_output() as captured:

                if change.new:
                    my_experiment.coordinate_field.generate_random_orientations()
                    my_experiment.coordinate_field.construct_static_field(reorient=True)
                else:
                    default_axis = my_experiment.coordinate_field.molecules_default_orientation
                    my_experiment.coordinate_field.generate_global_orientation(global_orientation=default_axis["direction"])
                    my_experiment.coordinate_field.construct_static_field(reorient=True)
                plot = my_experiment.coordinate_field.show_field(
                        return_fig=True,
                        view_init=[angle_view.value,0,0]
                    )
                plt.close()
            display(plot)
nparticles.observe(on_changes, names="value")
angle_view.observe(on_changes, names="value")
random_orientations.observe(on_changes, names="value")
grid2[:2, 0]  = widgets.VBox(vsample_params)
grid2[2:, 0] = vsample_output
angle_view.value = 90
# modalities
modalities_options = []
if mode == "default":
    modalities_list = modalities_default
else:
    modalities_list = my_experiment.local_modalities_names
modalities_options = copy.copy(modalities_list)
modalities_options.append("All")
modality_info = {}
for mod in modalities_list:
    mod_info = data_format.configuration_format.compile_modality_parameters(
        mod, local_config_dir
    )
    modality_info[mod] = mod_info
with io.capture_output() as captured:
    temp_imager, tmp_modality_parameters = create_imaging_system(
        modalities_id_list=modalities_list,
        config_dir=local_config_dir,
    )


def show_modality(modality_name):
    if modality_name != "All":
        pixelsize = modality_info[modality_name]["detector"]["pixelsize"]
        pixelsize_nm = pixelsize * 1000
        psf_sd = np.array(
            modality_info[modality_name]["psf_params"]["std_devs"]
        )
        psf_voxel = np.array(
            modality_info[modality_name]["psf_params"]["voxelsize"]
        )
        psf_sd_metric = np.multiply(psf_voxel, psf_sd)
        fig, axs = plt.subplots()
        modality_preview = temp_imager.modalities[modality_name]["psf"][
            "psf_stack"
        ]
        psf_shapes = modality_preview.shape
        stack_max = np.max(modality_preview)
        axs.imshow(
            modality_preview[:, :, int(psf_shapes[2] / 2)],
            cmap="gray",
            interpolation="none",
            vmin=0,
            vmax=stack_max,
        )
        axs.set_xticks([])
        axs.set_yticks([])
        s1 = "Detector pixelsize (nm): " + str(pixelsize_nm)
        s2 = "PSF sd (nm): " + str(psf_sd_metric)
        s3 = "PSF preview (on a 1x1 Âµm field of view)"
        axs.text(0.05, 0.1, s1, transform=axs.transAxes, size = 10, color = "w")
        axs.text(0.05, 0.15, s2, transform=axs.transAxes, size = 10, color = "w")
        axs.text(0.05, 0.2, s3, transform=axs.transAxes, size = 10, color = "w")


wgt2 = wgen.gen_interactive_dropdown(
            options=modalities_options,
            orientation="vertical",
            routine=show_modality
)
mods_text_base = "<b> Selected modalities and acquisition parameters: </b>"

def _mods_text_update(mods_text_base, mod_acq_params, keys_to_use = ["exp_time", "noise"]):
    mods_text = mods_text_base + "<br>"
    for modality_name, acq_params in mod_acq_params.items():
        if acq_params is None:
            acq_params = "Default"
        else:
            keys_subset = {key: acq_params[key] for key in keys_to_use}
            acq_params["exp_time"] = round(keys_subset["exp_time"], 3)
        mods_text +=  modality_name + ": " + "&emsp;" +  str(keys_subset) + "<br>"
    return mods_text

selected_mods_feedback = widgets.HTML(_mods_text_update(mods_text_base, my_experiment.selected_mods),
                                  style = dict(font_size= "15px", font_weight='bold'))
wgt2.children[0].children += (selected_mods_feedback,)
wgt2.children[0].children[0].description = "Modality preview"
wgt2.children[0].children[0].style = {'description_width': 'initial'}
grid2[:2, 1]  = wgt2.children[0]
grid2[2:, 1] = wgt2.children[1]


current_acq = dict()

def preview_acquisition(widget, exposure_time, noise):

    field = my_experiment.coordinate_field.export_field()
    preview_exp.exported_coordinate_field = field
    preview_exp.objects_created["exported_coordinate_field"] = True
    selected_mod = widget.children[0].children[0].value
    if selected_mod != "All":
        fig = plt.figure()
        ax = fig.add_subplot(111)
        with io.capture_output() as captured:
            preview_exp.update_modality(modality_name=selected_mod,remove=True)
            preview_exp.add_modality(modality_name=selected_mod, save=False)
            preview_exp.set_modality_acq(modality_name=selected_mod, exp_time=exposure_time, noise=noise)
            preview_exp.build(modules=["imager",])
            # consider using run_simulation
            timeseries, calibration_beads = (
                preview_exp.imager.generate_imaging(
                    modality=selected_mod, exp_time=exposure_time, noise=noise
                )
            )
            current_acq = preview_exp.selected_mods[selected_mod]
        min_val = np.min(timeseries[0])
        max_val = np.max(timeseries[0])
        preview_image=ax.imshow(
            timeseries[0],
            cmap="gray",
            interpolation="none",
            vmin=min_val,
            vmax=max_val,
        )
        ax.set_xticks([])
        ax.set_yticks([])
        plt.close()
        return fig

def button_method(b):
    selected_mod = wgt2.children[0].children[0].value
    exp_time= static.children[0].children[0].value
    noise = static.children[0].children[1].value
    if selected_mod == "All":
        for mod_names in modalities_options[0:len(modalities_options)-1]:
            my_experiment.add_modality(modality_name=mod_names, save=True)
    else:
        my_experiment.add_modality(modality_name=selected_mod, save=True)
        my_experiment.set_modality_acq(modality_name=selected_mod,
                                            exp_time=exp_time,
                                            noise=noise,
                                            save=True)
    wgt2.children[0].children[-1].value = _mods_text_update(mods_text_base, my_experiment.selected_mods)

def button_method2(b):
    modalities_set = list(my_experiment.imaging_modalities.keys())
    for mod in modalities_set:
        my_experiment.update_modality(modality_name=mod, remove=True)
    wgt2.children[0].children[-1].value = _mods_text_update(mods_text_base, my_experiment.selected_mods)


static = wgen.gen_action_with_options(
    param_widget=wgt2,
    routine=preview_acquisition,
    exposure_time = ["float_slider", [0.01,0,0.05,0.001]],
    noise = ["checkbox", True],
    button1 = ["button", ["Add current parameters", button_method]],
    button2 = ["button", ["Clear all modalities", button_method2]],
    options=None,
    action_name="Preview acquisition")
button_method2(True)

grid2[:2, 2]  = static.children[0]
grid2[2:, 2] = static.children[1]
grid2

# Run image acquisition with the selected modalities and sample

In [None]:
#@title Run simulation
output_path = Path.home() / "vlab4mic_outputs"
if not os.path.exists(output_path):
    os.makedirs(output_path)

experiment_gui = EZInput(title="experiment")
my_experiment.build(modules=["imager",])
def run_simulation(b):
    experiment_gui["Acquire"].disabled = True
    sav_dir = experiment_gui["saving_directory"].value
    if sav_dir is not None:
        my_experiment.output_directory = sav_dir
        save = True
    exp_name = experiment_gui["experiment_name"].value
    my_experiment.run_simulation(name=exp_name, save=save)
    experiment_gui.save_settings()

experiment_gui.add_label("Set experiment name")
experiment_gui.add_text_area(
    "experiment_name", value="Exp_name", remember_value=True
)
experiment_gui.add_label("Set saving directory")
experiment_gui.elements["saving_directory"] = FileChooser(
    output_path,
    title="<b>Select output directory</b>",
    show_hidden=False,
    select_default=True,
    show_only_dirs=False,
)
experiment_gui.add_button("Acquire", description="Run Simulation")
experiment_gui["Acquire"].on_click(run_simulation)
experiment_gui.show()
