# Imports

In [1]:
import json
import glob
import os
import pandas as pd
import math

import cv2
import numpy as np
from matplotlib import pyplot as plt

# Getting the file paths of json file

In [2]:
def get_files(path):
    all_files = []
    for root, dirs, files in os.walk(path):
        files = glob.glob(os.path.join(root, '*.json'))
        for f in files:
            all_files.append(os.path.abspath(f))
    return all_files

# Storing file path in a list

In [3]:
# Collision Detection Function
def has_collision(data):
    for pivot in range(data['colonies_number']):
        for compare in range(pivot + 1, data['colonies_number']):
            c1 = data['labels'][pivot]
            c2 = data['labels'][compare]

            # Not sure if only using height is ok, upon manual inspection width and height are always equal
            r1 = c1['height']/2
            r2 = c2['height']/2

            # NOTE: x and y (in the JSON file) is the top-left corner of the colony bounding box;
            x1 = c1['x'] + c1['width']/2
            x2 = c2['x'] + c2['width']/2
            y1 = c1['y'] + c1['height']/2
            y2 = c2['y'] + c2['height']/2

            # Detect if ANY pair of two colonies are colliding
            if (r1 + r2 > math.sqrt((x2 - x1)**2 + (y2 - y1)**2)):
                # print('Collision: ' + '(' + str(x1) + ', ' + str(y1) + ') (' + str(x2) + ', ' + str(y2) + ')')
                return True

In [4]:
dir = os.path.dirname("__file__")
json_file = get_files(os.path.join(dir, 'Sample Set'))

In [5]:
cleaned_data = []
for i in json_file:
    with open(i, "r") as f:
        data = json.load(f)
        if (data['background'] == 'vague'):
            # print('Background: ' + data['background'])
            # print('Filename: ' + json_file)
            cleaned_data.append(i)

# Resize the Image

In [6]:
def resizeImage(img):
    CONST_HEIGHT = 1000
    return cv2.resize(img, (int(img.shape[1]/4), int(img.shape[0]/4)))

# Convert the Image to Grayscale
The blue channel is used instead of getting the average intensities of each channel because of the fact that difference in intensities between the colonies and the dish itself is more apparent in this channel. 

In [7]:
def getBlueChannel(img):
    (B, G, R) = cv2.split(img)
    
    if (DEBUGGING):
        plt.subplots(figsize = (10, 10))
        plt.title("Grayscale Image (Blue Channel)")
        plt.imshow(B, cmap = plt.cm.gray)
        plt.show()
    
    return B

# Automatic Petri Dish Bounds Detection

In [8]:
def detectPetriDish(file_name):
    
    f = open("bounds.json")
    
    data = json.load(f)
    
    bounds = data['bounds']
    
    for d in bounds:
        if (d['file_name'] == os.path.splitext(os.path.basename(file_name))[0]):
            return int(d['h']), int(d['k']), int(d['r'])
    
    return -1, -1, -1

# Customized Histogram Equalization within Petri Dish Bounds

This special type of HE builds the cumulative histogram using only the pixels within the bounds of the petri dish found through the Circular Hough Transform.  

In [9]:
def histogramEqualization(img, h, k, r):
    
    img_equalized = img
    
    # Create a histogram using only the pixels within the petri dish
    hist_list = [0] * 256
    for i in range(img_equalized.shape[0]):
        for j in range(img_equalized.shape[1]):
            if ((i - h)**2 + (j - k)**2 < r**2):
                hist_list[img_equalized[i][j]] += 1;

    hist = np.array(hist_list)
    cdf = hist.cumsum()
    cdf_normalized = cdf * float(hist.max()) / cdf.max()
    
    # Show the CDF and histogram of the image
    if (DEBUGGING):
        plt.plot(cdf_normalized, color = 'b')
        plt.hist(img_equalized.flatten(),256,[0,256], color = 'r')
        plt.xlim([0,256])
        plt.legend(('Cumulative Distribution Function','Histogram'), loc = 'upper left')
        plt.show()
    
    cdf_m = np.ma.masked_equal(cdf,0)
    cdf_m = (cdf_m - cdf_m.min())*255/(cdf_m.max()-cdf_m.min())
    cdf = np.ma.filled(cdf_m,0).astype('uint8')

    img_equalized = cdf[img_equalized]
    
    # Show the equalized image'
    if (DEBUGGING):
        plt.subplots(figsize = (10, 10))
        plt.title("Histogram Equalization")
        plt.imshow(img_equalized, cmap = plt.cm.gray)
        plt.show()
    
    return img_equalized

# Non-Local Means Denoising

In [10]:
def denoise(img):
    # TODO: Change h-value?
    img_denoised = cv2.fastNlMeansDenoising(img, None, h = 31)
    
    if (DEBUGGING):
        plt.subplots(figsize = (10, 10))
        plt.title("Non-Local Means Denoising")
        plt.imshow(img_denoised, cmap = plt.cm.gray)
        plt.show()
    
    return img_denoised

# Blob Detection

Default Parameters of the SimpleBlobDetector are here: <br> https://github.com/opencv/opencv/blob/4.x/modules/features2d/src/blobdetector.cpp

OpenCVs blob detector iteratively binarizes the image from 'minThreshold' to 'maxThreshold' in steps of 'thresholdStep' and then finds the contours in that image. A contour is a group of pixels that have the same or similar values, in this case, they are either groups of black or groups of white pixels formed by the binarization process. The blobs detected may then be filtered to fit certain criteria (see parameters).

Notes: 
1. 'minRepeatability' is the number of times a centroid is must be found between each 'thresholdStep' of binarization to be considered a blob. This centroid also considers 'minDistBetweenBlobs' in that any two centroids that are under this minimum distance is considered as one centroid.
2. Inertia Ratio measures the elongatedness of each blob. 0 is a line, 1 is a perfect circle.
3. Convexity is the ratio between the area of the blob and the convex hull that encloses the shape.

In [11]:
def detectBlobs(img, file_name, h, k, r, min_repeatability, min_dist, min_inertia_ratio, min_convexity): 
    img_with_blobs = img

    params = cv2.SimpleBlobDetector_Params()
    
    # TODO: Adjust parameters?
    params.minThreshold = 0
    params.maxThreshold = 255

    params.thresholdStep = 1
    params.minRepeatability = min_repeatability
    params.minDistBetweenBlobs = min_dist

    params.minInertiaRatio = min_inertia_ratio
    params.minConvexity = min_convexity
    
    DETECT = "TRANSLUCENT"
    
    if DETECT == "OPAQUE":
        params.minArea = 45
        params.maxArea = 70
    elif DETECT == "TRANSLUCENT":
        params.minArea = 300
        
    actual_count = 0
    counted = 0
    tp = 0
    fp = 0
    fn = 0

    detectorobj = cv2.SimpleBlobDetector_create(params)
    keypoint_info = detectorobj.detect(img_with_blobs)

#   img_with_blobs = cv2.drawKeypoints(img_with_blobs, keypoint_info, np.array([]), (255, 0, 0), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
    
#     if (True):
#         plt.subplots(figsize = (10, 10))
#         plt.title("Blob Detection")
#         plt.imshow(img_with_blobs, cmap=plt.cm.gray)
#         plt.show()
    
    # "Number of Colonies: " + str(len(keypoint_info))
    # cv2.imwrite("Blob Results/" + file_name, img_with_blobs)
    
    f = open(f"Sample Set/{os.path.splitext(os.path.basename(file_name))[0]}.json")
    data = json.load(f)
    
    size_threshold = 60
    
    # Count actual colonies, whether OPAQUE or TRANSLUCENT depending on setting
    for colony in data['labels']:
        if colony['height'] < size_threshold and DETECT == "OPAQUE":
            actual_count +=1 
        elif colony['height'] >= size_threshold and DETECT == "TRANSLUCENT":
            actual_count +=1 
    
    # print(f'ACTUAL COUNT ({DETECT}): {actual_count}')
    print("") if False else None
    
    has_match = []
    for keypoint in keypoint_info:
        
#         # Do not count this circle if it is outside of the petri dish
#         if ((keypoint.pt[0] - h)**2 + (keypoint.pt[1] - k)**2 < r**2):
#             counted += 1
#         else:
#             print("SKIPPED")
#             continue

        counted += 1
        print(f"COUNT #{counted}, ({int(keypoint.pt[0])}, {int(keypoint.pt[1])}, {int(keypoint.size/2)})", end = "") if False else None
        
        for colony in data['labels']:
            if DETECT == "OPAQUE" and colony['height'] >= size_threshold:
                continue
            
            if DETECT == "TRANSLUCENT" and colony['height'] < size_threshold:
                continue
            
            colony_id = colony['id']
            radius = int(colony['height']/2)
            x = colony['x'] + radius
            y = colony['y'] + radius
            
            # Remap coordinates based on resizing factor
            radius /= 4
            x /= 4
            y /= 4
            
            # If the center of the detected colony is within max % of the radius of the actual colony, 
            # and the radius of the colony is within min % error of the actual radius, 
            # and there is no match for that colony yet, that is a TRUE POSITIVE
            
            MAX_RADIUS_DIST = radius*0.5
            MIN_RADIUS_ERROR = 0.95
            
            # RED -> FP
            img_with_blobs = cv2.drawKeypoints(img_with_blobs, [keypoint], np.array([]), (0, 0, 255), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
            
            if ((keypoint.pt[0] - x)**2 + (keypoint.pt[1] - y)**2 < MAX_RADIUS_DIST**2):
                fr1 = "{:.2f}".format(radius)
                fr2 = "{:.2f}".format(keypoint.size/2)
                err = "{:.2f}".format(radius - (keypoint.size/2)/(keypoint.size/2))
                print (f" DIST MATCH, R1({fr1}) R2({fr2}) ERR({err})", end = "") if False else None
                # BLUE -> FP, distance matched but not size
                img_with_blobs = cv2.drawKeypoints(img_with_blobs, [keypoint], np.array([]), (255, 0, 0), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
                # if (abs((radius - (keypoint.size/2))/(keypoint.size/2)) < MIN_RADIUS_ERROR):
                if (True):
                    print (" SIZE MATCH", end = "") if False else None
                    # YELLOW -> FP, distance and size matched but there is already a circle for that colony
                    img_with_blobs = cv2.drawKeypoints(img_with_blobs, [keypoint], np.array([]), (255, 255, 0), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
                    if colony_id not in has_match:
                        has_match.append(colony_id)
                        print(f" ({x}, {y}, {radius}) MATCH", end = "") if False else None
                        img_with_blobs = cv2.drawKeypoints(img_with_blobs, [keypoint], np.array([]), (0, 255, 0), cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
                        break
                   # else:
                         # print(f"{x}({circle[0]}), {y}({circle[1]}), {radius}({circle[2]}) DUPLICATE")
        
        print("") if False else None
        
    if (False):
        plt.subplots(figsize = (50, 50))
        plt.title("Blob Detection")
        plt.imshow(img_with_blobs, cmap=plt.cm.gray)
        plt.show()
    
    tp = len(has_match)
    
    # FALSE POSITIVE = COUNTED - TRUE POSITIVE
    # FALSE NEGATIVE = ACTUAL - COUNTED
    fp = counted - tp
    
    if actual_count < counted:
        fn = 0
    else:
        fn = actual_count - counted
    
    if (True):
        try:
            precision = tp/(tp+fp)
            prec = "{:.2%}".format(precision)
        except:
            prec = "0.00%"
            
        try:  
            recall = tp/(tp+fn)
            rec = "{:.2%}".format(recall)
        except:
            recall = "0.00%"
        
        try:
            fscore = (2*precision*recall)/(precision+recall)
            f1 = "{:.2%}".format(fscore)
        except:
            f1 = "0.00%"
        
        # ID, F, P, R, ACTUAL, COUNTED, TP, FP, FN
        # print(f"{file_name}, {str(f1)}, {str(prec)}, {str(rec)}, {actual_count}, {counted}, {tp}, {fp}, {fn}")
        # cv2.imwrite(f"OPTIMIZED PARAMETER RESULTS/BLOB {DETECT}/" + file_name, img_with_blobs)
    
    return actual_count, counted, tp, fp, fn

# Complete Detection Function

In [12]:
def countColonies(path, file_name, min_repeatability, min_dist, min_inertia_ratio, min_convexity):
    img_orig = cv2.imread(path)
    img_resized = resizeImage(img_orig)
    img_gray = getBlueChannel(img_resized)
    h, k, r = detectPetriDish(file_name)
    
    if not(h == -1 and k == -1 and r == -1):
        img_equalized = histogramEqualization(img_gray, h, k, r)
        img_denoised = denoise(img_equalized)

        # TODO: Sharpen image?
        return detectBlobs(img_denoised, file_name, h, k, r, min_repeatability, min_dist, min_inertia_ratio, min_convexity)

    else:
        return -1, -1, -1, -1, -1

# Main Program

In [13]:
DEBUGGING = False

# MIN_REPEATABILITY_LIST = np.arange(2, 4, 1)
# MIN_DIST_LIST = np.arange(2, 4, 1)
# MIN_INERTIA_RATIO_LIST = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
# MIN_CONVEXITY_LIST = [0.75, 0.8, 0.85, 0.9, 0.95]

MIN_REPEATABILITY_LIST = [2, 3, 4]
MIN_DIST_LIST = [2, 3]
MIN_INERTIA_RATIO_LIST = [0.4, 0.5, 0.6, 0.7]
MIN_CONVEXITY_LIST = [0.7, 0.8, 0.9]


# Iterate through all possible combinations of parameters

import itertools
for params in itertools.product(MIN_REPEATABILITY_LIST, MIN_DIST_LIST, MIN_INERTIA_RATIO_LIST, MIN_CONVEXITY_LIST):
                                
    # For each set of parameters, detect colonies in all images
    min_repeatability, min_dist, min_inertia, min_convexity = params
    
    total_actual_count = 0
    total_counted = 0
    total_tp = 0
    total_fp = 0
    total_fn = 0
    
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    
    for file in cleaned_data:
        file_name = os.path.splitext(os.path.basename(file))[0] + ".jpg"
        actual_count, counted, tp, fp, fn = countColonies("Sample Set/" + file_name, file_name, min_repeatability, min_dist, min_inertia, min_convexity)

        try:
            total_precision += tp/(tp+fp)
        except:
            total_precision += 0
        
        try:
            total_recall += tp/(tp+fn)
        except:
            total_recall += 0
            
        try:
            total_f1 += (2*(tp/(tp+fp))*(tp/(tp+fn)))/((tp/(tp+fp))+(tp/(tp+fn)))
        except:
            total_f1 += 0
        
        total_actual_count += actual_count
        total_counted += counted
        total_tp += tp
        total_fp += fp
        total_fn += fn
        
        # break
        
    try:
        # prec_int = total_tp/(total_tp+total_fp)
        prec_int = total_precision/30
        precision = "{:.2%}".format(prec_int)
    except:
        precision = "0.00%"
    
    try:
        # rec_int = total_tp/(total_tp+total_fn)
        rec_int = total_recall/30
        recall = "{:.2%}".format(rec_int)
    except:
        recall = "0.00%"
    
    try:
        # f1 = "{:.2%}".format((2*prec_int*rec_int)/(prec_int+rec_int))
        f1 = "{:.2%}".format(total_f1/30)
    except:
        f1 = "0.00%"

    # print(f"P: {precision} R: {recall} F: {f1}, ACTUAL: {str(total_actual_count)}, COUNT: {str(total_counted)}, TP: {str(total_tp)} FP: {str(total_fp)} FN: {str(total_fn)}, MN_REP: {min_repeatability} MN_DIST: {min_dist} MN_IN: {min_inertia} MN_CNVX: {min_convexity}")
    print(f"{precision}, {recall}, {f1}, {str(total_actual_count)}, {str(total_counted)}, {str(total_tp)}, {str(total_fp)}, {str(total_fn)}, {min_repeatability}, {min_dist}, {min_inertia}, {min_convexity}")

53.94%, 90.87%, 64.00%, 556, 565, 317, 248, 53, 2, 2, 0.4, 0.7
61.69%, 82.18%, 66.71%, 556, 490, 315, 175, 92, 2, 2, 0.4, 0.8
81.63%, 61.53%, 69.18%, 556, 364, 302, 62, 192, 2, 2, 0.4, 0.9
58.47%, 85.09%, 65.30%, 556, 500, 308, 192, 84, 2, 2, 0.5, 0.7
65.99%, 74.74%, 66.51%, 556, 442, 307, 135, 127, 2, 2, 0.5, 0.8
83.85%, 58.45%, 67.81%, 556, 341, 292, 49, 215, 2, 2, 0.5, 0.9
65.25%, 71.03%, 63.81%, 556, 415, 284, 131, 151, 2, 2, 0.6, 0.7
71.92%, 65.11%, 64.05%, 556, 377, 283, 94, 185, 2, 2, 0.6, 0.8
88.30%, 53.00%, 64.88%, 556, 303, 271, 32, 253, 2, 2, 0.6, 0.9
69.57%, 53.55%, 57.79%, 556, 335, 247, 88, 225, 2, 2, 0.7, 0.7
74.67%, 48.99%, 57.69%, 556, 312, 245, 67, 247, 2, 2, 0.7, 0.8
87.48%, 43.93%, 57.32%, 556, 268, 243, 25, 288, 2, 2, 0.7, 0.9
53.94%, 90.87%, 64.00%, 556, 565, 317, 248, 53, 2, 3, 0.4, 0.7
61.69%, 82.18%, 66.71%, 556, 490, 315, 175, 92, 2, 3, 0.4, 0.8
81.63%, 61.53%, 69.18%, 556, 364, 302, 62, 192, 2, 3, 0.4, 0.9
58.47%, 85.09%, 65.30%, 556, 500, 308, 192, 84, 2, 3,