In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import cv2
import random
import gc
import psutil

# --- Sanity ---
def test_data_sanity():
    print("✅ from data.ipynb")

import pandas as pd

# Load metadata once at generator init
def load_train_metadata(csv_path):
    df = pd.read_csv(csv_path)
    meta_dict = {
        row['tile_id']: row[[f'{i}: {name}' for i, name in enumerate(CLASS_NAMES)]].values.astype(float)
        for _, row in df.iterrows()
    }
    return meta_dict



COLOR_TO_CLASS = {
    (230, 25, 75): 0,       # BUILDING
    (145, 30, 180): 1,      # CLUTTER
    (60, 180, 75): 2,       # VEGETATION
    (245, 130, 48): 3,      # WATER
    (255, 255, 255): 4,     # GROUND
    (0, 130, 200): 5        # CAR
}
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
IMPORTANT_CLASSES = [0, 1, 3, 5]  # Building, Clutter, Water, Car

class StreamingDataGenerator(Sequence):
    def __init__(self, image_dir, elevation_dir, label_dir, 
                 batch_size=32, input_type='rgb', num_classes=6, 
                 shuffle=True, steps=None, fixed=False, augment=False,
                 background_threshold=0.95, metadata_csv_path=None):
        
        self.image_dir = image_dir
        self.elevation_dir = elevation_dir
        self.label_dir = label_dir
        self.batch_size = batch_size
        self.input_type = input_type
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.steps = steps
        self.fixed = fixed
        self.background_threshold = background_threshold
        self.augment = augment
        self.tile_list = [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if f.endswith('-ortho.png')]
        self.on_epoch_end()
        self.metadata = load_train_metadata(metadata_csv_path) if metadata_csv_path else {}

    def __len__(self):
        return self.steps

    def __getitem__(self, index):
        batch_x, batch_y = [], []
        important_count = 0

        while len(batch_x) < self.batch_size:
            tile = random.choice(self.tile_list)

            try:
                rgb_path = os.path.join(self.image_dir, tile + "-ortho.png")
                elev_path = os.path.join(self.elevation_dir, tile + "-elev.npy")
                label_path = os.path.join(self.label_dir, tile + "-label.png")

                if not (os.path.exists(rgb_path) and os.path.exists(elev_path) and os.path.exists(label_path)):
                    continue

                rgb = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
                elev = np.expand_dims(np.load(elev_path), -1)
                label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)

                h, w, _ = label_rgb.shape
                label = np.full((h, w), -1, dtype=np.uint8)
                for color, idx in COLOR_TO_CLASS.items():
                    mask = np.all(label_rgb == color, axis=-1)
                    label[mask] = idx

                if np.any(label == -1):
                    continue

                has_important = np.any(np.isin(label, IMPORTANT_CLASSES))

                if self.augment:
                    if np.all(label == 4):
                        continue  # skip pure background

                    class_counts = np.bincount(label.flatten(), minlength=self.num_classes)
                    total_pixels = label.size
                    class_ratios = class_counts / total_pixels

                    has_important = np.any(np.isin(label, IMPORTANT_CLASSES))
                    car_ratio = class_ratios[5]
                    water_ratio = class_ratios[3]
                    clutter_ratio = class_ratios[1]
                    building_ratio = class_ratios[0]
                    vegetation_ratio = class_ratios[2]
                    background_ratio = class_ratios[4]

                    batch_midpoint = self.batch_size // 2

                    if len(batch_x) < batch_midpoint:
                        # Biasing: Encourage keeping tiles with more car/water/clutter/building
                        keep_score = (
                            64.0 * car_ratio +
                            6.0 * water_ratio +
                            1.1 * clutter_ratio +
                            1.25 * building_ratio -
                            1.0 * vegetation_ratio -
                            1.8 * background_ratio
                        )

                        # Apply stochastic threshold only to early half of batch
                        if keep_score < random.uniform(0.1, 0.25):
                            continue

                    # --- Data Augmentation ---
                    flip_horizontal = random.choice([True, False])
                    flip_vertical = random.choice([True, False])
                    rotation_k = random.choice([0, 1, 2, 3])

                    if flip_horizontal:
                        rgb = np.fliplr(rgb)
                        elev = np.fliplr(elev)
                        label = np.fliplr(label)
                    if flip_vertical:
                        rgb = np.flipud(rgb)
                        elev = np.flipud(elev)
                        label = np.flipud(label)
                    if rotation_k > 0:
                        rgb = np.rot90(rgb, k=rotation_k)
                        elev = np.rot90(elev, k=rotation_k)
                        label = np.rot90(label, k=rotation_k)

                label_onehot = tf.keras.utils.to_categorical(label, num_classes=self.num_classes)

                if self.input_type == '1ch':
                    merged = np.expand_dims(rgb[:, :, 0], axis=-1)
                elif self.input_type == '2ch':
                    merged = np.concatenate([np.expand_dims(rgb[:, :, 0], axis=-1), elev], axis=-1)
                elif self.input_type == 'rgb':
                    merged = rgb
                elif self.input_type == 'rgb_elevation':
                    merged = np.concatenate([rgb, elev], axis=-1)
                else:
                    continue

                batch_x.append(merged.astype(np.float32) / 255.0)
                batch_y.append(label_onehot.astype(np.float32))

            except Exception as e:
                print(f"⚠️ Error loading tile {tile}: {e}")
                continue

        gc.collect()
        return np.array(batch_x), np.array(batch_y)

    def on_epoch_end(self):
        if self.fixed:
            random.seed(42)
        if self.shuffle:
            random.shuffle(self.tile_list)
