In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import numpy as np
import zipfile
import cv2
import plotly.graph_objects as go
import plotly.express as px
from ipywidgets import Button, Dropdown, HBox, VBox
from skimage.draw import polygon

In [14]:
class SegmentWidget:
    
    def __init__(self, path_imgs, path_masks):
        
        self._path_imgs = path_imgs

        if not os.path.exists(path_masks):
            os.mkdir(path_masks)

        self._path_masks = path_masks
        
        with zipfile.ZipFile(self._path_imgs, 'r') as img_arch, zipfile.ZipFile(self._path_masks, 'r') as msk_arch:
            fnames = sorted(img_arch.namelist())   
        
        self._ids = fnames
        self._current_id = self._ids[0]
        #This list will be used later to save in memory the coordinates 
        #of the clicks by the user
        self._polygon_coordinates = []
        self._initialize_widget()

    def _load_images(self):
        '''This method will be used to load image and mask when we select another image'''
        with zipfile.ZipFile(self._path_imgs, 'r') as img_arch, zipfile.ZipFile(self._path_masks, 'r') as msk_arch:
            img = cv2.imdecode(
                np.frombuffer(img_arch.read(self._current_id), np.uint8), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            
            m = cv2.imdecode(
                np.frombuffer(msk_arch.read(self._current_id), np.uint8), cv2.IMREAD_GRAYSCALE)
        
        self._current_img = np.array(img)
        h,w, _ = self._current_img.shape
        
        self._current_mask = np.array(m)
        # else:
        #     self._current_mask = np.zeros((h,w))
        #initiate an intermediate mask which will be used to store ongoing work
        self._intermediate_mask = self._current_mask.copy()

    def _gen_mask_from_polygon(self):
        '''This function set to 2 the values inside the polygon defined by the list of points provided'''
        h,w = self._current_mask.shape
        new_mask = np.zeros((h,w), dtype=int)
        rr, cc = polygon([e[0] for e in self._polygon_coordinates], 
                         [e[1] for e in self._polygon_coordinates], shape=new_mask.shape)

        self._intermediate_mask = self._current_mask.copy()
        self._intermediate_mask[rr,cc]=2
    
    def _on_click_figure(self, trace, points, state):
        '''Callback for clicking on the figure. At each click, the coordinates of the click are stored in the polygon coordinates
           and the figure is displayed again
        '''
        #Retrieve coordinates of the clicked point
        i,j = points.point_inds[0]
        #Add the point to the list of points
        self._polygon_coordinates.append((i,j))
        
        if len(self._polygon_coordinates)>2:
            self._gen_mask_from_polygon()
            with self._image_fig.batch_update():
                self._image_fig.data[1].z = self._intermediate_mask      
        
    def _initialize_figures(self):
        '''This function is called to initialize the figure and its callback'''
        self._image_fig = go.FigureWidget()
        self._mask_fig = go.FigureWidget()
        
        self._load_images()
        #We use plotly express to generate the RGB image from the 3D array loaded
        img_trace = px.imshow(self._current_img).data[0]
        #We use plotly HeatMap for the 2D mask array
        mask_trace = go.Heatmap(z=self._current_mask, showscale=False, zmin=0, zmax=1)
        
        #Add the traces to the figures
        self._image_fig.add_trace(img_trace)
        self._image_fig.add_trace(mask_trace)
        self._mask_fig.add_trace(mask_trace)
        
        #A bit of chart formating
        self._image_fig.data[1].opacity = 0.3 #make the mask transparent on image 1
        self._image_fig.data[1].zmax = 2 #the overlayed mask above the image can have values in range 0..2
        self._image_fig.update_xaxes(visible=False)
        self._image_fig.update_yaxes(visible=False)
        self._image_fig.update_layout(margin={"l": 10, "r": 10, "b": 10, "t": 50}, 
                                      title = "Define your Polygon Here",
                                      title_x = 0.5, title_y = 0.95)
        self._mask_fig.update_layout(yaxis=dict(autorange='reversed'), margin={"l": 10, "r": 10, "b": 10, "t": 50},)
        self._mask_fig.update_xaxes(visible=False)
        self._mask_fig.update_yaxes(visible=False)
    
        self._image_fig.data[-1].on_click(self._on_click_figure)

    def _callback_save_button(self, button):
        '''This callback save the current mask and reset the polygon coordinates to start a new label'''
        self._current_mask[self._intermediate_mask==2]=1
        self._current_mask[self._intermediate_mask==0]=0
        mask_path = os.path.join(self._path_masks,f"{self._current_id}.npy")
        np.save(mask_path,self._current_mask)
        self._intermediate_mask = self._current_mask.copy()
        with self._image_fig.batch_update():
            self._image_fig.data[1].z = self._current_mask     
        with self._mask_fig.batch_update():
            self._mask_fig.data[0].z = self._current_mask
        self._polygon_coordinates = []
            
    def _build_save_button(self):
        self._save_button = Button(description="Save Configuration")
        self._save_button.on_click(self._callback_save_button)
        
    def _callback_delete_current_config_button(self, button):
        '''This callback reset the intermediate_mask to the currently saved mask and refresh the figure'''
        self._intermediate_mask = self._current_mask.copy()
        with self._image_fig.batch_update():
            self._image_fig.data[1].z = self._intermediate_mask
        self._polygon_coordinates = []
        
    def _build_delete_current_config_button(self):
        self._delete_current_config_button = Button(description="Delete Current Mask")
        self._delete_current_config_button.on_click(self._callback_delete_current_config_button)
        
    def _callback_delete_all_button(self, button):
        '''This callback reset the intermediate_mask to 0 and refresh the figure'''
        self._intermediate_mask[:] = 0
        with self._image_fig.batch_update():
            self._image_fig.data[1].z = self._intermediate_mask
        self._polygon_coordinates = []
        
    def _build_delete_all_button(self):
        self._delete_all_button = Button(description="Delete All Mask")
        self._delete_all_button.on_click(self._callback_delete_all_button)
        
    def _callback_dropdown(self, change):
        '''This callback is used to navigate through the different images'''
        #Set the new id to the new dropdown value
        self._current_id = change['new']
        
        #Load the new image and the new mask, we already have a method to do this
        self._load_images()
        
        img_trace = px.imshow(self._current_img).data[0]

        #Update both figure
        with self._image_fig.batch_update():
            #Update the trace 0 and the trace 1 containing respectively
            #the image and the mask
            self._image_fig.data[0].source = img_trace.source
            self._image_fig.data[1].z = self._current_mask
        
        with self._mask_fig.batch_update():
            self._mask_fig.data[0].z = self._current_mask
            
        #Reset the list of coordinates used to store current work in progress
        self._polygon_coordinates = []
        
    def _build_dropdown(self):
        #The ids are passed as option for the dropdown
        self._dropdown = Dropdown(options = self._ids)
        self._dropdown.observe(self._callback_dropdown, names="value")            
            
    def _initialize_widget(self):
        '''Function called during the init phase to initialize all the components
           and build the widget layout
        '''
        
        #Initialize the components
        self._initialize_figures()
        self._build_save_button()
        self._build_delete_current_config_button()
        self._build_delete_all_button()
        self._build_dropdown()
        
        #Build the layout
        buttons_widgets = HBox([self._save_button,self._delete_current_config_button,self._delete_all_button])
        figure_widgets = HBox([self._image_fig, self._mask_fig])
        self.widget = VBox([self._dropdown, buttons_widgets, figure_widgets])
    
    def display(self):
        display(self.widget)

In [15]:
base_path = os.path.join(os.path.dirname(os.getcwd()),"data", 'processed')
SegmentWidget(os.path.join(base_path, 'train2.zip'), os.path.join(base_path, 'masks2.zip')).display()

VBox(children=(Dropdown(options=('0486052bb_0079.png', '0486052bb_0080.png', '0486052bb_0081.png', '0486052bb_…

FileNotFoundError: [Errno 2] No such file or directory: 'c:\\Users\\broug\\Projects\\Labeler\\data\\processed\\masks2.zip\\0486052bb_0080.png.npy'