In [None]:
import os
import numpy as np
import cv2
from tqdm import tqdm
import csv
from scipy.stats import entropy
from collections import defaultdict
import shutil

# === CONFIG ===
IN_DIR = '/content/dataset-medium'
OUT_DIR = '/content/chipped_data'
RAW_DIR = os.path.join(OUT_DIR, 'raw')
TILE_SIZE = 256
STRIDE = 128
LARGE_TILE_SIZE = 1024
IGNORE_COLOR = (255, 0, 255)
IGNORE_THRESHOLD = 0.0
BACKGROUND_CLASS = 4
BACKGROUND_SKIP_THRESHOLD = 0.95

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5
}

RARE_CLASSES = [0, 1, 3, 5]
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
NUM_CLASSES = len(CLASS_NAMES)

PROBLEM_REGIONS = {
    "25f1c24f30_EB81FE6E2BOPENPIPELINE": [
        {"x": 4050, "y": 0, "w": 1300, "h": 1050},
        {"x": 3700, "y": 3650, "w": 200, "h": 170},
        {"x": 3525, "y": 3810, "w": 250, "h": 190},
        {"x": 3780, "y": 3580, "w": 200, "h": 160}
    ],
    "39e77bedd0_729FB913CDOPENPIPELINE": [
        {"x": 2900, "y": 2700, "w": 250, "h": 100}
    ],
    "a1af86939f_F1BE1D4184OPENPIPELINE": [
        {"x": 0, "y": 800, "w": 300, "h": 110}
    ]
}

def read_elevation_float32(path):
    tif = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if tif is None:
        return None
    if len(tif.shape) == 3 and tif.shape[2] > 1:
        tif = tif[:, :, 0]
    return tif.astype(np.float32)

def overlaps_problem_region(x, y, base_name):
    if base_name not in PROBLEM_REGIONS:
        return False
    for region in PROBLEM_REGIONS[base_name]:
        rx, ry, rw, rh = region['x'], region['y'], region['w'], region['h']
        if (x + TILE_SIZE > rx and x < rx + rw and y + TILE_SIZE > ry and y < ry + rh):
            return True
    return False

def chip_image(rgb_path, elev_path, label_path, base_name, metadata_rows):
    rgb = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
    elev = read_elevation_float32(elev_path)
    label = cv2.imread(label_path, cv2.IMREAD_COLOR)

    h, w, _ = rgb.shape

    for y in range(0, h - TILE_SIZE + 1, STRIDE):
        for x in range(0, w - TILE_SIZE + 1, STRIDE):
            if overlaps_problem_region(x, y, base_name):
                continue

            rgb_tile = rgb[y:y+TILE_SIZE, x:x+TILE_SIZE]
            elev_tile = elev[y:y+TILE_SIZE, x:x+TILE_SIZE]
            label_tile = label[y:y+TILE_SIZE, x:x+TILE_SIZE]

            if np.mean(np.all(label_tile == IGNORE_COLOR, axis=-1)) > IGNORE_THRESHOLD:
                continue

            label_tile_rgb = cv2.cvtColor(label_tile, cv2.COLOR_BGR2RGB)
            label_ids = np.full((TILE_SIZE, TILE_SIZE), -1, dtype=np.int32)
            for color_rgb, class_idx in COLOR_TO_CLASS.items():
                mask = np.all(label_tile_rgb == color_rgb, axis=-1)
                label_ids[mask] = class_idx

            if np.any(label_ids == -1):
                continue

            counts = np.array([(label_ids == i).sum() for i in range(NUM_CLASSES)], dtype=np.float32)
            class_percentages = counts / max(counts.sum(), 1)
            background_pct = class_percentages[BACKGROUND_CLASS]

            if background_pct == 1.0 or (background_pct > BACKGROUND_SKIP_THRESHOLD and not any(counts[i] > 0 for i in RARE_CLASSES)):
                continue

            tile_id = f"{base_name}_{x}_{y}"

            for folder in ['images', 'elevations', 'labels']:
                os.makedirs(os.path.join(OUT_DIR, 'train', folder), exist_ok=True)

            cv2.imwrite(os.path.join(OUT_DIR, 'train', 'images', f'{tile_id}-ortho.png'), cv2.cvtColor(rgb_tile, cv2.COLOR_RGB2BGR))
            np.save(os.path.join(OUT_DIR, 'train', 'elevations', f'{tile_id}-elev.npy'), elev_tile)
            cv2.imwrite(os.path.join(OUT_DIR, 'train', 'labels', f'{tile_id}-label.png'), label_tile)

            entropy_val = entropy(class_percentages + 1e-9, base=2)
            metadata_rows.append([tile_id, base_name, x, y] + class_percentages.tolist() + [float(entropy_val)])
    pass

def chip_raw_large_tiles():
    metadata_rows = []
    for fname in tqdm(os.listdir(os.path.join(IN_DIR, 'images')), desc="🔄 Chipping raw directory"):
        if not fname.endswith('-ortho.tif'):
            continue
        base = fname.replace('-ortho.tif', '')

        rgb_path = os.path.join(IN_DIR, 'images', f'{base}-ortho.tif')
        elev_path = os.path.join(IN_DIR, 'elevations', f'{base}-elev.tif')
        label_path = os.path.join(IN_DIR, 'labels', f'{base}-label.png')

        rgb = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
        elev = read_elevation_float32(elev_path)
        label = cv2.imread(label_path, cv2.IMREAD_COLOR)

        h, w, _ = rgb.shape
        nx = max(1, w // LARGE_TILE_SIZE)
        ny = max(1, h // LARGE_TILE_SIZE)

        x_start = (w - nx * LARGE_TILE_SIZE) // 2
        y_start = (h - ny * LARGE_TILE_SIZE) // 2

        for i in range(nx):
            for j in range(ny):
                x = x_start + i * LARGE_TILE_SIZE
                y = y_start + j * LARGE_TILE_SIZE

                rgb_tile = rgb[y:y+LARGE_TILE_SIZE, x:x+LARGE_TILE_SIZE]
                elev_tile = elev[y:y+LARGE_TILE_SIZE, x:x+LARGE_TILE_SIZE]
                label_tile = label[y:y+LARGE_TILE_SIZE, x:x+LARGE_TILE_SIZE]

                tile_id = f"{base}_{x}_{y}"
                for folder in ['images', 'elevations', 'labels']:
                    os.makedirs(os.path.join(RAW_DIR, folder), exist_ok=True)

                cv2.imwrite(os.path.join(RAW_DIR, 'images', f'{tile_id}-ortho.png'), cv2.cvtColor(rgb_tile, cv2.COLOR_RGB2BGR))
                np.save(os.path.join(RAW_DIR, 'elevations', f'{tile_id}-elev.npy'), elev_tile)
                cv2.imwrite(os.path.join(RAW_DIR, 'labels', f'{tile_id}-label.png'), label_tile)

                label_tile_rgb = cv2.cvtColor(label_tile, cv2.COLOR_BGR2RGB)
                label_ids = np.full((LARGE_TILE_SIZE, LARGE_TILE_SIZE), -1, dtype=np.int32)
                for color_rgb, class_idx in COLOR_TO_CLASS.items():
                    mask = np.all(label_tile_rgb == color_rgb, axis=-1)
                    label_ids[mask] = class_idx

                counts = np.array([(label_ids == i).sum() for i in range(NUM_CLASSES)], dtype=np.float32)
                class_percentages = counts / max(counts.sum(), 1)
                entropy_val = entropy(class_percentages + 1e-9, base=2)
                metadata_rows.append([tile_id, base, x, y] + class_percentages.tolist() + [float(entropy_val)])

    csv_path = os.path.join(RAW_DIR, 'raw_metadata.csv')
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        header = ['tile_id', 'source_file', 'x', 'y'] + [f'{i}: {name}' for i, name in enumerate(CLASS_NAMES)] + ['entropy']
        writer.writerow(header)
        writer.writerows(metadata_rows)
    print(f"✅ Raw metadata saved to {csv_path}")
    print(f"📊 Total raw tiles: {len(metadata_rows)}")

def chip_all():
    metadata = {'train': []}

    for fname in tqdm(os.listdir(os.path.join(IN_DIR, 'images')), desc="🔄 Chipping dataset"):
        if not fname.endswith('-ortho.tif'):
            continue
        base = fname.replace('-ortho.tif', '')
        rgb_path = os.path.join(IN_DIR, 'images', f'{base}-ortho.tif')
        elev_path = os.path.join(IN_DIR, 'elevations', f'{base}-elev.tif')
        label_path = os.path.join(IN_DIR, 'labels', f'{base}-label.png')

        if os.path.exists(rgb_path) and os.path.exists(elev_path) and os.path.exists(label_path):
            chip_image(rgb_path, elev_path, label_path, base, metadata['train'])

    csv_path = os.path.join(OUT_DIR, 'train_metadata.csv')
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        header = ['tile_id', 'source_file', 'x', 'y'] + [f'{i}: {name}' for i, name in enumerate(CLASS_NAMES)] + ['entropy']
        writer.writerow(header)
        writer.writerows(metadata['train'])

    print(f"✅ Train metadata saved to {csv_path}")
    print(f"📊 Total train tiles: {len(metadata['train'])}")

def copy_raw_dataset():
    for folder in ['images', 'elevations', 'labels']:
        os.makedirs(os.path.join(RAW_DIR, folder), exist_ok=True)

    for src_file in os.listdir(os.path.join(IN_DIR, 'images')):
        base = src_file.replace('-ortho.tif', '')
        try:
            shutil.copy(os.path.join(IN_DIR, 'images', f'{base}-ortho.tif'), os.path.join(RAW_DIR, 'images', f'{base}-ortho.tif'))
            shutil.copy(os.path.join(IN_DIR, 'elevations', f'{base}-elev.tif'), os.path.join(RAW_DIR, 'elevations', f'{base}-elev.tif'))
            shutil.copy(os.path.join(IN_DIR, 'labels', f'{base}-label.png'), os.path.join(RAW_DIR, 'labels', f'{base}-label.png'))
        except:
            print(f"❌ Skipping {base} due to missing files")


# === Run ===
for split in ['train']:
    for folder in ['images', 'elevations', 'labels']:
        os.makedirs(os.path.join(OUT_DIR, split, folder), exist_ok=True)

chip_all()
chip_raw_large_tiles()
print("\n✅ Done! Train = chips, Raw = full image chips.")
