In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import cv2
import random
import matplotlib.pyplot as plt

CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6
}

IGNORE_COLOR = (255, 0, 255)
NUM_CLASSES = 6
IMPORTANT_CLASSES = [0, 1, 3, 5]

def compute_keep_prob_refined(row):
    ratios = {cls: row[f"{i}: {cls}"] for i, cls in enumerate(CLASS_NAMES)}

    if ratios['Background'] >= 0.999:
        return 0.0

    presence_flags = np.array([ratios[cls] > 0.01 for cls in CLASS_NAMES if cls != 'Background'], dtype=np.float32)
    num_present = np.sum(presence_flags)
    non_bg_ratios = np.array([ratios[cls] for cls in CLASS_NAMES if cls != 'Background'])
    std_dev = np.std(non_bg_ratios)

    diversity_score = (num_present / 5.0) - (0.5 * std_dev)
    rare_score = (
        3.0 * ratios['Car'] +
        2.0 * ratios['Water'] +
        1.5 * ratios['Building']
    )

    final_score = diversity_score + rare_score
    keep_prob = 1 / (1 + np.exp(-8 * (final_score - 0.15)))
    return keep_prob

def load_image_paths(df, image_dir, elevation_dir, label_dir):
    base_names = df['tile_id'].tolist()
    keep_probs = df['keep_prob'].tolist() if 'keep_prob' in df.columns else [1.0] * len(base_names)

    paths = [(os.path.join(image_dir, b + "-ortho.png"),
              os.path.join(elevation_dir, b + "-elev.npy"),
              os.path.join(label_dir, b + "-label.png"),
              b, p) for b, p in zip(base_names, keep_probs)]
    return paths

def _load_npy_impl(filepath):
    return np.load(filepath.numpy().decode()).astype(np.float32)

def load_elev_npy(filepath):
    elev_tensor = tf.py_function(_load_npy_impl, [filepath], tf.float32)
    elev_tensor = tf.ensure_shape(elev_tensor, [None, None])
    elev_tensor = tf.expand_dims(elev_tensor, axis=-1)
    return elev_tensor

def augment_image(rgb, elev, label):
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_left_right(rgb)
        elev = tf.image.flip_left_right(elev)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform([]) > 0.5:
        rgb = tf.image.flip_up_down(rgb)
        elev = tf.image.flip_up_down(elev)
        label = tf.image.flip_up_down(label)
    return rgb, elev, label

def parse_tile(rgb_path, elev_path, label_path, tile_id, input_type='rgb', split='train', augment=False, tile_size=256):
    rgb = tf.io.read_file(rgb_path)
    rgb = tf.image.decode_png(rgb, channels=3)
    rgb = tf.image.convert_image_dtype(rgb, tf.float32)

    label = tf.io.read_file(label_path)
    label = tf.image.decode_png(label, channels=1)
    label = tf.cast(label, tf.int32)

    if input_type in ["2ch", "rgb_elev"]:
        elev = load_elev_npy(elev_path)
    else:
        elev = tf.zeros_like(tf.cast(label, tf.float32))

    if split == 'train' and augment:
        rgb, elev, label = augment_image(rgb, elev, label)

    label_temp = tf.where(label == 255, tf.constant(0, dtype=tf.int32), label)
    label_onehot = tf.one_hot(tf.squeeze(label_temp, axis=-1), depth=NUM_CLASSES)

    if input_type == "1ch":
        input_image = tf.image.rgb_to_grayscale(rgb)
    elif input_type == "2ch":
        input_image = tf.concat([tf.image.rgb_to_grayscale(rgb), elev], axis=-1)
    elif input_type == "rgb":
        input_image = rgb
    elif input_type == "rgb_elev":
        input_image = tf.concat([rgb, elev], axis=-1)
    else:
        raise ValueError(f"Unknown input_type: {input_type}")

    input_image = tf.image.resize(input_image, [tile_size, tile_size])
    label_onehot = tf.image.resize(label_onehot, [tile_size, tile_size], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, label_onehot

def build_tf_dataset(df, image_dir, elevation_dir, label_dir,
                     input_type='rgb', batch_size=32, split='train',
                     augment=False, shuffle=True, tile_size=256):

    paths = load_image_paths(df, image_dir, elevation_dir, label_dir)
    dataset = tf.data.Dataset.from_tensor_slices(paths)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(paths), reshuffle_each_iteration=(split == 'train'))

    if split == 'train':
        def filter_fn(path_tuple):
            keep_prob = path_tuple[4]
            rand = tf.random.uniform([], 0, 1)
            return rand < keep_prob
        dataset = dataset.filter(filter_fn)

    def map_fn(path_tuple):
        rgb_path, elev_path, label_path, tile_id, _ = path_tuple
        return parse_tile(rgb_path, elev_path, label_path, tile_id, input_type, split, augment, tile_size)

    dataset = dataset.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

def plot_class_distribution_from_dataset(dataset, num_batches=1000, title="Sampled Training Class Distribution"):
    total_counts = np.zeros(NUM_CLASSES, dtype=np.int64)

    for i, (_, y_batch) in enumerate(dataset.take(num_batches)):
        labels = np.argmax(y_batch.numpy(), axis=-1)
        for cls in range(NUM_CLASSES):
            total_counts[cls] += np.sum(labels == cls)

    pixel_props = total_counts / np.sum(total_counts)

    class_labels = [f"{i}: {CLASS_NAMES[i]}" for i in range(NUM_CLASSES)]
    colours = [np.array(COLOR) / 255.0 for COLOR in COLOR_TO_CLASS if COLOR_TO_CLASS[COLOR] < NUM_CLASSES]
    colour_map = dict(zip(range(NUM_CLASSES), colours))
    colours_ordered = [colour_map[i] for i in range(NUM_CLASSES)]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(class_labels, pixel_props, color=colours_ordered, edgecolor='black')
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Proportion")
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)

    for bar, prop in zip(bars, pixel_props):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{prop:.2%}",
                 ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.show()