In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import cv2
import random
import matplotlib.pyplot as plt

CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6
}

IGNORE_COLOR = (255, 0, 255)
NUM_CLASSES = 6
IMPORTANT_CLASSES = [0, 1, 3, 5]

CLASS_TO_COLOR = {
    0: (230, 25, 75),
    1: (145, 30, 180),
    2: (60, 180, 75),
    3: (245, 130, 48),
    4: (255, 255, 255),
    5: (0, 130, 200),
}

def normalise_elevation(elev, valid_min=-50, valid_max=500, raster_nodata=-32767):
    elev = np.where(elev == raster_nodata, np.nan, elev)
    elev = np.clip(elev, valid_min, valid_max)
    norm_elev = (elev - valid_min) / (valid_max - valid_min)
    norm_elev = np.nan_to_num(norm_elev, nan=0.0)
    return norm_elev.astype(np.float32)

def standardise_elevation(elev, raster_nodata=-32767):
    valid_mask = elev != raster_nodata
    valid_elev = elev[valid_mask]
    mean = valid_elev.mean()
    std = valid_elev.std()
    standardised = np.zeros_like(elev, dtype=np.float32)
    if std > 0:
        standardised[valid_mask] = (valid_elev - mean) / std
    else:
        standardised[valid_mask] = 0.0
    return np.expand_dims(standardised, axis=-1)

def _load_npy_elev_and_standardise(path):
    elev = np.load(path.decode("utf-8"))
    return standardise_elevation(elev)

def decode_coloured_label(label_rgb):
    label_rgb = tf.cast(label_rgb, tf.uint8)
    label_flat = tf.reshape(label_rgb, [-1, 3])
    colors = tf.constant(list(COLOR_TO_CLASS.keys()), dtype=tf.uint8)
    class_ids = tf.constant(list(COLOR_TO_CLASS.values()), dtype=tf.int32)
    mask = tf.reduce_all(tf.equal(tf.expand_dims(label_flat, 1), colors), axis=2)
    indices = tf.argmax(tf.cast(mask, tf.int32), axis=1)
    mapped = tf.gather(class_ids, indices)
    return tf.reshape(mapped, [tf.shape(label_rgb)[0], tf.shape(label_rgb)[1]])

def augment_image(rgb, elev, label):
    label = tf.expand_dims(label, axis=-1) if tf.rank(label) == 2 else label
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        elev = tf.image.flip_left_right(elev)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        elev = tf.image.flip_up_down(elev)
        label = tf.image.flip_up_down(label)
    label = tf.squeeze(label, axis=-1)
    return rgb, elev, label

def parse_elevation(rgb_path, elev_path, label_path, tile_id, split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)
    label.set_shape([None, None])  # Ensure shape is known for resizing

    elev = tf.numpy_function(_load_npy_elev_and_standardise, [elev_path], tf.float32)
    elev.set_shape([None, None, 1])

    if split == 'train' and augment:
        rgb, elev, label = augment_image(rgb, elev, label)

    input_image = tf.concat([rgb, elev], axis=-1)
    input_image = tf.image.resize(input_image, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    label = tf.squeeze(label, axis=-1)
    label_onehot = tf.one_hot(label, depth=NUM_CLASSES)
    label_onehot.set_shape([tile_size, tile_size, NUM_CLASSES])

    return input_image, label_onehot

def parse_tile(rgb_path, label_path, tile_id, split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=3)
    label = decode_coloured_label(label)
    label.set_shape([None, None])  # Fix missing shape

    if split == 'train' and augment:
        rgb, _, label = augment_image(rgb, tf.zeros_like(rgb[..., :1]), label)

    input_image = tf.image.resize(rgb, [tile_size, tile_size])
    label = tf.image.resize(label[..., tf.newaxis], [tile_size, tile_size], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    label = tf.squeeze(label, axis=-1)
    label_onehot = tf.one_hot(label, depth=NUM_CLASSES)
    label_onehot.set_shape([tile_size, tile_size, NUM_CLASSES])

    return input_image, label_onehot
