In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import cv2
import random
import matplotlib.pyplot as plt


def test_data_sanity():
    print("✅ from data.ipynb")

COLOR_TO_CLASS = {
    (75, 25, 230): 0,       # BUILDING
    (180, 30, 145): 1,      # CLUTTER
    (75, 180, 60): 2,       # VEGETATION
    (48, 130, 245): 3,      # WATER
    (255, 255, 255): 4,     # GROUND
    (200, 130, 0): 5        # CAR
}

class StreamingDataGenerator(Sequence):
    def __init__(self, image_dir, elevation_dir, label_dir, batch_size=32, input_type='rgb', num_classes=6, shuffle=True):
        self.image_dir = image_dir
        self.elevation_dir = elevation_dir
        self.label_dir = label_dir
        self.batch_size = batch_size
        self.input_type = input_type
        self.shuffle = shuffle
        self.num_classes = num_classes
        self.tile_list = [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if f.endswith('-ortho.png')]
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.tile_list) / self.batch_size))

    def __getitem__(self, index):
        batch_x = []
        batch_y = []

        selected = self.tile_list[index * self.batch_size:(index + 1) * self.batch_size]

        for i, file_name in enumerate(selected):
            try:
                rgb_path = os.path.join(self.image_dir, file_name + "-ortho.png")
                elev_path = os.path.join(self.elevation_dir, file_name + "-elev.npy")
                label_path = os.path.join(self.label_dir, file_name + "-label.png")

                if not os.path.exists(rgb_path):
                    print(f"❌ RGB file not found: {rgb_path}")
                    continue
                if not os.path.exists(elev_path):
                    print(f"❌ Elevation file not found: {elev_path}")
                    continue
                if not os.path.exists(label_path):
                    print(f"❌ Label file not found: {label_path}")
                    continue

                rgb_image = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
                elevation_data = np.load(elev_path)
                if elevation_data.ndim == 2:
                    elevation_data = np.expand_dims(elevation_data, axis=-1)
                label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)

                h, w, _ = label_rgb.shape
                label = np.full((h, w), -1, dtype=np.uint8)
                for color, idx in COLOR_TO_CLASS.items():
                    mask = np.all(label_rgb == color, axis=-1)
                    label[mask] = idx

                if np.any(label == 255) or np.any(label == -1):
                    print(f"⚠️ Unknown colours found in label at {label_path}. Skipping.")
                    continue

                label_onehot = tf.keras.utils.to_categorical(label, num_classes=self.num_classes)

                if self.input_type == '1ch':
                    merged = np.expand_dims(rgb_image[:, :, 0], axis=-1)
                elif self.input_type == '2ch':
                    grayscale = np.expand_dims(rgb_image[:, :, 0], axis=-1)
                    merged = np.concatenate([grayscale, elevation_data], axis=-1)
                elif self.input_type == 'rgb':
                    merged = rgb_image
                elif self.input_type == 'rgb_elevation':
                    merged = np.concatenate([rgb_image, elevation_data], axis=-1)
                else:
                    raise ValueError(f"Invalid input_type: {self.input_type}")

                if i == 0:
                    fig, axs = plt.subplots(1, 2, figsize=(10, 4))
                    axs[0].imshow(rgb_image)
                    axs[0].set_title("RGB Preview")
                    axs[0].axis("off")
                    axs[1].imshow(label)
                    axs[1].set_title("Label Preview")
                    axs[1].axis("off")
                    plt.tight_layout()
                    plt.show()

                batch_x.append(merged.astype(np.float32) / 255.0)
                batch_y.append(label_onehot.astype(np.float32))

            except Exception as e:
                print(f"❌ Error processing tile {file_name}: {str(e)}")
                continue

        return np.array(batch_x, dtype=np.float32), np.array(batch_y, dtype=np.float32)

    def on_epoch_end(self):
        if self.shuffle:
            random.shuffle(self.tile_list)

    # rgb_to_class_index is intentionally commented out
    # def rgb_to_class_index(self, label_rgb):
    #     h, w, _ = label_rgb.shape
    #     label = np.full((h, w), -1, dtype=np.int32)
    #     for color, idx in COLOR_TO_CLASS.items():
    #         mask = np.all(label_rgb == color, axis=-1)
    #         label[mask] = idx
    #     if np.any(label == -1):
    #         raise ValueError("❌ Unknown RGB values in label mask.")
    #     return label
