In [None]:
from ultralytics import YOLO
import cv2

import os
from os.path import basename, join

from scipy import ndimage as ndi 

from skimage import io, color, data, measure, filters, exposure, img_as_ubyte, img_as_float
from skimage.measure import regionprops, label
from skimage.segmentation import clear_border
from skimage.filters import threshold_otsu
from skimage.morphology import binary_dilation, disk, binary_erosion, convex_hull_image, remove_small_objects, erosion, remove_small_holes, dilation


import pandas as pd

import numpy as np

import torch

from PIL import Image

import matplotlib.pyplot as plt

from tqdm import tqdm

import re

import math

In [2]:
MODEL_PATH = "best.pt"
model = YOLO(MODEL_PATH)

In [None]:
benchmark_path = r'Test\1;1;1;23;510;01_benchmark.png'
benchmark = io.imread(benchmark_path)

input_image_folder = r'Test\1;1;1;23;510;01'
input_image_list = os.listdir(input_image_folder)

series_name = '1;1;1;23;510;01'

sam_input_folder = r'Test\1;1;1;23;510;01\sam_input'
sam_csv_path = r'Test\sam_locate'

benchmark_output_csv = os.path.join(sam_csv_path, series_name + '.csv')

benchmark_img = color.rgb2gray(benchmark)
thresh_img = filters.threshold_yen(benchmark_img)
binary_img = benchmark_img>thresh_img

if len(benchmark.shape)==3:

    benchmark_img = color.rgb2gray(benchmark)
    thresh_img = filters.threshold_yen(benchmark_img)
    binary_img = benchmark_img>thresh_img

    if binary_img.sum()>1000000:
        seg_mask = np.zeros(binary_img.shape)
    else:
        seg_mask = binary_img
else:

    seg_mask = benchmark

seg_mask = clear_border(seg_mask)

base_bbox_list = []
seg_mask_base, num = ndi.label(seg_mask)
base_mask_regions = regionprops(seg_mask_base)

df_nematode_num = 1
base_nematode_id = []
for final_base_region in base_mask_regions:
    
    minr, minc, maxr, maxc = final_base_region.bbox
    base_bbox_list.append(final_base_region.bbox)
    base_nematode_id.append(df_nematode_num)

    df_nematode_num = df_nematode_num + 1
region_nematode = pd.DataFrame({'nematode id': base_nematode_id,
                                'nematode bbox':base_bbox_list})
region_nematode.to_csv(benchmark_output_csv, index=None)

for image_name in tqdm(input_image_list):

    image_path = os.path.join(input_image_folder, image_name)
    image = io.imread(image_path)
    pattern = r'\d{4}-\d{2}-\d{2}'
    result = re.findall(pattern, image_name)
    get_date = result[0]
    single_nematode_folder = os.path.join(sam_input_folder, get_date)
    os.makedirs(single_nematode_folder, exist_ok=True)

    base_num = 1
    for i in range(len(base_bbox_list)):

        x1 = base_bbox_list[i][0] -15
        x2 = base_bbox_list[i][1] -15
        y1 = base_bbox_list[i][2] +15
        y2 = base_bbox_list[i][3] +15
        
        nematode = image[x1:y1, x2:y2]

        output_single_nematode_path = os.path.join(single_nematode_folder, get_date+'_'+str(base_num)+'.png')
        nematode = img_as_ubyte(nematode)
        io.imsave(output_single_nematode_path, nematode)
        base_num = base_num + 1

In [None]:
sam_input_list = os.listdir(sam_input_folder)
sam_output_folder = r'Test\1;1;1;23;510;01\sam_output'
    
df = pd.read_csv(benchmark_output_csv)

df_bbox = df['nematode bbox']
df_id = df['nematode id'].values.tolist()


for folder_name in tqdm(sam_input_list):
    
    final_sam_path = os.path.join(sam_output_folder, folder_name + '.png')

    input_sam_crop_path = os.path.join(sam_input_folder, folder_name)
    final_sam_mask = np.zeros((3040, 4056)).astype('uint8')

    for i in range(len(df_id)):

        bbox = df_bbox[i]
        get_bbox = bbox[1:-1]
        minr, minc, maxr, maxc = get_bbox.split(',')
        x1 = int(minr) - 15
        x2 = int(minc) - 15
        y1 = int(maxr) + 15
        y2 = int(maxc) + 15

        input_crop_name = folder_name + '_' + str(df_id[i]) + '.png'
        get_crop_path = os.path.join(input_sam_crop_path, input_crop_name)
        image = io.imread(get_crop_path)
        if image.max() == 0:
            continue

        else:
            
            results = model(source=get_crop_path)
            result = results[0]
            masks = result.masks

            if masks == None:
                continue

            elif len(masks) == 1:

                mask1 = masks[0]
                get_mask1 = mask1.data[0].cpu().numpy()

                label1, num1 = ndi.label(get_mask1)
                regions1 = regionprops(label1)

                new_mask1 = np.zeros((masks.shape[1],masks.shape[2]))
                for region in regions1:

                    pi = math.pi   
                    perimeter1=region.perimeter
                    area1=region.area
                    diameter1=region.equivalent_diameter
                    radio1=perimeter1*diameter1/area1
                    roundness1 = (4*pi*area1)/(perimeter1*perimeter1)
                    if(1.05>roundness1>0.6):
                        coords=region.coords
                        new_mask1[coords[:,0],coords[:,1]]=1

                final_mask = new_mask1.astype('uint8')
                final_mask = remove_small_objects(final_mask, min_size=final_mask.sum()/5)
                get_final_mask = cv2.resize(final_mask, (result.orig_shape[1], result.orig_shape[0]))

            elif len(masks) == 2:

                mask1 = masks[0]
                get_mask1 = mask1.data[0].cpu().numpy()

                label1, num1 = ndi.label(get_mask1)
                regions1 = regionprops(label1)

                new_mask1 = np.zeros((masks.shape[1],masks.shape[2]))
                for region in regions1:

                    pi = math.pi   
                    perimeter1=region.perimeter
                    area1=region.area
                    diameter1=region.equivalent_diameter
                    radio1=perimeter1*diameter1/area1
                    roundness1 = (4*pi*area1)/(perimeter1*perimeter1)
                    if(1.05>roundness1>0.6):
                        coords=region.coords
                        new_mask1[coords[:,0],coords[:,1]]=1

                mask2 = masks[1]
                get_mask2 = mask2.data[0].cpu().numpy()

                label2, num2 = ndi.label(get_mask2)
                regions2 = regionprops(label2)

                new_mask2 = np.zeros((masks.shape[1],masks.shape[2]))
                for region in regions2:

                    pi = math.pi   
                    perimeter2=region.perimeter
                    area2=region.area
                    diameter2=region.equivalent_diameter
                    radio2=perimeter2*diameter2/area2
                    roundness2 = (4*pi*area2)/(perimeter2*perimeter2)
                    if(1.05>roundness2>0.6):
                        coords=region.coords
                        new_mask2[coords[:,0],coords[:,1]]=1

                and_mask = np.logical_and(new_mask1, new_mask2)
                if and_mask.sum()/new_mask1.sum()>0.8 or and_mask.sum()/new_mask1.sum()>0.8:
                    final_mask = and_mask.astype('uint8')
                    
                else:
                    final_mask = np.logical_or(new_mask1, new_mask2)
                    final_mask = remove_small_objects(final_mask, min_size=final_mask.sum()/5)
                    final_mask = clear_border(final_mask)
                
                final_mask = final_mask.astype('uint8')
                get_final_mask = cv2.resize(final_mask, (result.orig_shape[1], result.orig_shape[0]))

            elif len(masks) == 3:

                mask1 = masks[0]
                get_mask1 = mask1.data[0].cpu().numpy()

                label1, num1 = ndi.label(get_mask1)
                regions1 = regionprops(label1)

                new_mask1 = np.zeros((masks.shape[1],masks.shape[2]))
                for region in regions1:

                    pi = math.pi   
                    perimeter1=region.perimeter
                    area1=region.area
                    diameter1=region.equivalent_diameter
                    radio1=perimeter1*diameter1/area1
                    roundness1 = (4*pi*area1)/(perimeter1*perimeter1)
                    if(1.05>roundness1>0.6):
                        coords=region.coords
                        new_mask1[coords[:,0],coords[:,1]]=1

                mask2 = masks[1]
                get_mask2 = mask2.data[0].cpu().numpy()

                label2, num2 = ndi.label(get_mask2)
                regions2 = regionprops(label2)

                new_mask2 = np.zeros((masks.shape[1],masks.shape[2]))
                for region in regions2:

                    pi = math.pi   
                    perimeter2=region.perimeter
                    area2=region.area
                    diameter2=region.equivalent_diameter
                    radio2=perimeter2*diameter2/area2
                    roundness2 = (4*pi*area2)/(perimeter2*perimeter2)
                    if(1.05>roundness2>0.6):
                        coords=region.coords
                        new_mask2[coords[:,0],coords[:,1]]=1
                
                mask3 = masks[1]
                get_mask3 = mask3.data[0].cpu().numpy()

                label3, num3 = ndi.label(get_mask3)
                regions3 = regionprops(label3)

                new_mask3 = np.zeros((masks.shape[1],masks.shape[2]))
                for region in regions3:

                    pi = math.pi   
                    perimeter3=region.perimeter
                    area3=region.area
                    diameter3=region.equivalent_diameter
                    radio3=perimeter3*diameter3/area3
                    roundness3 = (4*pi*area3)/(perimeter3*perimeter3)
                    if(1.05>roundness3>0.6):
                        coords=region.coords
                        new_mask3[coords[:,0],coords[:,1]]=1
                

                final_mask = np.logical_or(new_mask1, new_mask2)
                final_mask = np.logical_or(final_mask, new_mask3)
                final_mask = remove_small_objects(final_mask, min_size=final_mask.sum()/5)
                final_mask = clear_border(final_mask)
            
                final_mask = final_mask.astype('uint8')
                get_final_mask = cv2.resize(final_mask, (result.orig_shape[1], result.orig_shape[0]))


            else:
                final_mask = np.zeros((masks.shape[1],masks.shape[2]))
                for i in range(len(masks)):
                    mask = masks[i]
                    get_mask = mask.data[0].cpu().numpy()

                    final_mask = np.logical_or(final_mask, get_mask)
                    
                final_mask = final_mask.astype('uint8')
                get_final_mask = cv2.resize(final_mask, (result.orig_shape[1], result.orig_shape[0]))

            final_full_mask = get_final_mask
            final_sam_mask[x1:y1, x2:y2] = final_full_mask
    
    io.imsave(final_sam_path, final_sam_mask*255)





