In [1]:
INPUT_CSV = r"/home/abhayadana/Downloads/SummerLowStatureSampling_pt_2025.csv"
OUTPUT_DIR = r"/home/abhayadana/Documents/GitHub/SLSS"

In [2]:
"""
Kernel categorical spatial interpolation per strManagementUnit
with:
- spatial block CV bandwidth selection
- probability rasters per class
- categorical winner raster (with uncertainty threshold + MMU)
- masking to the ENVELOPE of sample points (convex hull, optional buffer)

INPUT
- CSV fields:
    strManagementUnit
    strAveHeight_cm_PctCover
    UTM_X
    UTM_Y
- CRS assumed: UTM Zone 10, NAD83, meters (EPSG:26910)

OUTPUT (per unit folder)
- prob_<Class>.tif (float32, nodata=NaN)
- p_max_confidence.tif (float32, nodata=NaN)
- categorical_winner.tif (uint8, nodata=0; 0=Uncertain)
- metadata.json

DEPENDENCIES
- numpy, pandas
- rasterio
- scipy recommended (KD-tree + MMU + convex hull fallback)
- shapely optional (preferred for convex hull + buffering)
- scikit-learn optional fallback for neighbor search if scipy is missing

NOTES
- "Envelope of sample points" is implemented as a convex hull polygon around the
  points (optionally buffered), then rasters are masked outside the hull.
"""

from __future__ import annotations

import json
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import rasterio
from rasterio.features import geometry_mask
from rasterio.transform import from_origin

# ---------------------------------------------------------------------
# User parameters
# ---------------------------------------------------------------------

INPUT_CSV = r"/home/abhayadana/Downloads/SummerLowStatureSampling_pt_2025.csv"
OUTPUT_DIR = r"/home/abhayadana/Documents/GitHub/SLSS"

GRID_CELL_SIZE_M = 1.0

CANDIDATE_BANDWIDTHS_M = [3, 5, 8, 12, 20]
SEARCH_RADIUS_FACTOR = 3.0

BLOCK_CV_BLOCK_SIZE_M = 30.0
BLOCK_CV_FOLDS = 5

UNCERTAIN_TOLERANCE = 0.50
MMU_CELLS = 16

# Grid extent: bbox of points (required for a raster transform), then mask outside hull.
EXTENT_BUFFER_M = 0.0

# Kernel type: "gaussian" or "exponential"
KERNEL = "gaussian"

# If a grid cell has no neighbors within search radius:
# True -> prob rasters are NaN; categorical becomes 0 (Uncertain) there
NO_NEIGHBOR_AS_NAN = True

# Mask to envelope (convex hull) of sample points
USE_POINT_ENVELOPE_MASK = True
ENVELOPE_TYPE = "convex_hull"  # currently only convex_hull
ENVELOPE_BUFFER_M = 0.0        # optional outward buffer in meters (e.g., 5.0)

# CRS for output rasters
CRS_EPSG = 26910

EPS = 1e-12  # numerical safety for log-loss / normalization

# ---------------------------------------------------------------------
# Constants: fixed class order and codes
# ---------------------------------------------------------------------

CLASSES = [
    "Short Open",
    "Tall Open",
    "Mid Mod",
    "Short Dense",
    "Tall Dense",
]
CLASS_TO_CODE = {cls: i + 1 for i, cls in enumerate(CLASSES)}
CODE_TO_CLASS = {0: "Uncertain", **{i + 1: cls for i, cls in enumerate(CLASSES)}}


# ---------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------


def safe_name(name: str) -> str:
    """Make a safe folder/file stem from a string."""
    name = str(name).strip()
    name = re.sub(r"[^\w\-]+", "_", name)
    name = re.sub(r"_+", "_", name).strip("_")
    return name or "unit"


def gaussian_kernel(d: np.ndarray, h: float) -> np.ndarray:
    """Gaussian kernel weights for distances d and bandwidth h."""
    z = d / max(h, EPS)
    return np.exp(-0.5 * z * z)


def exponential_kernel(d: np.ndarray, h: float) -> np.ndarray:
    """Exponential kernel weights for distances d and bandwidth h."""
    z = d / max(h, EPS)
    return np.exp(-z)


def get_kernel_fn(kernel: str):
    """Return kernel function by name."""
    k = kernel.lower().strip()
    if k == "gaussian":
        return gaussian_kernel
    if k in ("exponential", "exp"):
        return exponential_kernel
    raise ValueError(f"Unknown kernel: {kernel!r}. Use 'gaussian' or 'exponential'.")


def build_grid(
    xmin: float,
    ymin: float,
    xmax: float,
    ymax: float,
    cell_size: float,
) -> Tuple[np.ndarray, np.ndarray, rasterio.Affine]:
    """
    Create a raster grid covering the bounding box.

    Returns:
    - x_centers: 1D array of x cell centers (width)
    - y_centers: 1D array of y cell centers (height)
    - transform: affine transform for rasterio
    """
    width = int(math.ceil((xmax - xmin) / cell_size))
    height = int(math.ceil((ymax - ymin) / cell_size))

    # Upper-left corner at (xmin, ymax)
    transform = from_origin(xmin, ymax, cell_size, cell_size)

    x_centers = xmin + (np.arange(width) + 0.5) * cell_size
    y_centers = ymax - (np.arange(height) + 0.5) * cell_size
    return x_centers, y_centers, transform


@dataclass
class UnitData:
    """Container for one management unit's point data."""
    unit: str
    x: np.ndarray
    y: np.ndarray
    y_code: np.ndarray
    present_codes: List[int]


def read_points(csv_path: str) -> pd.DataFrame:
    """Read CSV and validate required columns."""
    df = pd.read_csv(csv_path, dtype=str)

    required = {"strManagementUnit", "strAveHeight_cm_PctCover", "UTM_X", "UTM_Y"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {sorted(missing)}")

    df["UTM_X"] = pd.to_numeric(df["UTM_X"], errors="coerce")
    df["UTM_Y"] = pd.to_numeric(df["UTM_Y"], errors="coerce")

    df = df.dropna(
        subset=["UTM_X", "UTM_Y", "strManagementUnit", "strAveHeight_cm_PctCover"]
    ).copy()

    df["strAveHeight_cm_PctCover"] = df["strAveHeight_cm_PctCover"].astype(str).str.strip()
    df = df[df["strAveHeight_cm_PctCover"].isin(CLASSES)].copy()

    if df.empty:
        raise ValueError("No valid points after filtering to known classes and valid coordinates.")

    return df


def split_by_unit(df: pd.DataFrame) -> List[UnitData]:
    """Split dataframe into per-unit UnitData objects."""
    units: List[UnitData] = []
    for unit, g in df.groupby("strManagementUnit", sort=True):
        x = g["UTM_X"].to_numpy(dtype=float)
        y = g["UTM_Y"].to_numpy(dtype=float)
        y_code = g["strAveHeight_cm_PctCover"].map(CLASS_TO_CODE).to_numpy(dtype=int)
        present = sorted(np.unique(y_code).tolist())
        units.append(UnitData(unit=str(unit), x=x, y=y, y_code=y_code, present_codes=present))
    return units


# ---------------------------------------------------------------------
# Envelope (convex hull) masking
# ---------------------------------------------------------------------


def build_point_envelope_geometry(
    x: np.ndarray,
    y: np.ndarray,
    envelope_type: str = "convex_hull",
    buffer_m: float = 0.0,
) -> Sequence[dict]:
    """
    Build an envelope geometry (convex hull) around points, optionally buffered.

    Returns GeoJSON-like geometries suitable for rasterio.features.geometry_mask.
    """
    envelope_type = envelope_type.lower().strip()
    if envelope_type != "convex_hull":
        raise ValueError("Only envelope_type='convex_hull' is supported.")

    pts = np.column_stack([x, y]).astype(float)

    # Preferred: shapely
    try:
        from shapely.geometry import MultiPoint  # type: ignore
        from shapely.geometry import mapping  # type: ignore

        hull = MultiPoint(pts).convex_hull
        if buffer_m and float(buffer_m) != 0.0:
            hull = hull.buffer(float(buffer_m))
        return [mapping(hull)]
    except Exception:
        # Fallback: scipy ConvexHull -> polygon
        try:
            from scipy.spatial import ConvexHull  # type: ignore
        except Exception as exc:
            raise ImportError(
                "To use envelope masking, install shapely (preferred) or scipy."
            ) from exc

        hull = ConvexHull(pts)
        ring = pts[hull.vertices].tolist()
        ring.append(ring[0])

        geom = {"type": "Polygon", "coordinates": [ring]}
        return [geom]


def apply_envelope_mask(
    prob_stack: np.ndarray,
    p_max: np.ndarray,
    winner: np.ndarray,
    transform: rasterio.Affine,
    envelope_geoms: Sequence[dict],
    nodata_float: float = np.nan,
    nodata_cat: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Mask rasters outside envelope geometry.

    - prob_stack: (H, W, 5)
    - p_max: (H, W)
    - winner: (H, W)

    Outside envelope:
    - prob_stack -> NaN
    - p_max -> NaN
    - winner -> 0
    """
    h, w = winner.shape

    # geometry_mask returns True where pixels should be masked.
    # invert=True returns True for pixels INSIDE the geometry.
    inside = geometry_mask(
        geometries=envelope_geoms,
        out_shape=(h, w),
        transform=transform,
        invert=True,
        all_touched=False,
    )
    outside = ~inside

    prob_out = prob_stack.copy()
    pmax_out = p_max.copy()
    win_out = winner.copy()

    prob_out[outside, :] = nodata_float
    pmax_out[outside] = nodata_float
    win_out[outside] = nodata_cat

    return prob_out, pmax_out, win_out


# ---------------------------------------------------------------------
# Neighbor search backend
# ---------------------------------------------------------------------


class NeighborIndex:
    """
    KD-tree neighbor index with radius queries.

    Prefers scipy.spatial.cKDTree; falls back to sklearn NearestNeighbors.
    """

    def __init__(self, x: np.ndarray, y: np.ndarray):
        pts = np.column_stack([x, y]).astype(float)
        self.pts = pts
        self._backend = None
        self._tree = None

        try:
            from scipy.spatial import cKDTree  # type: ignore

            self._backend = "scipy"
            self._tree = cKDTree(pts)
        except Exception:
            try:
                from sklearn.neighbors import NearestNeighbors  # type: ignore

                self._backend = "sklearn"
                self._tree = NearestNeighbors(algorithm="ball_tree")
                self._tree.fit(pts)
            except Exception as exc:
                raise ImportError(
                    "Need scipy (recommended) or scikit-learn for neighbor search."
                ) from exc

    def query_radius_indices(self, targets: np.ndarray, radius: float) -> List[np.ndarray]:
        """Return neighbor indices within radius for each target point."""
        if self._backend == "scipy":
            return self._tree.query_ball_point(targets, r=radius)  # type: ignore

        # sklearn fallback: kNN then filter by radius (approx; can be slow)
        n = self.pts.shape[0]
        k = min(n, 5000)
        dists, inds = self._tree.kneighbors(targets, n_neighbors=k, return_distance=True)  # type: ignore

        out: List[np.ndarray] = []
        for di, ii in zip(dists, inds):
            out.append(ii[di <= radius])
        return out


# ---------------------------------------------------------------------
# Kernel probability prediction
# ---------------------------------------------------------------------


def predict_probabilities(
    train_x: np.ndarray,
    train_y: np.ndarray,
    train_code: np.ndarray,
    present_codes: List[int],
    targets_xy: np.ndarray,
    bandwidth_m: float,
    search_radius_m: float,
    kernel: str,
    no_neighbor_as_nan: bool = True,
) -> np.ndarray:
    """
    Predict class probabilities at targets using kernel-weighted frequencies.

    Returns:
        probs: (n_targets, 5) probabilities in CLASSES order.
               If no neighbors within radius:
                 - NaNs if no_neighbor_as_nan=True
                 - zeros otherwise
    """
    kernel_fn = get_kernel_fn(kernel)
    idx = NeighborIndex(train_x, train_y)
    neighbors = idx.query_radius_indices(targets_xy, radius=search_radius_m)

    n_targets = targets_xy.shape[0]
    probs = np.zeros((n_targets, len(CLASSES)), dtype=float)

    train_xy = np.column_stack([train_x, train_y]).astype(float)
    present_set = set(present_codes)
    absent_codes = [c for c in range(1, 6) if c not in present_set]

    for i in range(n_targets):
        nn = neighbors[i]
        if len(nn) == 0:
            probs[i, :] = np.nan if no_neighbor_as_nan else 0.0
            continue

        d = np.sqrt(np.sum((train_xy[nn] - targets_xy[i]) ** 2, axis=1))
        w = kernel_fn(d, h=bandwidth_m)
        w_sum = float(np.sum(w))

        if w_sum <= 0.0 or not np.isfinite(w_sum):
            probs[i, :] = np.nan if no_neighbor_as_nan else 0.0
            continue

        for code in present_codes:
            mask = train_code[nn] == code
            probs[i, code - 1] = float(np.sum(w[mask]) / w_sum) if np.any(mask) else 0.0

        for code in absent_codes:
            probs[i, code - 1] = 0.0

        s = np.nansum(probs[i, :])
        if s > 0:
            probs[i, :] = probs[i, :] / s
        else:
            probs[i, :] = np.nan if no_neighbor_as_nan else 0.0

    return probs


# ---------------------------------------------------------------------
# Spatial block CV bandwidth selection
# ---------------------------------------------------------------------


def assign_blocks(x: np.ndarray, y: np.ndarray, block_size_m: float) -> np.ndarray:
    """Assign each point to a spatial block id based on a block grid."""
    xmin = float(np.min(x))
    ymin = float(np.min(y))
    bx = np.floor((x - xmin) / block_size_m).astype(int)
    by = np.floor((y - ymin) / block_size_m).astype(int)
    return bx.astype(np.int64) * 1_000_000 + by.astype(np.int64)


def make_block_folds(block_ids: np.ndarray, n_folds: int, seed: int = 42) -> List[np.ndarray]:
    """Split unique block ids into folds; return boolean masks for test points."""
    rng = np.random.default_rng(seed)
    unique_blocks = np.unique(block_ids)
    rng.shuffle(unique_blocks)

    folds: List[np.ndarray] = []
    parts = np.array_split(unique_blocks, n_folds)
    for blocks_in_fold in parts:
        folds.append(np.isin(block_ids, blocks_in_fold))
    return folds


def log_loss_from_probs(true_codes: np.ndarray, probs: np.ndarray, eps: float = EPS) -> float:
    """Multiclass log loss with probs shape (n, 5), true_codes in 1..5."""
    p_true = probs[np.arange(len(true_codes)), true_codes - 1]
    p_true = np.clip(p_true, eps, 1.0)
    return float(np.mean(-np.log(p_true)))


def select_bandwidth_by_block_cv(
    unit_data: UnitData,
    candidate_bandwidths_m: Iterable[float],
    search_radius_factor: float,
    block_size_m: float,
    kernel: str,
    n_folds: int = 5,
    seed: int = 42,
) -> float:
    """Pick bandwidth minimizing spatial block CV log loss for one unit."""
    x, y, y_code = unit_data.x, unit_data.y, unit_data.y_code
    cand = sorted([float(h) for h in candidate_bandwidths_m])

    if len(x) < 10:
        return float(cand[len(cand) // 2])

    block_ids = assign_blocks(x, y, block_size_m=block_size_m)
    folds = make_block_folds(block_ids, n_folds=n_folds, seed=seed)

    best_h: Optional[float] = None
    best_score = np.inf

    for h in cand:
        r = float(search_radius_factor * h)
        fold_scores: List[float] = []

        for test_mask in folds:
            train_mask = ~test_mask
            if np.sum(train_mask) < 2 or np.sum(test_mask) < 1:
                continue

            probs = predict_probabilities(
                train_x=x[train_mask],
                train_y=y[train_mask],
                train_code=y_code[train_mask],
                present_codes=unit_data.present_codes,
                targets_xy=np.column_stack([x[test_mask], y[test_mask]]).astype(float),
                bandwidth_m=h,
                search_radius_m=r,
                kernel=kernel,
                no_neighbor_as_nan=False,  # keep finite for CV scoring
            )

            probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
            s = probs.sum(axis=1, keepdims=True)
            s[s == 0] = 1.0
            probs = probs / s
            probs = np.clip(probs, EPS, 1.0)
            probs = probs / probs.sum(axis=1, keepdims=True)

            fold_scores.append(log_loss_from_probs(y_code[test_mask], probs, eps=EPS))

        if not fold_scores:
            continue

        mean_score = float(np.mean(fold_scores))
        if mean_score < best_score:
            best_score = mean_score
            best_h = h

    return float(best_h) if best_h is not None else float(cand[len(cand) // 2])


# ---------------------------------------------------------------------
# MMU cleanup for categorical raster
# ---------------------------------------------------------------------


def fill_uncertain_by_majority(cat: np.ndarray) -> np.ndarray:
    """Fill zeros using 3x3 neighborhood majority class (excluding zero)."""
    try:
        from scipy import ndimage  # type: ignore
    except Exception:
        return cat

    out = cat.copy()
    uncertain = out == 0
    if not np.any(uncertain):
        return out

    kernel = np.ones((3, 3), dtype=int)
    class_counts = []
    for code in range(1, 6):
        m = (out == code).astype(int)
        class_counts.append(ndimage.convolve(m, kernel, mode="nearest"))

    counts = np.stack(class_counts, axis=0)  # (5, H, W)
    winner = np.argmax(counts, axis=0) + 1
    max_count = np.max(counts, axis=0)

    fillable = uncertain & (max_count > 0)
    out[fillable] = winner[fillable].astype(out.dtype)
    return out


def mmu_cleanup(cat: np.ndarray, mmu_cells: int, n_iters: int = 5) -> np.ndarray:
    """
    Remove patches smaller than mmu_cells for each class (>0), fill by neighborhood majority.
    Requires scipy.ndimage; if missing, returns input unchanged.
    """
    try:
        from scipy import ndimage  # type: ignore
    except Exception:
        print("WARNING: scipy not available; skipping MMU cleanup.")
        return cat

    out = cat.copy()
    structure = np.ones((3, 3), dtype=int)  # 8-connectivity

    for _ in range(n_iters):
        changed = False
        for code in range(1, 6):
            mask = out == code
            if not np.any(mask):
                continue

            labeled, nlab = ndimage.label(mask, structure=structure)
            if nlab == 0:
                continue

            sizes = np.bincount(labeled.ravel())
            small_labels = np.where((sizes < mmu_cells) & (np.arange(len(sizes)) != 0))[0]
            if len(small_labels) == 0:
                continue

            small_mask = np.isin(labeled, small_labels)
            if np.any(small_mask):
                out[small_mask] = 0
                changed = True

        if not changed:
            break

        out = fill_uncertain_by_majority(out)

    return out


# ---------------------------------------------------------------------
# Raster writing
# ---------------------------------------------------------------------


def write_geotiff(
    path: Path,
    array: np.ndarray,
    transform: rasterio.Affine,
    crs_epsg: int,
    nodata: Optional[float] = None,
    dtype: Optional[str] = None,
) -> None:
    """Write a single-band GeoTIFF."""
    path.parent.mkdir(parents=True, exist_ok=True)

    if dtype is None:
        dtype = str(array.dtype)

    height, width = array.shape
    profile = {
        "driver": "GTiff",
        "height": height,
        "width": width,
        "count": 1,
        "dtype": dtype,
        "crs": rasterio.crs.CRS.from_epsg(crs_epsg),
        "transform": transform,
        "compress": "deflate",
        "predictor": 2 if np.issubdtype(array.dtype, np.floating) else 1,
        "tiled": True,
        "blockxsize": 256,
        "blockysize": 256,
    }
    if nodata is not None:
        profile["nodata"] = nodata

    with rasterio.open(path, "w", **profile) as dst:
        dst.write(array, 1)


# ---------------------------------------------------------------------
# Main per-unit processing
# ---------------------------------------------------------------------


def unit_extent(unit_data: UnitData, buffer_m: float = 0.0) -> Tuple[float, float, float, float]:
    """Compute (xmin, ymin, xmax, ymax) extent from points, expanded by buffer."""
    xmin = float(np.min(unit_data.x)) - buffer_m
    ymin = float(np.min(unit_data.y)) - buffer_m
    xmax = float(np.max(unit_data.x)) + buffer_m
    ymax = float(np.max(unit_data.y)) + buffer_m
    return xmin, ymin, xmax, ymax


def predict_unit_rasters(
    unit_data: UnitData,
    output_dir: Path,
    cell_size_m: float,
    bandwidth_m: float,
    search_radius_factor: float,
    kernel: str,
    uncertain_tolerance: float,
    mmu_cells: int,
    extent_buffer_m: float = 0.0,
    no_neighbor_as_nan: bool = True,
    crs_epsg: int = CRS_EPSG,
    use_envelope_mask: bool = True,
    envelope_type: str = "convex_hull",
    envelope_buffer_m: float = 0.0,
) -> None:
    """Predict and write rasters for one unit."""
    unit_name = safe_name(unit_data.unit)
    unit_out = output_dir / unit_name
    unit_out.mkdir(parents=True, exist_ok=True)

    xmin, ymin, xmax, ymax = unit_extent(unit_data, buffer_m=extent_buffer_m)
    x_centers, y_centers, transform = build_grid(xmin, ymin, xmax, ymax, cell_size=cell_size_m)
    width = len(x_centers)
    height = len(y_centers)

    xx, yy = np.meshgrid(x_centers, y_centers)
    targets = np.column_stack([xx.ravel(), yy.ravel()]).astype(float)

    search_radius_m = float(search_radius_factor * bandwidth_m)

    probs = predict_probabilities(
        train_x=unit_data.x,
        train_y=unit_data.y,
        train_code=unit_data.y_code,
        present_codes=unit_data.present_codes,
        targets_xy=targets,
        bandwidth_m=float(bandwidth_m),
        search_radius_m=search_radius_m,
        kernel=kernel,
        no_neighbor_as_nan=no_neighbor_as_nan,
    )

    prob_stack = probs.reshape((height, width, 5)).astype(np.float32)

    # Renormalize per cell (keep NaN for all-NaN)
    s = np.nansum(prob_stack, axis=2, keepdims=True)
    s = np.where(s == 0, np.nan, s)
    prob_stack = prob_stack / s

    # Safe categorical + confidence (handles all-NaN cells)
    all_nan = np.all(~np.isfinite(prob_stack), axis=2)

    p_max = np.nanmax(prob_stack, axis=2).astype(np.float32)
    p_max[all_nan] = np.nan

    prob_for_argmax = np.where(np.isfinite(prob_stack), prob_stack, -np.inf)
    winner = (np.argmax(prob_for_argmax, axis=2) + 1).astype(np.uint8)
    winner[all_nan] = 0

    # Uncertainty threshold
    if uncertain_tolerance is not None:
        uncertain_mask = (p_max < float(uncertain_tolerance)) | ~np.isfinite(p_max)
        winner[uncertain_mask] = 0

    # Mask outside point envelope (convex hull)
    if use_envelope_mask:
        envelope_geoms = build_point_envelope_geometry(
            x=unit_data.x,
            y=unit_data.y,
            envelope_type=envelope_type,
            buffer_m=envelope_buffer_m,
        )
        prob_stack, p_max, winner = apply_envelope_mask(
            prob_stack=prob_stack,
            p_max=p_max,
            winner=winner,
            transform=transform,
            envelope_geoms=envelope_geoms,
            nodata_float=np.nan,
            nodata_cat=0,
        )

    # MMU cleanup (categorical only, after masking)
    if mmu_cells and int(mmu_cells) > 1:
        winner = mmu_cleanup(winner, mmu_cells=int(mmu_cells)).astype(np.uint8)

    # Write outputs
    nodata_float = np.nan

    for i, cls in enumerate(CLASSES):
        out_path = unit_out / f"prob_{safe_name(cls)}.tif"
        write_geotiff(
            out_path,
            prob_stack[:, :, i].astype(np.float32),
            transform=transform,
            crs_epsg=crs_epsg,
            nodata=nodata_float,
            dtype="float32",
        )

    write_geotiff(
        unit_out / "p_max_confidence.tif",
        p_max.astype(np.float32),
        transform=transform,
        crs_epsg=crs_epsg,
        nodata=nodata_float,
        dtype="float32",
    )

    write_geotiff(
        unit_out / "categorical_winner.tif",
        winner.astype(np.uint8),
        transform=transform,
        crs_epsg=crs_epsg,
        nodata=0,
        dtype="uint8",
    )

    meta = {
        "strManagementUnit": unit_data.unit,
        "bandwidth_m": float(bandwidth_m),
        "search_radius_m": float(search_radius_m),
        "grid_cell_size_m": float(cell_size_m),
        "kernel": kernel,
        "uncertain_tolerance": float(uncertain_tolerance),
        "mmu_cells": int(mmu_cells),
        "use_point_envelope_mask": bool(use_envelope_mask),
        "envelope_type": envelope_type,
        "envelope_buffer_m": float(envelope_buffer_m),
        "extent_buffer_m": float(extent_buffer_m),
        "class_to_code": CLASS_TO_CODE,
        "code_to_class": CODE_TO_CLASS,
        "present_codes_in_unit": unit_data.present_codes,
    }
    (unit_out / "metadata.json").write_text(json.dumps(meta, indent=2))


def run_workflow(
    input_csv: str,
    output_dir: str,
    grid_cell_size_m: float,
    candidate_bandwidths_m: Iterable[float],
    search_radius_factor: float,
    block_cv_block_size_m: float,
    uncertain_tolerance: float,
    mmu_cells: int,
    extent_buffer_m: float = 0.0,
    kernel: str = "gaussian",
    no_neighbor_as_nan: bool = True,
    crs_epsg: int = CRS_EPSG,
    use_envelope_mask: bool = True,
    envelope_type: str = "convex_hull",
    envelope_buffer_m: float = 0.0,
) -> pd.DataFrame:
    """
    Run per-unit workflow. Returns a summary dataframe with selected bandwidths.
    """
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    df = read_points(input_csv)
    units = split_by_unit(df)

    results = []
    for u in units:
        print(f"\n--- Unit: {u.unit} (n={len(u.x)}) ---")
        print(f"Present classes: {[CODE_TO_CLASS[c] for c in u.present_codes]}")

        best_h = select_bandwidth_by_block_cv(
            unit_data=u,
            candidate_bandwidths_m=candidate_bandwidths_m,
            search_radius_factor=search_radius_factor,
            block_size_m=block_cv_block_size_m,
            kernel=kernel,
            n_folds=BLOCK_CV_FOLDS,
            seed=42,
        )
        print(f"Selected bandwidth (m): {best_h}")

        predict_unit_rasters(
            unit_data=u,
            output_dir=out_dir,
            cell_size_m=grid_cell_size_m,
            bandwidth_m=best_h,
            search_radius_factor=search_radius_factor,
            kernel=kernel,
            uncertain_tolerance=uncertain_tolerance,
            mmu_cells=mmu_cells,
            extent_buffer_m=extent_buffer_m,
            no_neighbor_as_nan=no_neighbor_as_nan,
            crs_epsg=crs_epsg,
            use_envelope_mask=use_envelope_mask,
            envelope_type=envelope_type,
            envelope_buffer_m=envelope_buffer_m,
        )

        results.append(
            {
                "strManagementUnit": u.unit,
                "n_points": int(len(u.x)),
                "present_classes": ", ".join([CODE_TO_CLASS[c] for c in u.present_codes]),
                "selected_bandwidth_m": float(best_h),
                "search_radius_m": float(search_radius_factor * best_h),
            }
        )

    summary = pd.DataFrame(results).sort_values("strManagementUnit").reset_index(drop=True)
    summary_path = out_dir / "unit_bandwidth_summary.csv"
    summary.to_csv(summary_path, index=False)
    print(f"\nWrote: {summary_path}")
    return summary



In [3]:
# ---------------------------------------------------------------------
# Run (execute this cell in your notebook)
# ---------------------------------------------------------------------

summary_df = run_workflow(
    input_csv=INPUT_CSV,
    output_dir=OUTPUT_DIR,
    grid_cell_size_m=GRID_CELL_SIZE_M,
    candidate_bandwidths_m=CANDIDATE_BANDWIDTHS_M,
    search_radius_factor=SEARCH_RADIUS_FACTOR,
    block_cv_block_size_m=BLOCK_CV_BLOCK_SIZE_M,
    uncertain_tolerance=UNCERTAIN_TOLERANCE,
    mmu_cells=MMU_CELLS,
    extent_buffer_m=EXTENT_BUFFER_M,
    kernel=KERNEL,
    no_neighbor_as_nan=NO_NEIGHBOR_AS_NAN,
    crs_epsg=CRS_EPSG,
)
summary_df



--- Unit: 1NE (n=308) ---
Present classes: ['Short Open', 'Mid Mod', 'Short Dense', 'Tall Dense']
Selected bandwidth (m): 20.0

--- Unit: 2NW_2NE_N (n=116) ---
Present classes: ['Mid Mod', 'Tall Dense']
Selected bandwidth (m): 20.0

--- Unit: 2NW_2NE_S (n=61) ---
Present classes: ['Mid Mod', 'Tall Dense']
Selected bandwidth (m): 8.0

--- Unit: 3NC_3NE_3C (n=865) ---
Present classes: ['Short Open', 'Tall Open', 'Mid Mod', 'Short Dense', 'Tall Dense']
Selected bandwidth (m): 20.0


  p_max = np.nanmax(prob_stack, axis=2).astype(np.float32)



--- Unit: 3S_4NE (n=1597) ---
Present classes: ['Short Open', 'Tall Open', 'Mid Mod', 'Short Dense', 'Tall Dense']
Selected bandwidth (m): 20.0

--- Unit: 6N (n=105) ---
Present classes: ['Mid Mod', 'Tall Dense']
Selected bandwidth (m): 12.0

Wrote: /home/abhayadana/Documents/GitHub/SLSS/unit_bandwidth_summary.csv


Unnamed: 0,strManagementUnit,n_points,present_classes,selected_bandwidth_m,search_radius_m
0,1NE,308,"Short Open, Mid Mod, Short Dense, Tall Dense",20.0,60.0
1,2NW_2NE_N,116,"Mid Mod, Tall Dense",20.0,60.0
2,2NW_2NE_S,61,"Mid Mod, Tall Dense",8.0,24.0
3,3NC_3NE_3C,865,"Short Open, Tall Open, Mid Mod, Short Dense, T...",20.0,60.0
4,3S_4NE,1597,"Short Open, Tall Open, Mid Mod, Short Dense, T...",20.0,60.0
5,6N,105,"Mid Mod, Tall Dense",12.0,36.0
