# Visualize Localizations

Goal:
- Visualize DeepSTORM prediction output (widefield image, predicted image, found localizations).
- Compare ground truth and found localizations.
- Filter localizations based on confidence threshold and save them as new localization file, ready to be rendered as a high-resolution image in ThunderSTORM.

*Prediction directory:* Define the directory to test file single frames and localization files (prediciton output of DeepSTORM2D, sec 6.1), optionally this directory contains the ground truth localization file.<br/>
*Confidence threshold:* All localizations below this threshold will be filtered out. If no filtering should be applied, set the value to 0.<br/>
*Save:* The localizations are filtered by the confidence threshold and saved as new csv file.<br/>

In [1]:
from ImageBinner.widgets import widgetsVisLocs, loadVisLocs
from ImageBinner.tools import visLocs
from ImageBinner.save import saveVisLocs
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import plotly
import plotly.graph_objects as go
import plotly.io as pio
import numpy as np

### Load

In [2]:
prediction_dir = widgetsVisLocs.DefineDirectory("Prediction", value=r"")
prediction_dir.dir_button.on_click(prediction_dir.open_dir)
prediction_dir.dir_box.observe(prediction_dir.change_dir_box)
display(prediction_dir.dir_box, prediction_dir.dir_button)

Text(value='', description='Prediction', placeholder='directory to be searched in', style=DescriptionStyle(des…

Button(description='browse', style=ButtonStyle(), tooltip='browse for directory')

### Parameters

In [3]:
widget_parameters = widgetsVisLocs.Parameters(107, 0)  # adjust the default parameters
display(widget_parameters.pixel_size)

Text(value='107', description='Pixel size [nm]', placeholder='Insert pixel size', style=DescriptionStyle(descr…

### Display

In [4]:
def scroll_in_time(frame, show_single=True):
    fig, (ax1, ax2) = plt.subplots(1,2)
    fig.tight_layout()
    movie_idx = int(widget_parameters.display_movie_idx.value)-1
    fig.set_size_inches(14, 14, forward=True)
    ax1.imshow(data.movies[movie_idx][frame-1], interpolation="nearest", cmap = "gray")
    ax2.imshow(data.movies[movie_idx][frame-1], interpolation="nearest", cmap = "gray")
    x, y = imageBinner.xy_loc_vis(data.loc_files[movie_idx][data.loc_files[movie_idx]["frame"] == frame], int(widget_parameters.pixel_size.value))
    ax2.scatter(x=x, y=y, marker="x")
    ax2.set_xlim([-0.5, data.movies[movie_idx].shape[1]-0.5])
    ax2.set_ylim([data.movies[movie_idx].shape[2]-0.5, -0.5])
    ax1.axis("off"); ax2.axis("off")
    plt.show()
    if show_single:
        display(data.loc_files[movie_idx][data.loc_files[movie_idx]["frame"] == frame])


def scroll_in_time(frame):
    if len(data.gt_locs):
        # Show WIDEFIELD, PREDICTED, GT LOCS, PRED LOCS
        fig, axs = plt.subplots(2,2)
        fig.tight_layout()
        fig.set_size_inches(14, 14, forward=True)
        axs[0, 0].imshow(data.gt_tifs[frame-1], interpolation="nearest", cmap = "gray")
        axs[0, 1].imshow(data.predicted_stack[frame-1], interpolation="nearest", cmap = "gray", vmin=0, vmax=1)
        axs[1, 0].imshow(data.gt_tifs[frame-1] ,interpolation="nearest", cmap = "gray")
        x, y = visLocs.xy_loc_vis(data.gt_locs[data.gt_locs["frame"] == frame], int(widget_parameters.pixel_size.value))
        axs[1, 0].scatter(x=x, y=y, marker="x", s=3)
        axs[1, 1].imshow(data.predicted_stack[frame-1] ,interpolation="nearest", cmap = "gray")
        x, y = visLocs.xy_loc_vis(data.predicted_locs[data.predicted_locs["frame"] == frame], int(widget_parameters.pixel_size.value))
        cm = plt.cm.get_cmap("plasma")
        sc = axs[1, 1].scatter(x=x, y=y, c=list(data.predicted_locs[data.predicted_locs["frame"] == frame]["confidence [a.u]"]), vmin=0, vmax=1, cmap=cm, marker="x", s=3)  # c=ds_locs["confidence [a.u]"]
        divider = make_axes_locatable(axs[1, 1])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        fig.colorbar(sc, ax=axs[1, 1], cax=cax)
        axs[1, 1].set_xlim(-0.5, data.gt_tifs[frame-1].shape[1]-0.5)
        axs[1, 1].set_ylim(data.gt_tifs[frame-1].shape[0]-0.5, -0.5)
        axs[0, 0].axis("off"); axs[0, 1].axis("off"); axs[1, 0].axis("off"); axs[1, 1].axis("off"); 
        axs[0, 0].title.set_text("Widefield image")
        axs[0, 1].title.set_text("Predicted image")
        axs[1, 0].title.set_text("Widefield image & gt localizations")
        axs[1, 1].title.set_text("Predicted image & predicted localizations")
    else:
        # Show WIDEFIELD, PREDICTED, PRED LOCS
        fig, axs = plt.subplots(2,2)
        fig.tight_layout()
        fig.set_size_inches(14, 14, forward=True)
        axs[0, 0].imshow(data.gt_tifs[frame-1], interpolation="nearest", cmap = "gray")
        axs[0, 1].imshow(data.predicted_stack[frame-1], interpolation="nearest", cmap = "gray", vmin=0, vmax=1)
        axs[1, 1].imshow(data.predicted_stack[frame-1] ,interpolation="nearest", cmap = "gray")
        x, y = visLocs.xy_loc_vis(data.predicted_locs[data.predicted_locs["frame"] == frame], int(widget_parameters.pixel_size.value))
        cm = plt.cm.get_cmap("plasma")
        sc = axs[1, 1].scatter(x=x, y=y, c=list(data.predicted_locs[data.predicted_locs["frame"] == frame]["confidence [a.u]"]), vmin=0, vmax=1, cmap=cm, marker="x", s=3)  # c=ds_locs["confidence [a.u]"]
        divider = make_axes_locatable(axs[1, 1])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        fig.colorbar(sc, ax=axs[1, 1], cax=cax)
        axs[1, 1].set_xlim(-0.5, data.gt_tifs[frame-1].shape[1]-0.5)
        axs[1, 1].set_ylim(data.gt_tifs[frame-1].shape[0]-0.5, -0.5)
        axs[0, 0].axis("off"); axs[0, 1].axis("off"); axs[1, 0].axis("off"); axs[1, 1].axis("off"); 
        axs[0, 0].title.set_text("Widefield image")
        axs[0, 1].title.set_text("Predicted image")
        axs[1, 1].title.set_text("Predicted image & predicted localizations")
    plt.show()


def get_px_vals_prediction_stacked(predicted_tif, filter_threshold=0, n_bins=None, n_px=500000):
    predicted_flat = data.predicted_stack.flatten()
    predicted_filtermask= predicted_flat > filter_threshold
    predicted_filtered = predicted_flat[predicted_filtermask]
    np.random.shuffle(predicted_filtered)
    predicted_filtered = predicted_filtered[:n_px]
    fig = go.Figure(data=[go.Histogram(x=predicted_filtered, histnorm="probability", nbinsx=n_bins)])  
    fig.update_layout(template="plotly_white", title="Pixel values of predicted image (per frame)", xaxis_title="pixel value [a.u]", yaxis_title="Probability")
    fig.show()

def run(event):
    widget_run.create_clear_output()
    display(widget_run.run_load_button)
    data.get_files(prediction_dir.dir_box.value)
    interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=data.gt_tifs.shape[0], step=1, value=0, continuous_update=False))
    print("*"*100)
    visLocs.vis_wf_predicted_images(data.widefield_tif, data.predicted_tif)
    visLocs.confidence_hist(data.predicted_locs)
    visLocs.vis_density(data.gt_tifs.shape, data.gt_locs ,data.predicted_locs, int(widget_parameters.pixel_size.value))
    # adjust the number of bins (n_bins=integer), the filter threshold (everything below is ignored, float) and the number of processed px values (n_px=integer)
    get_px_vals_prediction_stacked(data.predicted_stack, filter_threshold=0, n_bins=None, n_px=500000)
    
data = loadVisLocs.LoadFiles()
widget_run = widgetsVisLocs.RunAnalysis()
display(widget_run.run_load_button)
widget_run.run_load_button.on_click(run)

Button(description='load & display', style=ButtonStyle(), tooltip='display input data')

### Filter & Save

Filter out localizations below confidence threshold [0,1] and save as new csv file.

In [5]:
display(widget_parameters.confidence_threshold)

Text(value='0', description='Confidence threshold', placeholder='locs below are filtered out', style=Descripti…

In [6]:
widget_save = widgetsVisLocs.SaveResults()
display(widget_save.save_button)

def save_analysis(event):
    widget_save.create_clear_output()
    display(widget_save.save_button)
    saveVisLocs.save(prediction_dir.dir_box.value, float(widget_parameters.confidence_threshold.value), data.predicted_locs)

widget_save.save_button.on_click(save_analysis)

Button(description='save', style=ButtonStyle(), tooltip='save the results')