### The python script: Image segmentation preprocessing
#### Author: Siyu Liu
#### Univeristy of Bristol

### Labeling section 1: The correct particle label
This script is implemented to integrate the .mrc and .svg to reflect the particles label

It generates two files:
1. .npy, the label of each single pixel of the .mrc file
2. .png, the visualization of the labeling

In [None]:
# import the package
import numpy as np
import cv2
import mrcfile
from pathlib import Path
from typing import Tuple, List, Dict
import warnings
from tqdm import tqdm
import matplotlib.pyplot as plt
import cairosvg
from PIL import Image
import io

# ignore the warning message
warnings.filterwarnings('ignore')

class Segment_Labeller:
    def __init__(self, stage_label: int = 1):
        """
        initialization

        argument:
        stage_label: 1 or 2 or 3 or 4, depending on the stage label.
        """
        self.stage_label = stage_label

    def load_mrc_file(self, mrc_path: str) -> np.ndarray:
        """
        read the .mrc file

        argument:
        mrc_path: the path of the .mrc file

        return: the data of the .mrc file
        """
        try:
            # read the mrc filee
            with mrcfile.open(mrc_path, mode = 'r', permissive = True) as mrc:
                data = mrc.data.copy() # copy the data
                return data
        except Exception as e:
            raise RuntimeError(f"Failed to load .mrc file {mrc_path}: {e}")

    def parse_svg_cairosvg(self, svg_path: str, image_shape: Tuple[int, int]) -> np.ndarray:
        """
        rendering .svg with "cairosvg"

        argument:
        svg_path: the path of the .svg file
        image_shape: the shape of the target image

        return: the binary mask of the .svg file
        """
        try:
            # shift .svg to .png
            png_data = cairosvg.svg2png(url = svg_path)

            # convert to numpy array
            # "BytesIO" packages the bytes to a file-like object
            pil_image = Image.open(io.BytesIO(png_data)) # open it as an image
            svg_array = np.array(pil_image) # to NumPy array

            # find the red marks
            # several possible types
            if len(svg_array.shape) == 3:
                if svg_array.shape[2] >= 3: # RGB: (H, W, 3) and RGBA:(H, W, 4)
                    # select red mask
                    red_mask = (svg_array[:, :, 0] > 200) & \
                              (svg_array[:, :, 1] < 100) & \
                              (svg_array[:, :, 2] < 100)
                else:
                    # greyscale with (H, W, 2) or (H, W, 1)
                    red_mask = svg_array[:, :, 0] > 128 # find the red mask
            else:
                # 2D greyscale, (H,W) only
                red_mask = svg_array < 240 # elements that aren't white(255)

            # get red_mask, a bool mask
            # resize if needed
            if red_mask.shape != image_shape:
                red_mask_resized = cv2.resize(red_mask.astype(np.uint8), (image_shape[1], image_shape[0]),
                                             interpolation = cv2.INTER_NEAREST) > 0
            else:
                red_mask_resized = red_mask

            # return the mask, a bool mask
            return red_mask_resized

        except Exception as e:
            raise RuntimeError(f"Failed to parse the .svg file {svg_path}: {e}")

    def create_mask(self, mrc_path: str, svg_path: str) -> Tuple[np.ndarray, Dict]:
        """
        integrate the previous functions to create the mask

        argument:
        mrc_path: the path of the .mrc file
        svg_path: the path of the .svg file

        return the mask and the related statistics
        """
        # load the .mrc
        image = self.load_mrc_file(mrc_path)
        # load and parse the .svg
        svg_mask = self.parse_svg_cairosvg(svg_path, image.shape)

        # KEY: LABELLING SECTION
        # creates a full-zero mask with the size of .mrc
        segmentation_mask = np.zeros(image.shape, dtype = np.uint8)

        # shift the bool to 1/2/3/4, depending on the particle stage
        segmentation_mask[svg_mask] = self.stage_label

        # summary statistic
        total_pixels = image.size
        # how many particle pixels
        occupied_pixels = np.sum(segmentation_mask == self.stage_label)

        stats = {'occupied_pixels': occupied_pixels,
                 'occupied_ratio': occupied_pixels / total_pixels}

        return segmentation_mask, stats

    def save_visualization_png(self, original_image: np.ndarray, mask: np.ndarray, save_path: str):
        """
        save the labeling result as a PNG file

        argument:
        original_image: the .mrc NumPy array
        mask: the mask
        save_path: the path to save the PNG file

        return: the dual .png
        """
        # a dual panel storing the original picture and processed one
        fig, axes = plt.subplots(1, 2, figsize = (12, 5))

        # original picture
        axes[0].imshow(original_image, cmap = 'gray')
        axes[0].set_title('The raw image')
        axes[0].axis('off')

        # copy the .mrc
        overlay = original_image.copy()
        # if the format isn't uint8
        if overlay.dtype != np.uint8:
            # normalize it to [0,1], times with 255 to get [0,255], then convert it to uint8
            overlay = ((overlay - overlay.min()) / (overlay.max() - overlay.min()) * 255).astype(np.uint8)

        # overlay the mask to image
        overlay_rgb = np.stack([overlay, overlay, overlay], axis = -1) # stack to the three-channel grayscale
        colored_mask = np.zeros_like(overlay_rgb) # create a full-zero RGB
        colored_mask[mask == self.stage_label] = [255, 0, 0]  # fill red color to the particle position

        # add weights, 70% of plain and 30% of red color
        result = cv2.addWeighted(overlay_rgb, 0.7, colored_mask, 0.3, 0)
        axes[1].imshow(result) # show the result

        # statistic calcation
        occupied_ratio = np.sum(mask == self.stage_label) / mask.size * 100 # the percentage of cell pixel
        axes[1].set_title(f'Cell pixel occupation ({occupied_ratio:.1f}%)')
        axes[1].axis('off')

        plt.tight_layout()

        # save the .png
        plt.savefig(save_path, dpi = 150, bbox_inches = 'tight')
        plt.close()

    def batch_process(self, mrc_files: List[str], svg_files: List[str], output_dir: str) -> Dict:
        """
        process the batch of files

        argument:
        mrc_files: the path of all .mrc files in the batch
        svg_files: the path of all .svg files in the batch
        output_dir: output directory

        return: the batch statistics
        """
        # if the length of .mrc and .svg are different
        if len(mrc_files) != len(svg_files):
            raise ValueError(f"Discrepancy between .mrc and .svg: {len(mrc_files)} vs {len(svg_files)}")

        # specify the output path
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok = True, parents = True) # create

        # statistic initialization
        batch_stats = {'total_files': len(mrc_files),
                       'success_count': 0,
                       'failed_files': [],
                       'total_occupied_pixels': 0,
                       'total_pixels': 0}

        # batch process of each stage
        for i, (mrc_file, svg_file) in enumerate(tqdm(zip(mrc_files, svg_files),
                                                 desc = f"Process stage{self.stage_label}",
                                                 total = len(mrc_files))):
            try:
                # for single file, create mask
                mask, stats = self.create_mask(mrc_file, svg_file)

                # load the .mrc
                original_image = self.load_mrc_file(mrc_file)

                # extract the file name
                filename = Path(mrc_file).stem

                # save the labeling result to .npy
                npy_file = output_path / f"mask_{filename}_v1.npy"
                np.save(npy_file, mask)

                # save the .png file at the same time with the same output path
                png_file = output_path / f"mask_{filename}.png"
                self.save_visualization_png(original_image, mask, str(png_file))

                # update the statistic
                batch_stats['success_count'] += 1
                batch_stats['total_occupied_pixels'] += stats['occupied_pixels']
                batch_stats['total_pixels'] += mask.size

            # if the process failed
            except Exception as e:
                batch_stats['failed_files'].append({'mrc_file': mrc_file,
                                                    'svg_file': svg_file,
                                                    'error': str(e)})

        # calculate the overall occupied pixels ratio(cell pixels)
        if batch_stats['total_pixels'] > 0:
            batch_stats['overall_occupied_ratio'] = (
            batch_stats['total_occupied_pixels'] / batch_stats['total_pixels'])
        else:
            batch_stats['overall_occupied_ratio'] = 0.0

        # print the summary
        print(f"\n The stage{self.stage_label}batch process complete")
        print(f"Involved files: {batch_stats['total_files']}")
        print(f"Processed successfully: {batch_stats['success_count']}")
        print(f"Failure process: {len(batch_stats['failed_files'])}")
        if batch_stats['total_pixels'] > 0:
            print(f"The average percentage of mask pixel occupation: {batch_stats['overall_occupied_ratio'] * 100:.2f}%")
        print(f"The output dictionary: {output_path}")

        # print the failed files
        if batch_stats['failed_files']:
            print(f"\n Failed files:")
            for failed in batch_stats['failed_files'][:10]:  # show 10
                print(f"{Path(failed['mrc_file']).name}: {failed['error']}")

        return batch_stats

def process_all_stages(data_config: Dict[str, Dict], output_base_dir: str = "masks") -> Dict:
    """
    process the data of all stages

    argument：
    data_config: the data configuration of all stages, including the path of .mrc and .svg files
    output_base_dir: the output directory of all stages

    return: the general statistics of all stages
    """
    all_stats = {}

    for stage_label, files_config in data_config.items():
        print(f"\n Now processing the stage {stage_label}")

        # use labeller
        segmenter = Segment_Labeller(stage_label = stage_label)

        # create the stage folder
        stage_output_dir = Path(output_base_dir) / f"stage{stage_label}"

        # load all .mrc and .svg files in a single stage
        batch_stats = segmenter.batch_process(mrc_files = files_config['mrc_files'],
                                              svg_files = files_config['svg_files'],
                                              output_dir = str(stage_output_dir))

        all_stats[f'stage{stage_label}'] = batch_stats

    print(f"\n === All stages processed ===")

    # summarize the final statistic
    total_success = 0
    total_failed = 0

    # for each stage
    for stage_name, stats in all_stats.items():
        success = stats['success_count']
        failed = len(stats['failed_files'])
        print(f"{stage_name}: {success} success, {failed} failed")

        # sum the stage circumstance to the total
        total_success += success
        total_failed += failed

    print(f"\n In general: {total_success} success, {total_failed} failed.")
    return all_stats

In [None]:
def main():

    # the root path
    root_path = Path("Dataset-processed")
    if not root_path.exists():
        print(f"Failed to locate the root path {root_path}")
        return

    # construct the data configuration
    stage_folders = ["stageI", "stageII", "stageIII", "stageIV"]
    data_config = {}

    # find the stage folder
    for i, stage_folder in enumerate(stage_folders, 1):
        stage_path = root_path / stage_folder
        if not stage_path.exists():
            print(f"Failed to find the stage folder {stage_path}")
            continue

        # the file list in a single stage
        mrc_files = []
        svg_files = []

        # get the prefix of the file name
        stage_prefix = stage_folders[i-1] # the index starts from 0

        # search the file
        missing_count = 0
        for file_num in range(1, 101):
            mrc_file = stage_path / f"{stage_prefix}- {file_num:02d}.mrc"
            svg_file = stage_path / f"{stage_prefix}- {file_num:02d}.svg"

            # verify if the paired files exist
            if mrc_file.exists() and svg_file.exists():
                mrc_files.append(str(mrc_file))
                svg_files.append(str(svg_file))
            else:
                missing_count += 1

        # load the mrc file and svg file into the data configuration
        if mrc_files and svg_files:
            data_config[i] = {'mrc_files': mrc_files, 'svg_files': svg_files}
            print(f"For stage {i} ({stage_folder}): there are {len(mrc_files)} pairs of files in total", end = "")
            if missing_count > 0:
                print(f", and {missing_count} pairs of files are missing.")
            else:
                print()
        else:
            print(f"For stage {i} ({stage_folder}): No available files found.")

    # the total number of files(400 in this research)
    total_files = sum(len(config['mrc_files']) for config in data_config.values())
    print(f"The total number of files to be processed: {total_files}")

    # create the output dictionary
    output_base_dir = "Image segmentation Level 1"

    # process the file
    try:
        all_stats = process_all_stages(data_config, output_base_dir = output_base_dir)
    except Exception as e:
        print(f"\n Error encountered during the process: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()