In [1]:
import argparse
import os
from pathlib import Path
from typing import List, Tuple, Dict
import sys
import random
import keras
import numpy as np
from PIL import Image

import tensorflow as tf
from keras import layers, Model, ops
from skimage.segmentation import find_boundaries
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.segmentation import watershed

In [None]:
from pathlib import Path

def read_labelmap(labelmap_path: Path): 
    """
    Reads a labelmap file, ignoring blank lines and lines starting with '#'.
    Returns two lists: names (labels) and colors (RGB tuples).
    """
    if not labelmap_path.exists():
        raise FileNotFoundError(f"File not found: {labelmap_path}")

    names, colors = [], []
    text = Path(labelmap_path).read_text(encoding="utf-8").splitlines()

    for raw in text:
        line = raw.strip()
        if not line or line.startswith("#"):
            continue
        if ":" not in line:
            raise ValueError(f"Missing colon in line: {line}")

        # Split once to get the name and the rest (color info)
        name, rest = line.split(":", 1)
        name = name.strip()

        # Take the first field (color only)
        color_field = rest.split(":", 1)[0]
        comps = color_field.split(",")

        if len(comps) != 3:
            raise ValueError(f"RGB must have 3 components: {line}")

        try:
            r, g, b = [int(c.strip()) for c in comps]
        except Exception as e:
            raise ValueError(f"Non-integer RGB values in line: {line}") from e

        names.append(name)
        colors.append((r, g, b))

    return names, colors


# ======= CALLING THE FUNCTION =======

# Replace this path with your actual file location
labelmap_path = Path("labelmap.txt")

# try:
#     names, colors = read_labelmap(labelmap_path)
#     print("Labelmap loaded successfully!")
#     print("Label names:", names)
#     print("RGB colors:", colors)
# except Exception as e:
#     print("Error reading labelmap:", e)


Labelmap loaded successfully!
Label names: ['background', 'cheetah', 'fox', 'hyena', 'lion', 'tiger', 'wolf']
RGB colors: [(0, 0, 0), (224, 64, 64), (160, 96, 64), (7, 73, 80), (32, 128, 224), (112, 192, 192), (102, 178, 138)]


In [4]:
def build_color_to_index(colors: List[Tuple[int,int,int]]) -> Dict[Tuple[int,int,int], int]:
    '''
    - input: [(0,0,0), (224,64,64), (160,96,64)]
    - output: {(0, 0, 0): 0, (224, 64, 64): 1, (160, 96, 64): 2}
    - FOR WHAT ?
    '''
    return {tuple(map(int, c)): i for i, c in enumerate(colors)}

In [5]:
def mask_rgb_to_index(mask_img: Image.Image, color_to_index: Dict[Tuple[int,int,int], int], ignore_index=255) -> np.ndarray:
    """
    Convert an RGB palette/truecolor mask (H,W,3) into class indices (H,W).
    Any pixel color not found in color_to_index becomes ignore_index.
    """
    m = np.array(mask_img.convert("RGB"), dtype=np.uint8)  # (H,W,3)
    h, w, _ = m.shape
    flat = m.reshape(-1, 3)
    out = np.full((h*w,), ignore_index, dtype=np.uint8)
    # Build a lookup by packing RGB into 24-bit int for speed
    keys = (flat[:,0].astype(np.int32) << 16) | (flat[:,1].astype(np.int32) << 8) | flat[:,2].astype(np.int32)
    lut = {}
    for (r,g,b), idx in color_to_index.items():
        lut[(r<<16) | (g<<8) | b] = idx
    # map
    for k, idx in lut.items():
        out[keys == k] = idx
    return out.reshape(h, w)

In [7]:
# Data wrapper
class MultiRootVOCDataset:
    """
    Read VOC-style segmentation from multiple dataset roots.
    Each root must contain:
      JPEGImages/, SegmentationClass/, ImageSets/Segmentation/train.txt or val.txt
    A single, unified (names, colors) defines the global classes.
    """
    def __init__(self, roots: List[str], image_set: str,
                 names: List[str], colors: List[Tuple[int,int,int]],
                 crop_size: int = 512, random_scale=(0.5, 2.0),
                 hflip_prob: float = 0.5, ignore_index: int = 255):
        super().__init__()
        self.roots = [Path(r) for r in roots]
        self.image_set = image_set
        self.names, self.colors = names, colors
        self.ignore_index = ignore_index
        self.crop_size, self.random_scale, self.hflip_prob = crop_size, random_scale, hflip_prob
        self.color_to_index = build_color_to_index(colors)

        # Build list of (root, id), look like 
        # [
        #   (Path(".../cheetah"), "2008_000123"),
        #   (Path(".../cheetah"), "2008_000456"),
        #   (Path(".../lion"),    "2011_003210"),
        #   ...
        # ]
        self.samples = []
        for root in self.roots:
            set_file = root / "ImageSets" / "Segmentation" / f"{image_set}.txt"
            ids = [s.strip() for s in set_file.read_text().splitlines() if s.strip()]
            for img_id in ids:
                self.samples.append((root, img_id))

        # Normalization (ImageNet stats)
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    def __len__(self): 
        return len(self.samples)

    def _load_sample(self, root: Path, img_id: str):
        img_dir, mask_dir = root / "JPEGImages", root / "SegmentationClass"
        img_path = img_dir / f"{img_id}.jpg"
        if not img_path.exists():
            alt = img_dir / f"{img_id}.png"
            img_path = alt if alt.exists() else img_path
        mask_path = mask_dir / f"{img_id}.png"

        image = Image.open(img_path).convert("RGB")
        mask_rgb = Image.open(mask_path)
        mask = mask_rgb_to_index(mask_rgb, self.color_to_index, ignore_index=self.ignore_index)  # (H,W) uint8
        return image, mask

    def _random_resize(self, img, mask):
        if self.random_scale:
            s = np.random.uniform(*self.random_scale)
            new_w, new_h = int(img.width * s), int(img.height * s)
            img = img.resize((new_w, new_h), Image.BILINEAR)
            mask = Image.fromarray(mask, mode="L").resize((new_w, new_h), Image.NEAREST)
            mask = np.array(mask, dtype=np.int64)
        return img, mask

    def _random_crop(self, img, mask):
        th, tw = self.crop_size, self.crop_size
        # Pad if needed
        if img.height < th or img.width < tw:
            pad_h, pad_w = max(0, th - img.height), max(0, tw - img.width)
            # left, top, right, bottom
            img = Image.fromarray(np.pad(np.array(img),
                                         ((0,pad_h),(0,pad_w),(0,0)),
                                         mode="constant", constant_values=0).astype(np.uint8))
            mask = np.pad(mask, ((0,pad_h),(0,pad_w)),
                          mode="constant", constant_values=self.ignore_index)

        # Random crop
        i = np.random.randint(0, img.height - th + 1)
        j = np.random.randint(0, img.width - tw + 1)
        img = img.crop((j, i, j+tw, i+th))
        mask = mask[i:i+th, j:j+tw]
        return img, mask

    def _hflip(self, img, mask):
        if np.random.rand() < self.hflip_prob:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask[:, ::-1]
        return img, mask

    def _center_crop_or_resize(self, img, mask):
        short = min(img.width, img.height)
        if short < self.crop_size:
            s = self.crop_size / short
            img = img.resize((int(img.width*s), int(img.height*s)), Image.BILINEAR)
            mask = Image.fromarray(mask, mode="L").resize((int(mask.shape[1]*s), int(mask.shape[0]*s)), Image.NEAREST)
            mask = np.array(mask, dtype=np.int64)
        # center crop
        th, tw = self.crop_size, self.crop_size
        i = max(0, (img.height - th)//2)
        j = max(0, (img.width - tw)//2)
        img = img.crop((j, i, j+tw, i+th))
        mask = mask[i:i+th, j:j+tw]
        return img, mask

    def get_item(self, idx):
        root, img_id = self.samples[idx]
        img, mask = self._load_sample(root, img_id)

        if self.image_set == "train":
            img, mask = self._random_resize(img, mask)
            img, mask = self._random_crop(img, mask)
            img, mask = self._hflip(img, mask)
        else:
            img, mask = self._center_crop_or_resize(img, mask)

        img_np = np.asarray(img, dtype=np.float32) / 255.0
        img_np = (img_np - self.mean) / self.std
        # to (H,W,3)
        mask_np = mask.astype(np.int64)
        return img_np, mask_np
