### Labeling stage 4 : The new black border (heavy carbon areas) labeling
This part is implemented to mask the "Black border" of the .mrc file and add the labels to the existed V3 masks with the feedback of existed model output. The model is confused with the black border and the particle

It generates two files:
1. .npy, the label of each single pixel of the .mrc file, inherent from the V1
2. .png, the visualization of the labeling

### The new labeling strategy

Cell: 1,2,3,4

Background: 0

Black Border: 6

In [None]:
# import the package
import numpy as np
import mrcfile
import matplotlib.pyplot as plt
from pathlib import Path
from skimage import measure, filters
import warnings
from tqdm import tqdm
import cv2

# ignore the warning message
warnings.filterwarnings('ignore')


class BlackBorderFinder:
    def __init__(self):
        """
        initialization
        """
        self.stages = ['stageI', 'stageII', 'stageIII', 'stageIV']
        self.stage_map = {'stageI': 'stage1',
                          'stageII': 'stage2',
                          'stageIII': 'stage3',
                          'stageIV': 'stage4'}

    def read_mrc(self, mrc_path):
        """
        process the .mrc file

        argument:
        mrc_path: the path of the .mrc file
        """
        # read the mrc file
        with mrcfile.open(mrc_path, permissive = True) as mrc:
            data = mrc.data.astype(np.float32)  # float for accuracy

        # min-max normalization
        data_min = data.min()
        data_max = data.max()
        # require uit8 for compatibility
        data = (((data - data_min) / (data_max - data_min) * 255).astype(np.uint8))
        return data

    def black_border_detection(self, image):
        """
        detect the black border

        argument:
        image: the input .mrc file
        """

        # downsampling, origin mrc too big
        h, w = image.shape
        ratio = min(2048 / h, 2048 / w)  # halve
        down_h, down_w = int(h * ratio), int(w * ratio)
        # downsize the image, use interarea
        down_image = cv2.resize(image, (down_w, down_h),
                                interpolation = cv2.INTER_AREA)

        # flood fill on the small graph
        small_mask = self._flood_fill(down_image)
        # to uint8 for cv2
        small_mask = small_mask.astype(np.uint8)

        # back to the original size
        border_mask = cv2.resize(small_mask, (w, h), interpolation = cv2.INTER_NEAREST)
        # to bool mask
        border_mask = border_mask.astype(bool)
        return border_mask

    def _flood_fill(self, image):
        """
        detect the black border with flood fill

        argument:
        image: the input .mrc file
        """
        # find the brightness threshold automatically
        threshold = filters.threshold_otsu(image)
        # mrc to binary, white or black
        binary = (image > threshold * 0.5).astype(np.uint8) * 255

        # morphological close operations, smooth, seamless
        # construct kernel
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11))
        # dilate first and then erode, fill holes less than 121 pixels
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)

        # Flood Fill
        h, w = image.shape
        # from center
        seed_x, seed_y = w // 2, h // 2
        seed_point = (seed_x, seed_y)

        # must start from white,255
        if binary[seed_y, seed_x] == 0:
            # clip 1/9
            center_region = binary[h // 3: 2 * h // 3, w // 3: 2 * w // 3]
            # find all 255
            all_y, all_x = np.where(center_region == 255)
            if len(all_y) > 0:
                # pick the middle one
                y = all_y[len(all_y) // 2]
                x = all_x[len(all_x) // 2]
                # from 1/9 to original picture
                seed_point = (x + w // 3, y + h // 3)

        # cv2 needs a bigger zero mask, avoid crossing the border
        mask = np.zeros((h + 2, w + 2), dtype = np.uint8)

        # copy
        filled = binary.copy()
        # flood fills the white area to 254
        cv2.floodFill(filled, mask, seed_point, (254,))

        # white area
        main_region = filled == 254

        # on the opposite, black border
        border_mask = ~main_region

        # store the true black border
        valid_border = np.zeros_like(border_mask)

        # these black border can be seen as connected regions
        # label each of them
        labels = measure.label(border_mask)

        # for one connected region
        for region in measure.regionprops(labels):
            # get bounding box
            row, col, max_row, max_col = region.bbox
            # the real black border is at the edge
            if row == 0 or col == 0 or max_row == h or max_col == w:
                valid_border[labels == region.label] = True

        return valid_border

    def visualize_integrated_result(self, mrc_image, original_mask, updated_mask, save_path = None):
        """
        visualize the result

        argument:
        mrc_image: the input .mrc file
        original_mask: the red mask of cellular area
        updated_mask: the upgraded mask with the black border
        """
        # 2 plots
        fig, axes = plt.subplots(1, 2, figsize = (12, 5))

        # .png with the cell
        axes[0].imshow(mrc_image, cmap = 'gray', alpha = 0.7)
        # greyscale to RGB
        h, w = mrc_image.shape
        red_overlay = np.zeros((h, w), dtype = np.float32)
        red_overlay = np.dstack((red_overlay, red_overlay, red_overlay))
        # red color on the cell
        cell = (original_mask == 1) | (original_mask == 2) | (original_mask == 3) | (original_mask == 4)
        red_overlay[cell] = [1, 0, 0]
        # show red color
        axes[0].imshow(red_overlay, alpha = 0.5)
        axes[0].set_title('Cells only(Red)', fontsize = 14)
        axes[0].axis('off')

        # flood fill
        axes[1].imshow(mrc_image, cmap = 'gray', alpha = 0.7)
        # rgb first
        black_overlay = np.zeros((h, w), dtype=np.float32)
        black_overlay = np.dstack((black_overlay, black_overlay, black_overlay))
        # red for cell
        cell = (updated_mask == 1) | (updated_mask == 2) | (updated_mask == 3) | (updated_mask == 4)
        black_overlay[cell] = [1, 0, 0]
        # blue for the border, here its label is 6
        black_overlay[updated_mask == 6] = [0.6, 0.8, 1]
        # show border
        axes[1].imshow(black_overlay, alpha = 0.5)
        axes[1].set_title('Cell(Red), Black edge(Light blue)', fontsize = 14)
        axes[1].axis('off')

        plt.tight_layout()

        # save the graph
        if save_path:
            plt.savefig(save_path, dpi = 150, bbox_inches = 'tight')
            plt.close()

    def process_single_file(self, mrc_path, npy_path, output_npy_path, output_png_path):
        """
        single file processing

        argument:
        mrc_path: the path of the .mrc file
        npy_path: the path of the .npy file
        output_npy_path: the path of the updated .npy file
        output_png_path: the path of the .png file
        """
        # mrc loading
        mrc_image = self.read_mrc(mrc_path)

        # npy loading
        original_mask = np.load(npy_path)

        # detect the black border
        border_mask = self.black_border_detection(mrc_image)

        # update the .npy
        # 6 for the border
        updated_mask = original_mask.copy().astype(np.int16)
        updated_mask[border_mask] = 6

        # save the new .npy
        np.save(output_npy_path, updated_mask)

        # visualize
        self.visualize_integrated_result(mrc_image, original_mask, updated_mask, output_png_path)

        # statistic
        # all pixels, 16777216
        total_pixels = mrc_image.size
        # cell
        cell = (original_mask == 1) | (original_mask == 2) | (original_mask == 3) | (original_mask == 4)
        cell_pixels = np.sum(cell)
        # border
        border_pixels = np.sum(border_mask)
        # overlap of cell and border, the less the good
        overlap_pixels = np.sum(cell & border_mask)

        return {'cell_percentage': (cell_pixels / total_pixels) * 100,
                'border_percentage': (border_pixels / total_pixels) * 100,
                'overlap_percentage': (overlap_pixels / total_pixels) * 100}

    def process_dataset(self, mrc_dir='Dataset-processed',
                        npy_dir='Image segmentation Level 1',
                        output_dir='Image segmentation Level 2'):
        """
        process the whole dataset

        argument:
        mrc_dir: the path that stores the .mrc files
        npy_dir: the path that stores the .npy files
        output_dir: the path for saving the new .npy and .png
        """
        # specify the mrc and npy path
        mrc_path = Path(mrc_dir)
        npy_path = Path(npy_dir)

        # construct the output path
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok = True)

        all_stats = []
        total_files = 0

        print("Start black border labeling")
        print()

        for mrc_stage, npy_stage in self.stage_map.items():
            # stage folder
            mrc_stage_path = mrc_path / mrc_stage
            npy_stage_path = npy_path / npy_stage

            # construct the new stage folders
            output_stage_path = output_path / npy_stage
            output_stage_path.mkdir(exist_ok = True)

            if not mrc_stage_path.exists() or not npy_stage_path.exists():
                print(f"Warning: {mrc_stage} or {npy_stage} not found")
                continue

            # find all mrc
            mrc_files = sorted(mrc_stage_path.glob('*.mrc'))
            print(f"\n {npy_stage}: Processing {len(mrc_files)} files")
            print()

            # find the svg according to mrc
            for mrc_file in tqdm(mrc_files, desc = f"Processing {mrc_stage}"):
                # get the name of npy
                npy_filename = 'mask_' + mrc_file.stem + '_v3.npy'
                npy_file = npy_stage_path / npy_filename

                if not npy_file.exists():
                    print(f"Warning: {npy_filename} not found, skipping")
                    continue

                # output .npy and png, the new V4 mask
                output_npy = output_stage_path / f"mask_{mrc_file.stem}_v4.npy"
                output_png = output_stage_path / f"mask_{mrc_file.stem}_v4.png"

                try:
                    # process a single pair of mrc and npy
                    statistic = self.process_single_file(mrc_file, npy_file, output_npy, output_png)

                    # save the stage and file name
                    statistic['stage'] = mrc_stage
                    statistic['file'] = mrc_file.name
                    all_stats.append(statistic)
                    total_files += 1

                except Exception as e:
                    print(f"Error processing {mrc_file}: {e}")
                    pass

            # the summary of each stage
            one_stage = [s for s in all_stats if s['stage'] == mrc_stage]
            avg_border = np.mean([s['border_percentage'] for s in one_stage])
            # the mean border percentage
            print(f"For stage {mrc_stage}: Average border percentage: {avg_border:.2f}%")

        print("\n" + "=" * 60)
        print(f"Process completed with {total_files} files")

        return all_stats

In [None]:
if __name__ == "__main__":
    integrator = BlackBorderFinder()

    stats = integrator.process_dataset(mrc_dir = 'Dataset-processed',
                                       npy_dir = 'Image segmentation Level 3',
                                       output_dir = 'Image segmentation Level 4')