In [1]:
import os
import cv2
import numpy as np
import csv
from re import findall
import matplotlib.pyplot as plt
from matplotlib import colors

In [2]:
class ImgSegSLIC():
    def __init__(self, src_path, ref_path, algorithm, region_size, ruler, max_iter, min_element_size):
        self.src_path = src_path
        self.ref_path = ref_path
        self.algorithm = algorithm
        self.region_size = region_size
        self.ruler = ruler
        self.max_iter = max_iter
        self.min_element_size = min_element_size
        
        self.image_src = cv2.cvtColor(cv2.imread(src_path), cv2.COLOR_BGR2RGB)
        self.image_ref = cv2.cvtColor(cv2.imread(ref_path), cv2.COLOR_BGR2RGB)
        
        self.SLIC()
        
    def SLIC(self):
        slic = cv2.ximgproc.createSuperpixelSLIC(self.image_src, self.algorithm, self.region_size, self.ruler) 
        slic.iterate(self.max_iter)
        slic.enforceLabelConnectivity(self.min_element_size)
        slic_inv_mask = cv2.bitwise_not(slic.getLabelContourMask())
        
        self.labels = slic.getLabels()
        self.num_label = slic.getNumberOfSuperpixels()         
        self.image_contour = cv2.bitwise_or(src1=self.image_ref, src2=self.image_ref, mask=slic_inv_mask)
    
    def seg_stats(self, serial, image_masked, mask, writer):
        means, stds = cv2.meanStdDev(image_masked, mask=mask)
        m1, m2, m3 = np.around(means,4)
        s1, s2, s3 = np.around(stds,4)
        writer.writerow([serial, m1[0], m2[0], m3[0], s1[0], s2[0], s3[0]])
        return means, stds
    
    def find_threshold(self, out_path):   
        plt.figure(figsize = (16,12))
        plt.imshow(self.image_contour)
        
        file_rgb = out_path + ''.join(findall('\d',self.src_path)) + '_rgb' + '.csv'
        rgb_csv = open(file_rgb, mode='w')
        rgb_writer = csv.writer(rgb_csv, delimiter=',')
        rgb_writer.writerow(['serial','R_mean','G_mean','B_mean','R_std','G_std','B_std'])
        
        file_lab = out_path + ''.join(findall('\d',self.src_path)) + '_lab' + '.csv'
        lab_csv = open(file_lab, mode='w')
        lab_writer = csv.writer(lab_csv, delimiter=',')
        lab_writer.writerow(['serial','L_mean','A_mean','B_mean','L_std','A_std','B_std'])
        
        file_hsv = out_path + ''.join(findall('\d',self.src_path)) + '_hsv' + '.csv'
        hsv_csv = open(file_hsv, mode='w')
        hsv_writer = csv.writer(hsv_csv, delimiter=',')
        hsv_writer.writerow(['serial','H_mean','S_mean','V_mean','H_std','S_std','V_std'])
                
        for i in range(self.num_label):
            label_mask = (np.where(self.labels==i,1,0)).astype(np.uint8)

            image_masked_rgb = cv2.bitwise_and(self.image_src, self.image_src, mask=label_mask)
            image_masked_lab = cv2.cvtColor(image_masked_rgb,cv2.COLOR_RGB2LAB)
            image_masked_hsv = cv2.cvtColor(image_masked_rgb,cv2.COLOR_RGB2HSV)
            
            means_rgb, stats_rgb = self.seg_stats(i, image_masked_rgb, label_mask, rgb_writer)
            self.seg_stats(i, image_masked_lab, label_mask, lab_writer)
            self.seg_stats(i, image_masked_hsv, label_mask, hsv_writer)
            
            label_moments = cv2.moments(label_mask, binaryImage = True)
            label_centroid_x = int(label_moments["m10"] / label_moments["m00"])
            label_centroid_y = int(label_moments["m01"] / label_moments["m00"])

            rm, gm, bm = means_rgb
            text_color = colors.to_rgb(tuple([(255-rm[0])/255, (255-gm[0])/255, (255-bm[0])/255])) 
            plt.annotate(str(i),(label_centroid_x,label_centroid_y),color=text_color)
               
        file_annotation = out_path + ''.join(findall('\d',self.src_path)) + '_annotation' + '.jpg'
        plt.savefig(file_annotation)
        plt.close()
        
        rgb_csv.close()
        lab_csv.close()
        hsv_csv.close()

In [3]:
src_path = "region_map_ori/"
ref_path = "region_map_ref/"
out_path = "region_map_result_data/"

In [4]:
for _, _, files_src in os.walk(src_path):
    files_src.sort()

for _, _, files_ref in os.walk(ref_path):
    files_ref.sort()
    
for src_name in files_src:
    for ref_name in files_ref:
        src_idx = ''.join(findall('\d',src_name))
        ref_idx = ''.join(findall('\d',ref_name))
        
        if src_idx == ref_idx:
            Seg = ImgSegSLIC(src_path + src_name, ref_path + ref_name, algorithm=100, region_size=20, ruler=20.0, max_iter=10, min_element_size=40)
            Seg.find_threshold(out_path)

In [None]:
pred_path = 'region_map_result/'
merge_path = 'region_map_result_merged/'

In [None]:
for _, _, files_ref in os.walk(ref_path):
    files_ref.sort()

for _, dirs, files_pred in os.walk(pred_path):
    files_pred.sort()

for ref_name in files_ref:
    img_ref = cv2.imread(ref_path + ref_name)
    for pred_name in files_pred:
        if pred_name[:4] == ref_name[:4]:
            img_pred = cv2.imread(pred_path + pred_name)
            img_merged = cv2.addWeighted(img_ref, 0.5, img_pred, 0.5, 0) 
            cv2.imwrite(merge_path + 'merged_' + pred_name, img_merged)