# Dataset generation for a pytorch well culture plate detector

first we will import the necessary libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display


In [2]:
# some global variables for fine tuning

adaptive_threshold_block_size = 401
adaptive_threshold_c = 13
cell_color_channel = None # set it to none if you want to use grayscale


hough_param1 = 80
hough_param2 = 450

inner_hough_param1 = 25
inner_hough_param2 = 60

circle_buffer = 4

min_circle_radius = 150
max_circle_radius = 5000



In [4]:
# first we will load the images in the folder cell_culture_plate_images
# then we will use openvc to detect the circles in the images.
# for each image, we will loop through all detected circles, display them using jupyter widgets and give the user a yes/no option. if the
# user selects yes, we will save the circle in a list. After all circles for the current image are processed,
# we will save the original image to the folder X and a mask where the circle areas are filled to the folder Y.
# NOTE there is one global widget set up and it will show the next circle after each user input. if it has looped
# through all circles, it will show the next image. if it has looped through all images, it will stop.


# global variables
# the folder where the images are stored
folder = Path("cell_culture_plate_images")
# the folder where the images with the circles are stored
folder_X = Path("X")
# the folder where the masks are stored
folder_Y = Path("Y")
# folder where cell colny masks are stored
folder_Z = Path("Z")

# the list of circles that the user has selected
circles = []

# current_circle
current_circle = None

# all detected_circles
detected_circles = []

# the widget that will be used to display the images and circles
image_widget = widgets.Image(format='jpg', width=600, height=500)


# two yes/no buttons
yes_button = widgets.Button(description='Yes')
no_button = widgets.Button(description='No')
skip_button = widgets.Button(description='Skip')
# the widget that will be used to display the buttons
button_widget = widgets.HBox([yes_button, no_button, skip_button])
# the widget that will be used to display the image and the buttons
box_widget = widgets.VBox([image_widget, button_widget])

image = None
image_path = None
mask = None
cell_mask = None
current_cell_mask = None
image_paths = None


def set_image(image):
    """sets the image in the image widget"""
    image_widget.value = cv2.imencode('.jpg', image)[1].tobytes()

def set_next_circle(last_answer_was_yes=False):
    global current_circle
    # get next circle that does not collide with any of the circles that the user has selected

    # mean radius of confirmed circles
    mean_radius = np.mean([circle[2] for circle in circles])

    while True:
        try:
            current_circle = detected_circles[0]
        except:
            on_next_image_button_clicked(None, last_answer_was_yes=last_answer_was_yes)
        for circle in circles:
            if np.linalg.norm(np.array(circle[:2]) - np.array(current_circle[:2])) < circle[2] + current_circle[2]:
                detected_circles.remove(current_circle)
                break
            # also remove the circle if it is too small comparent to the confirmed circles
            if current_circle[2] < mean_radius * 0.75:
                detected_circles.remove(current_circle)
                break

        else:
            break
    set_image(draw_circle(current_circle))

def on_yes_button_clicked(b):
    global current_circle, cell_mask, current_cell_mask
    # if the user clicks the yes button, we will save the current circle
    circles.append(current_circle)
    # and remove it from the list of detected circles
    detected_circles.remove(current_circle)

    cell_mask[current_cell_mask > 0] = current_cell_mask[current_cell_mask > 0]

    # if there are no more detected circles, we will show the next image
    if len(detected_circles) == 0:
        on_next_image_button_clicked(b, last_answer_was_yes=True)
    # otherwise we will show the next circle
    else:
        set_next_circle(last_answer_was_yes=True)
        

def on_no_button_clicked(b):
    global current_circle
    # if the user clicks the no button, we will show the next circle
    detected_circles.remove(current_circle)
    if len(detected_circles) > 0:
        # get next circle that does not collide with any of the circles that the user has selected
        set_next_circle()
    # if there are no more circles, we will show the next image
    else:
        on_next_image_button_clicked(b, last_answer_was_yes=False)

def on_skip_button_clicked(b):
    global image_path, image_paths, circles, image

    # clear the list of circles
    circles = []

    # if the user clicks the skip button, we will show the next image
    on_next_image_button_clicked(b, last_answer_was_yes=False)

def on_next_image_button_clicked(b, last_answer_was_yes=False):
    global image_path, image_paths, circles, image, mask, cell_mask
    if image_path is not None and len(circles) > 0 and last_answer_was_yes:
        # crop the image to cell areas

        x_min = min([circle[0] - circle[2] for circle in circles])
        x_max = max([circle[0] + circle[2] for circle in circles])
        y_min = min([circle[1] - circle[2] for circle in circles])
        y_max = max([circle[1] + circle[2] for circle in circles])

        # small border of 100 pixels
        x_min = max(0, x_min - 100)
        x_max = min(image.shape[1], x_max + 100)
        y_min = max(0, y_min - 100)
        y_max = min(image.shape[0], y_max + 100)

        # create the mask
        mask = np.zeros(image.shape[:2], dtype=np.uint8)

        image = image[y_min:y_max, x_min:x_max]

        # if the user clicks the next image button, we will save the original image to the folder X
        cv2.imwrite(str(folder_X / image_path.name.lower()), image)
        
        
        # fill the circle areas
        for circle in circles:
            cv2.circle(mask, (circle[0], circle[1]), circle[2], (255, 255, 255), -1)

        # crop mask
        mask = mask[y_min:y_max, x_min:x_max]

        # and the mask to the folder Y
        cv2.imwrite(str(folder_Y / image_path.name.lower().replace(".jpg", ".png")), mask)
        
        # and save the cell mask to the folder Z
        cv2.imwrite(str(folder_Z / image_path.name.lower().replace(".jpg", ".png")), cell_mask[y_min:y_max, x_min:x_max])



    circles = []
    if len(image_paths) == 0:
        return
    # otherwise we will load the next image
    image_path = image_paths.pop(0)
    print(image_path)
    load_image(image_path)

def draw_circle(circle):
    """draws a circle on the image and returns the result"""
    image_copy = image.copy()
    cv2.circle(image_copy, (round(circle[0]), round(circle[1])), round(circle[2]), (0, 255, 0), 2)

    # crop the image to circle area (with border of 100 pixel)
    x_min = max(0, round(circle[0]) - round(circle[2]) - 100)
    x_max = min(image.shape[1], round(circle[0]) + round(circle[2]) + 100)
    y_min = max(0, round(circle[1]) - round(circle[2]) - 100)
    y_max = min(image.shape[0], round(circle[1]) + round(circle[2]) + 100)

    image_copy = image_copy[y_min:y_max, x_min:x_max]

    cell_mask = draw_cell_mask(circle)
    red = image_copy[..., 2]
    red[cell_mask > 0] = 255
    image_copy[..., 2] = red

    return image_copy

def draw_cell_mask(circle):
    global current_cell_mask
    # based on a simple threshold based mechanism, we will try to detect the cells in the circle area
    # first we will crop the image to the circle area
    x_min = max(0, round(circle[0]) - round(circle[2]) - 100)
    x_max = min(image.shape[1], round(circle[0]) + round(circle[2]) + 100)
    y_min = max(0, round(circle[1]) - round(circle[2]) - 100)
    y_max = min(image.shape[0], round(circle[1]) + round(circle[2]) + 100)

    image_copy = image.copy()
    image_copy = image_copy[y_min:y_max, x_min:x_max]

    # create the mask
    mask = np.zeros(image_copy.shape[:2], dtype=np.uint8)
    buffer = 4
    cv2.circle(mask, (round(circle[0]) - x_min, round(circle[1]) - y_min), round(circle[2]) - buffer, (255, 255, 255), -1)

    # convert it to grayscale
    if cell_color_channel is None:
        gray = cv2.cvtColor(image_copy, cv2.COLOR_BGR2GRAY)
    else:
        gray = image_copy[..., cell_color_channel]

    # blur it
    gray_blurred = cv2.GaussianBlur(gray, (9, 9), 2, 2)
    
    # threshold it
    ret, thresh = cv2.threshold(gray_blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    #thresh = cv2.adaptiveThreshold(gray_blurred, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, adaptive_threshold_block_size, adaptive_threshold_c)

    # erode and dilate it
    kernel = np.ones((3, 3), np.uint8)
    thresh = cv2.erode(thresh, kernel, iterations=1)
    thresh = cv2.dilate(thresh, kernel, iterations=1)

    current_cell_mask = np.zeros(image.shape[:2], dtype=np.uint8)

    # add the mask to the global cell_mask
    current_cell_mask[y_min:y_max, x_min:x_max] = thresh

    # only keep values in the circle

    cropped = current_cell_mask[y_min:y_max, x_min:x_max]
    # eliminate everything outside the circle
    cropped[mask == 0] = 0

    current_cell_mask[y_min:y_max, x_min:x_max] = cropped

    return cropped



def load_image(image_path):
    """loads the image at the given path and detects the circles in it"""
    global image, mask, cell_mask, detected_circles, current_circle
    # load the image
    image = cv2.imread(str(image_path))
    # convert it to grayscale

    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)


    # blur it
    gray_blurred = cv2.GaussianBlur(gray, (9, 9), 2, 2)
    # detect the circles
    detected_circles_tmp = cv2.HoughCircles(
        gray_blurred, cv2.HOUGH_GRADIENT, 2, 1, param1 = hough_param1, param2 = hough_param2, minRadius = min_circle_radius, maxRadius = max_circle_radius)
    
    # now for every circle, try to detect a circle again on the cropped area. The then detected cells will be added to the final list
    detected_circles = []
    if detected_circles_tmp is None:
        on_next_image_button_clicked(None)
        return
    for i, circle in enumerate(detected_circles_tmp[0]):
        # crop the image
        cropped = image[int(circle[1] - circle[2]):int(circle[1] + circle[2]), int(circle[0] - circle[2]):int(circle[0] + circle[2])]
        # convert it to grayscale
        gray = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY)
        # blur it
        gray_blurred = cv2.GaussianBlur(gray, (9, 9), 2, 2)
        # detect the circles
        detected_circles_tmp2 = cv2.HoughCircles(
            gray_blurred, cv2.HOUGH_GRADIENT, 1, 1, param1 = inner_hough_param1, param2 = inner_hough_param2, minRadius = min_circle_radius, maxRadius = cropped.shape[0] // 2- circle_buffer)
        
        # add the circles to the list
        if detected_circles_tmp2 is not None:
            # add each circle and add the cropped image offset back
            for j, circle_tmp in enumerate(detected_circles_tmp2[0]):
                detected_circles.append([int(circle_tmp[0] + circle[0] - circle[2]), int(circle_tmp[1] + circle[1] - circle[2]), int(circle_tmp[2])])
                if j > 5:
                    break
        else:
            detected_circles.append([int(circle[0]), int(circle[1]), int(circle[2])])
        
        if i > 200:
            break

    # if there are no circles, we will load the next image
    if len(detected_circles) == 0:
        on_next_image_button_clicked(None)
        return
    #detected_circles = detected_circles
    # otherwise we will show the first circle

    current_circle = detected_circles[0]
    cell_mask = np.zeros(image.shape[:2], dtype=np.uint8)
    set_image(draw_circle(current_circle))

import random
def load_images():
    """loads all images in the folder and starts the process"""
    global image_paths
    # load all images in the folder and shuffle them
    image_paths = list(folder.glob("*.jpg"))

    # shuffle the images
    random.shuffle(image_paths)

    # start the process
    on_next_image_button_clicked(None)

# set the on_click event handlers
yes_button.on_click(on_yes_button_clicked)
no_button.on_click(on_no_button_clicked)
skip_button.on_click(on_skip_button_clicked)

# load the images
load_images()

# display the widget
display(box_widget)


cell_culture_plate_images/PXL_20230521_091255474.MP.jpg


VBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C…

cell_culture_plate_images/PXL_20230316_172350091.NIGHT.jpg
