This notebook is used to create masks for the organoids such that surrounding tissue can be excluded. For a more detailed explanation of how it works see the "tutorial_mask_creation_v2.ipynb" notebook.

In [1]:
%matplotlib widget

import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
from skimage import io
from skimage.draw import polygon as sk_polygon       # Changed name to avoid confusion
from matplotlib.patches import Polygon as mpl_polygon  # Changed name to avoid confusion
from pathlib import Path
from skimage.exposure import rescale_intensity
from skimage.filters import (gaussian, threshold_multiotsu)
from skimage.morphology import remove_small_objects
from scipy import ndimage as ndi
from skimage import (img_as_uint, img_as_bool)

In [2]:
class PolyPicker:
    def __init__(self, ax, orig_img):
        self.ax = ax
        line = self.ax.plot([], [], color="DarkOrange")[0]  # empty line, [0] necessary because ax.plot() returns a list with 1 element
        self.line = line
        
        # Create empty lists that will be filled with user input
        self.xs = []  # filled with x coordinates
        self.ys = []  # filled with y coordinates
        self.coords = []  # filled with coordinate tuples
        self.polygons = []  # filled with lists of coordinate tuples that define all polygons
        
        # Save input mask
        self.orig_img = orig_img
        
        # Connections to key and button presses
        self.cid_button = line.figure.canvas.mpl_connect('button_press_event', self.button_press)
        self.cid_key = line.figure.canvas.mpl_connect('key_press_event', self.key_press)

        
    def key_press(self,event):
        # Save name of the key pressed
        key_name = event.key
            
        if key_name == "u":  # Undo
            with output:
                print("Undo")
            # Remove last entry
            del self.xs[-1]
            del self.ys[-1]
            del self.coords[-1]
            # Redraw line
            self.line.set_data(self.xs, self.ys)
            self.line.figure.canvas.draw()
        
        if key_name == "a":  # Apply
            with output:
                print("Finished with {} polygons".format(len(self.polygons)))
                print("Creating new image...")
        
            # Create the filter mask
            self.filter_mask = self.create_filter_mask()

            
            # Create new mask by applying filter mask to old mask
            self.new_img = self.orig_img * self.filter_mask
                        
            with output:
                print("New image created successfully")
            
            # Stop gathering input
            self.line.figure.canvas.mpl_disconnect(self.cid_button)
            self.line.figure.canvas.mpl_disconnect(self.cid_key)

            
    def button_press(self, event):
        # Prevent error if user clicks outside of image
        if event.inaxes!=self.line.axes: return
        # Save name of mouse button clicked
        button_name = event.button.name
        
        if button_name == "LEFT":
            ix = event.xdata
            iy = event.ydata
            #with output:  # Print coordinates
            #    print('x = {}, y = {}'.format(int(ix), int(iy)))
            
            # Append coordinates to lists and redraw line
            self.xs.append(ix)
            self.ys.append(iy)
            self.coords.append((ix, iy))
            self.line.set_data(self.xs, self.ys)
            self.line.figure.canvas.draw()
            
        elif button_name == "RIGHT":
            with output:
                print("Completed polygon with {} corners".format(len(self.xs)))
            # Add coordinates of finished polygon to the polygon list
            self.polygons.append(self.coords)
            # Draw polygon where the line was
            self.ax.add_patch(mpl_polygon(self.coords, color="DarkOrange", alpha=0.3))
            # Reset coords and xs,ys
            self.xs = []
            self.ys = []
            self.coords = []
            # Reset line
            self.line.set_data(self.xs, self.ys)
            self.line.figure.canvas.draw()
        
        
    def create_filter_mask(self):
        # Create new array in correct size
        filter_mask = np.zeros(self.orig_img.shape, dtype=np.uint8)
        # Iterate over all objects in the polygons list
        for shape in self.polygons:
            # Get x and y coordinates for current polygon as separate lists
            xx, yy = [[i for i,j in shape], [j for i,j in shape]]
            # Get coordinates of all points within polygon
            rr, cc = sk_polygon(yy, xx)
            # Set value to 0
            filter_mask[rr, cc] = 1
        return filter_mask

In [3]:
def rescale_image(img, min_quant=0, max_quant=0.98): 
    img = img * 1.0 # turn type to float before rescaling
    min_val = np.quantile(img, min_quant)
    max_val = np.quantile(img, max_quant)
    img = rescale_intensity(img, in_range=(min_val, max_val))
    return img

In [4]:
# Create list with all the points
points = ["Point000"+str(x) for x in range(10)] + ["Point00"+str(x) for x in range(10, 74)]

# Specify "special" points and corresponding missing cycles
points_special = {
    "Point0000": ["cycle15"],
    "Point0001": ["cycle15"],
    "Point0042": ["cycle18"],
    "Point0065": ["cycle0", "cycle1", "cycle2", "cycle5_0"],
    "Point0066": ["cycle0", "cycle1", "cycle2"],
    "Point0067": ["cycle0", "cycle1"],
    "Point0070": ["cycle1"],
    "Point0071": ["cycle1"],
    "Point0072": ["cycle1"],
    "Point0073": ["cycle1", "cycle16", "cycle17", "cycle18", "cycle19", 
                  "cycle20", "cycle20_0", "cycle21", "cycle1_2", "cycle1_3"]
}

# Points excluded from analysis
points_excluded = ["Point0047", "Point0052", "Point0053", "Point0058", "Point0059", "Point0062", "Point0063",
             "Point0064", "Point0068", "Point0069"]

# For now just remove Points 0000 and 0001 because they are missing cycle15
del points[:2]

# Create iterator variable. Initialise at -1 because it will be increased to 0 in the next lines
iterator = -1

Rather than loading all stains it's probably sufficient to only load a Hoechst channel and a membrane stain, which will save a lot of time. Due to the AB order permutation, this membrane stain isn't always in the same cycle. Therefore, we just load both of those cycles. Will take slightly longer to load the images but won't have a big impact.

In [5]:
# Specify cycles and stains
img_dict = {
    "cycle1": [2],  # Hoechst stain
    "cycle15": [1],  # Membrane stain (normal order)
    "cycle18": [1]  # Membrane stain (permuted order)
}

---

### Instructions:
left-click: Draw corner  
right-click: Finalise polygon  
u: undo   
o: rectangle zoom tool  
c: back to previous view  
a: apply mask and finish drawing 

In [12]:
# Increment iterator. Print if cycles are missing (and manually adjust below). Skip point if in excluded list
while True:    
    iterator += 1
    point = points[iterator]
    print(point)
    if point in points_special.keys():
        print(10*"Missing cycles: {}\n".format(points_special[point]))
        print("\nMaybe cycles/stains need to be changed for this point!")
        break
    elif point in points_excluded:
        print("Skipping excluded point")
    else:
        break

# Change paths and file name once current point is defined
#point = "Point0043"
experiment_name = "tr_af_bs70_mask"
path_input = Path(r"/links/groups/treutlein/USERS/pascal_noser/plate14_results/alignment/")/experiment_name/point
#path_input = Path(r"/links/groups/treutlein/DATA/imaging/PW/4i/plate14/aligned")/point
path_output = Path(r"/links/groups/treutlein/USERS/pascal_noser/plate14_results/masks/")/experiment_name
out_file = path_output/(point+"_mask.png")

Point0003


In [13]:
# Specify cycles and stains
img_dict = {
    "cycle1": [2],  # Hoechst stain
    "cycle15": [1],  # Membrane stain (normal order)
    "cycle18": [1]  # Membrane stain (permuted order)
}

In [14]:
images = []
for cycle in img_dict:
    print(cycle)
    filename = point+"_"+cycle+".tif"
    img_path = path_input/cycle/filename
    img = io.imread(str(img_path))
    # Only keep relevant channels
    relev_chans = img_dict[cycle]
    img = img[...,relev_chans]
    # Rescale intensities
    img_rescaled = np.zeros(img.shape)
    for channel in range(len(relev_chans)):
        channel_rescaled = rescale_image(img[..., channel])
        img_rescaled[..., channel] = channel_rescaled
    # Append to list
    images.append(img_rescaled)


# Stack all images into a single array with their colour channels along the third axis
rescaled_array = np.dstack(images)
print("Array shape:", np.shape(rescaled_array))

plt.close('all')
# take average of rescaled array
img_avg = np.average(rescaled_array, axis=2)
# Show image
fig, ax = plt.subplots(figsize=(7,7))
ax.set_title(point)
ax.imshow(img_avg)

# Create eraser object to manually remove unwanted areas
picker = PolyPicker(ax, img_avg)

# Output widget needed to display print statements
output = widgets.Output()
display(output)

cycle1
cycle15
Array shape: (7749, 7746, 2)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Output()

In [15]:
# get newly created image
img_erased = picker.new_img
# blur image, find threshold and apply it
print("First round of blurring.....")
img_blurry = gaussian(img_erased, sigma = 20)
thr_blurry = threshold_multiotsu(img_blurry, classes=4)[0]  # Fairly conservative with 4 classes
img_thr_blurry = img_blurry > thr_blurry
# fill holes
img_thr_blurry = ndi.binary_fill_holes(img_thr_blurry)

# blur the thrresholded image again to increase the size of the object
print("Second round of blurring.....")
img_double_blurry = gaussian(img_thr_blurry, sigma = 50)
thr_double_blurry = 0.1
img_thr_double_blurry = img_double_blurry > thr_double_blurry

# Create "sharp" image to get edges correctly
print("Creating final mask...")
threshold = threshold_multiotsu(img_erased, classes=5)[0]
img_thr_sharp = (img_erased >= threshold)
# Create final mask by taking the intersection of the sharp and blurry masks
img_mask = img_thr_sharp * img_thr_double_blurry
# fill holes again for good measure
img_mask = ndi.binary_fill_holes(img_mask)
# remove artefacts
img_mask = remove_small_objects(img_mask, min_size=150_000)
print("Done")

# Compare original image to masked image before saving
fig, ax = plt.subplots(1, 2, figsize=(12,6))
ax[0].imshow(img_avg)
ax[0].set_title("Original image")
ax[1].imshow(img_avg * img_mask)
ax[1].set_title("Masked image")
[ax[i].set_axis_off() for i in range(len(ax))]
plt.tight_layout
plt.show()

First round of blurring.....
Second round of blurring.....
Creating final mask...
Done


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
# Have a look at the mask directly
fig,ax = plt.subplots(figsize=(4,4))
ax.imshow(img_mask)
ax.set_title("Final mask")
ax.set_axis_off()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
# Save mask if everything seems okay
io.imsave(str(out_file), img_as_uint(img_mask))

# Load image back into memory and make sure it's identical to the mask
test_img = img_as_bool(io.imread(str(out_file)))
print("Shape:",test_img.shape)
print("Number of non-identical pixels:", sum(sum(img_as_bool(test_img) != img_mask)))

Shape: (7749, 7746)
Number of non-identical pixels: 0
