### NAPARI PART

In [1]:
from napari.layers import Points
import napari
import skimage.data
import skimage.filters
from napari.types import PointsData
import pandas as pd
import os
import json
import csv

from magicgui import magicgui


import datetime
from enum import Enum
from pathlib import Path

### Input goes here: 

In [2]:
gui_csv = "path_to_prediction_table.csv"
image_path = "path_to_the_image"

### Setting up:

this will load the csv ( just run it ) : 

In [None]:
gui_df = pd.read_csv(gui_csv)
gui_df

this will load some functions ( just run it ) : 

In [123]:
zyx = gui_df[['Z','Y','X']].values
preds = gui_df['prob'].values
# modify Ponts class
class FixedPoints(Points):

    def _move(self):
        """Points are not allowed to move."""
        pass
    
def prepare_output(npz_filename):
    """
    Returns centroids and ROI start needed to use write csv from synspy.

    Parameters:
        npz_filename (string): path to npz with candidates
    Returns:
        centroids: centroid coordinates in the ROI space, in pixels
        slice_origin: ROI start
    """
    parts = np.load(npz_filename)
    centroids = parts['centroids'].astype(np.int32)

    props = json.loads(parts['properties'].tostring().decode('utf8'))
    slice_origin = np.array(props['slice_origin'], dtype=np.int32)

    return centroids, slice_origin

def dump_segment_info_to_csv(centroids, measures, status, offset_origin,
                             outfilename, saved_params=None,
                             all_segments=True, zx_swap=False,
                             zyx_grid_scale=None, filter_status=None):
    """Load a segment list with manual override status values validating against expected centroid list.

       Arguments:
         centroids: Nx3 array of Z,Y,X segment coordinates
         measures: NxK array of segment measures
         status: N array of segment status
         offset_origin: CSV coordinates = offset_origin + centroid coordinates
         outfilename: file to open to write CSV content
         saved_params: dict or None if saving threshold params row
         all_segments: True: dump all, False: dump only when matching filter_status values
         zx_swap: True: input centroids are in X,Y,Z order
         zyx_grid_scale: input centroids have been scaled by these coefficients in Z,Y,X order
         filter_status: set of values to include in outputs or None implies all non-zero values
    """
    if zx_swap:
        centroids = centroids_zx_swap(centroids)
    if zyx_grid_scale is not None:
        zyx_grid_scale = np.array(zyx_grid_scale, dtype=np.float32)
        assert zyx_grid_scale.shape == (3,)
        centroids = centroids * zyx_grid_scale
    # correct dumped centroids to global coordinate space of unsliced source image
    centroids = centroids + np.array(offset_origin, np.int32)
    csvfile = open(outfilename, 'w', newline='')
    writer = csv.writer(csvfile)
    writer.writerow(
        ('Z', 'Y', 'X', 'raw core', 'raw hollow', 'DoG core', 'DoG hollow')
        + (('red',) if (measures.shape[1] == 5) else ())
        + ('override',)
    )
    if saved_params:
        writer.writerow(
            (
                'saved',
                'parameters',
                saved_params.get('X', ''),
                saved_params.get('raw core', ''),
                saved_params.get('raw hollow', ''),
                saved_params.get('DoG core', ''),
                saved_params.get('DoG hollow', ''),
            )
            + ((saved_params.get('red', ''),) if 'red' in saved_params else ())
            + (saved_params.get('override', ''),)
        )

    filter_idx = np.zeros(status.shape, dtype=np.bool)
    if all_segments:
        filter_idx += np.bool(1)
    elif filter_status is not None:
        for value in filter_status:
            filter_idx += (status == value)
    else:
        filter_idx += (status > 0)

    indices = (status > 0).nonzero()[0]

    for i in indices:
        Z, Y, X = centroids[i]
        writer.writerow( 
            (Z, Y, X) + tuple(measures[i,m] for m in range(measures.shape[1])) + (status[i] or '',)
        )
    del writer
    csvfile.close()
    
def write_to_csv(npz_file, binary_labels, outfilename):
    print(f"Gonna do it to {outfilename}")
    npz = np.load(npz_file)
    print(f"Got npz")
    placeholder_params = {'Z': 'saved',
                          'Y': 'parameters',
                          'X': '(core, vicinity, zerolvl, toplvl, transp):',
                          'raw core': '0.0',
                          'raw hollow': '3500.0',
                          'DoG core': '0.0',
                          'DoG hollow': '65535.0',
                          'override': '0.8'}
    # placeholder_origin = np.array([0,0,0], dtype='int32')
    centroids, offset_origin = prepare_output(npz_file)
    centroids = centroids[binary_labels, :]
    measures = npz['measures']  # np.zeros((centroids.shape[0], 4), dtype='int32')
    status = 7 * np.ones((centroids.shape[0]), dtype='int32')

    dump_segment_info_to_csv(centroids, measures, status, offset_origin, outfilename,
                             saved_params=placeholder_params)
    print(f"Saved to {outfilename}")
    


### Running NAPARI 
#### 1. Finding the threshold :
'f' : flip the label ( good--> or bad--> good )

'+' : increase size 

'-' : decrease size 

'b' : hide 'bad' points 

'g' : hide 'good' points

will call napari ( just run it ) : 

In [120]:
with napari.gui_qt():

    viewer = napari.Viewer()
    viewer.open(image_path, scale=resolution)
    
    n_points = zyx.shape[0]
    point_size = np.array(n_points*[2])
    visibility = np.array(n_points*[1])
    
    point_properties = {
    'good_point': preds>0.5,
    'preds': preds,
    'old_size': point_size,
    'visible': visibility
    }
    
    points_layer = viewer.add_layer(FixedPoints(
        zyx, ndim= 3,
        size=point_size,
        properties=point_properties, scale=resolution,
        face_color='good_point',
        face_color_cycle=['yellow','magenta'],
        edge_width=1, name = 'points'
    ))
     
    @viewer.bind_key('f')
    def change_color(viewer):
        for index in list(viewer.layers['points'].selected_data):
            # to avoid changing points that are hidden: 
            print(viewer.layers['points'].size[index])
            if viewer.layers['points'].size[index][0]>0:
                current_status = viewer.layers['points'].properties['good_point'][index]
                viewer.layers['points'].properties['good_point'][index] = not current_status
        viewer.layers['points'].selected_data = []
        viewer.layers['points'].refresh_colors(update_color_mapping=True)
        
    @viewer.bind_key('+')
    # make point size larger
    def make_larger(viewer):
        current_size = viewer.layers['points'].size
        is_visible = viewer.layers['points'].properties['visible']
        print(is_visible)
        print(current_size)
        new_size = (current_size+0.5)*is_visible[:,np.newaxis]
        print(new_size)
        viewer.layers['points'].size=new_size
        viewer.layers['points'].properties['old_size'] = new_size

    @viewer.bind_key('-')
    # make point size smaller
    def make_smaller(viewer):
        current_size = viewer.layers['points'].size
        is_visible = viewer.layers['points'].properties['visible']
        new_size = (current_size-0.5)*is_visible[:,np.newaxis]
        viewer.layers['points'].size=new_size
        viewer.layers['points'].properties['old_size'] = new_size
        
    @viewer.bind_key('b')
    # make yellow points disappear 
    # "v" while having points selected makes all disappear 
    def hide_yellow_points(viewer):
        """Hides the yellow points"""
        # bad_point = don't touch then, good_point = make it appear/disappear 
        bad_point= viewer.layers['points'].properties['good_point']
        good_point = np.logical_not(bad_point)
        
        bad_size = np.unique(viewer.layers['points'].size[bad_point])[0]
        good_size = np.unique(viewer.layers['points'].size[good_point])[0]
        old_good_size = np.unique(viewer.layers['points'].properties['old_size'][good_point])[0]
        print(old_good_size)
        
        print(f"Bad size : {bad_size}")
        print(f"Good size : {good_size}")
        
        if good_size==0:
            viewer.layers['points'].properties['visible'][good_point] = 1
            if bad_size>0:
                viewer.layers['points'].size[good_point] = bad_size
            else:
                viewer.layers['points'].size[good_point] = old_good_size
        else:
            viewer.layers['points'].properties['visible'][good_point] = 0
            viewer.layers['points'].size[good_point] = 0
        viewer.layers['points'].refresh()
    
    @viewer.bind_key('g')
    # make magenta points disappear 
    # "v" while having points selected makes all disappear 
    def hide_magenta_points(viewer):
        """Hides the magenta points"""
        
        good_point = viewer.layers['points'].properties['good_point']
        bad_point = np.logical_not(good_point)
        
        bad_size = np.unique(viewer.layers['points'].size[bad_point])[0]
        good_size = np.unique(viewer.layers['points'].size[good_point])[0]
        old_good_size = np.unique(viewer.layers['points'].properties['old_size'][good_point])[0]
        print(old_good_size)
        
        print(f"Bad size : {bad_size}")
        print(f"Good size : {good_size}")
        
        if good_size==0:
            viewer.layers['points'].properties['visible'][good_point] = 1
            if bad_size>0:
                viewer.layers['points'].size[good_point] = bad_size
            else:
                viewer.layers['points'].size[good_point] = old_good_size
        else:
            viewer.layers['points'].properties['visible'][good_point] = 0
            viewer.layers['points'].size[good_point] = 0
        viewer.layers['points'].refresh()
            
            
    # thresholding:    
    @magicgui(auto_call=True)
    def prob_threshold_box(threshold_prc = 50):
        """ Gets rid of weak and old: only changes the color for now"""
        threshold = threshold_prc/100
        preds = viewer.layers['points'].properties['preds']
        viewer.layers['points'].properties['good_point'] = preds>threshold
        viewer.layers['points'].refresh_colors(update_color_mapping=True)
     
    @magicgui(
    call_button="Save Segmentation"
    )
    def save_segmentation(save_file = Path('/save/path.ext'),npz_file = Path('')):   
        status = viewer.layers['points'].properties['good_point']
        write_to_csv(npz_file, status, save_file)

        
    # Add slider
    viewer.window.add_dock_widget(save_segmentation, area='right')
    viewer.window.add_dock_widget(prob_threshold_box,area='right')
        