In [None]:

import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import cv2
import random

def test_data_sanity():
    print("✅ from data.ipynb")

# RGB to class index mapping
COLOR_TO_CLASS = {
    (75, 25, 230): 0,       # BUILDING
    (180, 30, 145): 1,      # CLUTTER
    (75, 180, 60): 2,       # VEGETATION
    (48, 130, 245): 3,      # WATER
    (255, 255, 255): 4,     # GROUND
    (200, 130, 0): 5        # CAR
}

class StreamingDataGenerator(Sequence):
    def __init__(self, image_dir, elevation_dir, label_dir, batch_size=32, input_type='rgb', num_classes=6, shuffle=True):
        self.image_dir = image_dir
        self.elevation_dir = elevation_dir
        self.label_dir = label_dir
        self.batch_size = batch_size
        self.input_type = input_type
        self.shuffle = shuffle
        self.num_classes = num_classes
        self.tile_list = [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if f.endswith('-ortho.png')]
        self.on_epoch_end()

    def __len__(self):
        return 1  # We generate a new batch each epoch from shuffled list

    def __getitem__(self, index):
        batch_x = []
        batch_y = []

        selected = np.random.choice(self.tile_list, self.batch_size, replace=False)

        for file_name in selected:
            # Load RGB
            rgb_path = os.path.join(self.image_dir, file_name + "-ortho.png")
            rgb_image = cv2.imread(rgb_path)
            rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)

            # Load Elevation
            elev_path = os.path.join(self.elevation_dir, file_name + "-elev.npy")
            elevation_data = np.load(elev_path)
            elevation_data = np.expand_dims(elevation_data, axis=-1)

            # Load Label
            label_path = os.path.join(self.label_dir, file_name + "-label.png")
            label_image = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

            # One-hot encode
            label_onehot = tf.keras.utils.to_categorical(label_image, num_classes=self.num_classes)

            # Combine input
            combined_input = np.concatenate([rgb_image, elevation_data], axis=-1)

            batch_x.append(combined_input)
            batch_y.append(label_onehot)

        return np.array(batch_x, dtype=np.float32) / 255.0, np.array(batch_y, dtype=np.float32)


    def on_epoch_end(self):
        if self.shuffle:
            random.shuffle(self.tile_list)

    def rgb_to_class_index(self, label_rgb):
        h, w, _ = label_rgb.shape
        label = np.full((h, w), -1, dtype=np.int32)

        for color, idx in COLOR_TO_CLASS.items():
            mask = np.all(label_rgb == color, axis=-1)
            label[mask] = idx

        if np.any(label == -1):
            raise ValueError("❌ Unknown RGB values in label mask.")

        return label

    def __data_generation(self, batch_files):
        X, y = [], []

        for base_name in batch_files:
            # Load RGB image
            image_path = os.path.join(self.image_dir, base_name + '-ortho.png')
            label_path = os.path.join(self.label_dir, base_name + '-label.png')
            elev_path = os.path.join(self.elevation_dir, base_name + '-elev.npy')

            image = cv2.imread(image_path)
            if image is None or image.shape[:2] != (512, 512):
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Load elevation
            if self.input_type in ['2ch', 'rgb_elevation']:
                if not os.path.exists(elev_path):
                    continue
                elev = np.load(elev_path)
                if elev.ndim == 3:
                    elev = elev[:, :, 0]
                elev = np.expand_dims(elev, axis=-1)
                if elev.shape[:2] != (512, 512):
                    continue

            # Load label
            label_rgb = cv2.imread(label_path)
            if label_rgb is None or label_rgb.shape[:2] != (512, 512):
                continue
            label_rgb = cv2.cvtColor(label_rgb, cv2.COLOR_BGR2RGB)

            try:
                label = self.rgb_to_class_index(label_rgb)
            except ValueError:
                continue
            label = tf.keras.utils.to_categorical(label, num_classes=self.num_classes)

            # Combine inputs
            if self.input_type == '1ch':
                merged = np.expand_dims(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), axis=-1)
            elif self.input_type == '2ch':
                grayscale = np.expand_dims(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), axis=-1)
                merged = np.concatenate([grayscale, elev], axis=-1)
            elif self.input_type == 'rgb':
                merged = image
            elif self.input_type == 'rgb_elevation':
                merged = np.concatenate([image, elev], axis=-1)
            else:
                raise ValueError("Invalid input_type")

            X.append(merged / 255.0)
            y.append(label)

        return np.array(X), np.array(y)
