# Creation of masks
**Note:** This tool can be used to draw polygons around stuff that should be excluded but I later on figured that instead it makes more sense to draw a polygon around the stuff you actually want to keep (see v2). In case someone changes their mind about this I will just keep this notebook around.
## Eraser tool
The PolyEraser class allows the user to manually erase sections of an image by drawing polygons around them.
#### Instructions
After loading an image a figure needs to be created (`fig,ax = plt.subplots()`) and the image plotted onto the axis (`ax.imshow()`). Then, a `PolyEraser` object needs to be created, supplying both the axis object where the image is plotted on, as well as the image itself. The user can then "draw" polygons directly onto the image using **left click** to create corners (which are automatically connected by a line), and **right click** to complete a polygon. E.g. if you want to draw a triangle, left click point one, then point two, then point three and finally "close" the polygon by right clicking anywhere in the image (Note: The first point and last point you clicked will automatically be connected). Pressing the **u** key ("undo") while drawing a polygon will remove the last point you clicked. There is no direct command to remove an already completed polygon, but since each polygon created is stored inside the PolyEraser object this can still be removed afterwards if needed. There is no theoretical limit to how many polygons there can be and how many corners each polygon is composed of. Once done with drawing polygons, pressing the **a** key ("apply") will create a filter mask and multiply the original image with this mask to create the new image. The new image can be accessed using `obj.new_image` where obj is the PolyEraser object.
#### Requirements
In order to run the notebook as "matplotlib widget" [this](https://github.com/matplotlib/ipympl) installation of ipympl needs to be followed. If you're using a virtual environment you may need to install it in the base environment as well.

In [2]:
%matplotlib widget

import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
from skimage import io
from skimage.draw import polygon as sk_polygon       # Changed name to avoid confusion
from matplotlib.patches import Polygon as mpl_polygon  # Changed name to avoid confusion
from pathlib import Path
from skimage.exposure import rescale_intensity
from skimage.filters import (gaussian, threshold_multiotsu)
from skimage.morphology import remove_small_objects
from scipy import ndimage as ndi
from skimage import (img_as_uint, img_as_float, img_as_float32, img_as_bool)

In [3]:
class PolyEraser:
    def __init__(self, ax, orig_img):
        self.ax = ax
        line = self.ax.plot([], [])[0]  # empty line, [0] necessary because ax.plot() returns a list with 1 element
        self.line = line
        
        # Create empty lists that will be filled with user input
        self.xs = []  # filled with x coordinates
        self.ys = []  # filled with y coordinates
        self.coords = []  # filled with coordinate tuples
        self.polygons = []  # filled with lists of coordinate tuples that define all polygons
        
        # Save input mask
        self.orig_img = orig_img
        
        # Connections to key and button presses
        self.cid_button = line.figure.canvas.mpl_connect('button_press_event', self.button_press)
        self.cid_key = line.figure.canvas.mpl_connect('key_press_event', self.key_press)

        
    def key_press(self,event):
        # Save name of the key pressed
        key_name = event.key
            
        if key_name == "u":  # Undo
            with output:
                print("Undo")
            # Remove last entry
            del self.xs[-1]
            del self.ys[-1]
            del self.coords[-1]
            # Redraw line
            self.line.set_data(self.xs, self.ys)
            self.line.figure.canvas.draw()
        
        if key_name == "a":  # Apply
            with output:
                print("Finished with {} polygons".format(len(self.polygons)))
                print("Creating new image...")
        
            # Create the mask of the mask
            self.filter_mask = self.create_filter_mask()

            
            # Create new mask by applying mom to old mask
            self.new_img = self.orig_img * self.filter_mask
                        
            with output:
                print("New image created successfully")
            
            # Stop gathering input
            self.line.figure.canvas.mpl_disconnect(self.cid_button)
            self.line.figure.canvas.mpl_disconnect(self.cid_key)

            
    def button_press(self, event):
        # Prevent error if user clicks outside of image
        if event.inaxes!=self.line.axes: return
        # Save name of mouse button clicked
        button_name = event.button.name
        
        if button_name == "LEFT":
            ix = event.xdata
            iy = event.ydata
            #with output:  # Print coordinates
            #    print('x = {}, y = {}'.format(int(ix), int(iy)))
            
            # Append coordinates to lists and redraw line
            self.xs.append(ix)
            self.ys.append(iy)
            self.coords.append((ix, iy))
            self.line.set_data(self.xs, self.ys)
            self.line.figure.canvas.draw()
            
        elif button_name == "RIGHT":
            with output:
                print("Completed polygon with {} corners".format(len(self.xs)))
            # Add coordinates of finished polygon to the polygon list
            self.polygons.append(self.coords)
            # Draw polygon where the line was
            self.ax.add_patch(mpl_polygon(self.coords))
            # Reset coords and xs,ys
            self.xs = []
            self.ys = []
            self.coords = []
            # Reset line
            self.line.set_data(self.xs, self.ys)
            self.line.figure.canvas.draw()
        
        
    def create_filter_mask(self):
        # Create new array in correct size
        filter_mask = np.ones(self.orig_img.shape, dtype=np.uint8)
        # Iterate over all objects in the polygons list
        for shape in self.polygons:
            # Get x and y coordinates for current polygon as separate lists
            xx, yy = [[i for i,j in shape], [j for i,j in shape]]
            # Get coordinates of all points within polygon
            rr, cc = sk_polygon(yy, xx)
            # Set value to 0
            filter_mask[rr, cc] = 0
        return filter_mask

In [4]:
def rescale_image(img, min_quant=0, max_quant=0.98): 
    img = img * 1.0 # turn type to float before rescaling
    min_val = np.quantile(img, min_quant)
    max_val = np.quantile(img, max_quant)
    img = rescale_intensity(img, in_range=(min_val, max_val))
    return img

### Creating a mask
The goal is to create a mask that includes only the organoid. This is achieved in multiple steps:  
First, a number of cycles and corresponding channels is defined (see `img_dict`). The specified images are loaded, the relevant channels rescaled and appended to a new array and finally an average across channels of the rescaled array is taken. This image is then modified using the PolyEraser tool to remove unwanted objects manually.

In [5]:
# Paths and images
point = "Point0043"
path_input = Path(r"/links/groups/treutlein/DATA/imaging/PW/4i/plate14/aligned")/point

img_dict = {
    "cycle1": [2],  # Hoechst stain
    "cycle15": [1]  # Membrane stain
}


images = []
for cycle in img_dict:
    print(cycle)
    filename = point+"_"+cycle+".tif"
    img_path = path_input/cycle/filename
    img = io.imread(str(img_path))
    # Only keep relevant channels
    relev_chans = img_dict[cycle]
    img = img[...,relev_chans]
    # Rescale intensities
    img_rescaled = np.zeros(img.shape)
    for channel in range(len(relev_chans)):
        channel_rescaled = rescale_image(img[..., channel])
        img_rescaled[..., channel] = channel_rescaled
    # Append to list
    images.append(img_rescaled)

# Stack all images into a single array with their colour channels along the third axis
rescaled_array = np.dstack(images)
print("Array shape:", np.shape(rescaled_array))

cycle1
cycle15
Array shape: (7749, 7746, 2)


In [6]:
plt.close('all')
# take average of rescaled array
img_avg = np.average(rescaled_array, axis=2)
# Manually remove objects
fig, ax = plt.subplots(figsize=(9,9))
ax.imshow(img_avg)

# Create eraser object
eraser = PolyEraser(ax, img_avg)

# Output widget needed to display print statements
output = widgets.Output()
display(output)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Output()

In [7]:
img_erased = eraser.new_img
filter_mask = eraser.filter_mask * 1.0  # turn dtype to float by multiplying by 1.0
# Show image
fig,ax = plt.subplots(1,3, figsize=(10,3))
ax[0].imshow(img_avg)
ax[0].set_title("Old image")
ax[1].imshow(filter_mask)
ax[1].set_title("Filter mask")
ax[2].imshow(img_erased)
ax[2].set_title("New image")
[ax[i].set_axis_off() for i in range(len(ax))]
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Next, the image is blurred and a threshold separating foreground from background is determined by using `threshold_multiotsu()` with 4 classes and taking the lowest value. This is fairly conservative and keeps the areas on the borders that aren't that high in intensity. After applying the threshold, any potential holes within the organoid are filled. This process yields a mask that roughly ecompasses the organoid and everything inside of it.

In [8]:
# blur image, find threshold and apply it
print("First round of blurring.....")
img_blurry = gaussian(img_erased, sigma = 20)
thr_blurry = threshold_multiotsu(img_blurry, classes=4)[0]  # Fairly conservative with 4 classes
img_thr_blurry = img_blurry > thr_blurry
# fill holes
img_thr_blurry = ndi.binary_fill_holes(img_thr_blurry)

# plotting the mask
fig, axes = plt.subplots(1,2, figsize=(8,4))
ax = axes.ravel()
ax[0].imshow(img_avg)
ax[1].imshow(img_thr_blurry)
[ax[i].set_axis_off() for i in range(len(ax))]
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

The problem with this mask is that due to the blurring, the faint edges of the organoid no longer remain after the thresholding. In order to combat this, the mask can be "grown" outwards by performing another round of blurring and selecting a low threshold. There may or may not be more straightforward and efficient ways to do this but this does the job just fine. 

In [9]:
# blur the thrresholded image again to increase the size of the object
img_double_blurry = gaussian(img_thr_blurry, sigma = 50)
thr_double_blurry = 0.1
img_thr_double_blurry = img_double_blurry > thr_double_blurry

# Comparing blurred to double blurred
fig, axes = plt.subplots(1, 3, figsize=(9,3))
ax = axes.ravel()
ax[0].set_title("img_thr_blurry")
ax[0].imshow(img_thr_blurry)
ax[0].set_axis_off()
ax[1].set_title("img_double_blurry")
ax[1].imshow(img_double_blurry)
ax[1].set_axis_off()
ax[2].set_title("Difference")
ax[2].imshow(img_thr_double_blurry ^ img_thr_blurry)
ax[2].set_axis_off()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

While we can now assume that the entire organoid is included in the mask, we are still faced with a problem on the edges. Due to the blurring, the edge of our organoid mask is extremely smooth which is not at all what the real organoid looks like. So, what we can do now is apply a thresholed to the unblurred image to capture the shape of the edges and then combine the two masks to get what we want. This is done by `threshold_multiotsu()` with 5 classes and taking the lowest value. The number of classes here is rather arbitrary and what I decided to pick after playing around with it a little.

In [10]:
threshold = threshold_multiotsu(img_erased, classes=5)[0]
img_thr_sharp = (img_erased >= threshold)
fig, ax = plt.subplots()
ax.imshow(img_thr_sharp)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [30]:
# Create final mask by taking the intersection of the sharp and blurry masks
img_mask = img_thr_sharp * img_thr_double_blurry
# fill holes again for good measure
img_mask = ndi.binary_fill_holes(img_mask)
# remove artefacts
img_mask = remove_small_objects(img_mask, min_size=100_000)
print("Done")

fig, ax = plt.subplots(1, 4, figsize=(12,3))
ax[0].imshow(img_thr_double_blurry)
ax[0].set_title("Blurry mask")
ax[1].imshow(img_thr_sharp)
ax[1].set_title("'Sharp' mask")
ax[2].imshow(img_mask)
ax[2].set_title("Final mask")
ax[3].imshow(img_avg)
ax[3].set_title("Average image")
[ax[i].set_axis_off() for i in range(len(ax))]
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Now all that remains is to save the mask. Since we're dealing with binary images it makes sense to save them as png rather than tiff files because png can compress them a lot without losing information. For a unit8 image, chosing png reduces the file size from > 100 MB to around 100 KB. Both file formats, however, don't support a boolean data type, so it needs to be converted to e.g. uint8 which can be done by using the `img_as_uint()` function from skimage. When loading the image back into memory, it can be converted back into a boolean data type by using `img_as_bool()`.

In [21]:
img_mask.dtype

dtype('bool')

In [22]:
path_output = Path(r"/links/groups/treutlein/USERS/pascal_noser/plate14_results/masks/version01")
out_file = path_output/(point+"_mask.png")
io.imsave(str(out_file), img_as_uint(img_mask))

In [28]:
test_img = io.imread(out_file)
# Check if shape or intensities changed
print("Shape:",test_img.shape)
print("Number of non-identical pixels:", sum(sum(img_as_bool(test_img) != img_mask)))

Shape: (7749, 7746)
Number of non-identical pixels: 0
