In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
from typing import Tuple, List, Union, Literal

# -------------------------- CONSTANTS --------------------------
WATER_CLASS_ID = 2
NUM_CLASSES = 6  # Assuming 6 classes based on your earlier context
TILE_SIZE = 512  # Updated to match your latest context, adjust if needed

# Define color-to-class mapping (adjust based on your actual classes)
COLOR_TO_CLASS = {
    (0, 0, 0): 0,    # Background
    (255, 0, 0): 1,  # Class 1
    (0, 0, 255): 2,  # Water
    (0, 255, 0): 3,  # Class 3
    (255, 255, 0): 4,  # Class 4
    (255, 0, 255): 5,  # Class 5
}

# -------------------------- UTILS --------------------------

def decode_coloured_label(label_rgb: tf.Tensor) -> tf.Tensor:
    """Decodes an RGB label image into a single-channel integer class ID mask.

    This function maps pixel-wise RGB color values to their corresponding integer
    class IDs based on a predefined `COLOR_TO_CLASS` mapping. It is designed
    for use within TensorFlow graph operations.

    Args:
        label_rgb: A `tf.Tensor` of shape `(H, W, 3)` representing an RGB label image.
            Expected `dtype` is typically `tf.uint8` or convertible to `tf.uint8`.

    Returns:
        A `tf.Tensor` of shape `(H, W)` containing integer class IDs for each pixel.
    """
    label_rgb = tf.cast(label_rgb, tf.uint8)
    flat = tf.reshape(label_rgb, [-1, 3])
    keys = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    values = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)
    match = tf.reduce_all(tf.equal(tf.expand_dims(flat, 1), keys), axis=2)
    indices = tf.argmax(tf.cast(match, tf.int32), axis=1)
    return tf.reshape(indices, tf.shape(label_rgb)[:2])

def _load_npy(path: Union[tf.Tensor, str]) -> np.ndarray:
    """Loads a NumPy .npy file from a given path.

    This helper function is designed to be wrapped by `tf.py_function` for use
    within TensorFlow `tf.data` pipelines, allowing NumPy operations to run in
    the graph.

    Args:
        path: A string representing the file path to the .npy file, or a
            `tf.Tensor` containing the string path (which will be decoded).

    Returns:
        A `numpy.ndarray` loaded from the .npy file, cast to `np.float32`.
    """
    path = path.numpy().decode("utf-8") if isinstance(path, tf.Tensor) else path
    return np.load(path).astype(np.float32)

def load_image_paths(
    df: pd.DataFrame,
    image_dir: str,
    elev_dir: str,
    label_dir: str,
) -> Tuple[List[str], List[str], List[str], List[str]]:
    """Constructs full file paths for RGB images, elevation, and label masks.

    Paths are generated based on tile IDs from a pandas DataFrame and specified
    base directories for each modality.

    Args:
        df: A `pandas.DataFrame` expected to contain a "tile_id" column,
            where each entry is the unique identifier for an image tile.
        image_dir: The base directory path where RGB image files (e.g., `*-ortho.png`)
            are stored.
        elev_dir: The base directory path where elevation data files (e.g., `*-elev.npy`)
            are stored.
        label_dir: The base directory path where label mask files (e.g., `*-label.png`)
            are stored.

    Returns:
        A tuple containing four lists:
        - `image_paths`: List of full file paths for RGB images.
        - `elev_paths`: List of full file paths for elevation data.
        - `label_paths`: List of full file paths for label masks.
        - `tile_ids`: List of original tile IDs from the input DataFrame.
    """
    tile_ids = df["tile_id"].tolist()
    image_paths = [os.path.join(image_dir, f"{tid}-ortho.png") for tid in tile_ids]
    elev_paths = [os.path.join(elev_dir, f"{tid}-elev.npy") for tid in tile_ids]
    label_paths = [os.path.join(label_dir, f"{tid}-label.png") for tid in tile_ids]
    return image_paths, elev_paths, label_paths, tile_ids

# -------------------------- AUGMENTATION FUNCTIONS --------------------------

def augment_rgb_label(rgb: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """Applies geometric and color augmentations to RGB images and labels.

    Geometric augmentations (flip, rotate) are applied identically to both the
    RGB image and its corresponding label mask to maintain pixel-wise correspondence.
    Color augmentations (brightness, contrast, saturation, hue) are applied only
    to the RGB image. All augmentations are applied stochastically based on
    random uniform probabilities.

    Args:
        rgb: A `tf.Tensor` representing the RGB image `(H, W, 3)`, typically `tf.float32`
            in the range [0, 1].
        label: A `tf.Tensor` representing the label mask `(H, W)`, typically `tf.int32`
            (integer class IDs).

    Returns:
        A `tuple` containing two `tf.Tensor` objects:
        - The augmented RGB image `(H, W, 3)`.
        - The augmented label mask `(H, W)`.
    """
    label = tf.expand_dims(label, axis=-1)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        label = tf.image.flip_up_down(label)
    if tf.random.uniform([]) > 0.25:  # Adjusted from 0.0 for more controlled rotation
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        label = tf.image.rot90(label, k)
    if tf.random.uniform([]) > 0.75:
        rgb = tf.image.random_brightness(rgb, max_delta=0.025)
    if tf.random.uniform([]) > 0.75:
        rgb = tf.image.random_contrast(rgb, lower=0.97, upper=1.03)
    if tf.random.uniform([]) > 0.75:
        rgb = tf.image.random_saturation(rgb, lower=0.97, upper=1.03)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_hue(rgb, max_delta=0.025)
    return rgb, tf.squeeze(label, axis=-1)

def augment_rgb_elev(rgb: tf.Tensor, elev: tf.Tensor, label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """Applies geometric augmentations to RGB, elevation, and labels; color augmentations to RGB only.

    Geometric augmentations (flip, rotate) are applied identically to all input
    modalities (RGB, elevation, label) to maintain spatial consistency. Color
    augmentations are applied only to the RGB image. All augmentations are applied
    stochastically based on random uniform probabilities.

    Args:
        rgb: A `tf.Tensor` representing the RGB image `(H, W, 3)`, typically `tf.float32`
            in the range [0, 1].
        elev: A `tf.Tensor` representing the elevation data `(H, W, 1)`, typically `tf.float32`.
        label: A `tf.Tensor` representing the label mask `(H, W)`, typically `tf.int32`
            (integer class IDs).

    Returns:
        A `tuple` containing three `tf.Tensor` objects:
        - The augmented RGB image `(H, W, 3)`.
        - The augmented elevation data `(H, W, 1)`.
        - The augmented label mask `(H, W)`.
    """
    label = tf.expand_dims(label, axis=-1)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        elev = tf.image.flip_left_right(elev)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        elev = tf.image.flip_up_down(elev)
        label = tf.image.flip_up_down(label)
    if tf.random.uniform([]) > 0.25:
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        elev = tf.image.rot90(elev, k)
        label = tf.image.rot90(label, k)
    if tf.random.uniform([]) > 0.75:
        rgb = tf.image.random_brightness(rgb, max_delta=0.025)
    if tf.random.uniform([]) > 0.75:
        rgb = tf.image.random_contrast(rgb, lower=0.97, upper=1.03)
    if tf.random.uniform([]) > 0.75:
        rgb = tf.image.random_saturation(rgb, lower=0.97, upper=1.03)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_hue(rgb, max_delta=0.025)
    return rgb, elev, tf.squeeze(label, axis=-1)

# -------------------------- CUTMIX FUNCTIONS --------------------------

def _apply_water_cutmix_rgb_only(
    rgb: tf.Tensor,
    label: tf.Tensor,
    tile_size: int,
    water_class_id: int,
    cutmix_prob: float = 0.9,
    patch_size_range: Tuple[int, int] = (64, 64),
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Applies self-CutMix for the water class in an RGB + label setting.

    This augmentation technique attempts to copy a random water-containing patch
    from the current image and paste it onto another random location within the
    same image. This helps to increase the representation and diversity of water
    pixels within the training data, especially for underrepresented classes.

    The operation is stochastic and includes checks to ensure valid patch selection
    and a minimum amount of water content within the selected patch.

    Args:
        rgb: A `tf.Tensor` representing the RGB image, with shape `(H, W, 3)` and
            `dtype=tf.float32` (typically normalized to [0, 1]).
        label: A `tf.Tensor` representing the integer label mask, with shape `(H, W)`
            and `dtype=tf.int32` (class IDs).
        tile_size: An `int` indicating the size of the square image tile (e.g., 512).
        water_class_id: An `int` representing the class index for 'water'.
        cutmix_prob: A `float` between 0 and 1, indicating the probability that CutMix
            will be applied to a given image.
        patch_size_range: A `tuple` of two integers `(min_size, max_size)` defining
            the inclusive range for the side length of the square patch to be cut and
            pasted.

    Returns:
        A `tuple` containing two `tf.Tensor` objects:
        - The augmented RGB image `(H, W, 3)`.
        - The augmented label mask `(H, W)`.
        If CutMix is not applied (due to probability or validation checks), the
        original `rgb` and `label` tensors are returned.
    """
    def return_original():
        return rgb, label

    def do_cutmix():
        water_mask = tf.cast(tf.equal(label, water_class_id), tf.float32)
        water_coords = tf.cast(tf.where(water_mask > 0), tf.int32)
        num_water_pixels = tf.shape(water_coords)[0]
        idx_max = tf.maximum(1, num_water_pixels)
        idx = tf.random.uniform([], 0, idx_max, dtype=tf.int32)

        center_y = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][0], lambda: tf.constant(0, dtype=tf.int32))
        center_x = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][1], lambda: tf.constant(0, dtype=tf.int32))

        patch_size = tf.random.uniform([], patch_size_range[0], patch_size_range[1] + 1, dtype=tf.int32)
        y1_src = tf.maximum(0, center_y - patch_size // 2)
        x1_src = tf.maximum(0, center_x - patch_size // 2)
        y2_src = tf.minimum(tile_size, y1_src + patch_size)
        x2_src = tf.minimum(tile_size, x1_src + patch_size)
        y1_src = tf.maximum(0, y2_src - patch_size)
        x1_src = tf.maximum(0, x2_src - patch_size)

        patch_height = y2_src - y1_src
        patch_width = x2_src - x1_src
        min_pixels_for_patch = (patch_size_range[0] * patch_size_range[0]) // 4

        condition = tf.logical_or(
            tf.less(num_water_pixels, min_pixels_for_patch),
            tf.logical_or(tf.less_equal(patch_height, 0), tf.less_equal(patch_width, 0))
        )

        def return_early():
            return rgb, label

        def do_patch():
            rgb_patch = rgb[y1_src:y2_src, x1_src:x2_src, :]
            label_patch = label[y1_src:y2_src, x1_src:x2_src]
            patch_water_mask = tf.cast(tf.equal(label_patch, water_class_id), tf.float32)
            required_patch_pixels = tf.cast(patch_height * patch_width / 4, tf.float32)
            patch_water_mask_sum = tf.reduce_sum(patch_water_mask)

            def return_due_to_low_water():
                return rgb, label

            def paste_patch():
                dest_y_max = tf.maximum(1, tile_size - patch_height + 1)
                dest_x_max = tf.maximum(1, tile_size - patch_width + 1)
                dest_y = tf.random.uniform([], 0, dest_y_max, dtype=tf.int32)
                dest_x = tf.random.uniform([], 0, dest_x_max, dtype=tf.int32)

                indices_rgb = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(3), indexing='ij'
                ), axis=-1)
                rgb_out = tf.tensor_scatter_nd_update(
                    rgb, tf.reshape(indices_rgb, [-1, 3]), tf.reshape(rgb_patch, [-1])
                )

                indices_label = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width), indexing='ij'
                ), axis=-1)
                label_out = tf.tensor_scatter_nd_update(
                    label, tf.reshape(indices_label, [-1, 2]), tf.reshape(label_patch, [-1])
                )

                return rgb_out, label_out

            return tf.cond(patch_water_mask_sum < required_patch_pixels, return_due_to_low_water, paste_patch)

        return tf.cond(condition, return_early, do_patch)

    return tf.cond(tf.random.uniform([]) < cutmix_prob, do_cutmix, return_original)

def _apply_water_cutmix_rgb_elev(
    rgb: tf.Tensor,
    elev: tf.Tensor,
    label: tf.Tensor,
    tile_size: int,
    water_class_id: int,
    cutmix_prob: float = 0.9,
    patch_size_range: Tuple[int, int] = (64, 64),
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """Applies self-CutMix for the water class in an RGB + elevation + label setting.

    This augmentation technique attempts to copy a random water-containing patch
    (including its corresponding elevation and label data) from the current image
    and paste it onto another random location within the same image. This helps
    to increase the representation and diversity of water pixels within the
    training data, especially for underrepresented classes.

    The operation is stochastic and includes checks to ensure valid patch selection
    and a minimum amount of water content within the selected patch. All modalities
    are updated consistently.

    Args:
        rgb: A `tf.Tensor` representing the RGB image, with shape `(H, W, 3)` and
            `dtype=tf.float32` (typically normalized to [0, 1]).
        elev: A `tf.Tensor` representing the elevation data, with shape `(H, W, 1)`
            and `dtype=tf.float32`.
        label: A `tf.Tensor` representing the integer label mask, with shape `(H, W)`
            and `dtype=tf.int32` (class IDs).
        tile_size: An `int` indicating the size of the square image tile (e.g., 512).
        water_class_id: An `int` representing the class index for 'water'.
        cutmix_prob: A `float` between 0 and 1, indicating the probability that CutMix
            will be applied to a given image.
        patch_size_range: A `tuple` of two integers `(min_size, max_size)` defining
            the inclusive range for the side length of the square patch to be cut and
            pasted.

    Returns:
        A `tuple` containing three `tf.Tensor` objects:
        - The augmented RGB image `(H, W, 3)`.
        - The augmented elevation data `(H, W, 1)`.
        - The augmented label mask `(H, W)`.
        If CutMix is not applied (due to probability or validation checks), the
        original input tensors are returned.
    """
    def return_original():
        return rgb, elev, label

    def do_cutmix():
        water_mask = tf.cast(tf.equal(label, water_class_id), tf.float32)
        water_coords = tf.cast(tf.where(water_mask > 0), tf.int32)
        num_water_pixels = tf.shape(water_coords)[0]
        idx_max = tf.maximum(1, num_water_pixels)
        idx = tf.random.uniform([], 0, idx_max, dtype=tf.int32)

        center_y = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][0], lambda: tf.constant(0, dtype=tf.int32))
        center_x = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][1], lambda: tf.constant(0, dtype=tf.int32))

        patch_size = tf.random.uniform([], patch_size_range[0], patch_size_range[1] + 1, dtype=tf.int32)
        y1_src = tf.maximum(0, center_y - patch_size // 2)
        x1_src = tf.maximum(0, center_x - patch_size // 2)
        y2_src = tf.minimum(tile_size, y1_src + patch_size)
        x2_src = tf.minimum(tile_size, x1_src + patch_size)
        y1_src = tf.maximum(0, y2_src - patch_size)
        x1_src = tf.maximum(0, x2_src - patch_size)

        patch_height = y2_src - y1_src
        patch_width = x2_src - x1_src
        min_pixels_for_patch = (patch_size_range[0] * patch_size_range[0]) // 4

        condition = tf.logical_or(
            tf.less(num_water_pixels, min_pixels_for_patch),
            tf.logical_or(tf.less_equal(patch_height, 0), tf.less_equal(patch_width, 0))
        )

        def return_early():
            return rgb, elev, label

        def do_patch():
            rgb_patch = rgb[y1_src:y2_src, x1_src:x2_src, :]
            elev_patch = elev[y1_src:y2_src, x1_src:x2_src, :]
            label_patch = label[y1_src:y2_src, x1_src:x2_src]
            patch_water_mask = tf.cast(tf.equal(label_patch, water_class_id), tf.float32)
            required_patch_pixels = tf.cast(patch_height * patch_width / 4, tf.float32)
            patch_water_mask_sum = tf.reduce_sum(patch_water_mask)

            def return_due_to_low_water():
                return rgb, elev, label

            def paste_patch():
                dest_y_max = tf.maximum(1, tile_size - patch_height + 1)
                dest_x_max = tf.maximum(1, tile_size - patch_width + 1)
                dest_y = tf.random.uniform([], 0, dest_y_max, dtype=tf.int32)
                dest_x = tf.random.uniform([], 0, dest_x_max, dtype=tf.int32)

                indices_rgb = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(3), indexing='ij'
                ), axis=-1)
                rgb_out = tf.tensor_scatter_nd_update(
                    rgb, tf.reshape(indices_rgb, [-1, 3]), tf.reshape(rgb_patch, [-1])
                )

                indices_elev = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(1), indexing='ij'
                ), axis=-1)
                elev_out = tf.tensor_scatter_nd_update(
                    elev, tf.reshape(indices_elev, [-1, 3]), tf.reshape(elev_patch, [-1])
                )

                indices_label = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width), indexing='ij'
                ), axis=-1)
                label_out = tf.tensor_scatter_nd_update(
                    label, tf.reshape(indices_label, [-1, 2]), tf.reshape(label_patch, [-1])
                )

                return rgb_out, elev_out, label_out

            return tf.cond(patch_water_mask_sum < required_patch_pixels, return_due_to_low_water, paste_patch)

        return tf.cond(condition, return_early, do_patch)

    return tf.cond(tf.random.uniform([]) < cutmix_prob, do_cutmix, return_original)

# -------------------------- PARSING FUNCTIONS --------------------------

def parse_tile(
    rgb_path: tf.Tensor,
    label_path: tf.Tensor,
    tile_id: tf.Tensor,  # `tile_id` not used in this function, but passed in dataset map_fn.
    split: str = 'train',
    augment: bool = False,
    tile_size: int = TILE_SIZE,
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses a single RGB image and its corresponding label mask.

    This function reads the image and label files from disk, decodes them,
    resizes them to a uniform `tile_size`, and applies data augmentations
    (geometric, color, and self-CutMix for water) if `split` is 'train' and
    `augment` is True. Finally, it converts the label mask to a one-hot encoding.

    Args:
        rgb_path: A `tf.Tensor` (string) representing the file path to the RGB image.
        label_path: A `tf.Tensor` (string) representing the file path to the label image.
        tile_id: A `tf.Tensor` (string) representing the ID of the tile. This argument
            is received from the dataset but not directly used within this function.
        split: A `str` indicating the dataset split ('train', 'val', or 'test').
            Augmentations are typically only applied for the 'train' split.
        augment: A `bool` flag indicating whether to apply data augmentations.
        tile_size: An `int` specifying the desired height and width for the output
            square image and label tensors.

    Returns:
        A `tuple` containing two `tf.Tensor` objects:
        - The processed RGB image `(tile_size, tile_size, 3)`, `tf.float32`.
        - The one-hot encoded label mask `(tile_size, tile_size, NUM_CLASSES)`,
            `tf.float32`.
    """
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)

    if split == 'train' and augment:
        rgb, label = augment_rgb_label(rgb, label)
        rgb, label = _apply_water_cutmix_rgb_only(rgb, label, tile_size, WATER_CLASS_ID)

    label = tf.one_hot(label, depth=NUM_CLASSES)
    return rgb, label

def parse_elevation(
    rgb_path: tf.Tensor,
    elev_path: tf.Tensor,
    label_path: tf.Tensor,
    tile_id: tf.Tensor,  # `tile_id` not used in this function, but passed in dataset map_fn.
    split: str = 'train',
    augment: bool = False,
    tile_size: int = TILE_SIZE,
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses RGB, elevation, and label data for a single tile.

    This function reads and decodes the image and NumPy files for RGB, elevation,
    and label masks. It then resizes all modalities, applies data augmentations
    (geometric, color, and self-CutMix for water) if `split` is 'train' and
    `augment` is True, and finally concatenates the input modalities and one-hot
    encodes the label mask.

    Args:
        rgb_path: A `tf.Tensor` (string) representing the file path to the RGB image.
        elev_path: A `tf.Tensor` (string) representing the file path to the
            elevation `.npy` file.
        label_path: A `tf.Tensor` (string) representing the file path to the label image.
        tile_id: A `tf.Tensor` (string) representing the ID of the tile. This argument
            is received from the dataset but not directly used within this function.
        split: A `str` indicating the dataset split ('train', 'val', or 'test').
            Augmentations are typically only applied for the 'train' split.
        augment: A `bool` flag indicating whether to apply data augmentations.
        tile_size: An `int` specifying the desired height and width for the output
            square tensors.

    Returns:
        A `tuple` containing two `tf.Tensor` objects:
        - `input_image`: The concatenated input image `(tile_size, tile_size, 4)`,
            `tf.float32` (RGB (3) + Elevation (1) channels).
        - `label`: The one-hot encoded label mask `(tile_size, tile_size, NUM_CLASSES)`,
            `tf.float32`.
    """
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    elev = tf.py_function(_load_npy, [elev_path], tf.float32)
    elev.set_shape([None, None])
    elev = tf.expand_dims(elev, axis=-1)

    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    elev = tf.image.resize(elev, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)

    if split == 'train' and augment:
        rgb, elev, label = augment_rgb_elev(rgb, elev, label)
        rgb, elev, label = _apply_water_cutmix_rgb_elev(rgb, elev, label, tile_size, WATER_CLASS_ID)

    input_image = tf.concat([rgb, elev], axis=-1)  # 4 channels: RGB (3) + Elevation (1)
    label = tf.one_hot(label, depth=NUM_CLASSES)
    return input_image, label

# -------------------------- DATASET BUILDER --------------------------

def build_tf_dataset(
    df: pd.DataFrame,
    image_dir: str,
    elev_dir: str,
    label_dir: str,
    input_type: Literal['rgb', 'rgb_elev'] = 'rgb',
    batch_size: int = 32,
    split: Literal['train', 'val', 'test', 'custom'] = 'train',  # Added 'custom' to align with usage
    augment: bool = False,
    shuffle: bool = True,
    tile_size: int = TILE_SIZE,
) -> tf.data.Dataset:
    """Builds a TensorFlow Dataset for semantic segmentation.

    This function orchestrates the data loading, parsing, augmentation, batching,
    and prefetching to create an efficient input pipeline for training or
    evaluation of a semantic segmentation model. It supports different input
    modalities (RGB only or RGB + Elevation).

    Args:
        df: A `pandas.DataFrame` containing a "tile_id" column, which identifies
            the unique tiles to be processed.
        image_dir: The base directory path where RGB image files are stored.
        elev_dir: The base directory path where elevation data files (`.npy`) are stored.
            Required if `input_type` is 'rgb_elev'.
        label_dir: The base directory path where label mask files are stored.
        input_type: A `str` specifying the input modality for the dataset.
            Must be either 'rgb' (3 channels) or 'rgb_elev' (4 channels: RGB + Elevation).
        batch_size: An `int` specifying the number of samples per batch.
        split: A `str` indicating the dataset split ('train', 'val', 'test', or 'custom').
            This string influences whether data augmentations are applied (`augment=True`
            and `split='train'`). 'custom' can be used for reconstruction where augmentations
            should be off.
        augment: A `bool` flag indicating whether to apply data augmentations.
            Augmentations are only applied if `split` is 'train'.
        shuffle: A `bool` flag indicating whether to shuffle the dataset.
            Typically `True` for training sets and `False` for validation/test sets.
        tile_size: An `int` specifying the desired height and width for the output
            square image tiles.

    Returns:
        A `tf.data.Dataset` object ready for consumption by a TensorFlow model.

    Raises:
        ValueError: If an unsupported `input_type` is provided.
    """
    image_paths, elev_paths, label_paths, tile_ids = load_image_paths(df, image_dir, elev_dir, label_dir)

    if input_type == 'rgb':
        # Dataset for RGB only
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths, tile_ids))

        def map_fn(rgb_path, label_path, tile_id):
            return parse_tile(rgb_path, label_path, tile_id, split, augment, tile_size)

    elif input_type == 'rgb_elev':
        # Dataset for RGB and Elevation
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, elev_paths, label_paths, tile_ids))

        def map_fn(rgb_path, elev_path, label_path, tile_id):
            return parse_elevation(rgb_path, elev_path, label_path, tile_id, split, augment, tile_size)

    else:
        raise ValueError(f"Unsupported input_type: {input_type}")

    # Map the parsing function across the dataset elements
    dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        # Shuffle the dataset with a buffer size equal to the number of images
        dataset = dataset.shuffle(buffer_size=len(image_paths), reshuffle_each_iteration=True)

    # Batch the dataset
    dataset = dataset.batch(batch_size)
    # Prefetch data to overlap data preprocessing and model execution
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

# -------------------------- EXAMPLE USAGE (Optional) --------------------------
if __name__ == "__main__":
    # Example DataFrame and directories (replace with actual paths)
    df = pd.DataFrame({'tile_id': [f"tile_{i:03d}" for i in range(100)]})
    dataset = build_tf_dataset(
        df,
        image_dir="./data/images",
        elev_dir="./data/elevation",
        label_dir="./data/labels",
        input_type='rgb_elev',
        batch_size=32,
        split='train',
        augment=True,
        shuffle=True
    )
    # Inspect dataset (optional)
    for inputs, labels in dataset.take(1):
        print(f"Input shape: {inputs.shape}, Label shape: {labels.shape}")