In [None]:
import glob
import os

import numpy as np
import plotly.express as px
import tifffile
from tqdm.contrib import tzip

# Specifying the dataset

In [None]:
specific = "test"
dataset_name = "MBSDS"

images_path = glob.glob(pathname=f"data/{dataset_name}/tiff/{specific}/*.tiff")
masks_path = glob.glob(pathname=f"data/{dataset_name}/tiff/{specific}_labels/*.tif")

print(f"{len(images_path)} images found")
print(f"{len(masks_path)} masks found")

# Aux functions

In [None]:
def split_array_by_resolution(
    array: np.ndarray, resolution: int, axis: int
) -> list[np.ndarray]:
    list_arrays = np.split(
        array, np.arange(resolution, array.shape[axis], resolution), axis=axis
    )

    return list_arrays

In [None]:
def fix_image_array_shape(
    image_array: np.ndarray, expected_resolution: int
) -> np.ndarray:
    fixed_image = np.pad(
        array=image_array,
        pad_width=(
            (0, expected_resolution - image_array.shape[0]),
            (0, expected_resolution - image_array.shape[1]),
            (0, 0),
        ),
        mode="constant",
    )

    return fixed_image

In [None]:
def split_image_array(image_array: np.ndarray, resolution: int) -> np.ndarray:
    first_split_images = split_array_by_resolution(
        array=image_array, resolution=resolution, axis=0
    )

    all_images = list()

    for image in first_split_images:
        second_split_images = split_array_by_resolution(
            array=image, resolution=resolution, axis=1
        )

        for img in second_split_images:
            if img.shape != (resolution, resolution, 3):
                img = fix_image_array_shape(
                    image_array=img, expected_resolution=resolution
                )

            all_images.append(img)

    return all_images

In [None]:
def check_valid_mask(mask: np.ndarray, threshold: float) -> bool:
    total_values = mask.shape[0] * mask.shape[1]
    num_ones = np.count_nonzero(mask[:, :, 0] == 1)

    if num_ones >= total_values * threshold:
        return True

    else:
        return False

In [None]:
def save_smaller_arrays(
    arrays: np.ndarray, array_path: str, is_label: bool, dataset_name: str
) -> None:
    image_type = "LABEL" if is_label else "RGB"

    folder_suffix = "/labels" if is_label else "/tiles"

    folder = array_path.split(sep="\\")[0] + folder_suffix

    if not os.path.exists(path=folder):
        os.makedirs(name=folder, exist_ok=True)

    for index, array in enumerate(arrays):
        image_code = array_path.split(sep="\\")[-1]
        image_code = image_code.replace("_", "")

        path = f"{folder}/{dataset_name}_{image_type}_{image_code}_p{index:04d}.tif"

        tifffile.imwrite(path, data=array)

# Main

In [None]:
for image_path, mask_path in tzip(images_path, masks_path):
    original_size_image_array = tifffile.imread(files=image_path)
    original_size_mask = tifffile.imread(files=mask_path)

    new_image_arrays_path = image_path[:-5].replace("tiff", "preprocessed")
    new_mask_arrays_path = (
        mask_path[:-5].replace("tiff", "preprocessed").replace("train_labels", "train")
    )

    smaller_image_arrays = split_image_array(
        image_array=original_size_image_array, resolution=256
    )

    smaller_mask_arrays = split_image_array(
        image_array=original_size_mask, resolution=256
    )

    images_array = list()
    masks_array = list()

    for image, mask in zip(smaller_image_arrays, smaller_mask_arrays):
        cliped_mask = np.clip(a=mask, a_min=0.0, a_max=1.0)

        if check_valid_mask(mask=cliped_mask, threshold=0.1):
            images_array.append(image)
            masks_array.append(cliped_mask)

    save_smaller_arrays(
        arrays=images_array,
        array_path=new_image_arrays_path,
        is_label=False,
        dataset_name=dataset_name,
    )

    save_smaller_arrays(
        arrays=masks_array,
        array_path=new_mask_arrays_path,
        is_label=True,
        dataset_name=dataset_name,
    )

In [None]:
px.imshow(smaller_image_arrays[11])