In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import cv2
import random
import gc
import psutil

# --- Sanity ---
def test_data_sanity():
    print("✅ from data.ipynb")

COLOR_TO_CLASS = {
    (230, 25, 75): 0,       # BUILDING
    (145, 30, 180): 1,      # CLUTTER
    (60, 180, 75): 2,       # VEGETATION
    (245, 130, 48): 3,      # WATER
    (255, 255, 255): 4,     # GROUND
    (0, 130, 200): 5        # CAR
}
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)

class StreamingDataGenerator(Sequence):
    def __init__(self, image_dir, elevation_dir, label_dir, 
                 batch_size=32, input_type='rgb', num_classes=6, 
                 shuffle=True, steps_per_epoch=16, fixed=False):
        self.image_dir = image_dir
        self.elevation_dir = elevation_dir
        self.label_dir = label_dir
        self.batch_size = batch_size
        self.input_type = input_type
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.steps_per_epoch = steps_per_epoch
        self.fixed = fixed
        self.tile_list = [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if f.endswith('-ortho.png')]
        self.on_epoch_end()

    def __len__(self):
        return self.steps_per_epoch

    def __getitem__(self, index):

        batch_x, batch_y = [], []

        while len(batch_x) < self.batch_size:
            tile = random.choice(self.tile_list)

            try:
                rgb_path = os.path.join(self.image_dir, tile + "-ortho.png")
                elev_path = os.path.join(self.elevation_dir, tile + "-elev.npy")
                label_path = os.path.join(self.label_dir, tile + "-label.png")

                if not (os.path.exists(rgb_path) and os.path.exists(elev_path) and os.path.exists(label_path)):
                    continue

                rgb = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
                elev = np.expand_dims(np.load(elev_path), -1)
                label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)

                h, w, _ = label_rgb.shape
                label = np.full((h, w), -1, dtype=np.uint8)
                for color, idx in COLOR_TO_CLASS.items():
                    mask = np.all(label_rgb == color, axis=-1)
                    label[mask] = idx

                if np.any(label == -1):
                    continue

                if np.all(label == 4):  # All white background
                    print(f"🧼 Skipping all-background tile: {tile}")
                    continue

                label_onehot = tf.keras.utils.to_categorical(label, num_classes=self.num_classes)

                if self.input_type == '1ch':
                    merged = np.expand_dims(rgb[:, :, 0], axis=-1)
                elif self.input_type == '2ch':
                    merged = np.concatenate([np.expand_dims(rgb[:, :, 0], axis=-1), elev], axis=-1)
                elif self.input_type == 'rgb':
                    merged = rgb
                elif self.input_type == 'rgb_elevation':
                    merged = np.concatenate([rgb, elev], axis=-1)
                else:
                    continue

                batch_x.append(merged.astype(np.float32) / 255.0)
                batch_y.append(label_onehot.astype(np.float32))

            except Exception as e:
                print(f"⚠️ Error loading tile {tile}: {e}")
                continue

        print(f"📦 Generated batch {index+1}/{self.steps_per_epoch} | Batch size: {len(batch_x)} | Memory: {psutil.virtual_memory().percent}% used")
        gc.collect()  # Force garbage collection after batch

        return np.array(batch_x), np.array(batch_y)

    def on_epoch_end(self):
        
        if self.fixed:
            random.seed(42)
    
        if self.shuffle:
            random.shuffle(self.tile_list)
