In [None]:
from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.enums import Resampling
import torch


@dataclass
class PreprocessConfig:
    data_dir: Path
    images_dirname: str = "images"
    masks_dirname: str = "masks"
    image_glob: str = "*.tif"
    mask_glob: str = "*.tif"
    add_ndsi: bool = True
    save_format: str = "npz"  # one of {"npz", "pt"}
    output_dirname: str = "preprocessed"
    # Band order in the input TIFFs: [Blue, Green, Red, SWIR, Thermal]
    band_indices: Tuple[int, int, int, int, int] = (1, 2, 3, 4, 5)
    # If masks are separate TIFFs, set to True; if they are embedded, set False and provide mask_band
    separate_masks: bool = True
    mask_band: int = 1


cfg = PreprocessConfig(data_dir=Path("data"))
output_dir = cfg.data_dir / cfg.output_dirname
output_dir.mkdir(parents=True, exist_ok=True)

print("Config:\n", cfg)


In [None]:
def list_pairs(cfg: PreprocessConfig) -> List[Tuple[Path, Optional[Path]]]:
    images_root = cfg.data_dir / cfg.images_dirname
    masks_root = cfg.data_dir / cfg.masks_dirname
    image_paths = sorted(images_root.rglob(cfg.image_glob))
    pairs: List[Tuple[Path, Optional[Path]]] = []
    for img_path in image_paths:
        if cfg.separate_masks:
            mask_candidate = masks_root / img_path.name
            if not mask_candidate.exists():
                # fallback: same stem, different extension
                candidates = list(masks_root.rglob(img_path.stem + "*.tif"))
                mask_path = candidates[0] if len(candidates) > 0 else None
            else:
                mask_path = mask_candidate
        else:
            mask_path = None
        pairs.append((img_path, mask_path))
    return pairs


def read_multispectral(image_path: Path, band_indices: Tuple[int, ...]) -> np.ndarray:
    # Returns array with shape (C, H, W) in float32
    with rasterio.open(image_path) as src:
        bands = []
        for idx in band_indices:
            band = src.read(idx).astype(np.float32)
            bands.append(band)
        arr = np.stack(bands, axis=0)
    return arr


def read_mask(mask_path: Path | None, image_path: Path, band: int) -> np.ndarray:
    if mask_path is None:
        # assume embedded mask
        with rasterio.open(image_path) as src:
            mask = src.read(band)
    else:
        with rasterio.open(mask_path) as src:
            mask = src.read(1)
    # ensure binary {0,1}
    mask = (mask > 0).astype(np.uint8)
    return mask


def compute_ndsi(green: np.ndarray, swir: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    return (green - swir) / (green + swir + eps)


def compute_mean_std(paths: List[Path], band_indices: Tuple[int, ...]) -> Tuple[np.ndarray, np.ndarray]:
    # Streaming two-pass over images to compute per-channel mean and std
    # First pass: mean
    channel_sum = None
    pixel_count = 0
    for p in paths:
        arr = read_multispectral(p, band_indices)  # (C,H,W)
        C, H, W = arr.shape
        if channel_sum is None:
            channel_sum = np.zeros((C,), dtype=np.float64)
        channel_sum += arr.reshape(C, -1).sum(axis=1)
        pixel_count += H * W
    mean = (channel_sum / pixel_count).astype(np.float32)

    # Second pass: std
    channel_sq_diff_sum = np.zeros_like(mean, dtype=np.float64)
    for p in paths:
        arr = read_multispectral(p, band_indices)
        channel_sq_diff_sum += ((arr.reshape(arr.shape[0], -1) - mean[:, None]) ** 2).sum(axis=1)
    var = (channel_sq_diff_sum / pixel_count).astype(np.float32)
    std = np.sqrt(np.maximum(var, 1e-12)).astype(np.float32)
    return mean, std


def normalize(arr: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    return (arr - mean[:, None, None]) / std[:, None, None]



In [None]:
# Discover files
pairs = list_pairs(cfg)
print(f"Found {len(pairs)} image/mask pairs")
# For mean/std, consider only training subset directories if applicable; here we use all
image_paths = [p for p, _ in pairs]

# Compute mean/std on the 5 original bands
mean_5, std_5 = compute_mean_std(image_paths, cfg.band_indices)
print("Per-channel mean (5):", mean_5)
print("Per-channel std (5):", std_5)

# Save stats for reuse
stats_path = output_dir / "stats_npz.npz"
np.savez_compressed(stats_path, mean_5=mean_5, std_5=std_5)
print("Saved stats to", stats_path)


In [None]:
def process_one(image_path: Path, mask_path: Optional[Path], mean_5: np.ndarray, std_5: np.ndarray, add_ndsi: bool) -> Tuple[np.ndarray, np.ndarray]:
    arr = read_multispectral(image_path, cfg.band_indices)  # (5,H,W)
    # optional NDSI
    if add_ndsi:
        green = arr[1]
        swir = arr[3]
        ndsi = compute_ndsi(green, swir)
        arr = np.concatenate([arr, ndsi[None, ...].astype(np.float32)], axis=0)  # (6,H,W)
        # For normalization, extend mean/std with ndsi stats computed on-the-fly per-image (zero-mean approx)
        # Safer: compute separate stats for NDSI across the train set if desired. Here we z-score using image-level stats.
        ndsi_mean = float(ndsi.mean())
        ndsi_std = float(ndsi.std() + 1e-12)
        mean = np.concatenate([mean_5, np.array([ndsi_mean], dtype=np.float32)])
        std = np.concatenate([std_5, np.array([ndsi_std], dtype=np.float32)])
    else:
        mean = mean_5
        std = std_5
    arr_norm = normalize(arr, mean, std)
    mask = read_mask(mask_path, image_path, cfg.mask_band)
    return arr_norm.astype(np.float32), mask.astype(np.uint8)


def save_sample(image: np.ndarray, mask: np.ndarray, stem: str, fmt: str, out_dir: Path) -> Path:
    out_dir.mkdir(parents=True, exist_ok=True)
    if fmt == "npz":
        out_path = out_dir / f"{stem}.npz"
        np.savez_compressed(out_path, image=image, mask=mask)
    elif fmt == "pt":
        out_path = out_dir / f"{stem}.pt"
        torch.save({"image": torch.from_numpy(image), "mask": torch.from_numpy(mask)}, out_path)
    else:
        raise ValueError("Unsupported save_format: " + fmt)
    return out_path


In [None]:
saved = []
for img_path, mask_path in pairs:
    stem = img_path.stem
    image, mask = process_one(img_path, mask_path, mean_5, std_5, cfg.add_ndsi)
    out_path = save_sample(image, mask, stem, cfg.save_format, output_dir)
    saved.append(out_path)

print(f"Saved {len(saved)} samples to {output_dir}")
