In [82]:
import os
import cv2
import skimage
import numpy as np
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt
import SimpleITK as sitk

import utils
import orb

In [59]:
def is_chromosome(mask, threshold=200):
    return mask.sum()/mask.max() > 200

def get_bbox(mask):
    pos = np.where(mask)
    xmin = np.min(pos[0])
    xmax = np.max(pos[0])
    ymin = np.min(pos[1])
    ymax = np.max(pos[1])
    return xmin, ymin, xmax, ymax

def orb_match(src, tar):
    detector = cv2.ORB_create(10000)

    kpts1, desc1 = detector.detectAndCompute(src, None)
    kpts2, desc2 = detector.detectAndCompute(tar, None)

    matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
    matches = matcher.knnMatch(desc1, desc2, 2)

    return kpts1, kpts2, matches

def return_ransac_num(src, tar):
    kpts1, kpts2, matches = orb_match(src, tar)
    good = []
    for m, n in matches:
        if m.distance < 0.95 * n.distance:
            good.append(m)

    if len(good) >= 4:
        src_pts = np.float32([kpts1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
        dst_pts = np.float32([kpts2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)

        _, ransac_mask = cv2.estimateAffinePartial2D(src_pts, dst_pts)
        matches_mask = ransac_mask.ravel().tolist()
        return sum(matches_mask)
    return 0

def sub_chromosome_generator(img):
    k_s = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]], np.float32) # sharpen kernel
    k_e = np.ones((5, 5), np.uint8)                                        # erode kernel
    msk = utils.threshold_segm(img, 250)
    for prop in regionprops(label(msk)):
        sub_img = utils.show_spec_area(img, prop.bbox, 255)
        sub_msk = utils.show_spec_area(msk, prop.bbox, 0)

#         sub_img = cv2.filter2D(sub_img, -1, k_s)
        sub_msk = utils.fill_holes(sub_msk)
        if is_chromosome(sub_msk):
            sub_msk_erode = cv2.erode(sub_msk, k_e)
            
            labels = label(sub_msk_erode)
            props  = regionprops(labels)
            
            if len(props) > 1:
                obj_ids = np.unique(labels)
                obj_ids = obj_ids[1:]
                labels = labels == obj_ids[:, None, None]
                
                for ssub_msk in labels:
                    ssub_msk = cv2.dilate(255*ssub_msk.astype(np.uint8), k_e)
#                     xmin, xmax, ymin, ymax = get_bbox(ssub_msk)
                    if is_chromosome(ssub_msk):
                        bbox = get_bbox(ssub_msk)
                        ssub_img = utils.show_spec_area(sub_img, bbox, 255)
                        yield ssub_img, ssub_msk
            else:
                yield sub_img, sub_msk
                
# def get_sub_chromosome(src, dst, src_msk):
def get_sub_chromosome(src_img, src_msk, dst_img):

    # src: image_k
    # dst: image_a
    # return: bbox
    k_s = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]], np.float32) # sharpen kernel
    k_e = np.ones((5, 5), np.uint8)                                        # erode kernel
#     msk = utils.threshold_segm(img, 250)
#     msk = utils.fill_holes(msk)
#     plt.imshow(msk), plt.show()
#     exit()
    for prop in regionprops(label(src_msk)):
        sub_src_img = utils.show_spec_area(src_img, prop.bbox, 255)
        sub_src_msk = utils.show_spec_area(src_msk, prop.bbox, 0)

        if is_chromosome(sub_src_msk):
            esub_src_msk = cv2.erode(sub_src_msk, k_e)
            
            ssub_src_msks = label(esub_src_msk)
            props = regionprops(ssub_src_msks)
            
            if len(props) > 1:
                obj_ids = np.unique(ssub_src_msks)
                obj_ids = obj_ids[1:]
                ssub_src_msks = ssub_src_msks == obj_ids[:, None, None]
                
                for ssub_src_msk in ssub_src_msks:
                    ssub_src_msk = cv2.dilate(255*ssub_src_msk.astype(np.uint8), k_e)
#                     xmin, xmax, ymin, ymax = get_bbox(ssub_msk)
                    if is_chromosome(ssub_src_msk):
#                         kpts1, kpts2, matches
                        yield return_ransac_num(ssub_src_msk, dst_img), get_bbox(ssub_src_msk)
            else:
                yield return_ransac_num(sub_src_img, dst_img), prop.bbox

In [61]:
def clearBoxAndNum(imageK):

    image = imageK.copy()
    x, y = image.shape[:2]
    low_idxs, y_sum = [], []
    for i in range(x):
        y_sum.append(sum(image[i, :])/y)
        if y_sum[-1] < 100:
            low_idxs.append(i)

    for low_idx in low_idxs:
        image[low_idx:low_idx+40, :] = 255

    return image

def calcTheBottomPoint(imageK):
    theClassBottomPoint = []
    image = 255*(imageK > 0)
    x, y = image.shape[:2]
    for i in range(x):
        #print(sum(image[i, :])/y)
        if sum(image[i, :])/y < 150:
            for j in range(1, y):
                if image[i, j-1] and not image[i, j]:
                    j1 = j
                elif not image[i, j-1] and image[i, j]:
                    theClassBottomPoint.append([i, (j1+j)//2])
                    j1 = 0
    return np.array(theClassBottomPoint)

def findNearestPoint(theClassBottomPoint, bbox):
    minr, minc, maxr, maxc = bbox
    y = (maxr, (minc+maxc)//2)
    dis = [np.linalg.norm(x-y) for x in theClassBottomPoint]
    return np.argmin(dis)+1

In [90]:
src_path = './image/k/'
dst_path = './image/a/'

root = './test'

src_names = sorted(list(os.listdir(src_path)))
dst_names = sorted(list(os.listdir(dst_path)))

# k_erode = np.ones((5, 5), np.uint8)
k_sharpen = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]], np.float32)

for name in src_names:
    name = name.replace('.K.TIF', '')
    
    src_img_name = os.path.join(src_path, name+'.K.TIF')
    dst_img_name = os.path.join(dst_path, name+'.A.TIF')
    save_path = os.path.join(root, name)
    save_matches_path = os.path.join(save_path, 'matches')
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(save_matches_path, exist_ok=True)
    os.system(f'cp {src_img_name} {save_path}')
    os.system(f'cp {dst_img_name} {save_path}')
    
    dst_msk_name = os.path.join(save_path, name+'.A.msk.nrrd')
    
    src_img = cv2.imread(src_img_name, 0)
    dst_img = cv2.imread(dst_img_name, 0)
    dst_img = cv2.filter2D(dst_img, -1, k_sharpen)
 
    src_msk = utils.threshold_segm(src_img, 250)
    dst_msk = []
    
    cls_pts = calcTheBottomPoint(src_img)
    
    match_infos = [info for info in get_sub_chromosome(src_img, src_msk, dst_img)]
    match_infos.sort(key=lambda x: x[0], reverse=True)
    
    index = 1
    
    for _, bbox in match_infos:
        sub_src_img = utils.show_spec_area(src_img, bbox, 255)
        sub_src_msk = utils.show_spec_area(src_msk, bbox, 0)   
        
        cur_cls = findNearestPoint(cls_pts, bbox)
        
        kpts1, kpts2, matches = orb_match(sub_src_img, dst_img)
        
        good = [m for m, n in matches if m.distance < 0.95*n.distance]

        if len(good) >= 4:
            src_pts = np.float32([kpts1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
            dst_pts = np.float32([kpts2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)

            M, ransac_mask = cv2.estimateAffinePartial2D(src_pts, dst_pts)
            matches_mask = ransac_mask.ravel().tolist()
            
            draw_params = dict(matchColor = (0, 255, 0), # draw matches in green color
                               singlePointColor = None,
                               matchesMask = matches_mask, # draw only inliers
                               flags = 2)
            cv2.rectangle(src_img, (bbox[1], bbox[0]), (bbox[3], bbox[2]), (128,128,128), 1)
            
            matches_draw = cv2.drawMatches(src_img, kpts1, dst_img, kpts2, good, None, **draw_params)
#             plt.imshow(matches_draw), plt.show()
            cv2.imwrite(os.path.join(save_matches_path, f'{str(index).zfill(2)}.png'), matches_draw)
            index += 1
        
            rotate_msk = cv2.warpAffine(sub_src_msk, M, dst_img.shape[::-1])
            rotate_msk = rotate_msk.astype(bool)
            
            dst_msk.append(cur_cls*rotate_msk)
            dst_img[rotate_msk] = 255
#             plt.imshow(dst_img), plt.show()
    dst_msk = sitk.GetImageFromArray(np.array(dst_msk))
    sitk.WriteImage(dst_msk, dst_msk_name)