In [34]:
!pip install rasterio




In [33]:
# 1_data_preparation.ipynb

import rasterio
import numpy as np
import os
import random
import time
from rasterio.warp import reproject, Resampling

# --- Parameters ---
WINDOW_SIZE = 256
NUM_SAMPLES = 100
OUTPUT_DIR = "training_samples"

LIDAR_PATH = "South_Clear_Creek/Lidar_DEM_Hillshade/South_Clear_Creek_BareEarth_DEM_1m.tif"
HILLSHADE_PATH = "South_Clear_Creek/Lidar_DEM_Hillshade/South_Clear_Creek_BareEarth_Hillshade_1m.tif"
NAIP_PATH = "South_Clear_Creek/NAIP/South_Clear_Creek_2023_NAIP_1m.tif"
MASK_PATH = "South_Clear_Creek/Roads_Boundary/South_Clear_Creek_Roads_Mask.tif"

# --- Helper Functions ---
def normalize(image):
    if image.ndim == 3:
        return np.array([normalize(b) for b in image])
    mask = ~np.isnan(image)
    if not np.any(mask):
        return np.zeros_like(image)
    min_val = np.min(image[mask])
    max_val = np.max(image[mask])
    return (image - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(image)

def get_random_window(path, size):
    with rasterio.open(path) as src:
        col_off = random.randint(0, max(0, src.width - size))
        row_off = random.randint(0, max(0, src.height - size))
        return rasterio.windows.Window(col_off, row_off, size, size)

def read_band(path, window):
    with rasterio.open(path) as src:
        band = src.read(1, window=window, boundless=True, fill_value=0.0)
    return np.nan_to_num(band, nan=0.0).astype(np.float32)

def read_mask(path, window):
    with rasterio.open(path) as src:
        mask = src.read(1, window=window, boundless=True, fill_value=0)
    return (mask > 0).astype(np.uint8)

def read_and_resample_naip(naip_path, ref_path, window):
    with rasterio.open(ref_path) as ref_src:
        dst_transform = ref_src.window_transform(window)
        dst_crs = ref_src.crs
        dst_width, dst_height = window.width, window.height
        dst_bounds = rasterio.windows.bounds(window, ref_src.transform)

    with rasterio.open(naip_path) as naip_src:
        src_window = rasterio.windows.from_bounds(*dst_bounds, transform=naip_src.transform)
        naip_data = naip_src.read(window=src_window, boundless=True)
        naip_data = np.clip(naip_data, 0, 65535).astype(np.float32) / 65535.0

        resampled = np.empty((naip_data.shape[0], dst_height, dst_width), dtype=np.float32)
        reproject(
            source=naip_data,
            destination=resampled,
            src_transform=naip_src.window_transform(src_window),
            src_crs=naip_src.crs,
            dst_transform=dst_transform,
            dst_crs=dst_crs,
            resampling=Resampling.bilinear
        )
    return resampled

def generate_sample(lidar_path, hillshade_path, naip_path, mask_path, window_size):
    for _ in range(10):  # Try up to 10 times to find a valid (non-empty) mask
        window = get_random_window(lidar_path, window_size)
        lidar = normalize(read_band(lidar_path, window))
        hillshade = normalize(read_band(hillshade_path, window))
        naip = read_and_resample_naip(naip_path, lidar_path, window)
        mask = read_mask(mask_path, window)

        if mask.sum() > 0:  # Only keep patches with road pixels
            x = np.stack([lidar, hillshade, naip[0], naip[1], naip[2]], axis=0)  # (5, H, W)
            y = np.expand_dims(mask, axis=0)  # (1, H, W)
            return x.astype(np.float32), y.astype(np.uint8)

    raise ValueError("‚ùå No valid mask found after 10 attempts")

def export_samples(lidar_path, hillshade_path, naip_path, mask_path, window_size, output_dir, num_samples=100):
    os.makedirs(output_dir, exist_ok=True)
    start_time = time.time()

    for i in range(num_samples):
        print(f"\nüì¶ Generating sample {i + 1}/{num_samples}...")
        try:
            x, y = generate_sample(lidar_path, hillshade_path, naip_path, mask_path, window_size)
            np.savez_compressed(os.path.join(output_dir, f"sample_{i:04d}.npz"), x=x, y=y)
            print(f"‚úÖ Saved sample_{i:04d}.npz")
        except Exception as e:
            print(f"‚ö†Ô∏è Skipped sample {i} due to error: {e}")

    print(f"\n‚è±Ô∏è Completed in {time.time() - start_time:.2f} seconds")

# --- Run the Export ---
export_samples(LIDAR_PATH, HILLSHADE_PATH, NAIP_PATH, MASK_PATH, WINDOW_SIZE, OUTPUT_DIR, NUM_SAMPLES)



üì¶ Generating sample 1/100...
‚úÖ Saved sample_0000.npz

üì¶ Generating sample 2/100...
‚úÖ Saved sample_0001.npz

üì¶ Generating sample 3/100...
‚ö†Ô∏è Skipped sample 2 due to error: ‚ùå No valid mask found after 10 attempts

üì¶ Generating sample 4/100...
‚úÖ Saved sample_0003.npz

üì¶ Generating sample 5/100...
‚ö†Ô∏è Skipped sample 4 due to error: ‚ùå No valid mask found after 10 attempts

üì¶ Generating sample 6/100...
‚úÖ Saved sample_0005.npz

üì¶ Generating sample 7/100...
‚úÖ Saved sample_0006.npz

üì¶ Generating sample 8/100...
‚úÖ Saved sample_0007.npz

üì¶ Generating sample 9/100...
‚úÖ Saved sample_0008.npz

üì¶ Generating sample 10/100...
‚úÖ Saved sample_0009.npz

üì¶ Generating sample 11/100...
‚úÖ Saved sample_0010.npz

üì¶ Generating sample 12/100...
‚úÖ Saved sample_0011.npz

üì¶ Generating sample 13/100...
‚ö†Ô∏è Skipped sample 12 due to error: ‚ùå No valid mask found after 10 attempts

üì¶ Generating sample 14/100...
‚ö†Ô∏è Skipped sample 13 due to 