In [None]:
import tensorflow as tf
import os
import glob


# --- Configuration ---
BUFFER_SIZE = 400  # For shuffling
BATCH_SIZE = 1     # Pix2Pix and other GANs often train with a batch size of 1
IMG_WIDTH = 512
IMG_HEIGHT = 512

# --- Helper Function to Decode Labels ---
# (This is the same function from your previous data scripts)
def decode_coloured_label(label_rgb: tf.Tensor) -> tf.Tensor:
    """Decodes an RGB label image into a single-channel integer class ID mask."""
    label_rgb = tf.cast(label_rgb, tf.uint8)
    flat = tf.reshape(label_rgb, [-1, 3])
    # Ensure keys are in the correct order for mapping
    keys = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    values = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)

    match = tf.reduce_all(tf.equal(tf.expand_dims(flat, 1), keys), axis=2)
    indices = tf.argmax(tf.cast(match, tf.int32), axis=1)
    return tf.reshape(indices, tf.shape(label_rgb)[:2])

# --- Core Data Loading and Preprocessing Functions ---

def load_image_pair(label_path_tensor, image_path_tensor):
    """Loads a label and its corresponding real image from their file paths."""
    label_path = tf.compat.as_str_any(label_path_tensor)
    image_path = tf.compat.as_str_any(image_path_tensor)
    
    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)

    return label, image

def normalize(label, image):
    """Normalizes images to the [-1, 1] range."""
    label = tf.cast(label, tf.float32)
    image = tf.cast(image, tf.float32)
    # The label map is not one-hot encoded for Pix2Pix input, but converted to float
    label = (label / 127.5) - 1
    image = (image / 127.5) - 1
    return label, image

def random_jitter(label, image):
    """Applies random jitter/augmentation: resizing and random cropping."""
    # Resize to slightly larger images
    label = tf.image.resize(label, [IMG_HEIGHT + 30, IMG_WIDTH + 30],
                           method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.image.resize(image, [IMG_HEIGHT + 30, IMG_WIDTH + 30],
                           method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # Randomly crop back to the original size
    stacked_image = tf.stack([label, image], axis=0)
    cropped_image = tf.image.random_crop(
        stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

def random_mirror(label, image):
    """Applies random horizontal flip."""
    if tf.random.uniform(()) > 0.5:
        label = tf.image.flip_left_right(label)
        image = tf.image.flip_left_right(image)
    return label, image

@tf.function()
def parse_image_pair(label_path, image_path, augment=True):
    """The main parsing function for loading and augmenting an image-label pair."""
    label, image = load_image_pair(label_path, image_path)
    
    if augment:
        label, image = random_jitter(label, image)
        label, image = random_mirror(label, image)
        
    label, image = normalize(label, image)

    return label, image

# --- Main Function to Build the Dataset ---

def get_gan_dataset(chipped_data_dir, augment=True, shuffle=True):
    """
    Builds and returns a tf.data.Dataset for training a GAN.
    
    Args:
        chipped_data_dir (str): The path to the root of the chipped data (e.g., 'chipped_data_512').
        augment (bool): Whether to apply data augmentation.
        shuffle (bool): Whether to shuffle the dataset.

    Returns:
        A tf.data.Dataset object that yields pairs of (label_map, real_image).
    """
    # Find all the label files, which will be our primary key
    label_dir = os.path.join(chipped_data_dir, 'train', 'labels')
    image_dir = os.path.join(chipped_data_dir, 'train', 'images')
    
    label_paths = sorted(glob.glob(os.path.join(label_dir, '*.png')))
    # Create corresponding image paths, assuming filenames match
    image_paths = [path.replace(label_dir, image_dir).replace('-label.png', '-ortho.png') for path in label_paths]

    # Create the dataset from slices of file paths
    dataset = tf.data.Dataset.from_tensor_slices((label_paths, image_paths))
    
    if shuffle:
        dataset = dataset.shuffle(BUFFER_SIZE)

    # Use a py_function to wrap the parsing function for compatibility
    def wrapped_parse(label_path, image_path):
        return tf.py_function(
            func=lambda lp, ip: parse_image_pair(lp, ip, augment=augment),
            inp=[label_path, image_path],
            Tout=[tf.float32, tf.float32]
        )

    dataset = dataset.map(wrapped_parse, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

# --- Example Usage (for testing the script directly) ---
'''
if __name__ == '__main__':
    # This block will only run if you execute this script directly
    # e.g., python data_gan.py
    
    # Make sure to create a util.py with TILE_SIZE = 512 etc.
    # or define them here for a quick test.
    if 'TILE_SIZE' not in globals():
        TILE_SIZE = 512
        IMG_WIDTH = TILE_SIZE
        IMG_HEIGHT = TILE_SIZE
        # Dummy color map for direct execution
        COLOR_TO_CLASS = {
            (230, 25, 75): 0, (60, 180, 75): 1, (245, 130, 48): 2,
            (255, 255, 255): 3, (0, 130, 200): 4, (128, 128, 128): 5,
        }

    print("Testing GAN data pipeline...")
    chipped_data_path = 'chipped_data_512' # Assumes this folder is in the same directory
    
    if not os.path.exists(chipped_data_path):
        print(f"Error: Dataset not found at '{chipped_data_path}'. Please run the chipping script first.")
    else:
        gan_dataset = get_gan_dataset(chipped_data_path)
        
        # Fetch and visualize one batch to verify
        for label, image in gan_dataset.take(1):
            print("Successfully fetched one batch.")
            print(f"Label batch shape: {label.shape}")
            print(f"Image batch shape: {image.shape}")
            
            # Denormalize for visualization
            label_vis = (label[0] * 0.5 + 0.5).numpy()
            image_vis = (image[0] * 0.5 + 0.5).numpy()
            
            plt.figure(figsize=(8, 4))
            plt.subplot(1, 2, 1)
            plt.title("Sample Label Map (Input)")
            plt.imshow(label_vis)
            plt.axis('off')
            
            plt.subplot(1, 2, 2)
            plt.title("Sample Real Image (Target)")
            plt.imshow(image_vis)
            plt.axis('off')
            
            plt.show()
            '''