In [1]:
# Uncomment the lines below to install the required dependencies. 
#!pip install opencv_python_headless==4.1.2.30 --exists-action i -q
#!pip install git+https://github.com/AllenInstitute/aics-segmentation@d1c666b26e901cbe6d7c74f57eafefa5af9443a5 --exists-action i -q

[31mERROR: itkwidgets 0.22.0 has requirement itk-meshtopolydata>=0.5.1, but you'll have itk-meshtopolydata 0.3.4 which is incompatible.[0m


In [None]:
"""
The notebook consits of two parts, both in different cells.

Cell 1 : 

Author: Nick Schaub (nick.schaub@nih.gov)

Description: This cell contains WippPy which consists of classes for interacting with WIPP.

"""

import json as json_lib
import requests, copy, re
from pathlib import Path
import logging

# Initialize the loggercli
logging.basicConfig(format='%(asctime)s - %(name)-8s - %(levelname)-8s - %(message)s',
                    datefmt='%d-%b-%y %H:%M:%S')
logger = logging.getLogger("wipp")
logger.setLevel(logging.WARNING)

# Initialize the WippData Class
class WippData(object):
    """ Wipp data superclass
    
    This class should be implemented by all Wipp data type classes.
    
    """
    _entry_point = None
    _data_type_name = None
    api_route = 'http://wipp-ui.ci.aws.labshare.org/api/'
    _headers = {'Content-Type': 'application/json'}
    _logger = logging.getLogger('wipp.Data')
    
    def __init__(self,name=None,create=False,json=None,api_route=None,**kwargs):
        if api_route != None:
            self._logger.info('api_route: {}'.format(api_route))
            self.api_route = api_route
        if create and name != None:
            self._logger.info('create(): {}'.format(name))
            kwargs['data'] = {'name': name}
            self.json = self.create(**kwargs)
        elif 'data' in kwargs:
            self._logger.info('create(): attempting to create instance of {}'.format(self.__class__.__name__))
            self.json = self.create(**kwargs)
        elif json:
            self._logger.debug('creating object using json')
            self.json = json
        else:
            self.json = self._get(**kwargs)
    
        if self.json!=None:
            for key,value in self.json.items():
                self._logger.debug('setattr(): {}={}')
                setattr(self,key,value)
        
    def __repr__(self):
        return f'{self.name} (id: {self.id})'
    
    def delete(self):
        self._logger.info('delete(): {}'.format(self.api_route + self._entry_point + '/' + self.id))
        requests.delete(self.api_route + self._entry_point + '/' + self.id)
        
    def create(self,**kwargs):
        self._logger.info('create(): {}'.format(self.api_route + self._entry_point))
        return self._post(**kwargs)
        
    def _post(self,**kwargs):
        self._logger.debug('_post(): endpoint={}'.format(self.api_route + self._entry_point))
        
        config = {key:value for key,value in kwargs.items() if key in ['params','headers']}
        if 'data' in kwargs:
            config['data'] = json_lib.dumps(kwargs['data'])
        if 'headers' not in config.keys():
            config['headers'] = self._headers
        
        for key,val in config.items():
            self._logger.debug('_post(): {}={}'.format(key,val))
        
        if 'entrypoint' not in kwargs.keys():
            entrypoint = self._entry_point
        else:
            entrypoint = kwargs['entrypoint']
            
        r = requests.post(self.api_route + entrypoint,**config)
        self._logger.debug('_post(): status_code={}'.format(r.status_code))
        if r.status_code==201 or r.status_code==200:
            return r.json()
        elif r.status_code==409:
            self._logger.warning('_post(): Plugin already exists.')
        else:
            self._logger.critical('_post(): message={}'.format(r.text),exc_info=True)
            raise ValueError(self.__class__.__name__ + ' Error (status code {}): {}'.format(r.status_code,r.text))
        
    def _get(self,entrypoint=None,**kwargs):
        if entrypoint==None:
            entrypoint=self._entry_point
        self._logger.debug('_get(): endpoint={}'.format(self.api_route + entrypoint))
        
        config = {key:value for key,value in kwargs.items() if key in ['params','headers','data']}
        if 'data' in kwargs:
            config['data'] = json_lib.dumps(kwargs['data'])
        if 'headers' not in config.keys():
            config['headers'] = self._headers
        
        for key,val in config.items():
            self._logger.debug('_get(): {}={}'.format(key,val))
        
        r = requests.get(self.api_route + entrypoint,**config)
        self._logger.debug('_get(): status_code={}'.format(r.status_code))
        if r.status_code==200:
            return r.json()
        else:
            self._logger.critical('_get(): message={}'.format(r.text))
            raise ValueError(self.__class__.__name__ + ' Error (status code {}): {}'.format(r.status_code,r.text))
            
    @classmethod
    def setWippUrl(cls,url):
        cls._logger.info('setWippUrl(): {}'.format(url))
        cls.api_route = url
            
    @classmethod
    def all(cls,entry_point=False):
        """Get all instances of a data type

        Args:
            cls: Class reference for handling a WIPP data type
            entry_point: API entry point, appended to api path

        Returns:
            A dictionary, where the keys are hashes referencing a data
            instance and values are data_class objects.
        """
        if not entry_point:
            entry_point = cls._entry_point
        cls._logger.info('all(): getting all instances...')
        page = 0
        numel = 1000
        r = requests.get(cls.api_route + entry_point,params={'page':page,'size':numel})
        if r.status_code==200:
            all_data = r.json()['_embedded'][cls._entry_point]
            data = {}
            for datum in all_data:
                data[datum['id']] = cls(json=datum)
                cls._logger.debug('all(): object={}'.format(data[datum['id']]))
            for i in range(r.json()['page']['totalPages']-1):
                page += 1
                r = requests.get(cls.api_route + entry_point,params={'page':page,'size':numel})
                if r.status_code==200:
                    all_data = r.json()['_embedded'][cls._entry_point]
                    data = {}
                    for datum in all_data:
                        data[datum['id']] = cls(json=datum)
                        cls._logger.debug('all(): object={}'.format(data[datum['id']]))
        else:
            data = {}
        return data
    
    @classmethod
    def get_by_id(cls,oid):
        """Get data by id

        Args:
            cls: Class reference for handling a WIPP data type
            oid: Hash reference of data to access

        Returns:
            An object of type cls
        """    
        cls._logger.debug('get_by_id(): oid={}'.format(oid))
        r = requests.get(cls.api_route + cls._entry_point + '/' + oid)
        if r.status_code==200:
            instance = cls(json=r.json())
        else:
            cls._logger.warning('get_by_id(): returning NoneType')
            instance = None
        return instance
    
    @staticmethod
    def get_name(dtype,value):
        """ Get the name of a data instance

        Args:
            dtype: WIPP data type
            value: Unique hash reference

        Returns:
            A string containing the human readable dataset name
        """
        for cls in WippData.__subclasses__():
            if dtype==cls._entry_point:
                cls._logger.debug('get_name(): finding object associated with id={}'.format(value))
                return cls.all()[value].name
    
class WippJob(WippData):
    """ Class to handle WIPP Jobs

    Attributes:
        name: the name given to the WIPP job
        id: a unique hash assigned to the WIPP job
        json: The raw json returned by the WIPP Job backend query
        status: execution status of the WIPP job
        plugin_id: a unique hash assigned to the plugin used by the WIPP job
        plugin_name: the name of the WIPP plugin executed by the job
        inputs: the plugin input keys and values for the job
        outputs: the plugin output keys and values for the job

    Class Methods:
        get_all(): Returns a dictionary of all jobs {job hash: WippJob object}
        get(jid): Returns job with hash equal to jid

    Object Methods:
        delete(): Delete the job from WIPP.
        create(): Create the job in WIPP.
    """
    _entry_point = 'jobs'
    _data_type_name = 'Job'
    _logger = logging.getLogger('wipp.Data.WippJob')

    # Job template
    _payload = {'name': '',            # name of job
                'wippExecutable': '',  # plugin id
                'type': '',            # name of the plugin
                'dependencies': [],    # job ids for dependencies
                'parameters': {},      # dictionary of parameters
                'outputParameters': {},# dictionary of output parameters
                'wippWorkflow': ''}    # wipp workflow id
        
    def __repr__(self):
        return f'{self.name} (id: {self.id})'
    
class WippWorkflow(WippData):
    """ Class to handle WIPP Workflows

    Attributes:
        name: the name given to the WIPP workflow
        id: a unique hash assigned to the WIPP workflow
        json: the raw json returned by the WIPP Workflow backend query
        status: the execution status of the workflow
        link: a url to the backend workflow json
    
    Class Methods:
        get_all(): Returns a dictionary of all workflows{workflow hash: WippWorkflow object}
        get(wid): Returns workflow with hash equal to wid
    
    Object Methods:
        delete(): Delete the workflow from WIPP.
        create(): Create the workflow in WIPP.
        jobs(): Returns dictionary of all jobs in workflow, {job hash: WippJob object}
    """
    _entry_point = 'workflows'
    _data_type_name = 'Workflow'
    _logger = logging.getLogger('wipp.Data.WippWorkflow')
    
    def jobs(self):
        self._logger.info('jobs(): Getting all jobs for workflow={}'.format(self.id))
        if self.id not in WippWorkflow.all().keys():
            self._logger.critical('jobs(): Could not find workflow jobs')
            raise KeyError('Invalid workflow id.')
        
        r = self._get(entrypoint='jobs/search/findByWippWorkflow?wippWorkflow='+self.id)
        
        jobs = {job_json['id']:WippJob(json=job_json) for job_json in r['_embedded']['jobs']}
        for job in jobs.values():
            self._logger.debug('jobs(): job='.format(job))
            
        return jobs
    
    def update(self):
        self._logger.info('update(): updating workflow - {}'.format(self.name))
        wf = WippWorkflow.get_by_id(self.id)
        for key,value in wf.json.items():
            setattr(self,key,value)
    
    def submit(self):
        self._post(entrypoint='workflows/' + self.id + '/submit',parameters={'wippWorkflow': self.id})
    
    def add_job(self,plugin_name,job_name,inputs,plugin_version=None):
        payload = copy.deepcopy(WippJob._payload)
        plugin = WippPlugin.get_by_name(plugin_name,plugin_version)
        dependency_pattern = r'\{\{ (.*)\.(.*) \}\}'
        self._logger.info('add_job(): job_name={}, plugin_name={}, plugin_version={}'.format(job_name,plugin_name,plugin_version))
        
        # Add basic info to the payload
        payload['name'] = job_name
        payload['wippExecutable'] = plugin.id
        payload['type'] = plugin.name
        payload['wippWorkflow'] = self.id
        
        # validate and set inputs
        for inp in plugin.inputs:
            if inp['name'] not in inputs.keys() and inp['required']:
                self._logger.critical('add_job(): Missing input {} for plugin {}'.format(inp['name'],plugin.name))
            elif inp['name'] not in inputs.keys():
                continue
            self._logger.debug('add_job(): {}={}'.format(inp['name'],inputs[inp['name']]))
            payload['parameters'][inp['name']] = inputs[inp['name']]
            
            # If input has {{ }}, then it has a dependency
            if isinstance(inputs[inp['name']],str):
                dependency = re.match(dependency_pattern,inputs[inp['name']])
                if dependency != None:
                    self._logger.info('add_job(): adding dependency {}'.format(dependency.groups()[0]))
                    payload['dependencies'].append(dependency.groups()[0])
        
        job = WippJob(data=payload)
        
        return job

class WippImageCollection(WippData):
    """ Class to handle WIPP Image Collections

    Attributes:
        name: the name given to the WIPP Image Collection
        id: a unique hash assigned to the WIPP Image Collection
        json: the raw json returned by the WIPP Workflow backend query
    
    Class Methods:
        get_all(): Returns a dictionary of all image collections, {image collection hash: WippImageCollection object}
        get(icid): Returns image collection with hash equal to icid
        get_by_name(ic_name): Returns the first result of a search of image collections matching ic_name
    
    Object Methods:
        create(): Create the workflow in WIPP.
        images(): Return a list of dictionaries containing information on every image in the collection
    """
    _entry_point = 'imagesCollections'
    _data_type_name = 'Image Collection'
    _images = []
    _logger = logging.getLogger('wipp.Data.WippImageCollection')
    
    def delete(self):
        """ Throw an error only if image collection is locked"""
        self._logger.info('delete(): deleting image collection - {}'.format(self.name))
        if self.locked:
            self._logger.critical('delete(): Cannot delete locked image collection.')
            raise PermissionError('Cannot delete locked image collection.')
        else:
            super().delete()
    
    @classmethod
    def get_by_name(cls,ic_name):
        cls._logger.info('get_by_name(): getting image collection - {}'.format(ic_name))
        r = requests.get(cls.api_route + cls._entry_point + '/search/findByName',params={'name':ic_name})
        cls._logger.debug('get_by_name(): status_code={}'.format(r.status_code))
        if r.status_code==200:
            imageCollection = cls(json=r.json()['_embedded'][cls._entry_point][0])
        else:
            imageCollection = []
        return imageCollection
    
    def update(self):
        self._logger.info('update(): updating image collection - {}'.format(self.name))
        ic = WippImageCollection.get_by_name(self.name)
        for key,value in ic.json.items():
            setattr(self,key,value)

    def add_image(self,file_path):
        self._logger.info('add_image(): file_path={}'.format(file_path))
        if self.locked:
            self._logger.info('add_image(): cannot add image to locked collection')
            raise PermissionError('Cannot add images to locked collection.')
        if not isinstance(file_path,Path):
            file_path = Path(file_path)
        return WippImage(self.id,file_path)
    
    def lock(self):
        self._logger.info('lock(): locking imaging collection - {}'.format(self.name))
        r = requests.patch(self.api_route + self._entry_point + '/' + self.id,
                           headers={'Content-Type': 'application/json'},
                           data=json_lib.dumps({'locked': True}))
        self._logger.debug('lock(): status_code={}'.format(r.status_code))
        
    def images(self):
        self._logger.info('images(): getting all images for image collection - {}'.format(self.name))
        if len(self._images) > 0 and self.locked:
            return self._images
        page = 0
        numel = 1000
        images = []
        r = self._get(entrypoint=self._entry_point + '/' + self.id + '/images',
                      params={'page':page,'size':numel})
        if '_embedded' not in r.keys():
            return images
        
        images = r['_embedded']['images']
        
        for i in range(r['page']['totalPages']-1):
            page += 1
            r = requests._get(self.api_route + self._entry_point + '/' + self.id + '/images',
                              params={'page':page,'size':numel})
            images.extend(r['_embedded']['images'])
        self._images = images
        return images
    
class WippCsvCollection(WippData):
    """ Class to handle WIPP Csv Collections

    Attributes:
        name: the name given to the WIPP csv collection
        id: a unique hash assigned to the WIPP csv collection
        json: the raw json returned by the WIPP csv collection backend query
        
    Class Methods:
        all(): Returns a dictionary of all csv collections, {csv collection hash: WippCsvCollection object}
    
    Object Methods:
        delete(): Delete the csv collection from WIPP.
        create(): Create the csv collection in WIPP.
    """
    _entry_point = 'csvCollections'
    _data_type_name = 'CSV Collection'
    _logger = logging.getLogger('wipp.Data.WippCsvCollection')
    
class WippNotebook(WippData):
    """ Class to handle WIPP Notebook

    Attributes:
        name: the name given to the WIPP notebook
        id: a unique hash assigned to the WIPP notebook
        json: the raw json returned by the WIPP notebook backend query
        
    Class Methods:
        all(): Returns a dictionary of all notebooks, {notebook hash: WippNotebook object}
    
    Object Methods:
        delete(): Delete the notebook from WIPP.
        create(): Create the notebook in WIPP.
    """
    _entry_point = 'notebooks'
    _data_type_name = 'Notebook'
    _logger = logging.getLogger('wipp.Data.WippNotebook')

class WippStitchingVector(WippData):
    """ Class to handle WIPP Stitching Vectors

    Attributes:
        name: the name given to the WIPP stitching vector
        id: a unique hash assigned to the WIPP stitching vector
        json: the raw json returned by the WIPP stitching vector backend query
    
    Class methods:
        all(): Returns a dictionary of all stitching vectors, {stitching vector hash: WippStitchingVector object}
        
    Object Methods:
        delete(): Delete the stitching vector from WIPP.
        create(): Create the stitching vector in WIPP.
    """
    _entry_point = 'stitchingVectors'
    _data_type_name = 'Stitching Vector'
    _logger = logging.getLogger('wipp.Data.WippStitchingVector')

class WippPyramid(WippData):
    """ Class to handle WIPP Pyramid

    Attributes:
        name: the name given to the WIPP pyramid
        id: a unique hash assigned to the WIPP pyramid
        json: the raw json returned by the WIPP pyramid backend query
    
    Class Methods:
        all(): Returns a dictionary of all pyramids, {pyramid hash: WippPyramid object}
        
    Object Methods:
        delete(): Delete the pyramid from WIPP.
        create(): Create the pyramid in WIPP.
    """
    _entry_point = 'pyramids'
    _data_type_name = 'Image Pyramid'
    _logger = logging.getLogger('wipp.Data.WippPyramid')

class WippImage(object):
    """ Class to handle WIPP Images

    Unlike most other WIPP classes, the WippImage class acts very differently from the
    other data types. Part of this comes from images being a child of an image collection
    and therefore necessitates attachment to a WippImageCollection id.

    In general, the best way to instantiate this class is through an WippImageCollection
    object using either the images() method to get all images in a collection or the
    add_image() method to prepare an image to upload to an unlocked collection.

    Attributes:
        To be determined
    
    Class Methods:
        To be determined
    
    Object Methods:
        delete(): Delete the image from an unlocked collection in WIPP.
        send(): Send the image to WIPP.
    """
    _entry_point = 'imagesCollections/{}/images'
    _data_type_name = 'Image'
    _flowChunkSize = 1048576
    _logger = logging.getLogger('wipp.Data.WippImage')

    def __init__(self,ic_id,file_path):
        self._entry_point = WippData.api_route + self._entry_point.format(ic_id)
        self.file_path = Path(file_path)
        if not self.file_path.is_file():
            self._logger.critical('__init__(): could not find file - {}'.format(str(self.file_path.absolute())))
            raise FileNotFoundError('Could not find file: {}'.format(str(self.file_path.absolute())))

        with open(self.file_path,'rb') as in_file:
            in_file.seek(0,2)
            self._flowTotalSize = in_file.tell()
            self._flowTotalChunks = self._flowTotalSize//self._flowChunkSize
        
        self._flowFilename = self.file_path.name

        self.params = {'flowChunkNumber': 1,
                       'flowChunkSize': self._flowChunkSize,
                       'flowCurrentChunkSize': self._flowChunkSize,
                       'flowTotalSize': self._flowTotalSize,
                       'flowIdentifier': str(self._flowTotalSize) + '-' + self.file_path.name.replace('.',''),
                       'flowFilename': self.file_path.name,
                       'flowRelativePath': self.file_path.name,
                       'flowTotalChunks': self._flowTotalChunks}
        
        for key,val in self.params.items():
            self._logger.debug('__init__(): {}={}'.format(key,val))
        
    def get_name(self):
        self._logger.info('get_name(): name={}'.format(self.file_path.name))
        return self.file_path.name
        
    def set_name(self,name):
        self._logger.info('set_name(): name={}'.format(name))
        suffix = ''.join(self.file_path.suffixes)
        name = name.split('.')[0] + suffix
        self.params['flowIdentifier'] = str(self._flowTotalSize) + '-' + name.replace('.','')
        self._logger.debug('set_name(): flowIdentifier={}'.format(self.params['flowIdentifier']))
        self.params['flowFilename'] = name
        self._logger.debug('set_name(): flowFilename={}'.format(self.params['flowFilename']))
        self.params['flowRelativePath'] = name
        self._logger.debug('set_name(): flowRelativePath={}'.format(self.params['flowRelativePath']))
        
    def send(self):
        self._logger.info('send(): file={}'.format(self.file_path))
        with open(self.file_path,'rb') as in_file:
            for chunk in range(1,self._flowTotalChunks):
                self._logger.debug('send(): sending chunk {} of {} for file {}'.format(chunk,self._flowTotalChunks,self.file_path))
                self.params['flowChunkNumber'] = chunk
                for retry in range(0,10):
                    try:
                        r = requests.post(self._entry_point,
                                        params=self.params,
                                        headers={'Content-Type': 'image/tiff'},
                                        data=in_file.read(1048576))
                        break
                    except:
                        if retry==9:
                            print('{}: Reached max tries.'.format(self.params['flowFilename']))
                            raise
                        print('{}: There was an upload error, will retry in 3 seconds (try {})'.format(self.params['flowFilename'],retry+1))
                        in_file.seek(-1048576,1)
                        time.sleep(3)
            self.params['flowChunkNumber'] = self._flowTotalChunks
            self.params['flowCurrentChunkSize'] = self._flowTotalSize-in_file.tell()
            self._logger.debug('send(): sending chunk {} of {} for file {}'.format(self._flowTotalChunks,self._flowTotalChunks,self.file_path))
            r = requests.post(self._entry_point,
                              params=self.params,
                              headers={'Content-Type': 'image/tiff'},
                              data=in_file.read(self._flowTotalSize-in_file.tell()))

class WippTensorflowModel(WippData):
    """ Class to handle WIPP Tensorflow Models

    Attributes:
        name: the name given to the WIPP tensorflow model
        id: a unique hash assigned to the WIPP tensorflow model
        json: the raw json returned by the WIPP tensorflow model backend query
    
    Class Methods:
        all(): Returns a dictionary of all models, {tensorflow model hash: WippTensorflowModel object}
    
    Object Methods:
        delete(): Delete the tensorflow model from WIPP.
        create(): Create the tensorflow model in WIPP.
    """
    _entry_point = 'tesorflowModels'
    _data_type_name = 'Tensorflow Models'
    _logger = logging.getLogger('wipp.Data.WippTensorflowModel')
        
class WippPlugin(WippData):
    """ Class to handle WIPP Plugins

    Attributes:
        name: the name given to the WIPP plugin
        id: a unique hash assigned to the WIPP plugin
        json: the raw json returned by the WIPP plugin backend query
        version: the plugin version
        inputs: a dictionary containing plugin inputs and settings
        outputs: a dictionary containing plugin output types and settings
        ui: a dictionary containing ui settings
    
    Class Methods:
        all(): Returns a dictionary of all plugins, {plugin hash: WippPlugin object}
        
    Object Methods:
        delete(): Delete the plugin from WIPP.
        create(): Create the plugin in WIPP.
    """
    _entry_point = 'plugins'
    _data_type_name = 'Plugin'
    _logger = logging.getLogger('wipp.Data.WippPlugin')

    # Get the newest plugin that matches a plugin name
    @classmethod
    def get_by_name(cls,name,version=None):
        cls._logger.info('get_by_name(): name={}, version={}'.format(name,version))
        all_plugins = cls.all().values()
        matching_plugins = [p for p in all_plugins if p.name==name]
        
        # If there are no matching plugins, throw an error
        if len(matching_plugins)==0:
            raise ValueError('No plugins match the supplied name: {}'.format(name))
        
        # If no version provided, get the latest version of the plugin
        if version == None:
            # If only one plugin matches, return that
            if len(matching_plugins)==1:
                return matching_plugins[0]
            
            version = [0,0,0] # major, minor, patch
            
            for p in reversed(matching_plugins):
                c_ver = re.match(r"([0-9]+).([0-9]+).([0-9]+)-?(.*)?",p.version)
                for i in range(3):
                    v = version[i]
                    c = c_ver.groups()[i]
                    if int(c) > v:
                        plugin = p
                        version = [int(v) for v in p.version.split('.')]
                        break
                    elif int(c) < v:
                        break
        # Return specified version of plugin
        else:
            for p in reversed(matching_plugins):
                if p.version==version:
                    return p
            # If the specified version could not be found, throw an error
            raise ValueError('Version {} of plugin {} was not found in WIPP. Try installing it.'.format(version,name))
            
        return plugin
        
    def __repr__(self):
        return f'{self.name} (id: {self.id}, version: {self.version})'
    
    @classmethod
    def all(cls):
        return super().all(entry_point='plugins/')
    
    @classmethod
    def install(cls,json):
        cls._logger.info('install(): installing plugin...')
        cls._logger.debug('install(): json={}'.format(json))
        p = cls(data=json)

In [None]:
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
import cv2
import numpy as np
from aicssegmentation.core.visual import seg_fluo_side_by_side,  single_fluorescent_view, segmentation_quick_view
from aicsimageio import AICSImage
from aicssegmentation.core.vessel import filament_2d_wrapper
from aicssegmentation.core.pre_processing_utils import intensity_normalization, image_smoothing_gaussian_3d
from aicssegmentation.core.utils import get_middle_frame, hole_filling, get_3dseed_from_mid_frame
from skimage.morphology import remove_small_objects, watershed, dilation, ball
import json
import markdown



IMG = []
segmented_images = []
checkbox_dict = {}
z_value_dict={}
toggle_dict={}
z_sliders = []
image_list = []

def get_selection_id(selection):
    selection_id = re.match(r".* \(id: ([0-9A-Za-z]+).*\)",selection)
    return selection_id.groups()[0]

def fig2data ( fig ):
    """ This function converts a matplotlib plot to an rbg image
    
    Input: 
        fig: matplotlib figure
    Output:
        data: rbg image for the input plot
    """    
    
    fig.canvas.draw ( )    
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))  
    return data

def parse_list_string(list_str):
    list_str = list_str.replace(' ', '')
    list_str = list_str.replace('[','').replace(']','')
    result_list = [ float(i) for i in list_str.split(',')]
    return result_list

def parse_nested_list_string(nested_str):

    list_string = nested_str.replace(" ","")
    list_string = list_string[1:-1].split('],[')
    nested_list = []

    for string in list_string:
        temp_string = []
        string = string.replace('[','').replace(']','')
        temp_string = [ float(i) for i in string.split(',')]
        nested_list.append(temp_string)
    return nested_list


def segment_image(image):
    global metadata_textbox
    
    structure_channel = 0
    struct_img0 = image[0,structure_channel,:,:,:].copy()

    intensity_scaling_param = config_data['intensity_scaling_param']
    struct_img = intensity_normalization(struct_img0, scaling_param=intensity_scaling_param) 

    gaussian_smoothing_sigma = config_data['gaussian_smoothing_sigma'] 
    structure_img_smooth = image_smoothing_gaussian_3d(struct_img, sigma=gaussian_smoothing_sigma)

    middle_frame_method = config_data['middle_frame_method']
    mid_z = get_middle_frame(structure_img_smooth, method=middle_frame_method)

    f2_param = config_data['f2_param']
    bw_mid_z = filament_2d_wrapper(structure_img_smooth[mid_z,:,:], f2_param)

    hole_max = config_data['hole_max']
    hole_min = config_data['hole_min']

    bw_fill_mid_z = hole_filling(bw_mid_z, hole_min, hole_max)
    seed = get_3dseed_from_mid_frame(np.logical_xor(bw_fill_mid_z, bw_mid_z), struct_img.shape, mid_z, hole_min)
    bw_filled = watershed(struct_img, seed.astype(int), watershed_line=True)>0
    seg = np.logical_xor(bw_filled, dilation(bw_filled, selem=ball(1)))

    seg = seg > 0
    out=seg.astype(np.uint8)
    out[out>0]=255
    return out        
        
    
def display_img():
    global img_display        
    global checkbox_dict
    global z_value_dict    
    global IMG
    
    for key in checkbox_dict:
        if checkbox_dict[key] == True:
            index = key 
    z_index  = z_value_dict[index]
    display_img = IMG[index][0,0,z_index,:,:]
    display_img = np.interp(display_img,(np.min(display_img),np.max(display_img)), (0,1))*255       
    img_display.value = cv2.imencode('.png', display_img)[1].tostring()
    
def display_segmented_img():
    global seg_display
    global segmented_images
    global z_value_dict
    global checkbox_dict    

    for key in checkbox_dict:
        if checkbox_dict[key] == True:
            index = key     
    z_index  = z_value_dict[index]
    display_img = segmented_images[index][z_index,:,:]
    display_img = np.interp(display_img,(np.min(display_img),np.max(display_img)), (0,1))*255       
    seg_display.value = cv2.imencode('.png', display_img)[1].tostring()
    
    
def update_segmentation(*args):
    global IMG
    global segmented_images
    global metadata_textbox
    global segment_button
    
    segment_button.disbaled = True
    segment_button.description = 'Segmenting...'
 
    
    segmented_images = []
    for index in range(len(IMG)):
        img = segment_image(IMG[index])
        segmented_images.append(img)
        
    segment_button.disbaled = False
    segment_button.description = 'Segment'    
    
    display_segmented_img()
    
    
def slider_callback(index):
    def call_back(*args):
        global z_value_dict
        global metadata_textbox
        global toggle_dict
        #metadata_textbox.value = json.dumps(args[0], indent = 4)  
        value = int(args[0]['new'])        
        z_value_dict[index] = value
        display_img()
        display_segmented_img()         
    return call_back


def image_checkbox_observer():
    """
    Call back function for the view_image check box widget. 
    It updates the check_box_dict when a user wants to 
    view/hide image.
    """
    def call_back(*args):
        global checkbox_dict
        global toggle_dict
        global metadata_textbox      
         
        # get_selection        
        selection=args[0]['new']         
        index = toggle_dict[selection]
        #metadata_textbox.value = str(index)
        
        for key in checkbox_dict:
            if key == index:
                checkbox_dict[key] = True
            else:
                checkbox_dict[key] = False        
        # update final image aray
        display_img()
        
        display_segmented_img()  
        metadata_textbox.value = 'great'
    return call_back   
   

image_collection_path = ''
def image_collection_observer(image_collections,images_widget):
    
    """
    call back function for the `select image collection` widget
    """
    
    def call_back(*args):
        global image_collection_path
        selection = get_selection_id(args[0]['new']) 
        image_collection_path = os.path.join('/opt/shared/wipp/collections', selection,'images')        
        # enable widget to select image within that collection
        images_widget.disabled=False
        images_widget.options=os.listdir(image_collection_path)        
        image_collections.value=args[0]['new']
        image_collections.options=[str(ic) for ic in WippImageCollection.all().values()]        
    return call_back 


def image_display_observer(index, z_slider):
    def call_back(*args):
        global image_collection_path
        global img_display
        global IMG
        global out
        global config_data
        global seg_display
        global z_value_dict
        global toggle_dict
        global segmented_images
        global metadata_textbox
        
        FILE_NAME = os.path.join(image_collection_path, args[0]['new'])
        reader = AICSImage(FILE_NAME) 
        IMG.append(reader.data.astype(np.float32)) 
        if index == 0:
            checkbox_dict[index] = True
        else:
            checkbox_dict[index] = False
        
        toggle_dict['Image {}'.format(index+1)] = index
            
        if IMG[index].shape[2] > 1:
            z_slider.disabled = False
            depth = IMG[index].shape[2] - 1        
            z_slider.max = depth   
        z_value_dict[index] = 0       
        display_img()       
        segmented_images.append(segment_image(IMG[index]))
        display_segmented_img()
        
    return call_back
   

def add_image(*args):
    
    global z_sliders
    global select_image_panel
    global image_selection_widget
    
    image_collections = widgets.Combobox(placeholder='Click on the box or start typing!',
                                         options=[str(ic) for ic in WippImageCollection.all().values()],
                                         description='Image Collections',
                                         ensure_option=True,
                                         disabled=False,
                                         layout=widgets.Layout(width='95%'))

    # widget to list images in the chosen image collection
    images = widgets.Combobox(placeholder='Select an image collection first',
                              options=[],
                              description='Images',
                              ensure_option=True,
                              disabled=True,
                              layout=widgets.Layout(width='95%'))

    z_slider = widgets.FloatSlider(value=1,
                                   min=0,
                                   max=5,
                                   step=1,
                                   description="Z Position",
                                   continuous_update=False,
                                   orientation='horizontal',
                                   readout=False,
                                   layout=widgets.Layout(width='68%'),
                                   disabled=True)
    
    
    z_sliders.append(z_slider)
    image_selection_widget.options = image_selection_widget.options + ('Image {}'.format(len(select_image_panel.children)+1), )
    #image_list.append(view_image)
    
    #link the widgets with their corresponding call back functions
    image_collections.observe(image_collection_observer(image_collections,images),'value')
    images.observe(image_display_observer(len(select_image_panel.children), z_slider), 'value')
    z_slider.observe(slider_callback(len(select_image_panel.children)), 'value')
    #view_image.observe(image_checkbox_observer(len(select_image_panel.children)),'value')
    
    # update the UI accordions
    select_image_panel.children = select_image_panel.children + (widgets.VBox([image_collections,images,z_slider]),)
    select_image_panel.set_title(len(select_image_panel.children) - 1,'Select Image {}'.format(len(select_image_panel.children)))
    #image_list_panel.children = ((widgets.VBox(image_list),))

def reset(*args):
    global select_image_panel
    global image_selection_widget
    global IMG
    global segmented_images
    global checkbox_dict
    global z_value_dict
    global toggle_dict
    global img_display
    global seg_display
    global metadata_textbox
    global image_list_panel
    

    select_image_panel.children = []

    IMG = []
    segmented_images = []
    checkbox_dict = {}
    z_value_dict = {}
    toggle_dict = {}
    
    img_display.value = cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring()
                              
    seg_display.value = cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring()

    image_selection_widget = widgets.RadioButtons( options=[],                                               
                                                   layout={'width': 'max-content'},
                                                   description='Images:',
                                                   disabled=False)
    
    image_list_panel.children = [image_selection_widget]
    metadata_textbox.value = 'reset3'    
    add_image()
    image_selection_widget.observe(image_checkbox_observer(),'value')
    
    
# initialize image display with a blank image
img_display = widgets.Image( value=cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring(),
                           format='png',
                           width=500,
                           height=500)
        
seg_display = widgets.Image( value=cv2.imencode('.png', np.zeros((1024,1024)))[1].tostring(),
                           format='png',
                           width=500,
                           height=500)     

image_selection_widget = widgets.RadioButtons( options=[],                                               
                                               layout={'width': 'max-content'},
                                               description='Images:',
                                               disabled=False)

metadata_textbox= widgets.Textarea(value='Waiting to load image..',
                                   placeholder='Type something',
                                   description='String:',
                                   disabled=False,
                                   layout=widgets.Layout(width='300px',height='500px'))


# ui accordions    
select_image_panel = widgets.Accordion(children=[],
                             description='Job inputs:')
image_list_panel = widgets.Accordion(children=[image_selection_widget],
                             description='Job inputs:')
image_list_panel.set_title(0,'Image List')


# add image button 
add_image()
image_selection_widget.observe(image_checkbox_observer(),'value')
add_image_button = widgets.Button(description='Add Image')
add_image_button.on_click(add_image)
reset_button = widgets.Button(description='Reset')
reset_button.on_click(reset)
AddImage_reset_buttons = widgets.HBox([add_image_button, reset_button])
###############################################################################################

config_data = {
    "workflow_name": "Playground_shell",
    "intensity_scaling_param": [
        4000
    ],
    "gaussian_smoothing_sigma": 1,
    "middle_frame_method": 'Intensity',
    "f2_param": [
        [
            0.5,
            0.01
        ]
    ],
    "hole_max": 40000,
    "hole_min": 400
}



def param_observer(key):
    def call_back(*args):
        global config_data
        global metadata_textbox
        
        if key == 'intensity':
            config_data['intensity_scaling_param'] = parse_list_string(args[0]['new'])
        elif key == 'gaussian_smoothing_sigma':
            config_data['gaussian_smoothing_sigma'] = float(args[0]['new'])  
        elif key == 'middle_frame_method':
            config_data['middle_frame_method'] = args[0]['new']
        elif key == 'f2_param':
            config_data['f2_param'] = parse_nested_list_string(args[0]['new'])
        elif key == 'hole_max':
            config_data['hole_max'] = float(args[0]['new'])        
        elif key == 'hole_min':
            config_data['hole_min'] = float(args[0]['new'])           
        json_object = json.dumps(config_data, indent = 4)
        metadata_textbox.value = json_object           
    
    return call_back



intensity_textbox = widgets.Textarea(value='[4000]',
                                   placeholder='Type something',
                                   description='Intensity:',
                                   disabled=False,
                                   layout=widgets.Layout(width='200px',height='30px'))

gaussian_smoothing_sigma = widgets.FloatText( value=1,
                                              description='Gaus. Sigma:', 
                                              layout=widgets.Layout(width='50%'),
                                              disabled=False,
                                              tooltip = 'Hello')

middle_frame_method = widgets.RadioButtons(options=['intensity', 'z' ],
                                     value='intensity', 
                                     description='middle_frame_method:',
                                     disabled=False)

f2_param_textbox= widgets.Textarea(value='[[0.5,0.01]]',
                                   placeholder='Type something',
                                   description='f2 param:',
                                   disabled=False,
                                   layout=widgets.Layout(width='300px',height='30px'))

hole_max = widgets.FloatText( value=40000,
                              description='hole_max:', 
                              layout=widgets.Layout(width='50%'),
                              disabled=False)

hole_min = widgets.FloatText( value=400,
                              description='hole_min:', 
                              layout=widgets.Layout(width='50%'),
                              disabled=False)

#link to respective call_back functions
intensity_textbox.observe(param_observer('intensity'), 'value')
gaussian_smoothing_sigma.observe(param_observer('gaussian_smoothing_sigma'), 'value')
f2_param_textbox.observe(param_observer('f2_param'), 'value')
middle_frame_method.observe(param_observer('middle_frame_method'), 'value')
hole_max.observe(param_observer('hole_max'), 'value')
hole_min.observe(param_observer('hole_min'), 'value')

#buttons
segment_button = widgets.Button(description='Segment')
segment_button.on_click(update_segmentation)
save_config_button = widgets.Button(description='Save Config')
buttons = widgets.HBox([segment_button, save_config_button])

# user interface

step1 = widgets.HTML(markdown.markdown("""<h4>Step 1: Pre-Processing</h4>"""))
step2 = widgets.HTML(markdown.markdown("""<h4>Step 2: Core Algorithm</h4>"""))
step3 = widgets.HTML(markdown.markdown("""<h4>Step 3: Water Shed</h4>"""))
param_box = widgets.VBox([step1, intensity_textbox, gaussian_smoothing_sigma, step2, middle_frame_method, f2_param_textbox, step3, hole_min, hole_max, buttons])
parameter_panel= widgets.Accordion(children=[param_box],
                             description='Parameters')
parameter_panel.set_title(0,'Parameters')


########################################################################################################

########################################################################################################

title = widgets.HTML(markdown.markdown(""" <br><h1>Documentation</h1>"""))

intro = widgets.HTML(markdown.markdown("""   

This notebook contains the workflows for lamin B1 (interphase-specific), and serves as a starting point for developing a classic segmentation workflow for your data with shell-like shapes.

----------------------------------------

Cell Structure Observations:

* [Lamin B1](https://www.allencell.org/cell-observations/category/lamin)

----------------------------------------

Key steps of the workflows:

* Min-max intensity normalization / Auto-contrast
* 3D Gaussian smoothing 
* 2D filament filter 
* watershed


 """))

step_1 = widgets.HTML(markdown.markdown(""" 

About selected algorithms and tuned parameters

* **Intensity normalization**: Parameter `intensity_scaling_param` has two options: two values, say `[A, B]`, or single value, say `[K]`. For the first case, `A` and `B` are non-negative values indicating that the full intensity range of the stack will first be cut-off into **[mean - A * std, mean + B * std]** and then rescaled to **[0, 1]**. The smaller the values of `A` and `B` are, the higher the contrast will be. For the second case, `K`>0 indicates min-max Normalization with an absolute intensity upper bound `K` (i.e., anything above `K` will be chopped off and reset as the minimum intensity of the stack) and `K`=0 means min-max Normalization without any intensity bound.

    * Parameter for Lamin B1 (interphase specific):  `intensity_scaling_param = [4000]`


* **Smoothing** 

    3D gaussian smoothing with `gaussian_smoothing_sigma = 1`. The large the value is, the more the image will be smoothed.


 """))



step_2 = widgets.HTML(markdown.markdown("""

#### Apply 2d filament filter on the middle frame 

* **Part 1: get the middle frame**: We support two methods to get middle frame: `method='intensity'` and `method='z'`. `'intensity'` method assumes the number of foreground pixels (estimated by intensity) along z dimension has a unimodal distribution (such as Gaussian). Then, the middle frame is defined as the frame with peak of the distribution along z. `'z'` method simply return the middle z frame. 

    * Paramete for lamin b1 (interphase-specific):  `method='intensity'`


* **Part 2: apply 2d filament filter on the middle frame**

    * Parameter syntax: `[[scale_1, cutoff_1], [scale_2, cutoff_2], ....]` 
        * `scale_x` is set based on the estimated width of your target curvilinear shape. For example, if visually the width of the objects is usually 3~4 pixels, then you may want to set `scale_x` as `1` or something near `1` (like `1.25`). Multiple scales can be used, if you have objects of very different sizes.  
        * `cutoff_x` is a threshold applied on the actual filter reponse to get the binary result. Smaller `cutoff_x` may yielf fatter segmentation, while larger `cutoff_x` could be less permisive and yield less objects and slimmer segmentation. 
    * Parameter for lamin b1 (interphase-specific):  `f2_param = [[0.5, 0.01]]`


"""))

step_3 = widgets.HTML(markdown.markdown(""" 
 
Apply watershed to get the shell:

Parameters:
  
* hole_max = 40000
* hole_min = 400


 """))

documentation_panel= widgets.Accordion(children=[intro, step_1, step_2, step_3 ],
                             description='Documentation', selected_index=None )

documentation_panel.set_title(0, 'Introduction:  Segmentation workflow for lamin b1 (interphase)')
documentation_panel.set_title(1, 'Step 1: Pre-Processing')
documentation_panel.set_title(2, 'Step 2: Core Algorithm')
documentation_panel.set_title(3, 'Step 3: Water Shed ')

#########################################################################################################################

display(widgets.VBox([widgets.HBox([widgets.VBox([select_image_panel, AddImage_reset_buttons, image_list_panel, parameter_panel]), img_display,seg_display]), title,documentation_panel]))