In [None]:
# This code should be in a cell in your main Colab notebook

import tensorflow as tf
import os
import glob

# --- Configuration ---
# These are used by the functions below. Ensure they are defined in your notebook's global scope
# before running this cell.
# TILE_SIZE = 512
# BATCH_SIZE = 1
BUFFER_SIZE = 400
# COLOR_TO_CLASS = { ... }

# Set image dimensions based on TILE_SIZE
IMG_WIDTH = TILE_SIZE
IMG_HEIGHT = TILE_SIZE


# --- Helper Function to Decode Labels ---
def decode_coloured_label(label_rgb: tf.Tensor) -> tf.Tensor:
    """Decodes an RGB label image into a single-channel integer class ID mask."""
    label_rgb = tf.cast(label_rgb, tf.uint8)
    flat = tf.reshape(label_rgb, [-1, 3])
    keys = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    values = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)
    match = tf.reduce_all(tf.equal(tf.expand_dims(flat, 1), keys), axis=2)
    indices = tf.argmax(tf.cast(match, tf.int32), axis=1)
    return tf.reshape(indices, tf.shape(label_rgb)[:2])


# --- Core Data Loading and Preprocessing Functions ---

def load_image_pair(label_path_tensor, image_path_tensor):
    """Loads a label and its corresponding real image from their file paths."""
    # This function is designed to be wrapped by tf.py_function, which handles tensor-to-string conversion
    label = tf.io.read_file(label_path_tensor)
    label = tf.image.decode_png(label, channels=3)

    image = tf.io.read_file(image_path_tensor)
    image = tf.image.decode_png(image, channels=3)

    return label, image


def normalize(label, image):
    """Normalizes images to the [-1, 1] range, required for GAN training."""
    label = tf.cast(label, tf.float32)
    image = tf.cast(image, tf.float32)
    label = (label / 127.5) - 1
    image = (image / 127.5) - 1
    return label, image


def random_jitter(label, image):
    """Applies random jitter/augmentation: resizing and random cropping."""
    label = tf.image.resize(label, [IMG_HEIGHT + 30, IMG_WIDTH + 30],
                           method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.image.resize(image, [IMG_HEIGHT + 30, IMG_WIDTH + 30],
                           method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    stacked_image = tf.stack([label, image], axis=0)
    cropped_image = tf.image.random_crop(
        stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]


def random_mirror(label, image):
    """Applies random horizontal flip."""
    if tf.random.uniform(()) > 0.5:
        label = tf.image.flip_left_right(label)
        image = tf.image.flip_left_right(image)
    return label, image


@tf.function()
def parse_image_pair(label_path, image_path, augment=True):
    """The main parsing function for loading and augmenting an image-label pair."""
    label, image = tf.py_function(
        func=load_image_pair, inp=[label_path, image_path], Tout=[tf.uint8, tf.uint8]
    )

    # Ensure shape is set after py_function
    label.set_shape([None, None, 3])
    image.set_shape([None, None, 3])

    if augment:
        label, image = random_jitter(label, image)
        label, image = random_mirror(label, image)

    label, image = normalize(label, image)

    return label, image


# --- Main Function to Build the Dataset ---

def get_gan_dataset(chipped_data_dir, augment=True, shuffle=True):
    """
    Builds and returns a tf.data.Dataset for training a GAN.
    """
    label_dir = os.path.join(chipped_data_dir, 'train', 'labels')
    image_dir = os.path.join(chipped_data_dir, 'train', 'images')

    # Use the more robust path matching logic
    all_label_paths = sorted(glob.glob(os.path.join(label_dir, '*-label.png')))
    all_image_paths = sorted(glob.glob(os.path.join(image_dir, '*-ortho.png')))

    image_map = {os.path.basename(p).replace('-ortho.png', ''): p for p in all_image_paths}

    final_label_paths = []
    final_image_paths = []

    for label_path in all_label_paths:
        tile_id = os.path.basename(label_path).replace('-label.png', '')
        if tile_id in image_map:
            final_label_paths.append(label_path)
            final_image_paths.append(image_map[tile_id])

    if not final_label_paths:
        print(f"Error: No matching image/label pairs found in {chipped_data_dir}.")
        return tf.data.Dataset.from_tensor_slices(([], []))

    dataset = tf.data.Dataset.from_tensor_slices((final_label_paths, final_image_paths))

    if shuffle:
        dataset = dataset.shuffle(BUFFER_SIZE)

    def map_func(label_path, image_path):
        return parse_image_pair(label_path, image_path, augment=augment)

    dataset = dataset.map(map_func, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset