In [None]:
'''
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import random

# Define color-to-class mapping for decoding labels
COLOR_TO_CLASS = {
    (230, 25, 75): 0,      # Building
    (145, 30, 180): 1,     # Clutter
    (60, 180, 75): 2,      # Vegetation
    (245, 130, 48): 3,     # Water
    (255, 255, 255): 4,    # Background
    (0, 130, 200): 5,      # Car
}

CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
NUM_CLASSES = len(CLASS_NAMES) # Dynamically get number of classes
WATER_CLASS_ID = 3 # Define the class ID for 'Water'

def decode_coloured_label(label_rgb):
    """
    Decodes an RGB label image into a single-channel integer class label mask.

    Args:
        label_rgb (tf.Tensor): A 3-channel RGB label image tensor (H, W, 3).

    Returns:
        tf.Tensor: A single-channel integer class label mask tensor (H, W).
    """
    label_rgb = tf.cast(label_rgb, tf.uint8)
    flat = tf.reshape(label_rgb, [-1, 3])
    keys = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    values = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)
    
    # Check if each pixel's RGB value matches any of the defined keys
    match = tf.reduce_all(tf.equal(tf.expand_dims(flat, 1), keys), axis=2)
    # Get the index of the matching key, which corresponds to the class ID
    # tf.argmax will return 0 if no match, so ensure all colors are mapped
    indices = tf.argmax(tf.cast(match, tf.int32), axis=1)
    return tf.reshape(indices, tf.shape(label_rgb)[:2])

def _load_npy(path):
    """
    Loads a .npy file using numpy. This needs to be wrapped in tf.py_function
    for use in a TensorFlow dataset.

    Args:
        path (tf.Tensor): A TensorFlow string tensor containing the file path.

    Returns:
        np.ndarray: The loaded NumPy array.
    """
    if isinstance(path, tf.Tensor):
        path = path.numpy().decode("utf-8")
    return np.load(path).astype(np.float32)

def load_image_paths(df, image_dir, elev_dir, slope_dir, label_dir):
    """
    Constructs full file paths for images, elevation, slope, and labels
    based on a DataFrame of tile IDs and directory paths.

    Args:
        df (pd.DataFrame): DataFrame containing "tile_id" column.
        image_dir (str): Base directory for RGB images.
        elev_dir (str): Base directory for elevation .npy files.
        slope_dir (str): Base directory for slope .npy files.
        label_dir (str): Base directory for label images.

    Returns:
        tuple: Lists of image_paths, elev_paths, slope_paths, label_paths, and tile_ids.
    """
    tile_ids = df["tile_id"].tolist()
    image_paths = [os.path.join(image_dir, f"{tid}-ortho.png") for tid in tile_ids]
    elev_paths = [os.path.join(elev_dir, f"{tid}-elev.npy") for tid in tile_ids]
    slope_paths = [os.path.join(slope_dir, f"{tid}-slope.npy") for tid in tile_ids]
    label_paths = [os.path.join(label_dir, f"{tid}-label.png") for tid in tile_ids]
    return image_paths, elev_paths, slope_paths, label_paths, tile_ids

def augment_rgb_label(rgb, label):
    """
    Applies geometric and color augmentations to RGB images and their labels.
    Color augmentations are only applied to RGB.

    Args:
        rgb (tf.Tensor): RGB image tensor.
        label (tf.Tensor): Label mask tensor (integer class IDs).

    Returns:
        tuple: Augmented rgb and label tensors.
    """
    label = tf.expand_dims(label, axis=-1) # Add channel dimension for label for consistent ops

    # Geometric augmentations (applied to both RGB and label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        label = tf.image.flip_left_right(label)

    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        label = tf.image.flip_up_down(label)

    if tf.random.uniform([]) > 0.0:
        k = tf.random.uniform([], 1, 4, dtype=tf.int32) # Randomly rotate 90, 180, or 270 degrees
        rgb = tf.image.rot90(rgb, k)
        label = tf.image.rot90(label, k)

    # Colour augmentations (only applied to RGB)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_brightness(rgb, max_delta=0.016)

    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_contrast(rgb, lower=0.98, upper=1.03)

    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_saturation(rgb, lower=0.98, upper=1.03)

    if tf.random.uniform([]) > 0.9:
        rgb = tf.image.random_hue(rgb, max_delta=0.016)

    label = tf.squeeze(label, axis=-1) # Remove channel dimension from label
    return rgb, label

def augment_all(rgb, elev, slope, label):
    """
    Applies geometric and color augmentations to RGB, elevation, slope, and label.
    Color augmentations are only applied to RGB.

    Args:
        rgb (tf.Tensor): RGB image tensor.
        elev (tf.Tensor): Elevation mask tensor.
        slope (tf.Tensor): Slope mask tensor.
        label (tf.Tensor): Label mask tensor (integer class IDs).

    Returns:
        tuple: Augmented rgb, elev, slope, and label tensors.
    """
    label = tf.expand_dims(label, axis=-1) # Add channel dimension for label

    # Geometric augmentations (applied to all inputs)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        elev = tf.image.flip_left_right(elev)
        slope = tf.image.flip_left_right(slope)
        label = tf.image.flip_left_right(label)

    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        elev = tf.image.flip_up_down(elev)
        slope = tf.image.flip_up_down(slope)
        label = tf.image.flip_up_down(label)

    if tf.random.uniform([]) > 0.0: # Original code had 0.0, which means always rotate.
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        elev = tf.image.rot90(elev, k)
        slope = tf.image.rot90(slope, k)
        label = tf.image.rot90(label, k)
    
    # Colour augmentations (only applied to RGB within augment_all)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_brightness(rgb, max_delta=0.016)

    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_contrast(rgb, lower=0.98, upper=1.03)

    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_saturation(rgb, lower=0.98, upper=1.03)

    if tf.random.uniform([]) > 0.9:
        rgb = tf.image.random_hue(rgb, max_delta=0.016)

    label = tf.squeeze(label, axis=-1) # Remove channel dimension from label
    return rgb, elev, slope, label



    
def _apply_water_cutmix_rgb_only(rgb, label, tile_size, water_class_id, cutmix_prob=0.9, patch_size_range=(64, 64)):
    """
    Applies a 'self-cutmix' augmentation for the water class, operating only on RGB and Label.
    It finds a water patch within the current image and pastes it into another random location
    within the same image, updating both the RGB image and its label.

    Args:
        rgb (tf.Tensor): RGB image tensor.
        label (tf.Tensor): Label mask tensor (integer class IDs).
        tile_size (int): The target size of the image tiles (e.g., 256).
        water_class_id (int): The integer ID for the water class.
        cutmix_prob (float): Probability of applying this augmentation.
        patch_size_range (tuple): (min_size, max_size) for the square patch.

    Returns:
        tuple: Augmented rgb and label tensors.
    """
    # Initialize output variables to the original inputs.
    # These will be updated only if cutmix is successfully applied.
    rgb_out = rgb
    label_out = label

    # Outer probability check for applying cutmix
    if tf.random.uniform([]) < cutmix_prob:
        water_mask = tf.cast(tf.equal(label, water_class_id), tf.float32)
        water_coords = tf.cast(tf.where(water_mask > 0), tf.int32)
        
        # Ensure the upper bound for selecting a random water coordinate is always at least 1.
        # tf.shape(water_coords)[0] gives the number of water pixels. If 0, maxval would be 0.
        num_water_pixels = tf.shape(water_coords)[0]
        idx_max = tf.maximum(1, num_water_pixels) 

        # Select a random index into water_coords
        # If num_water_pixels is 0, idx will be 0, which is handled below.
        idx = tf.random.uniform([], 0, idx_max, dtype=tf.int32)

        # Get center from water_coords, conditionally to avoid indexing into an empty tensor
        center_y = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][0], lambda: tf.constant(0, dtype=tf.int32))
        center_x = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][1], lambda: tf.constant(0, dtype=tf.int32))

        patch_size = tf.random.uniform([], patch_size_range[0], patch_size_range[1] + 1, dtype=tf.int32)

        y1_src = tf.maximum(0, center_y - patch_size // 2)
        x1_src = tf.maximum(0, center_x - patch_size // 2)
        y2_src = tf.minimum(tile_size, y1_src + patch_size)
        x2_src = tf.minimum(tile_size, x1_src + patch_size)

        y1_src = y2_src - patch_size
        x1_src = x2_src - patch_size
        y1_src = tf.maximum(0, y1_src)
        x1_src = tf.maximum(0, x1_src)

        patch_height = y2_src - y1_src
        patch_width = x2_src - x1_src

        min_pixels_for_patch = (patch_size_range[0] * patch_size_range[0]) // 4

        # Combined check for initial early exit conditions
        if num_water_pixels < min_pixels_for_patch or \
           patch_height <= 0 or patch_width <= 0:
            pass # No augmentation, rgb_out and label_out remain original
        else:
            # If augmentation conditions are met, proceed with cutting and pasting
            rgb_patch = rgb[y1_src:y2_src, x1_src:x2_src, :]
            label_patch = label[y1_src:y2_src, x1_src:x2_src]

            patch_water_mask = tf.cast(tf.equal(label_patch, water_class_id), tf.float32)
            required_patch_pixels = tf.cast(patch_height * patch_width / 4, tf.float32)
            patch_water_mask_sum = tf.reduce_sum(patch_water_mask)

            # Check if the cropped patch actually contains a significant amount of water
            if patch_water_mask_sum < required_patch_pixels:
                pass # Patch not water-rich enough, rgb_out and label_out remain original
            else:
                # Proceed with actual pasting if all checks pass
                # Ensure maxval for random_uniform is always > minval
                dest_y_max = tf.maximum(1, tile_size - patch_height + 1)
                dest_x_max = tf.maximum(1, tile_size - patch_width + 1)

                dest_y = tf.random.uniform([], 0, dest_y_max, dtype=tf.int32)
                dest_x = tf.random.uniform([], 0, dest_x_max, dtype=tf.int32)

                indices_rgb = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(3), indexing='ij'
                ), axis=-1)
                rgb_out = tf.tensor_scatter_nd_update(rgb, tf.reshape(indices_rgb, [-1, 3]), tf.reshape(rgb_patch, [-1]))

                indices_label = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width), indexing='ij'
                ), axis=-1)
                label_out = tf.tensor_scatter_nd_update(label, tf.reshape(indices_label, [-1, 2]), tf.reshape(label_patch, [-1]))
    return rgb_out, label_out

def _apply_water_cutmix_rgb_elev(rgb, elev, slope, label, tile_size, water_class_id, cutmix_prob=0.9, patch_size_range=(64, 64)):
    """
    Applies a 'self-cutmix' augmentation for the water class, operating on RGB, Elevation, Slope, and Label.
    It finds a water patch within the current image and pastes it into another random location
    within the same image, updating all modalities accordingly.

    Args:
        rgb (tf.Tensor): RGB image tensor.
        elev (tf.Tensor): Elevation mask tensor.
        slope (tf.Tensor): Slope mask tensor.
        label (tf.Tensor): Label mask tensor (integer class IDs).
        tile_size (int): The target size of the image tiles (e.g., 256).
        water_class_id (int): The integer ID for the water class.
        cutmix_prob (float): Probability of applying this augmentation.
        patch_size_range (tuple): (min_size, max_size) for the square patch.

    Returns:
        tuple: Augmented rgb, elev, slope, and label tensors.
    """
    # Initialize output variables to the original inputs.
    rgb_out = rgb
    elev_out = elev
    slope_out = slope
    label_out = label

    if tf.random.uniform([]) < cutmix_prob:
        water_mask = tf.cast(tf.equal(label, water_class_id), tf.float32)
        water_coords = tf.cast(tf.where(water_mask > 0), tf.int32)
        
        num_water_pixels = tf.shape(water_coords)[0]
        idx_max = tf.maximum(1, num_water_pixels) 
        idx = tf.random.uniform([], 0, idx_max, dtype=tf.int32)

        center_y = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][0], lambda: tf.constant(0, dtype=tf.int32))
        center_x = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][1], lambda: tf.constant(0, dtype=tf.int32))

        patch_size = tf.random.uniform([], patch_size_range[0], patch_size_range[1] + 1, dtype=tf.int32)

        y1_src = tf.maximum(0, center_y - patch_size // 2)
        x1_src = tf.maximum(0, center_x - patch_size // 2)
        y2_src = tf.minimum(tile_size, y1_src + patch_size)
        x2_src = tf.minimum(tile_size, x1_src + patch_size)

        y1_src = y2_src - patch_size
        x1_src = x2_src - patch_size
        y1_src = tf.maximum(0, y1_src)
        x1_src = tf.maximum(0, x1_src)

        patch_height = y2_src - y1_src
        patch_width = x2_src - x1_src

        min_pixels_for_patch = (patch_size_range[0] * patch_size_range[0]) // 4

        if num_water_pixels < min_pixels_for_patch or \
           patch_height <= 0 or patch_width <= 0:
            pass
        else:
            rgb_patch = rgb[y1_src:y2_src, x1_src:x2_src, :]
            label_patch = label[y1_src:y2_src, x1_src:x2_src]

            patch_water_mask = tf.cast(tf.equal(label_patch, water_class_id), tf.float32)
            required_patch_pixels = tf.cast(patch_height * patch_width / 4, tf.float32)
            patch_water_mask_sum = tf.reduce_sum(patch_water_mask)

            if patch_water_mask_sum < required_patch_pixels:
                pass
            else:
                elev_patch = elev[y1_src:y2_src, x1_src:x2_src, :]
                slope_patch = slope[y1_src:y2_src, x1_src:x2_src, :]

                # Ensure maxval for random_uniform is always > minval
                dest_y_max = tf.maximum(1, tile_size - patch_height + 1)
                dest_x_max = tf.maximum(1, tile_size - patch_width + 1)

                dest_y = tf.random.uniform([], 0, dest_y_max, dtype=tf.int32)
                dest_x = tf.random.uniform([], 0, dest_x_max, dtype=tf.int32)

                indices_rgb = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(3), indexing='ij'
                ), axis=-1)
                rgb_out = tf.tensor_scatter_nd_update(rgb, tf.reshape(indices_rgb, [-1, 3]), tf.reshape(rgb_patch, [-1]))

                indices_label = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width), indexing='ij'
                ), axis=-1)
                label_out = tf.tensor_scatter_nd_update(label, tf.reshape(indices_label, [-1, 2]), tf.reshape(label_patch, [-1]))

                indices_elev = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(1), indexing='ij'
                ), axis=-1)
                elev_out = tf.tensor_scatter_nd_update(elev, tf.reshape(indices_elev, [-1, 3]), tf.reshape(elev_patch, [-1]))

                indices_slope = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(1), indexing='ij'
                ), axis=-1)
                slope_out = tf.tensor_scatter_nd_update(slope, tf.reshape(indices_slope, [-1, 3]), tf.reshape(slope_patch, [-1]))
    
    return rgb_out, elev_out, slope_out, label_out



def apply_cutmix_rgb_batch(rgb_batch, label_batch, tile_size, water_class_id=3, inter_prob=0.3, patch_size_range=(64, 64)):
    """
    Applies intra-chip or inter-chip CutMix for water class in RGB-only setting.
    """
    rgb_batch_out = []
    label_batch_out = []

    batch_size = rgb_batch.shape[0]

    for i in range(batch_size):
        rgb = rgb_batch[i]
        label = label_batch[i]
        water_mask = tf.equal(label, water_class_id)

        num_water = tf.reduce_sum(tf.cast(water_mask, tf.int32))

        # If the chip has water, apply standard CutMix
        if num_water > 0:
            rgb_aug, label_aug = _apply_water_cutmix_rgb_only(rgb, label, tile_size, water_class_id, cutmix_prob=0.9, patch_size_range=patch_size_range)
        else:
            # Try inter-chip CutMix with 30% chance
            if tf.random.uniform([]) < inter_prob:
                # Find donor chips with water
                donors = []
                for j in range(batch_size):
                    if j == i:
                        continue
                    if tf.reduce_any(tf.equal(label_batch[j], water_class_id)):
                        donors.append(j)
                if donors:
                    donor_idx = tf.random.shuffle(donors)[0]
                    donor_rgb = rgb_batch[donor_idx]
                    donor_label = label_batch[donor_idx]
                    rgb_aug, label_aug = _apply_water_cutmix_rgb_only(
                        donor_rgb, donor_label, tile_size, water_class_id, cutmix_prob=1.0, patch_size_range=patch_size_range
                    )
                    # paste into this chip
                    rgb_aug, label_aug = _apply_water_cutmix_rgb_only(rgb_aug, label_aug, tile_size, water_class_id, cutmix_prob=1.0, patch_size_range=patch_size_range)
                else:
                    rgb_aug, label_aug = rgb, label
            else:
                rgb_aug, label_aug = rgb, label

        rgb_batch_out.append(rgb_aug)
        label_batch_out.append(label_aug)

    return tf.stack(rgb_batch_out), tf.stack(label_batch_out)


def apply_cutmix_rgb_elev_batch(rgb_batch, elev_batch, slope_batch, label_batch, tile_size, water_class_id=3, inter_prob=0.3, patch_size_range=(64, 64)):
    """
    Applies intra- or inter-chip CutMix for water class in RGB+elev+slope setting.
    """
    rgb_out, elev_out, slope_out, label_out = [], [], [], []
    batch_size = rgb_batch.shape[0]

    for i in range(batch_size):
        rgb = rgb_batch[i]
        elev = elev_batch[i]
        slope = slope_batch[i]
        label = label_batch[i]

        water_mask = tf.equal(label, water_class_id)
        num_water = tf.reduce_sum(tf.cast(water_mask, tf.int32))

        if num_water > 0:
            rgb_aug, elev_aug, slope_aug, label_aug = _apply_water_cutmix_rgb_elev(rgb, elev, slope, label, tile_size, water_class_id, 0.9, patch_size_range)
        else:
            if tf.random.uniform([]) < inter_prob:
                donors = []
                for j in range(batch_size):
                    if j == i:
                        continue
                    if tf.reduce_any(tf.equal(label_batch[j], water_class_id)):
                        donors.append(j)
                if donors:
                    donor_idx = tf.random.shuffle(donors)[0]
                    donor_rgb = rgb_batch[donor_idx]
                    donor_elev = elev_batch[donor_idx]
                    donor_slope = slope_batch[donor_idx]
                    donor_label = label_batch[donor_idx]

                    # Paste donor patch into this chip
                    rgb_aug, elev_aug, slope_aug, label_aug = _apply_water_cutmix_rgb_elev(
                        donor_rgb, donor_elev, donor_slope, donor_label, tile_size, water_class_id, 1.0, patch_size_range
                    )
                    rgb_aug, elev_aug, slope_aug, label_aug = _apply_water_cutmix_rgb_elev(
                        rgb_aug, elev_aug, slope_aug, label_aug, tile_size, water_class_id, 1.0, patch_size_range
                    )
                else:
                    rgb_aug, elev_aug, slope_aug, label_aug = rgb, elev, slope, label
            else:
                rgb_aug, elev_aug, slope_aug, label_aug = rgb, elev, slope, label

        rgb_out.append(rgb_aug)
        elev_out.append(elev_aug)
        slope_out.append(slope_aug)
        label_out.append(label_aug)

    return (
        tf.stack(rgb_out),
        tf.stack(elev_out),
        tf.stack(slope_out),
        tf.stack(label_out)
    )


'''


# inter-chip CutMix 
'''
def parse_tile(rgb_path, label_path, tile_id, split='train', augment=False, tile_size=256):
    """
    Loads and optionally augments RGB + label. CutMix is applied later at batch level.
    """
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)

    if split == 'train' and augment:
        rgb, label = augment_rgb_label(rgb, label)

    # No one-hot here — we'll do it after CutMix
    return rgb, label

def parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id,
                    split='train', augment=False, tile_size=256):
    """
    Loads and optionally augments RGB, elevation, slope, and label. CutMix is applied later.
    """
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    elev = tf.py_function(_load_npy, [elev_path], tf.float32)
    slope = tf.py_function(_load_npy, [slope_path], tf.float32)
    elev.set_shape([None, None])
    slope.set_shape([None, None])
    elev = tf.expand_dims(elev, axis=-1)
    slope = tf.expand_dims(slope, axis=-1)

    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    elev = tf.image.resize(elev, [tile_size, tile_size])
    slope = tf.image.resize(slope, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)

    if split == 'train' and augment:
        rgb, elev, slope, label = augment_all(rgb, elev, slope, label)

    return rgb, elev, slope, label


def build_tf_dataset(df, image_dir, elev_dir, slope_dir, label_dir,
                     input_type='rgb', batch_size=32, split='train',
                     augment=False, shuffle=True, tile_size=256):
    """
    Builds a TensorFlow Dataset for semantic segmentation with intra + inter-chip CutMix.
    """

    image_paths, elev_paths, slope_paths, label_paths, tile_ids = load_image_paths(
        df, image_dir, elev_dir, slope_dir, label_dir
    )

    if input_type == 'rgb':
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths, tile_ids))

        def map_fn(rgb_path, label_path, tile_id):
            return parse_tile(rgb_path, label_path, tile_id, split, augment, tile_size)

    elif input_type == 'rgb_elev':
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, elev_paths, slope_paths, label_paths, tile_ids))

        def map_fn(rgb_path, elev_path, slope_path, label_path, tile_id):
            return parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id, split, augment, tile_size)

    else:
        raise ValueError(f"Unsupported input_type: {input_type}")

    dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), reshuffle_each_iteration=True)

    # Batch the dataset
    dataset = dataset.batch(batch_size)

    # Apply intra + inter-chip CutMix at batch level
    if split == 'train' and augment:

        if input_type == 'rgb':
            def cutmix_fn(rgb, label):
                rgb_new, label_new = tf.py_function(
                    func=lambda r, l: apply_cutmix_rgb_batch(r, l, tile_size),
                    inp=[rgb, label],
                    Tout=[tf.float32, tf.int32]
                )
                rgb_new.set_shape(rgb.shape)

                # Fix: reshape to (B, H, W) before one-hot
                label_new = tf.reshape(label_new, [tf.shape(label_new)[0], tile_size, tile_size])
                label_onehot = tf.one_hot(label_new, depth=NUM_CLASSES)
                return rgb_new, label_onehot

            dataset = dataset.map(cutmix_fn, num_parallel_calls=tf.data.AUTOTUNE)

        elif input_type == 'rgb_elev':
            def cutmix_fn(rgb, elev, slope, label):
                rgb_new, elev_new, slope_new, label_new = tf.py_function(
                    func=lambda r, e, s, l: apply_cutmix_rgb_elev_batch(r, e, s, l, tile_size),
                    inp=[rgb, elev, slope, label],
                    Tout=[tf.float32, tf.float32, tf.float32, tf.int32]
                )
                rgb_new.set_shape(rgb.shape)
                elev_new.set_shape(elev.shape)
                slope_new.set_shape(slope.shape)

                # Fix: reshape to (B, H, W) before one-hot
                label_new = tf.reshape(label_new, [tf.shape(label_new)[0], tile_size, tile_size])
                label_onehot = tf.one_hot(label_new, depth=NUM_CLASSES)

                input_image = tf.concat([rgb_new, elev_new, slope_new], axis=-1)
                return input_image, label_onehot

            dataset = dataset.map(cutmix_fn, num_parallel_calls=tf.data.AUTOTUNE)

    return dataset.prefetch(tf.data.AUTOTUNE)

'''



'''
def build_tf_dataset(df, image_dir, elev_dir, slope_dir, label_dir,
                     input_type='rgb', batch_size=32, split='train',
                     augment=False, shuffle=True, tile_size=256):
    """
    Builds a TensorFlow Dataset for semantic segmentation.

    Args:
        df (pd.DataFrame): DataFrame containing 'tile_id' column.
        image_dir (str): Directory containing RGB images.
        elev_dir (str): Directory containing elevation .npy files.
        slope_dir (str): Directory containing slope .npy files.
        label_dir (str): Directory containing label images.
        input_type (str): 'rgb' or 'rgb_elev' to determine input channels.
        batch_size (int): Number of samples per batch.
        split (str): 'train', 'val', or 'test' (influences augmentation).
        augment (bool): Whether to apply data augmentations.
        shuffle (bool): Whether to shuffle the dataset.
        tile_size (int): Desired height and width of the output tiles.

    Returns:
        tf.data.Dataset: A TensorFlow dataset ready for training/evaluation.
    """
    image_paths, elev_paths, slope_paths, label_paths, tile_ids = load_image_paths(
        df, image_dir, elev_dir, slope_dir, label_dir
    )

    if input_type == 'rgb':
        # Dataset for RGB only
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths, tile_ids))

        def map_fn(rgb_path, label_path, tile_id):
            return parse_tile(rgb_path, label_path, tile_id, split, augment, tile_size)

    elif input_type == 'rgb_elev':
        # Dataset for RGB, Elevation, and Slope
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, elev_paths, slope_paths, label_paths, tile_ids))

        def map_fn(rgb_path, elev_path, slope_path, label_path, tile_id):
            return parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id, split, augment, tile_size)

    else:
        raise ValueError(f"Unsupported input_type: {input_type}")

    # Map the parsing function across the dataset elements
    dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        # Shuffle the dataset with a buffer size equal to the number of images
        dataset = dataset.shuffle(buffer_size=len(image_paths), reshuffle_each_iteration=True)

    # Batch the dataset
    dataset = dataset.batch(batch_size)
    # Prefetch data to overlap data preprocessing and model execution
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

    


def parse_tile(rgb_path, label_path, tile_id, split='train', augment=False, tile_size=256):
    """
    Parses a single RGB image and its label. Applies augmentations if specified.

    Args:
        rgb_path (tf.Tensor): Path to the RGB image.
        label_path (tf.Tensor): Path to the label image.
        tile_id (tf.Tensor): The ID of the tile (currently unused in this function).
        split (str): 'train', 'val', or 'test' (influences augmentation).
        augment (bool): Whether to apply data augmentations.
        tile_size (int): Desired height and width of the output tiles.

    Returns:
        tuple: Processed rgb and label tensors.
    """
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    # Resize first to ensure fixed dimensions for augmentations
    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    # Use 'nearest' for label resizing to preserve class integrity
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size]) # Remove channel dimension
    label = tf.cast(label, tf.int32) # Ensure label is integer type

    if split == 'train' and augment:
        # Apply standard geometric and color augmentations
        rgb, label = augment_rgb_label(rgb, label)
        # Call the RGB-only cutmix function here
        rgb, label = _apply_water_cutmix_rgb_only(rgb, label, tile_size, WATER_CLASS_ID)

    # One-hot encode the label for model training
    label = tf.one_hot(label, depth=NUM_CLASSES)
    return rgb, label

def parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id,
                    split='train', augment=False, tile_size=256):
    """
    Parses RGB, elevation, slope, and label data for a single tile.
    Applies augmentations if specified.

    Args:
        rgb_path (tf.Tensor): Path to the RGB image.
        elev_path (tf.Tensor): Path to the elevation .npy file.
        slope_path (tf.Tensor): Path to the slope .npy file.
        label_path (tf.Tensor): Path to the label image.
        tile_id (tf.Tensor): The ID of the tile (currently unused in this function).
        split (str): 'train', 'val', or 'test' (influences augmentation).
        augment (bool): Whether to apply data augmentations.
        tile_size (int): Desired height and width of the output tiles.

    Returns:
        tuple: Processed input_image (concatenation of rgb, elev, slope) and label tensors.
    """
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    # Load .npy files using tf.py_function
    elev = tf.py_function(_load_npy, [elev_path], tf.float32)
    slope = tf.py_function(_load_npy, [slope_path], tf.float32)
    # Set shapes for the loaded npy data, as py_function returns dynamic shapes
    elev.set_shape([None, None]) # Height, Width
    slope.set_shape([None, None]) # Height, Width
    elev = tf.expand_dims(elev, axis=-1) # Add channel dimension
    slope = tf.expand_dims(slope, axis=-1) # Add channel dimension

    # Resize all inputs to the target tile_size.
    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    elev = tf.image.resize(elev, [tile_size, tile_size])
    slope = tf.image.resize(slope, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size]) # Remove channel dimension
    label = tf.cast(label, tf.int32) # Ensure label is integer type

    if split == 'train' and augment:
        # Apply standard geometric and color augmentations
        rgb, elev, slope, label = augment_all(rgb, elev, slope, label)
        # Call the RGB+Elev+Slope cutmix function here
        rgb, elev, slope, label = _apply_water_cutmix_rgb_elev(rgb, elev, slope, label, tile_size, WATER_CLASS_ID)

    # Concatenate RGB, elevation, and slope to form the combined input image
    input_image = tf.concat([rgb, elev, slope], axis=-1)
    # One-hot encode the label for model training
    label = tf.one_hot(label, depth=NUM_CLASSES)
    return input_image, label
'''


In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import random

# -------------------------- CONSTANTS --------------------------
COLOR_TO_CLASS = {
    (230, 25, 75): 0,      # Building
    (145, 30, 180): 1,     # Clutter
    (60, 180, 75): 2,      # Vegetation
    (245, 130, 48): 3,     # Water
    (255, 255, 255): 4,    # Background
    (0, 130, 200): 5,      # Car
}
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
NUM_CLASSES = len(CLASS_NAMES)
WATER_CLASS_ID = 3

# -------------------------- UTILS --------------------------
def decode_coloured_label(label_rgb):
    label_rgb = tf.cast(label_rgb, tf.uint8)
    flat = tf.reshape(label_rgb, [-1, 3])
    keys = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    values = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)
    match = tf.reduce_all(tf.equal(tf.expand_dims(flat, 1), keys), axis=2)
    indices = tf.argmax(tf.cast(match, tf.int32), axis=1)
    return tf.reshape(indices, tf.shape(label_rgb)[:2])

def _load_npy(path):
    path = path.numpy().decode("utf-8") if isinstance(path, tf.Tensor) else path
    return np.load(path).astype(np.float32)

def load_image_paths(df, image_dir, elev_dir, slope_dir, label_dir):
    tile_ids = df["tile_id"].tolist()
    image_paths = [os.path.join(image_dir, f"{tid}-ortho.png") for tid in tile_ids]
    elev_paths = [os.path.join(elev_dir, f"{tid}-elev.npy") for tid in tile_ids]
    slope_paths = [os.path.join(slope_dir, f"{tid}-slope.npy") for tid in tile_ids]
    label_paths = [os.path.join(label_dir, f"{tid}-label.png") for tid in tile_ids]
    return image_paths, elev_paths, slope_paths, label_paths, tile_ids

# -------------------------- AUGMENTATION FUNCS --------------------------
def augment_rgb_label(rgb, label):
    label = tf.expand_dims(label, axis=-1)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        label = tf.image.flip_up_down(label)
    if tf.random.uniform([]) > 0.0:
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        label = tf.image.rot90(label, k)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_brightness(rgb, max_delta=0.016)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_contrast(rgb, lower=0.98, upper=1.03)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_saturation(rgb, lower=0.98, upper=1.03)
    if tf.random.uniform([]) > 0.9:
        rgb = tf.image.random_hue(rgb, max_delta=0.016)
    return rgb, tf.squeeze(label, axis=-1)

def augment_all(rgb, elev, slope, label):
    label = tf.expand_dims(label, axis=-1)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        elev = tf.image.flip_left_right(elev)
        slope = tf.image.flip_left_right(slope)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        elev = tf.image.flip_up_down(elev)
        slope = tf.image.flip_up_down(slope)
        label = tf.image.flip_up_down(label)
    if tf.random.uniform([]) > 0.0:
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        elev = tf.image.rot90(elev, k)
        slope = tf.image.rot90(slope, k)
        label = tf.image.rot90(label, k)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_brightness(rgb, max_delta=0.016)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_contrast(rgb, lower=0.98, upper=1.03)
    if tf.random.uniform([]) > 0.8:
        rgb = tf.image.random_saturation(rgb, lower=0.98, upper=1.03)
    if tf.random.uniform([]) > 0.9:
        rgb = tf.image.random_hue(rgb, max_delta=0.016)
    return rgb, elev, slope, tf.squeeze(label, axis=-1)

# -------------------------- CUTMIX FUNCS --------------------------
  
def _apply_water_cutmix_rgb_only(rgb, label, tile_size, water_class_id, cutmix_prob=0.9, patch_size_range=(64, 64)):
    """
    Applies a 'self-cutmix' augmentation for the water class, operating only on RGB and Label.
    It finds a water patch within the current image and pastes it into another random location
    within the same image, updating both the RGB image and its label.

    Args:
        rgb (tf.Tensor): RGB image tensor.
        label (tf.Tensor): Label mask tensor (integer class IDs).
        tile_size (int): The target size of the image tiles (e.g., 256).
        water_class_id (int): The integer ID for the water class.
        cutmix_prob (float): Probability of applying this augmentation.
        patch_size_range (tuple): (min_size, max_size) for the square patch.

    Returns:
        tuple: Augmented rgb and label tensors.
    """
    # Initialize output variables to the original inputs.
    # These will be updated only if cutmix is successfully applied.
    rgb_out = rgb
    label_out = label

    # Outer probability check for applying cutmix
    if tf.random.uniform([]) < cutmix_prob:
        water_mask = tf.cast(tf.equal(label, water_class_id), tf.float32)
        water_coords = tf.cast(tf.where(water_mask > 0), tf.int32)
        
        # Ensure the upper bound for selecting a random water coordinate is always at least 1.
        # tf.shape(water_coords)[0] gives the number of water pixels. If 0, maxval would be 0.
        num_water_pixels = tf.shape(water_coords)[0]
        idx_max = tf.maximum(1, num_water_pixels) 

        # Select a random index into water_coords
        # If num_water_pixels is 0, idx will be 0, which is handled below.
        idx = tf.random.uniform([], 0, idx_max, dtype=tf.int32)

        # Get center from water_coords, conditionally to avoid indexing into an empty tensor
        center_y = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][0], lambda: tf.constant(0, dtype=tf.int32))
        center_x = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][1], lambda: tf.constant(0, dtype=tf.int32))

        patch_size = tf.random.uniform([], patch_size_range[0], patch_size_range[1] + 1, dtype=tf.int32)

        y1_src = tf.maximum(0, center_y - patch_size // 2)
        x1_src = tf.maximum(0, center_x - patch_size // 2)
        y2_src = tf.minimum(tile_size, y1_src + patch_size)
        x2_src = tf.minimum(tile_size, x1_src + patch_size)

        y1_src = y2_src - patch_size
        x1_src = x2_src - patch_size
        y1_src = tf.maximum(0, y1_src)
        x1_src = tf.maximum(0, x1_src)

        patch_height = y2_src - y1_src
        patch_width = x2_src - x1_src

        min_pixels_for_patch = (patch_size_range[0] * patch_size_range[0]) // 4

        # Combined check for initial early exit conditions
        if num_water_pixels < min_pixels_for_patch or \
           patch_height <= 0 or patch_width <= 0:
            pass # No augmentation, rgb_out and label_out remain original
        else:
            # If augmentation conditions are met, proceed with cutting and pasting
            rgb_patch = rgb[y1_src:y2_src, x1_src:x2_src, :]
            label_patch = label[y1_src:y2_src, x1_src:x2_src]

            patch_water_mask = tf.cast(tf.equal(label_patch, water_class_id), tf.float32)
            required_patch_pixels = tf.cast(patch_height * patch_width / 4, tf.float32)
            patch_water_mask_sum = tf.reduce_sum(patch_water_mask)

            # Check if the cropped patch actually contains a significant amount of water
            if patch_water_mask_sum < required_patch_pixels:
                pass # Patch not water-rich enough, rgb_out and label_out remain original
            else:
                # Proceed with actual pasting if all checks pass
                # Ensure maxval for random_uniform is always > minval
                dest_y_max = tf.maximum(1, tile_size - patch_height + 1)
                dest_x_max = tf.maximum(1, tile_size - patch_width + 1)

                dest_y = tf.random.uniform([], 0, dest_y_max, dtype=tf.int32)
                dest_x = tf.random.uniform([], 0, dest_x_max, dtype=tf.int32)

                indices_rgb = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(3), indexing='ij'
                ), axis=-1)
                rgb_out = tf.tensor_scatter_nd_update(rgb, tf.reshape(indices_rgb, [-1, 3]), tf.reshape(rgb_patch, [-1]))

                indices_label = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width), indexing='ij'
                ), axis=-1)
                label_out = tf.tensor_scatter_nd_update(label, tf.reshape(indices_label, [-1, 2]), tf.reshape(label_patch, [-1]))
    return rgb_out, label_out

def _apply_water_cutmix_rgb_elev(rgb, elev, slope, label, tile_size, water_class_id, cutmix_prob=0.9, patch_size_range=(64, 64)):
    """
    Applies a 'self-cutmix' augmentation for the water class, operating on RGB, Elevation, Slope, and Label.
    It finds a water patch within the current image and pastes it into another random location
    within the same image, updating all modalities accordingly.

    Args:
        rgb (tf.Tensor): RGB image tensor.
        elev (tf.Tensor): Elevation mask tensor.
        slope (tf.Tensor): Slope mask tensor.
        label (tf.Tensor): Label mask tensor (integer class IDs).
        tile_size (int): The target size of the image tiles (e.g., 256).
        water_class_id (int): The integer ID for the water class.
        cutmix_prob (float): Probability of applying this augmentation.
        patch_size_range (tuple): (min_size, max_size) for the square patch.

    Returns:
        tuple: Augmented rgb, elev, slope, and label tensors.
    """
    # Initialize output variables to the original inputs.
    rgb_out = rgb
    elev_out = elev
    slope_out = slope
    label_out = label

    if tf.random.uniform([]) < cutmix_prob:
        water_mask = tf.cast(tf.equal(label, water_class_id), tf.float32)
        water_coords = tf.cast(tf.where(water_mask > 0), tf.int32)
        
        num_water_pixels = tf.shape(water_coords)[0]
        idx_max = tf.maximum(1, num_water_pixels) 
        idx = tf.random.uniform([], 0, idx_max, dtype=tf.int32)

        center_y = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][0], lambda: tf.constant(0, dtype=tf.int32))
        center_x = tf.cond(num_water_pixels > 0, lambda: water_coords[idx][1], lambda: tf.constant(0, dtype=tf.int32))

        patch_size = tf.random.uniform([], patch_size_range[0], patch_size_range[1] + 1, dtype=tf.int32)

        y1_src = tf.maximum(0, center_y - patch_size // 2)
        x1_src = tf.maximum(0, center_x - patch_size // 2)
        y2_src = tf.minimum(tile_size, y1_src + patch_size)
        x2_src = tf.minimum(tile_size, x1_src + patch_size)

        y1_src = y2_src - patch_size
        x1_src = x2_src - patch_size
        y1_src = tf.maximum(0, y1_src)
        x1_src = tf.maximum(0, x1_src)

        patch_height = y2_src - y1_src
        patch_width = x2_src - x1_src

        min_pixels_for_patch = (patch_size_range[0] * patch_size_range[0]) // 4

        if num_water_pixels < min_pixels_for_patch or \
           patch_height <= 0 or patch_width <= 0:
            pass
        else:
            rgb_patch = rgb[y1_src:y2_src, x1_src:x2_src, :]
            label_patch = label[y1_src:y2_src, x1_src:x2_src]

            patch_water_mask = tf.cast(tf.equal(label_patch, water_class_id), tf.float32)
            required_patch_pixels = tf.cast(patch_height * patch_width / 4, tf.float32)
            patch_water_mask_sum = tf.reduce_sum(patch_water_mask)

            if patch_water_mask_sum < required_patch_pixels:
                pass
            else:
                elev_patch = elev[y1_src:y2_src, x1_src:x2_src, :]
                slope_patch = slope[y1_src:y2_src, x1_src:x2_src, :]

                # Ensure maxval for random_uniform is always > minval
                dest_y_max = tf.maximum(1, tile_size - patch_height + 1)
                dest_x_max = tf.maximum(1, tile_size - patch_width + 1)

                dest_y = tf.random.uniform([], 0, dest_y_max, dtype=tf.int32)
                dest_x = tf.random.uniform([], 0, dest_x_max, dtype=tf.int32)

                indices_rgb = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(3), indexing='ij'
                ), axis=-1)
                rgb_out = tf.tensor_scatter_nd_update(rgb, tf.reshape(indices_rgb, [-1, 3]), tf.reshape(rgb_patch, [-1]))

                indices_label = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width), indexing='ij'
                ), axis=-1)
                label_out = tf.tensor_scatter_nd_update(label, tf.reshape(indices_label, [-1, 2]), tf.reshape(label_patch, [-1]))

                indices_elev = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(1), indexing='ij'
                ), axis=-1)
                elev_out = tf.tensor_scatter_nd_update(elev, tf.reshape(indices_elev, [-1, 3]), tf.reshape(elev_patch, [-1]))

                indices_slope = tf.stack(tf.meshgrid(
                    tf.range(dest_y, dest_y + patch_height),
                    tf.range(dest_x, dest_x + patch_width),
                    tf.range(1), indexing='ij'
                ), axis=-1)
                slope_out = tf.tensor_scatter_nd_update(slope, tf.reshape(indices_slope, [-1, 3]), tf.reshape(slope_patch, [-1]))
    
    return rgb_out, elev_out, slope_out, label_out






def apply_cutmix_rgb_batch(rgb_batch, label_batch, tile_size, water_class_id=3, inter_prob=0.3, patch_size_range=(64, 64)):
    """
    Applies intra-chip or inter-chip CutMix for water class in RGB-only setting.
    """
    rgb_batch = rgb_batch.numpy()
    label_batch = label_batch.numpy()

    rgb_batch_out = []
    label_batch_out = []

    batch_size = rgb_batch.shape[0]

    for i in range(batch_size):
        rgb = rgb_batch[i]
        label = label_batch[i]
        water_mask = label == water_class_id
        num_water = np.sum(water_mask)

        if num_water > 0:
            rgb_aug, label_aug = _apply_water_cutmix_rgb_only(tf.convert_to_tensor(rgb), tf.convert_to_tensor(label),
                                                              tile_size, water_class_id, 0.9, patch_size_range)
        else:
            if np.random.rand() < inter_prob:
                donors = [j for j in range(batch_size) if j != i and np.any(label_batch[j] == water_class_id)]
                if donors:
                    donor_idx = random.choice(donors)
                    donor_rgb = rgb_batch[donor_idx]
                    donor_label = label_batch[donor_idx]

                    # First copy from donor
                    rgb_aug, label_aug = _apply_water_cutmix_rgb_only(
                        tf.convert_to_tensor(donor_rgb), tf.convert_to_tensor(donor_label),
                        tile_size, water_class_id, 1.0, patch_size_range
                    )

                    # Then paste into original chip
                    rgb_aug, label_aug = _apply_water_cutmix_rgb_only(
                        tf.convert_to_tensor(rgb_aug), tf.convert_to_tensor(label_aug),
                        tile_size, water_class_id, 1.0, patch_size_range
                    )
                else:
                    rgb_aug, label_aug = tf.convert_to_tensor(rgb), tf.convert_to_tensor(label)
            else:
                rgb_aug, label_aug = tf.convert_to_tensor(rgb), tf.convert_to_tensor(label)

        rgb_batch_out.append(rgb_aug)
        label_batch_out.append(label_aug)

    return tf.stack(rgb_batch_out), tf.stack(label_batch_out)


def apply_cutmix_rgb_elev_batch(rgb_batch, elev_batch, slope_batch, label_batch, tile_size, water_class_id=3, inter_prob=0.3, patch_size_range=(64, 64)):
    """
    Applies intra- or inter-chip CutMix for water class in RGB+elev+slope setting.
    """
    rgb_batch = rgb_batch.numpy()
    elev_batch = elev_batch.numpy()
    slope_batch = slope_batch.numpy()
    label_batch = label_batch.numpy()

    rgb_out, elev_out, slope_out, label_out = [], [], [], []
    batch_size = rgb_batch.shape[0]

    for i in range(batch_size):
        rgb = rgb_batch[i]
        elev = elev_batch[i]
        slope = slope_batch[i]
        label = label_batch[i]

        water_mask = label == water_class_id
        num_water = np.sum(water_mask)

        if num_water > 0:
            rgb_aug, elev_aug, slope_aug, label_aug = _apply_water_cutmix_rgb_elev(
                tf.convert_to_tensor(rgb), tf.convert_to_tensor(elev), tf.convert_to_tensor(slope),
                tf.convert_to_tensor(label), tile_size, water_class_id, 0.9, patch_size_range
            )
        else:
            if np.random.rand() < inter_prob:
                donors = [j for j in range(batch_size) if j != i and np.any(label_batch[j] == water_class_id)]
                if donors:
                    donor_idx = random.choice(donors)
                    donor_rgb = rgb_batch[donor_idx]
                    donor_elev = elev_batch[donor_idx]
                    donor_slope = slope_batch[donor_idx]
                    donor_label = label_batch[donor_idx]

                    rgb_aug, elev_aug, slope_aug, label_aug = _apply_water_cutmix_rgb_elev(
                        tf.convert_to_tensor(donor_rgb), tf.convert_to_tensor(donor_elev),
                        tf.convert_to_tensor(donor_slope), tf.convert_to_tensor(donor_label),
                        tile_size, water_class_id, 1.0, patch_size_range
                    )

                    rgb_aug, elev_aug, slope_aug, label_aug = _apply_water_cutmix_rgb_elev(
                        rgb_aug, elev_aug, slope_aug, label_aug,
                        tile_size, water_class_id, 1.0, patch_size_range
                    )
                else:
                    rgb_aug, elev_aug, slope_aug, label_aug = tf.convert_to_tensor(rgb), tf.convert_to_tensor(elev), tf.convert_to_tensor(slope), tf.convert_to_tensor(label)
            else:
                rgb_aug, elev_aug, slope_aug, label_aug = tf.convert_to_tensor(rgb), tf.convert_to_tensor(elev), tf.convert_to_tensor(slope), tf.convert_to_tensor(label)

        rgb_out.append(rgb_aug)
        elev_out.append(elev_aug)
        slope_out.append(slope_aug)
        label_out.append(label_aug)

    return (
        tf.stack(rgb_out),
        tf.stack(elev_out),
        tf.stack(slope_out),
        tf.stack(label_out)
    )



# -------------------------- PARSE FUNCS --------------------------
def parse_tile(rgb_path, label_path, tile_id, split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)
    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)
    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)
    if split == 'train' and augment:
        rgb, label = augment_rgb_label(rgb, label)
    return rgb, label

def parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id,
                    split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)
    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)
    elev = tf.py_function(_load_npy, [elev_path], tf.float32)
    slope = tf.py_function(_load_npy, [slope_path], tf.float32)
    elev.set_shape([None, None])
    slope.set_shape([None, None])
    elev = tf.expand_dims(elev, axis=-1)
    slope = tf.expand_dims(slope, axis=-1)
    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    elev = tf.image.resize(elev, [tile_size, tile_size])
    slope = tf.image.resize(slope, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)
    if split == 'train' and augment:
        rgb, elev, slope, label = augment_all(rgb, elev, slope, label)
    return rgb, elev, slope, label


# -------------------------- BUILD DATASET --------------------------
def build_tf_dataset(df, image_dir, elev_dir, slope_dir, label_dir,
                     input_type='rgb', batch_size=32, split='train',
                     augment=False, shuffle=True, tile_size=256):
    image_paths, elev_paths, slope_paths, label_paths, tile_ids = load_image_paths(
        df, image_dir, elev_dir, slope_dir, label_dir
    )

    if input_type == 'rgb':
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths, tile_ids))

        def map_fn(rgb_path, label_path, tile_id):
            return parse_tile(rgb_path, label_path, tile_id, split, augment, tile_size)

        def cutmix_fn(rgb, label):
            rgb_new, label_new = tf.py_function(
                func=lambda r, l: apply_cutmix_rgb_batch(r, l, tile_size),
                inp=[rgb, label],
                Tout=[tf.float32, tf.int32]
            )
            rgb_new.set_shape([None, tile_size, tile_size, 3])
            label_new.set_shape([None, tile_size, tile_size])
            return rgb_new, tf.one_hot(label_new, depth=NUM_CLASSES)

    elif input_type == 'rgb_elev':
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, elev_paths, slope_paths, label_paths, tile_ids))

        def map_fn(rgb_path, elev_path, slope_path, label_path, tile_id):
            return parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id, split, augment, tile_size)

        def cutmix_fn(rgb, elev, slope, label):
            rgb_new, elev_new, slope_new, label_new = tf.py_function(
                func=lambda r, e, s, l: apply_cutmix_rgb_elev_batch(r, e, s, l, tile_size),
                inp=[rgb, elev, slope, label],
                Tout=[tf.float32, tf.float32, tf.float32, tf.int32]
            )
            rgb_new.set_shape([None, tile_size, tile_size, 3])
            elev_new.set_shape([None, tile_size, tile_size, 1])
            slope_new.set_shape([None, tile_size, tile_size, 1])
            label_new.set_shape([None, tile_size, tile_size])
            input_image = tf.concat([rgb_new, elev_new, slope_new], axis=-1)
            return input_image, tf.one_hot(label_new, depth=NUM_CLASSES)

    else:
        raise ValueError(f"Unsupported input_type: {input_type}")

    dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), reshuffle_each_iteration=True)

    dataset = dataset.batch(batch_size)

    if split == 'train' and augment:
        dataset = dataset.map(cutmix_fn, num_parallel_calls=tf.data.AUTOTUNE)

    return dataset.prefetch(tf.data.AUTOTUNE)
