In [None]:
# import 
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import cv2
import importlib

In [None]:
import Modules.Data.SegLabelPreprocess as SegLabelPreprocess

In [None]:
## processing configuration parameters

# NOTE: seg file path globs may need to adjusted according to the data location in your file system
seg_file_path_globs = [
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\01_ST\SEG\man_seg*.tif",
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\02_ST\SEG\man_seg*.tif"
]

erode_shape = (3,3)
erode_nof_itrs = 1
bkg_color = 0
broder_gauss_sigma = 5
broder_topk = 2
broder_omega0 = 10

In [None]:
## get all the data path

seg_file_paths = []

for cur_path_glob in seg_file_path_globs:
    seg_file_paths += glob.glob(cur_path_glob)

print(len(seg_file_paths))

In [None]:
## preview data processing configurations
importlib.reload(SegLabelPreprocess)

# load source segmentation data
check_i_file = 0
check_seg_file_path = seg_file_paths[check_i_file]
check_seg_image = cv2.imread(check_seg_file_path, -1)

plt.figure()
plt.imshow(check_seg_image, cmap = "Set3")
plt.colorbar()
plt.title("src lables")
plt.tight_layout()
plt.show()

# erode source segmentation data to seperate borders
check_erode_kernel = np.ones(erode_shape, dtype = check_seg_image.dtype)
check_erode_image = SegLabelPreprocess.erode_colored_labels(
    check_seg_image, 
    kernel = check_erode_kernel,
    nof_itrs = erode_nof_itrs, 
    bkg_color = bkg_color, 
)

plt.figure()
plt.imshow(check_erode_image, cmap = "Set3")
plt.colorbar()
plt.tight_layout()
plt.title("eroded lables")
plt.show()

# calculate class frequency balance weight for uncolored segmentation output 
check_erode_uncolored_image = np.zeros_like(check_erode_image)
check_erode_uncolored_image[check_erode_image != bkg_color] = 1

check_label_weights = SegLabelPreprocess.balanced_weight_colored_labels(check_erode_uncolored_image)

plt.figure()
plt.imshow(check_label_weights)
plt.colorbar()
plt.tight_layout()
plt.title("label weights")
plt.show()

# calculate border distance gaussian weight 
check_border_weights = SegLabelPreprocess.border_distance_gaussian_weight(
    check_erode_image, 
    bkg_color = bkg_color,
    sigma = broder_gauss_sigma,
    topk = broder_topk,
)

plt.figure()
plt.imshow(check_border_weights)
plt.colorbar()
plt.title("border weights")
plt.tight_layout()
plt.show()

# calculate total weights
check_tot_weights = check_label_weights + broder_omega0 * check_border_weights

plt.figure()
plt.imshow(check_tot_weights)
plt.colorbar()
plt.title("total weights")
plt.tight_layout()
plt.show()

In [None]:
## iterate through all the data

nof_files = len(seg_file_paths)
for i_file in range(nof_files):
    cur_seg_file_path = seg_file_paths[i_file]

    print(f"Processing {i_file}/{nof_files}")
    print("Processing " + cur_seg_file_path)
    
    cur_seg_file_dir_path, cur_seg_file_name = os.path.split(cur_seg_file_path)
    cur_seg_file_purename, cur_seg_file_ext = os.path.splitext(cur_seg_file_name)
    
    cur_erode_seg_file_dir_path = cur_seg_file_dir_path + r"_ERODE"
    cur_weight_file_dir_path = cur_seg_file_dir_path + r"_WEIGHT"

    if not os.path.isdir(cur_erode_seg_file_dir_path):
        os.makedirs(cur_erode_seg_file_dir_path)

    if not os.path.isdir(cur_weight_file_dir_path):
        os.makedirs(cur_weight_file_dir_path)

    cur_erode_seg_file_path = os.path.join(cur_erode_seg_file_dir_path, cur_seg_file_name)
    cur_weight_file_path = os.path.join(cur_weight_file_dir_path, cur_seg_file_name)

    if os.path.exists(cur_weight_file_path) and os.path.exists(cur_erode_seg_file_path):
        print("Result alreay exists!")
        continue

    # image processing starting from here
    cur_seg_image = cv2.imread(cur_seg_file_path, -1)
    
    cur_erode_kernel = np.ones(erode_shape, dtype = cur_seg_image.dtype)
    
    cur_erode_image = SegLabelPreprocess.erode_colored_labels(
        cur_seg_image, 
        kernel = cur_erode_kernel,
        nof_itrs = erode_nof_itrs, 
        bkg_color = bkg_color, 
    )

    cur_erode_uncolored_image = np.zeros_like(cur_erode_image)
    cur_erode_uncolored_image[cur_erode_image != bkg_color] = 1
    
    cur_label_weights = SegLabelPreprocess.balanced_weight_colored_labels(cur_erode_uncolored_image)
    
    cur_border_weights = SegLabelPreprocess.border_distance_gaussian_weight(
        cur_erode_image, 
        bkg_color = bkg_color,
        sigma = broder_gauss_sigma,
        topk = broder_topk,
    )
    
    cur_tot_weights = cur_label_weights + broder_omega0 * cur_border_weights

    # save result
    cv2.imwrite(cur_erode_seg_file_path, cur_erode_image)
    cv2.imwrite(cur_weight_file_path, cur_tot_weights)
    
    print("Erode Seg Saved to: " + cur_erode_seg_file_path)
    print("Weight Saved to: " + cur_weight_file_path)
