In [None]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import cv2
import random
import gc
import pandas as pd

train_files = [

    "d9161f7e18_C05BA1BC72OPENPIPELINE",
    "c2e8370ca3_3340CAC7AEOPENPIPELINE",
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    "551063e3c5_8FCB044F58INSPIRE",
    "ec09336a6f_06BA0AF311OPENPIPELINE",
    
    "2ef883f08d_F317F9C1DFOPENPIPELINE",
    "c37dbfae2f_84B52814D2OPENPIPELINE",
    "f4dd768188_NOLANOPENPIPELINE",
    "1553541585_APIGENERATED",
    "dabec5e872_E8AD935CEDINSPIRE",
    "b61673f780_4413A67E91INSPIRE",
    "d06b2c67d2_2A62B67B52OPENPIPELINE",
    "cc4b443c7d_A9CBEF2C97INSPIRE",
    "11cdce7802_B6A62F8BE0INSPIRE",
    "1476907971_CHADGRISMOPENPIPELINE"


    "c68d5dd5e5_8D20F0,2042OPENPIPELINE",
    "ae5f6f1a67_E8AD935CEDINSPIRE",
    "203b0f447b_536DE05ED2OPENPIPELINE",
    "c0b48f9e09_84B52814D2OPENPIPELINE",
    "89d1b23f42_60693DB04DINSPIRE",
    "aa7e37f4e2_625EDFBAB6OPENPIPELINE",
    "fc2f3b5504_8FCB044F58OPENPIPELINE",
    "b0d23b19d8_A9CBEF2C97INSPIRE",
    "6e13af485d_53197F206FOPENPIPELINE",
    "7c216a3883_0CCD105428INSPIRE"
]

val_files = [
    "1476907971_CHADGRISMOPENPIPELINE",
    "dabec5e872_E8AD935CEDINSPIRE",
    "c6d131e346_536DE05ED2OPENPIPELINE",
    "57426ebe1e_84B52814D2OPENPIPELINE",
    "1726eb08ef_60693DB04DINSPIRE",
    "9170479165_625EDFBAB6OPENPIPELINE",
    "520947aa07_8FCB044F58OPENPIPELINE",
    "cc4b443c7d_A9CBEF2C97INSPIRE",
    "12fa5e614f_53197F206FOPENPIPELINE"
    "2ef3a4994a_0CCD105428INSPIRE",
]

test_files = [
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    "f9f43e5144_1DB9E6F68BINSPIRE",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE",
    "a1af86939f_F1BE1D4184OPENPIPELINE",
    "1553541487_APIGENERATED",
    "74d7796531_EB81FE6E2BOPENPIPELINE",
    "8710b98ea0_06E6522D6DINSPIRE",
    "c644f91210_27E21B7F30OPENPIPELINE",
    "d9161f7e18_C05BA1BC72OPENPIPELINE", 
]

CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

# --- Sanity ---
def test_data_sanity():
    print("✅ from data.ipynb")

def load_train_metadata(csv_path):
    df = pd.read_csv(csv_path)
    meta_dict = {
        row['tile_id']: row[[f'{i}: {name}' for i, name in enumerate(CLASS_NAMES)]].values.astype(float)
        for _, row in df.iterrows()
    }
    return meta_dict

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5
}
IGNORE_COLOR = (255, 0, 255)
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
IMPORTANT_CLASSES = [0, 1, 3, 5]

class StreamingDataGenerator(Sequence):
    def __init__(self, image_dir, elevation_dir, label_dir, 
                 batch_size=32, input_type='rgb', num_classes=6, 
                 shuffle=True, steps=None, fixed=False, augment=False,
                 background_threshold=0.95, metadata_csv_path=None,
                 split='train', val_files=None, test_files=None):

        self.image_dir = image_dir
        self.elevation_dir = elevation_dir
        self.label_dir = label_dir
        self.batch_size = batch_size
        self.input_type = input_type
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.steps = steps
        self.fixed = fixed
        self.background_threshold = background_threshold
        self.augment = augment
        self.metadata = load_train_metadata(metadata_csv_path) if metadata_csv_path else {}
        self.split = split

        all_tiles = [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if f.endswith('-ortho.png')]

        if split == 'train':
            # self.tile_list = [t for t in all_tiles if not any(t.startswith(excl) for excl in (val_files or []) + (test_files or []))] # use full dataset
            self.tile_list = [t for t in all_tiles if any(t.startswith(v) for v in train_files)]                                        # use subset
        elif split == 'val':
            self.tile_list = [t for t in all_tiles if any(t.startswith(v) for v in (val_files or []))]
        elif split == 'test':
            self.tile_list = [t for t in all_tiles if any(t.startswith(v) for v in (test_files or []))]
        else:
            raise ValueError("Unknown split type: " + str(split))

    def __len__(self):
        return self.steps

    def __getitem__(self, index):
        batch_x, batch_y = [], []
        attempts = 0
        max_attempts = 5 * self.batch_size

        while len(batch_x) < self.batch_size and attempts < max_attempts:
            tile = random.choice(self.tile_list)
            attempts += 1
            try:
                rgb_path = os.path.join(self.image_dir, tile + "-ortho.png")
                elev_path = os.path.join(self.elevation_dir, tile + "-elev.npy")
                label_path = os.path.join(self.label_dir, tile + "-label.png")

                if not (os.path.exists(rgb_path) and os.path.exists(elev_path) and os.path.exists(label_path)):
                    continue

                rgb = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
                elev = np.expand_dims(np.load(elev_path), -1)
                label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)

                h, w, _ = label_rgb.shape
                label = np.full((h, w), -1, dtype=np.uint8)
                for color, idx in COLOR_TO_CLASS.items():
                    mask = np.all(label_rgb == color, axis=-1)
                    label[mask] = idx
                label[np.all(label_rgb == IGNORE_COLOR, axis=-1)] = 255

                valid_mask = label != 255

                '''
                black_pixels = np.all(rgb == [0, 0, 0], axis=-1)
                black_ratio = np.mean(black_pixels)
                if black_ratio > 0.5:
                    print(f"🕳️ Skipping black-dominant tile: {tile}, black ratio = {black_ratio:.2f}")
                    continue
'''
                if self.split == 'train' and np.sum(valid_mask) == 0:
                    continue

                if self.split == 'train' and self.augment:
                    if np.all(label == 4):
                        continue

                if self.split == 'train' and self.augment:
                    # 🆕 Always augment if contains class 0, 3, or 5; else 25% chance
                    contains_important = any(np.any(label == c) for c in [0, 3, 5])
                    if not contains_important and random.random() > 0.25:
                        continue

                    class_ratios = self.metadata.get(tile, None)
                    if class_ratios is None:
                        continue

                    car_ratio = class_ratios[5]
                    water_ratio = class_ratios[3]
                    clutter_ratio = class_ratios[1]
                    building_ratio = class_ratios[0]
                    vegetation_ratio = class_ratios[2]
                    background_ratio = class_ratios[4]

                    batch_midpoint = self.batch_size // 2

                    if len(batch_x) < batch_midpoint:
                        keep_score = (
                            256.0 * car_ratio +
                            6.5 * water_ratio +
                            0.08 * clutter_ratio +
                            1.25 * building_ratio +
                            0.9 * vegetation_ratio -
                            1.7 * background_ratio
                        )
                    
                        keep_threshold = random.uniform(-1.25, 0.45)

                        if keep_score < keep_threshold:
                            print(f"⛔ Skipping {tile}, score={keep_score:.3f}, threshold={keep_threshold:.3f} "
                                f"(car={car_ratio:.2f}, water={water_ratio:.2f}, clutter={clutter_ratio:.2f}, "
                                f"building={building_ratio:.2f}, vegetation={vegetation_ratio:.2f}, background={background_ratio:.2f})")
                            continue
                        else:
                            print(f"✅ Keeping {tile}, score={keep_score:.3f}, threshold={keep_threshold:.3f} "
                                f"(car={car_ratio:.2f}, water={water_ratio:.2f}, clutter={clutter_ratio:.2f}, "
                                f"building={building_ratio:.2f}, vegetation={vegetation_ratio:.2f}, background={background_ratio:.2f})")

                    # 💫 Apply augmentation
                    flip_horizontal = random.choice([True, False])
                    flip_vertical = random.choice([True, False])
                    rotation_k = random.choice([0, 1, 2, 3])

                    if flip_horizontal:
                        rgb = np.fliplr(rgb)
                        elev = np.fliplr(elev)
                        label = np.fliplr(label)
                        valid_mask = np.fliplr(valid_mask)
                    if flip_vertical:
                        rgb = np.flipud(rgb)
                        elev = np.flipud(elev)
                        label = np.flipud(label)
                        valid_mask = np.flipud(valid_mask)
                    if rotation_k > 0:
                        rgb = np.rot90(rgb, k=rotation_k)
                        elev = np.rot90(elev, k=rotation_k)
                        label = np.rot90(label, k=rotation_k)
                        valid_mask = np.rot90(valid_mask, k=rotation_k)

                label_temp = label.copy()
                label_temp[label_temp == 255] = 0
                label_onehot = tf.keras.utils.to_categorical(label_temp, num_classes=self.num_classes)
                label_onehot[~valid_mask] = 0

                if self.input_type == '1ch':
                    merged = np.expand_dims(rgb[:, :, 0], axis=-1)
                elif self.input_type == '2ch':
                    merged = np.concatenate([np.expand_dims(rgb[:, :, 0], axis=-1), elev], axis=-1)
                elif self.input_type == 'rgb':
                    merged = rgb
                elif self.input_type == 'rgb_elevation':
                    merged = np.concatenate([rgb, elev], axis=-1)
                else:
                    continue

                batch_x.append(merged.astype(np.float32) / 255.0)
                batch_y.append(label_onehot.astype(np.float32))

            except Exception as e:
                print(f"⚠️ Error loading tile {tile}: {e}")
                continue

        if len(batch_x) == 0:
            # print(f"❌ No valid tiles loaded for {self.split} at index {index} — skipping")
            batch_x = np.zeros((self.batch_size, 256, 256, 3), dtype=np.float32)
            batch_y = np.zeros((self.batch_size, 256, 256, self.num_classes), dtype=np.float32)

        gc.collect()
        return np.array(batch_x), np.array(batch_y)

    def on_epoch_end(self):
        if self.fixed:
            random.seed(42)
        if self.shuffle:
            random.shuffle(self.tile_list)
