# 图像分割

In [None]:
import os

import cv2
from skimage import img_as_float
from skimage.filters import threshold_multiotsu
from skimage.measure import label, regionprops_table

import numpy as np
from sklearn.cluster import KMeans, DBSCAN, OPTICS, SpectralClustering as SpC

import matplotlib.pyplot as plt

In [None]:
class OtsuSeg():
    def __init__(self, image, clipLimit=5, tileGrid=4, ksize=5, sigma=9, classes=4, threshold=40000):
        r,g,b = cv2.split(image)
        r = (np.zeros(r.shape)).astype(np.uint8)
        self.image = cv2.merge((r,g,b))
        self.clipLimit = clipLimit
        self.tileGrid = tileGrid
        self.ksize = ksize
        self.sigma = sigma
        self.classes = classes
        self.threshold = threshold
    
    def Segmentation(self):
        labels = self.Otsu()
        binary_mask = (np.where(labels==0,0,1)).astype(np.uint8)
        labels_dissected = label(binary_mask, connectivity=2)
        
        properties = regionprops_table(labels_dissected, properties=['area'])
        data = np.asarray([properties['area']]).T
        
        mask = (np.zeros(labels.shape)).astype(np.uint8)
        for i in range(data.shape[0]):
            label_mask = (np.where(labels_dissected==i+1,1,0)).astype(np.uint8)
            if (data[i][0] > self.threshold):
                mask += label_mask
        
        mask_canvas = np.zeros((mask.shape[0]+4, mask.shape[1]+4)).astype(np.uint8)
        mask_canvas[2:-2,2:-2] = mask
        fill_mask = np.zeros((mask_canvas.shape[0]+2, mask_canvas.shape[1]+2)).astype(np.uint8)
        cv2.floodFill(mask_canvas, fill_mask, (0,0), 255)
        mask = mask | cv2.bitwise_not(mask_canvas)[2:-2,2:-2]   
        return mask
    
    def Otsu(self):
        image_enhanced = self.CLAHE()
        image_blurred = cv2.GaussianBlur(image_enhanced, (self.ksize, self.ksize), sigmaX = self.sigma)
        image_blurred = img_as_float(cv2.cvtColor(image_blurred, cv2.COLOR_RGB2GRAY))
        thresholds = threshold_multiotsu(image_blurred, classes = self.classes)    
        labels = np.digitize(image_blurred, bins=thresholds)
        return labels
    
    def CLAHE(self):
        lab_channels = cv2.split(cv2.cvtColor(self.image, cv2.COLOR_RGB2LAB))
        clahe = cv2.createCLAHE(clipLimit = self.clipLimit, tileGridSize = (self.tileGrid, self.tileGrid))
        lab_channels[0] = clahe.apply(lab_channels[0])
        image_enhanced = cv2.cvtColor(cv2.merge(lab_channels), cv2.COLOR_LAB2RGB)
        return image_enhanced

In [None]:
class SLICSeg():
    def __init__(self, image, algorithm=100, region_size=30, ruler=30.0, max_iter=10, min_element_size=40):
        self.image = image
        self.algorithm = algorithm
        self.region_size = region_size
        self.ruler = ruler
        self.max_iter = max_iter
        self.min_element_size = min_element_size

    def KMeans(self, num_cluster=3):  
        data = self.get_means(self)
        kmeans = KMeans(n_clusters = num_cluster).fit(data[:,1:])
        labels_new = kmeans.labels_
        return self.get_image(num_cluster, labels_new)
    
    def DBSCAN(self, eps=2, min_samples=5):
        data = self.get_means(self)
        labels_new = DBSCAN(eps = eps, min_samples = min_samples).fit_predict(data[:,1:])
        num_cluster = len(set(labels_new))
        return self.get_image(num_cluster, labels_new)
    
    def OPTICS(self, min_samples=26):
        data = self.get_means()
        labels_new = OPTICS(min_samples = min_samples).fit_predict(data[:,1:])
        num_cluster = len(set(labels_new))
        return self.get_image(num_cluster, labels_new)
    
    def SPC(self, num_cluster=4, gamma=1.0):
        data = self.get_means()
        labels_new = SpC(n_clusters = num_cluster, gamma = gamma).fit_predict(data[:,1:])
        return self.get_image(num_cluster, labels_new)
    
    def get_image(self, num_cluster, labels_new):
        new_label_mask = np.zeros((self.labels).shape)
        for i in range(self.num_labels): 
            new_label_mask += (np.where(self.labels==i,labels_new[i],0)).astype(np.uint8)
        
        image_stack = []
        for i in range(num_cluster):
            mask = (np.where(new_label_mask==i,1,0)).astype(np.uint8)
            image_masked = cv2.bitwise_and(self.image, self.image, mask=mask)
            image_stack.append(image_masked)    
        return np.asarray(image_stack)
    
    def get_means(self):
        self.SLIC()
        image_new = cv2.cvtColor(self.image, cv2.COLOR_RGB2LAB)
        data = np.zeros((self.num_labels,3))
        for i in range(self.num_labels):
            label_mask = (np.where(self.labels==i,1,0)).astype(np.uint8)
            means = cv2.mean(image_new, mask=label_mask)
            data[i] = means[:-1]
        return data
    
    def SLIC(self):
        slic = cv2.ximgproc.createSuperpixelSLIC(self.image, algorithm = self.algorithm, 
                                                 region_size = self.region_size, ruler = self.ruler) 
        slic.iterate(self.max_iter)
        slic.enforceLabelConnectivity(self.min_element_size)
#         visualization
#         mask_inv_slic = cv2.bitwise_not(slic.getLabelContourMask()) 
#         image_contour = cv2.bitwise_or(self.image, self.image, mask=mask_inv_slic)
#         plt.imshow(image_contour)
#         plt.show()
#         plt.close()        
        self.labels = slic.getLabels()
        self.num_labels = slic.getNumberOfSuperpixels()

In [None]:
src_path = "fig_sample_labo/"
save_path = 'fakepath/'

In [None]:
stack = []

for _, _, files_src in os.walk(src_path):
    files_src.sort()

for src_name in files_src:
    image = cv2.cvtColor(cv2.imread(src_path + src_name), cv2.COLOR_BGR2RGB)
    Seg1 = OtsuSeg(image)
    mask = Seg1.Segmentation()
    image_masked = cv2.bitwise_and(image,image,mask=mask)
    
    Seg2 = SLICSeg(image_masked)
#     利用下面的代码在stack里存储分割结果
    stack.append(Seg2.OPTICS(min_samples=4))