In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6
}

CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
NUM_CLASSES = 6

def decode_coloured_label(label_rgb):
    label_rgb = tf.cast(label_rgb, tf.uint8)
    flat = tf.reshape(label_rgb, [-1, 3])
    keys = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    values = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)
    match = tf.reduce_all(tf.equal(tf.expand_dims(flat, 1), keys), axis=2)
    indices = tf.argmax(tf.cast(match, tf.int32), axis=1)
    return tf.reshape(indices, tf.shape(label_rgb)[:2])

def _load_npy(path):
    if isinstance(path, tf.Tensor):
        path = path.numpy().decode("utf-8")
    return np.load(path).astype(np.float32)

def load_image_paths(df, image_dir, elev_dir, slope_dir, label_dir):
    tile_ids = df["tile_id"].tolist()
    image_paths = [os.path.join(image_dir, f"{tid}-ortho.png") for tid in tile_ids]
    elev_paths = [os.path.join(elev_dir, f"{tid}-elev.npy") for tid in tile_ids]
    slope_paths = [os.path.join(slope_dir, f"{tid}-slope.npy") for tid in tile_ids]
    label_paths = [os.path.join(label_dir, f"{tid}-label.png") for tid in tile_ids]
    return image_paths, elev_paths, slope_paths, label_paths, tile_ids

def augment_rgb_label(rgb, label):
    label = tf.expand_dims(label, axis=-1)

    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        label = tf.image.flip_left_right(label)

    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        label = tf.image.flip_up_down(label)

    if tf.random.uniform([]) > 0.0:
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        label = tf.image.rot90(label, k)

    label = tf.squeeze(label, axis=-1)
    return rgb, label

def augment_all(rgb, elev, slope, label):
    label = tf.expand_dims(label, axis=-1)

    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        elev = tf.image.flip_left_right(elev)
        slope = tf.image.flip_left_right(slope)
        label = tf.image.flip_left_right(label)

    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        elev = tf.image.flip_up_down(elev)
        slope = tf.image.flip_up_down(slope)
        label = tf.image.flip_up_down(label)

    if tf.random.uniform([]) > 0.0:
        k = tf.random.uniform([], 1, 4, dtype=tf.int32)
        rgb = tf.image.rot90(rgb, k)
        elev = tf.image.rot90(elev, k)
        slope = tf.image.rot90(slope, k)
        label = tf.image.rot90(label, k)

    label = tf.squeeze(label, axis=-1)
    return rgb, elev, slope, label

def parse_tile(rgb_path, label_path, tile_id, split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    if split == 'train' and augment:
        rgb, label = augment_rgb_label(rgb, label)

    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)
    label = tf.one_hot(label, depth=NUM_CLASSES)

    return rgb, label

def parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id,
                    split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)

    elev = tf.py_function(_load_npy, [elev_path], tf.float32)
    slope = tf.py_function(_load_npy, [slope_path], tf.float32)
    elev.set_shape([None, None])
    slope.set_shape([None, None])
    elev = tf.expand_dims(elev, axis=-1)
    slope = tf.expand_dims(slope, axis=-1)

    if split == 'train' and augment:
        rgb, elev, slope, label = augment_all(rgb, elev, slope, label)

    rgb = tf.image.resize(rgb, [tile_size, tile_size])
    elev = tf.image.resize(elev, [tile_size, tile_size])
    slope = tf.image.resize(slope, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method='nearest')
    label = tf.reshape(label, [tile_size, tile_size])
    label = tf.cast(label, tf.int32)
    label = tf.one_hot(label, depth=NUM_CLASSES)

    input_image = tf.concat([rgb, elev, slope], axis=-1)
    return input_image, label

def build_tf_dataset(df, image_dir, elev_dir, slope_dir, label_dir,
                     input_type='rgb', batch_size=32, split='train',
                     augment=False, shuffle=True, tile_size=256):
    image_paths, elev_paths, slope_paths, label_paths, tile_ids = load_image_paths(df, image_dir, elev_dir, slope_dir, label_dir)

    if input_type == 'rgb':
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths, tile_ids))

        def map_fn(rgb_path, label_path, tile_id):
            return parse_tile(rgb_path, label_path, tile_id, split, augment, tile_size)

    elif input_type == 'rgb_elev':
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, elev_paths, slope_paths, label_paths, tile_ids))

        def map_fn(rgb_path, elev_path, slope_path, label_path, tile_id):
            return parse_elevation(rgb_path, elev_path, slope_path, label_path, tile_id, split, augment, tile_size)

    else:
        raise ValueError(f"Unsupported input_type: {input_type}")

    dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths), reshuffle_each_iteration=True)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset
