In [1]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from skimage import io
from skimage.color import rgb2grey
from skimage.filters import sobel
from skimage import filters

def imread_convert(f):
    return rgb2grey(io.imread(f))
out = widgets.Output()
from skimage.segmentation import flood

In [113]:
from IPython.display import clear_output
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path
from matplotlib.path import Path



class single_image_segmenter:
    def __init__(self, img, classes, overlay_alpha=.5,figsize=(10,10)):
        """
        TODO allow for intializing with a shape instead of an image
        
        parameters
        ----------
        classes : Int or list
            Number of classes or a list of class names
        ensure_rgba : boolean
            whether to force the displayed image to have an alpha channel to enable transparent overlay
        """
        plt.ioff() # see https://github.com/matplotlib/matplotlib/issues/17013
        self.fig = plt.figure(figsize=figsize)
        self.ax = self.fig.gca()
        lineprops = {'color': 'black', 'linewidth': 1, 'alpha': 0.8}
        self.lasso = LassoSelector(self.ax, self.onselect,lineprops=lineprops)
        self.lasso.set_visible(True)
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        
        
        # setup lasso stuff
        
        self.shape = None


        self.new_image(img)

        

        plt.ion()
        
        if isinstance(classes, int):
            classes = np.arange(classes)
        if len(classes)<=10:
            self.colors = 'tab10'
        elif len(classes)<=20:
            self.colors = 'tab20'
        else:
            raise ValueError(f'Currently only up to 20 classes are supported, you tried to use {len(classes)} classes')
        
        self.colors = plt.get_cmap(self.colors)(np.arange(len(classes)))[:,:3]
        
        self.class_dropdown = widgets.Dropdown(
                options=[(str(classes[i]), i) for i in range(len(classes))],
                value=0,
                description='Class:',
                disabled=False,
            )
        self.lasso_button = widgets.Button(
            description='lasso select',
            disabled=False,
            button_style='success', # 'success', 'info', 'warning', 'danger' or ''
            icon='mouse-pointer', # (FontAwesome names without the `fa-` prefix)
        )
        self.flood_button = widgets.Button(
            description='flood fill',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='fill-drip', # (FontAwesome names without the `fa-` prefix)
        )
        
        self.erase_check_box = widgets.Checkbox(
            value=False,
            description='Erase Mode',
            disabled=False,
            indent=False
        )
        
        self.reset_button = widgets.Button(
            description='reset',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='refresh', # (FontAwesome names without the `fa-` prefix)
        )
        self.reset_button.on_click(self.reset)
        def button_click(button):
            if button.description == 'flood fill':
                self.flood_button.button_style='success'
                self.lasso_button.button_style=''
                self.lasso.set_active(False)
            else:
                self.flood_button.button_style=''
                self.lasso_button.button_style='success'
                self.lasso.set_active(True)
        
        self.lasso_button.on_click(button_click)
        self.flood_button.on_click(button_click)
        self.overlay_alpha = overlay_alpha
    def new_image(self, img):
        self.img = img
        if img.shape != self.shape:
            self.shape = img.shape
            pix_x = np.arange(self.shape[0])
            pix_y = np.arange(self.shape[1])
            xv, yv = np.meshgrid(pix_y,pix_x)
            self.pix = np.vstack( (xv.flatten(), yv.flatten()) ).T
            self.displayed = self.ax.imshow(self.img)
            self.class_mask = -np.ones([self.shape[0],self.shape[1]],dtype=np.int)
        else:
            self.displayed.set_data(self.img)
            self.class_mask[:,:] = -1
        

    def reset(self,*args):
        self.displayed.set_data(self.img)
        self.class_mask[:,:] = -1
        self.fig.canvas.draw()

    def onclick(self, event):
        """
        handle clicking to remove already added stuff
        """

        if event.xdata is not None and not self.lasso.active:
            with out:
                # transpose x and y bc imshow transposes
                self.indices = flood(self.class_mask,(np.int(event.ydata), np.int(event.xdata)))
                self.updateArray()
        
    def updateArray(self):
        with out:
            array = self.displayed.get_array().data
            
            if self.erase_check_box.value:
                self.class_mask[self.indices] = -1
                array[self.indices] = self.img[self.indices]
            else:
                self.class_mask[self.indices] = self.class_dropdown.value
                # https://en.wikipedia.org/wiki/Alpha_compositing#Straight_versus_premultiplied           
                c_overlay = self.colors[self.class_dropdown.value]*255*self.overlay_alpha
                array[self.indices] = c_overlay + self.img[self.indices]*(1-self.overlay_alpha)
            self.displayed.set_data(array)
        self.ax.set_title(np.sum(array==1.1))
        
    def onselect(self,verts):
        self.verts = verts
        p = Path(verts)

        self.indices = p.contains_points(self.pix, radius=0).reshape(450,540)

        self.updateArray()
        self.fig.canvas.draw_idle()
        
    def render(self):
        layers = [widgets.HBox([self.lasso_button, self.flood_button])]
        layers.append(widgets.HBox([self.reset_button, self.class_dropdown,self.erase_check_box]))
        layers.append(self.fig.canvas)    
        return widgets.VBox(layers)
    
    def _ipython_display_(self):
        display(self.render())

def zoom_factory(ax,base_scale = 1.1):
    def limits_to_range(lim):
        return lim[1] - lim[0]
    
    fig = ax.get_figure() # get the figure of interest
    toolbar = fig.canvas.toolbar
    toolbar.push_current()
    orig_xlim = ax.get_xlim()
    orig_ylim = ax.get_ylim()
    orig_yrange = limits_to_range(orig_ylim)
    orig_xrange = limits_to_range(orig_xlim)
    orig_center = ((orig_xlim[0]+orig_xlim[1])/2, (orig_ylim[0]+orig_ylim[1])/2)

    def zoom_fun(event):
        # get the current x and y limits
        cur_xlim = ax.get_xlim()
        cur_ylim = ax.get_ylim()
        # set the range
        cur_xrange = (cur_xlim[1] - cur_xlim[0])*.5
        cur_yrange = (cur_ylim[1] - cur_ylim[0])*.5
        xdata = event.xdata # get event x location
        ydata = event.ydata # get event y location
        if event.button == 'up':
            # deal with zoom in
            scale_factor = base_scale
        elif event.button == 'down':
            # deal with zoom out
#             if orig_xlim[0]<cur_xlim[0] 
            scale_factor = 1/base_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
#             print(event.button)
        # set new limits
        new_xlim = [xdata - (xdata-cur_xlim[0]) / scale_factor,
                     xdata + (cur_xlim[1]-xdata) / scale_factor]
        new_ylim = [ydata - (ydata-cur_ylim[0]) / scale_factor,
                         ydata + (cur_ylim[1]-ydata) / scale_factor]
        new_yrange = limits_to_range(new_ylim)
        new_xrange = limits_to_range(new_xlim)

        with out:
            if np.abs(new_yrange)>np.abs(orig_yrange):
                new_ylim = orig_center[1] -new_yrange/2 , orig_center[1] +new_yrange/2
            if np.abs(new_xrange)>np.abs(orig_xrange):
                new_xlim = orig_center[0] -new_xrange/2 , orig_center[0] +new_xrange/2
        ax.set_xlim(new_xlim)
        ax.set_ylim(new_ylim)

        toolbar.push_current()
        ax.figure.canvas.draw_idle() # force re-draw


    # attach the call back
    cid = fig.canvas.mpl_connect('scroll_event',zoom_fun)
    def disconnect_zoom():
        fig.canvas.mpl_disconnect(cid)    

    #return the function
    return disconnect_zoom
out.clear_output()
plt.close('all')
tstimage = io.imread('test-image.jpg')
obj = single_image_segmenter(tstimage, ['class1', 'yeast'])
# zoom_factory(obj.ax)
from sidecar import Sidecar
from ipywidgets import IntSlider
sc = Sidecar(title='Segmentation area')
sl = IntSlider(description='Some slider')
with sc:
    # force the _nav_stack to record the initial position so the home button works as expected
    obj.fig.canvas.toolbar.push_current()
    zoom_factory(obj.ax)
    display(obj)
out

Output()

In [36]:
from matplotlib.backend_bases import MouseEvent

<matplotlib.backend_bases.MouseEvent at 0x7f08681ab640>

In [3]:
small = obj.displayed.get_array()
small = small[:200,:]
obj.displayed.set_data(small)

In [4]:
obj.ax.imshow(small)

<matplotlib.image.AxesImage at 0x7f08ad2612e0>

In [104]:
obj.ax.get_ylim()

(449.5, -0.5)

In [17]:
obj.ax.drag_pan(1,None,20,3)

AttributeError: 'dict' object has no attribute 'x'

In [16]:
class pan:

obj.ax._pan_start = {'x':0,'y':10}

In [50]:
obj.fig.canvas.toolbar

Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous view', 'arrow-l…

In [51]:
plt.figure()

AttributeError: module 'matplotlib.pyplot' has no attribute '_figure'

In [71]:
# based on https://github.com/matplotlib/matplotlib/blob/b5d4a6c6b484afaa97ce999b36c03a28ec750ea3/lib/matplotlib/backend_bases.py#L3119-L3135
list(obj.fig.canvas.toolbar._nav_stack[0].items())

[(<matplotlib.axes._subplots.AxesSubplot at 0x7f0861205bb0>,
  ((-0.5, 539.5, 449.5, -0.5),
   (Bbox([[0.125, 0.10999999999999999], [0.9, 0.88]]),
    Bbox([[0.125, 0.1720833333333333], [0.9, 0.8179166666666666]]))))]

In [74]:
obj.ax.get_position(True)

Bbox([[0.125, 0.10999999999999999], [0.9, 0.88]])

In [76]:
obj.ax.get_position()

Bbox([[0.125, 0.17208333333333342], [0.9, 0.8179166666666666]])

In [77]:
obj.ax.set_position(obj.ax.get_position(True), 'active')

In [78]:
obj.fig.canvas.draw()

In [81]:
orig = obj.ax._get_view()

In [82]:
obj.ax._set_view(orig)