In [None]:
import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm 

input_dir = os.getcwd() + "\cat_dog_segmentation_dataset\segmentation"
output_dir = os.getcwd() + "\cat_dog_segmentation_dataset\masks"

In [4]:
def create_color_to_label_map(mask_folder_path):
    unique_colors = set()
    mask_files = os.listdir(mask_folder_path)

    for filename in tqdm(mask_files, desc="Finding unique colors"):
        mask_path = os.path.join(mask_folder_path, filename)
        mask_bgr = cv2.imread(mask_path)
        mask_rgb = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)

        pixels = mask_rgb.reshape(-1, 3)
        unique_pixel_colors = np.unique(pixels, axis=0)

        for color in unique_pixel_colors:
            unique_colors.add(tuple(color))
    
    sorted_colors = sorted(list(unique_colors))

    color_to_label = {color: label for label, color in enumerate(sorted_colors)}

    print("\nScan complete!")
    print(f"Found {len(color_to_label)} unique classes.")

    return color_to_label


COLOR_TO_LABEL = create_color_to_label_map(input_dir)

Finding unique colors:   0%|          | 0/150 [00:00<?, ?it/s]

Finding unique colors: 100%|██████████| 150/150 [00:28<00:00,  5.35it/s]


Scan complete!
Found 3 unique classes.





In [7]:
def encode_mask_to_grayscale(mask_path, color_map):
    """
    Converts an RGB segmentation mask to a grayscale mask with class labels.
    """
    mask_bgr = cv2.imread(mask_path)
    mask_rgb = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
    height, width, _ = mask_rgb.shape

    mask_grayscale = np.zeros((height, width), dtype=np.uint8)

    for color, label in color_map.items():
        matches = np.where(np.all(mask_rgb == color, axis=-1))
        mask_grayscale[matches] = label
    
    return mask_grayscale

In [6]:
os.makedirs(output_dir, exist_ok=True)

for filename in tqdm(os.listdir(input_dir), desc="Encoding masks"):
    mask_path = os.path.join(input_dir, filename)
    grayscale_mask = encode_mask_to_grayscale(mask_path, COLOR_TO_LABEL)

    save_path = os.path.join(output_dir, os.path.splitext(filename)[0] + ".png")
    cv2.imwrite(save_path, grayscale_mask)


Encoding masks: 100%|██████████| 150/150 [00:02<00:00, 72.03it/s]
