In [None]:
INPUT_CSV = r"/home/abhayadana/Downloads/SummerLowStatureSampling_pt_2025.csv"
OUTPUT_DIR = r"/home/abhayadana/Documents/GitHub/SLSS_analysis"

In [None]:
"""
SIS workflow per strManagementUnit with diagnostics bundle and robust p_max/winner
computation (no All-NaN slice warnings).

Requested diagnostics per unit:
1) variogram_J_fit.png
2) cv_scores.csv + cv_scores.png
4) variogram_fit_surface_<best_model>.png
5) realizations_panel.png

Fix applied:
- Replace np.nanmax / np.nanargmax on prob_stack with a safe max/argmax that
  handles outside-domain all-NaN slices without warnings.

Also added:
- Verbose per-unit printout of class counts (n) and percent of unit total.

CRS assumption:
- UTM Zone 10, NAD83, meters (EPSG:26910)

Dependencies:
- numpy, pandas, rasterio, matplotlib
Optional:
- shapely (envelope buffering)
- scipy (ndimage MMU cleanup; ConvexHull fallback)
"""

from __future__ import annotations

import json
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import rasterio
from rasterio.features import geometry_mask
from rasterio.transform import from_origin

import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.patches import Patch


# ---------------------------------------------------------------------
# User parameters (edit in your notebook)
# ---------------------------------------------------------------------

INPUT_CSV = r"/home/abhayadana/Downloads/SummerLowStatureSampling_pt_2025.csv"
OUTPUT_DIR = r"/home/abhayadana/Documents/GitHub/SLSS_analysis"
CRS_EPSG = 26910

# Grid
GRID_CELL_SIZE_M = 1.0
EXTENT_BUFFER_M = 0.0

# Envelope domain (convex hull)
USE_POINT_ENVELOPE_MASK = True
ENVELOPE_BUFFER_M = 5.0

# Declustering
DECLUSTER_MULT = 2.0
DECLUSTER_CELL_MIN_M = 5.0
DECLUSTER_CELL_MAX_M = 50.0

# Spatial block CV (for variogram family selection on J)
CV_FOLDS = 5
BLOCK_SIZE_MULT = 10.0
BLOCK_SIZE_MIN_M = 25.0
BLOCK_SIZE_MAX_M = 100.0

# Variogram fitting
N_LAGS = 12
MAX_DIST_FRACTION_OF_SPAN = 0.5
MAX_PAIRS_FOR_VARIOGRAM = 200_000
RANGE_GRID_SIZE = 18
NUGGET_GRID_SIZE = 10
NUGGET_MAX = 0.30

# Neighborhood for local kriging in CV and SIS
SEARCH_RADIUS_FACTOR = 1.5
MIN_NEIGHBORS = 8
MAX_NEIGHBORS = 40

# SIS settings
N_REALIZATIONS = 10
RANDOM_SEED = 57718

# Outputs / post-processing
UNCERTAIN_TOLERANCE = 0.50
MMU_CELLS = 2

# Rare-class stabilization (optional)
SILL_FLOOR = 0.0  # set to 0.005 or 0.01 if needed

# Realization panel
PANEL_N = 9  # number of realizations to show in panel (e.g., 9 -> 3x3)

EPS = 1e-12


# ---------------------------------------------------------------------
# Classes (fixed order)
# ---------------------------------------------------------------------

CLASSES = [
    "Short Open",
    "Tall Open",
    "Mid Mod",
    "Short Dense",
    "Tall Dense",
]
CLASS_TO_CODE = {cls: i + 1 for i, cls in enumerate(CLASSES)}
CODE_TO_CLASS = {0: "Uncertain", **{i + 1: cls for i, cls in enumerate(CLASSES)}}

# 0 is transparent/blank in PNGs; 1..5 are class colors
PALETTE = [
    "#00000000",  # 0 Uncertain/Masked (transparent)
    "#1f77b4",    # 1 Short Open
    "#ff7f0e",    # 2 Tall Open
    "#2ca02c",    # 3 Mid Mod
    "#d62728",    # 4 Short Dense
    "#9467bd",    # 5 Tall Dense
]


# ---------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------


@dataclass
class UnitData:
    unit: str
    x: np.ndarray
    y: np.ndarray
    code: np.ndarray
    present_codes: List[int]


@dataclass
class VariogramModel:
    model_type: str  # "exponential" | "spherical" | "gaussian"
    range_m: float
    nugget: float
    sill_total: float  # total sill for J


@dataclass
class CondPoint:
    x: float
    y: float
    code: int


@dataclass
class VariogramFitArtifacts:
    lag: np.ndarray
    gamma_emp: np.ndarray
    bin_w: np.ndarray
    sill_total: float
    range_grid: np.ndarray
    nugget_grid: np.ndarray
    sse_grid: np.ndarray  # shape (len(nugget_grid), len(range_grid))


# ---------------------------------------------------------------------
# Small utilities
# ---------------------------------------------------------------------


def safe_name(name: str) -> str:
    """Make a safe folder/file stem from a string."""
    name = str(name).strip()
    name = re.sub(r"[^\w\-]+", "_", name)
    name = re.sub(r"_+", "_", name).strip("_")
    return name or "unit"


def nearest_neighbor_distances(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """
    Return per-point nearest-neighbor distances (meters).
    Uses scipy.cKDTree if available; falls back to sklearn.
    """
    pts = np.column_stack([x, y]).astype(float)
    if pts.shape[0] < 2:
        return np.array([], dtype=float)

    try:
        from scipy.spatial import cKDTree  # type: ignore

        tree = cKDTree(pts)
        dists, _ = tree.query(pts, k=2)  # self + nearest
        return dists[:, 1].astype(float)
    except Exception:
        try:
            from sklearn.neighbors import NearestNeighbors  # type: ignore

            nn = NearestNeighbors(n_neighbors=2).fit(pts)
            dists, _ = nn.kneighbors(pts, n_neighbors=2)
            return dists[:, 1].astype(float)
        except Exception:
            # last-resort fallback
            return np.array([], dtype=float)


def print_unit_class_summary(unit_data: UnitData) -> None:
    """
    Verbose per-unit printout of:
      - class counts and percent of unit total
      - nearest-neighbor distance summary stats (meters)
      - derived declustering cell size and block size (meters)
    """
    n = int(len(unit_data.code))
    print(f"Unit: {unit_data.unit}  |  n={n}")

    # Class counts
    counts = {c: int(np.sum(unit_data.code == c)) for c in range(1, 6)}
    for c in range(1, 6):
        pct = 100.0 * counts[c] / max(n, 1)
        print(f"  - {CODE_TO_CLASS[c]}: n={counts[c]} ({pct:.1f}%)")

    present = [CODE_TO_CLASS[c] for c in unit_data.present_codes]
    print(f"  Present classes: {present}")

    # Nearest-neighbor distance stats
    nn = nearest_neighbor_distances(unit_data.x, unit_data.y)
    if nn.size == 0:
        print("  NN distance (m): n/a (need >=2 points and scipy or sklearn)")
        print("  Declustering cell (m): n/a")
        print("  Block size (m): n/a")
        return

    q10, q50, q90 = np.percentile(nn, [10, 50, 90])
    nn_min = float(np.min(nn))
    nn_med = float(q50)
    nn_mean = float(np.mean(nn))
    nn_max = float(np.max(nn))

    print(
        "  NN distance (m): "
        f"min={nn_min:.2f}, p10={q10:.2f}, median={nn_med:.2f}, "
        f"mean={nn_mean:.2f}, p90={q90:.2f}, max={nn_max:.2f}"
    )

    # Derived sizes from median NN (with clamps)
    decluster_cell = float(np.clip(DECLUSTER_MULT * nn_med, DECLUSTER_CELL_MIN_M, DECLUSTER_CELL_MAX_M))
    block_size = float(np.clip(BLOCK_SIZE_MULT * nn_med, BLOCK_SIZE_MIN_M, BLOCK_SIZE_MAX_M))

    print(
        "  Declustering cell (m): "
        f"{decluster_cell:.2f}  (mult={DECLUSTER_MULT:g}, clamp=[{DECLUSTER_CELL_MIN_M:g},{DECLUSTER_CELL_MAX_M:g}])"
    )
    print(
        "  Block size (m): "
        f"{block_size:.2f}  (mult={BLOCK_SIZE_MULT:g}, clamp=[{BLOCK_SIZE_MIN_M:g},{BLOCK_SIZE_MAX_M:g}])"
    )



def safe_max_and_argmax(prob_stack: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Safe computation of p_max and winner for a probability stack with NaNs.

    Returns:
      p_max (float32): max probability, NaN where all classes are NaN
      winner (uint8): 1..K winner, 0 where all classes are NaN
      all_nan (bool): mask where all classes are NaN
    """
    finite = np.isfinite(prob_stack)
    all_nan = np.all(~finite, axis=2)
    prob_for_ops = np.where(finite, prob_stack, -np.inf)

    p_max = np.max(prob_for_ops, axis=2).astype(np.float32)
    p_max[all_nan] = np.nan

    winner = (np.argmax(prob_for_ops, axis=2) + 1).astype(np.uint8)
    winner[all_nan] = 0

    return p_max, winner, all_nan


# ---------------------------------------------------------------------
# IO helpers
# ---------------------------------------------------------------------


def read_points(csv_path: str) -> pd.DataFrame:
    """Read CSV and validate required fields."""
    df = pd.read_csv(csv_path, dtype=str)
    required = {"strManagementUnit", "strAveHeight_cm_PctCover", "UTM_X", "UTM_Y"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {sorted(missing)}")

    df["UTM_X"] = pd.to_numeric(df["UTM_X"], errors="coerce")
    df["UTM_Y"] = pd.to_numeric(df["UTM_Y"], errors="coerce")
    df["strAveHeight_cm_PctCover"] = df["strAveHeight_cm_PctCover"].astype(str).str.strip()

    df = df.dropna(subset=["UTM_X", "UTM_Y", "strManagementUnit", "strAveHeight_cm_PctCover"]).copy()
    df = df[df["strAveHeight_cm_PctCover"].isin(CLASSES)].copy()

    if df.empty:
        raise ValueError("No valid points after filtering to known classes and valid coordinates.")

    df["class_code"] = df["strAveHeight_cm_PctCover"].map(CLASS_TO_CODE).astype(int)
    return df


def split_by_unit(df: pd.DataFrame) -> List[UnitData]:
    """Split dataframe into per-unit UnitData objects."""
    units: List[UnitData] = []
    for unit, g in df.groupby("strManagementUnit", sort=True):
        x = g["UTM_X"].to_numpy(float)
        y = g["UTM_Y"].to_numpy(float)
        code = g["class_code"].to_numpy(int)
        present = sorted(np.unique(code).tolist())
        units.append(UnitData(unit=str(unit), x=x, y=y, code=code, present_codes=present))
    return units


# ---------------------------------------------------------------------
# Grid + envelope domain
# ---------------------------------------------------------------------


def unit_bbox(unit_data: UnitData, buffer_m: float = 0.0) -> Tuple[float, float, float, float]:
    """Axis-aligned bbox of points, used to create raster transform."""
    xmin = float(np.min(unit_data.x)) - float(buffer_m)
    ymin = float(np.min(unit_data.y)) - float(buffer_m)
    xmax = float(np.max(unit_data.x)) + float(buffer_m)
    ymax = float(np.max(unit_data.y)) + float(buffer_m)
    return xmin, ymin, xmax, ymax


def build_grid(
    xmin: float,
    ymin: float,
    xmax: float,
    ymax: float,
    cell_size: float,
) -> Tuple[np.ndarray, np.ndarray, rasterio.Affine]:
    """Build a raster grid with cell centers and a rasterio transform."""
    width = int(math.ceil((xmax - xmin) / cell_size))
    height = int(math.ceil((ymax - ymin) / cell_size))

    transform = from_origin(xmin, ymax, cell_size, cell_size)
    x_centers = xmin + (np.arange(width) + 0.5) * cell_size
    y_centers = ymax - (np.arange(height) + 0.5) * cell_size
    return x_centers, y_centers, transform


def build_point_envelope_geometry(x: np.ndarray, y: np.ndarray, buffer_m: float = 0.0) -> Sequence[dict]:
    """
    Build a convex hull geometry around points, optionally buffered.

    Returns GeoJSON-like geometry dict(s) suitable for rasterio.geometry_mask.
    """
    pts = np.column_stack([x, y]).astype(float)

    try:
        from shapely.geometry import MultiPoint  # type: ignore
        from shapely.geometry import mapping  # type: ignore

        hull = MultiPoint(pts).convex_hull
        if buffer_m and float(buffer_m) != 0.0:
            hull = hull.buffer(float(buffer_m))
        return [mapping(hull)]
    except Exception:
        try:
            from scipy.spatial import ConvexHull  # type: ignore
        except Exception as exc:
            raise ImportError("Envelope masking requires shapely (preferred) or scipy.") from exc

        hull = ConvexHull(pts)
        ring = pts[hull.vertices].tolist()
        ring.append(ring[0])
        return [{"type": "Polygon", "coordinates": [ring]}]


def envelope_mask_for_grid(
    geoms: Sequence[dict],
    out_shape: Tuple[int, int],
    transform: rasterio.Affine,
) -> np.ndarray:
    """Boolean mask of pixels inside envelope geometry. True = inside domain."""
    inside = geometry_mask(
        geometries=geoms,
        out_shape=out_shape,
        transform=transform,
        invert=True,
        all_touched=False,
    )
    return inside


# ---------------------------------------------------------------------
# Nearest neighbor distance (median)
# ---------------------------------------------------------------------


def estimate_median_nn_distance(x: np.ndarray, y: np.ndarray) -> float:
    """Median nearest-neighbor distance using KD-tree if available."""
    pts = np.column_stack([x, y]).astype(float)
    if pts.shape[0] < 2:
        return 1.0

    try:
        from scipy.spatial import cKDTree  # type: ignore

        tree = cKDTree(pts)
        dists, _ = tree.query(pts, k=2)
        return float(np.median(dists[:, 1]))
    except Exception:
        try:
            from sklearn.neighbors import NearestNeighbors  # type: ignore

            nn = NearestNeighbors(n_neighbors=2).fit(pts)
            dists, _ = nn.kneighbors(pts, n_neighbors=2)
            return float(np.median(dists[:, 1]))
        except Exception:
            dx = pts[:, 0].max() - pts[:, 0].min()
            dy = pts[:, 1].max() - pts[:, 1].min()
            return float(max(1.0, 0.01 * math.hypot(dx, dy)))


# ---------------------------------------------------------------------
# Declustering weights
# ---------------------------------------------------------------------


def compute_declustering_weights(x: np.ndarray, y: np.ndarray, cell_size_m: float) -> np.ndarray:
    """
    Cell declustering weights: w_i = 1 / n_cell(i), normalized to mean 1.0.
    """
    xmin = float(np.min(x))
    ymin = float(np.min(y))

    cx = np.floor((x - xmin) / cell_size_m).astype(int)
    cy = np.floor((y - ymin) / cell_size_m).astype(int)
    key = cx.astype(np.int64) * 1_000_000 + cy.astype(np.int64)

    _, inv, counts = np.unique(key, return_inverse=True, return_counts=True)
    w = 1.0 / counts[inv].astype(float)
    w *= (len(w) / max(w.sum(), EPS))
    return w


def weighted_prevalence(code: np.ndarray, weights: np.ndarray) -> np.ndarray:
    """Weighted prevalence p_k for each code 1..5. Returns shape (5,), sums to 1."""
    p = np.zeros(5, dtype=float)
    wsum = float(np.sum(weights))
    if wsum <= 0:
        return np.ones(5, dtype=float) / 5.0

    for k in range(1, 6):
        p[k - 1] = float(np.sum(weights[code == k]) / wsum)

    s = float(np.sum(p))
    return p / max(s, EPS)


# ---------------------------------------------------------------------
# Spatial blocks for CV
# ---------------------------------------------------------------------


def assign_blocks(x: np.ndarray, y: np.ndarray, block_size_m: float) -> np.ndarray:
    """Assign points to spatial blocks using a regular grid."""
    xmin = float(np.min(x))
    ymin = float(np.min(y))
    bx = np.floor((x - xmin) / block_size_m).astype(int)
    by = np.floor((y - ymin) / block_size_m).astype(int)
    return bx.astype(np.int64) * 1_000_000 + by.astype(np.int64)


def make_block_folds(block_ids: np.ndarray, n_folds: int, seed: int) -> List[np.ndarray]:
    """Return list of boolean masks for test points in each fold."""
    rng = np.random.default_rng(seed)
    uniq = np.unique(block_ids)
    rng.shuffle(uniq)
    parts = np.array_split(uniq, n_folds)
    return [np.isin(block_ids, blk) for blk in parts]


# ---------------------------------------------------------------------
# Variogram models and covariance
# ---------------------------------------------------------------------


def variogram_gamma(model_type: str, h: np.ndarray, rng_m: float, nugget: float, partial_sill: float) -> np.ndarray:
    """Semivariogram γ(h) for exp/sph/gau."""
    h = np.asarray(h, dtype=float)

    if model_type == "exponential":
        return nugget + partial_sill * (1.0 - np.exp(-h / max(rng_m, EPS)))

    if model_type == "gaussian":
        z = h / max(rng_m, EPS)
        return nugget + partial_sill * (1.0 - np.exp(-(z * z)))

    if model_type == "spherical":
        r = h / max(rng_m, EPS)
        out = np.empty_like(r, dtype=float)
        inside = r < 1.0
        out[inside] = nugget + partial_sill * (1.5 * r[inside] - 0.5 * r[inside] ** 3)
        out[~inside] = nugget + partial_sill
        return out

    raise ValueError(f"Unknown model_type: {model_type!r}")


def covariance_from_model(
    model_type: str,
    h: np.ndarray,
    rng_m: float,
    nugget: float,
    partial_sill: float,
) -> np.ndarray:
    """
    Covariance corresponding to the chosen model:
    - off-diagonal: partial_sill * corr(h)
    - diagonal: partial_sill + nugget
    """
    h = np.asarray(h, dtype=float)

    if model_type == "exponential":
        c = partial_sill * np.exp(-h / max(rng_m, EPS))
    elif model_type == "gaussian":
        z = h / max(rng_m, EPS)
        c = partial_sill * np.exp(-(z * z))
    elif model_type == "spherical":
        r = h / max(rng_m, EPS)
        c = np.zeros_like(r, dtype=float)
        inside = r < 1.0
        c[inside] = partial_sill * (1.0 - 1.5 * r[inside] + 0.5 * r[inside] ** 3)
    else:
        raise ValueError(f"Unknown model_type: {model_type!r}")

    return np.where(h == 0.0, partial_sill + nugget, c)


# ---------------------------------------------------------------------
# Empirical variogram (weighted; pair subsampling)
# ---------------------------------------------------------------------


def _sample_pairs(n: int, max_pairs: int, rng: np.random.Generator) -> Tuple[np.ndarray, np.ndarray]:
    """Sample up to max_pairs index pairs (i < j) from n points."""
    if n < 2:
        return np.array([], dtype=int), np.array([], dtype=int)

    total_pairs = n * (n - 1) // 2
    if total_pairs <= max_pairs:
        i_idx, j_idx = np.triu_indices(n, k=1)
        return i_idx.astype(int), j_idx.astype(int)

    i_idx = rng.integers(0, n, size=max_pairs, dtype=np.int64)
    j_idx = rng.integers(0, n, size=max_pairs, dtype=np.int64)
    swap = i_idx > j_idx
    i_idx[swap], j_idx[swap] = j_idx[swap], i_idx[swap]
    neq = i_idx != j_idx
    return i_idx[neq].astype(int), j_idx[neq].astype(int)


def empirical_variogram_weighted(
    x: np.ndarray,
    y: np.ndarray,
    z: np.ndarray,
    w: np.ndarray,
    n_lags: int,
    max_dist: float,
    max_pairs: int,
    seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Weighted empirical variogram:
      gamma(h) = 0.5*(z_i - z_j)^2 aggregated in lag bins, weights w_i*w_j.
    """
    rng = np.random.default_rng(seed)
    n = len(z)
    if n < 2:
        return np.full(n_lags, np.nan), np.full(n_lags, np.nan), np.zeros(n_lags)

    i_idx, j_idx = _sample_pairs(n, max_pairs=max_pairs, rng=rng)
    if i_idx.size == 0:
        return np.full(n_lags, np.nan), np.full(n_lags, np.nan), np.zeros(n_lags)

    dx = x[i_idx] - x[j_idx]
    dy = y[i_idx] - y[j_idx]
    d = np.sqrt(dx * dx + dy * dy)

    valid = d <= max_dist
    if not np.any(valid):
        return np.full(n_lags, np.nan), np.full(n_lags, np.nan), np.zeros(n_lags)

    i_idx = i_idx[valid]
    j_idx = j_idx[valid]
    d = d[valid]

    semivar = 0.5 * (z[i_idx] - z[j_idx]) ** 2
    pw = (w[i_idx] * w[j_idx]).astype(float)

    edges = np.linspace(0.0, max_dist, n_lags + 1)
    bin_idx = np.searchsorted(edges, d, side="right") - 1
    bin_idx = np.clip(bin_idx, 0, n_lags - 1)

    gamma_sum = np.zeros(n_lags, dtype=float)
    weight_sum = np.zeros(n_lags, dtype=float)

    np.add.at(gamma_sum, bin_idx, semivar * pw)
    np.add.at(weight_sum, bin_idx, pw)

    gamma_emp = np.where(weight_sum > 0, gamma_sum / weight_sum, np.nan)
    lag_centers = 0.5 * (edges[:-1] + edges[1:])
    return lag_centers, gamma_emp, weight_sum


# ---------------------------------------------------------------------
# Variogram fitting surface + best fit (grid search)
# ---------------------------------------------------------------------


def fit_variogram_surface(
    model_type: str,
    lag: np.ndarray,
    gamma_emp: np.ndarray,
    bin_w: np.ndarray,
    sill_total: float,
    range_min: float,
    range_max: float,
    nugget_max: float,
    range_grid_size: int,
    nugget_grid_size: int,
) -> Tuple[float, float, VariogramFitArtifacts]:
    """
    Fit (range, nugget) for a fixed sill_total via grid search and return fit surface.
    """
    ok = np.isfinite(gamma_emp) & (bin_w > 0)
    if not np.any(ok):
        rg = np.geomspace(max(range_min, EPS), max(range_max, range_min + EPS), range_grid_size)
        ng = np.linspace(0.0, max(nugget_max, 0.0), nugget_grid_size)
        sse = np.full((len(ng), len(rg)), np.nan, dtype=float)
        best_r = float(rg[len(rg) // 2])
        best_n = float(min(nugget_max, 0.05))
        return best_r, best_n, VariogramFitArtifacts(lag, gamma_emp, bin_w, sill_total, rg, ng, sse)

    lag2 = lag[ok]
    emp2 = gamma_emp[ok]
    w2 = bin_w[ok]

    range_grid = np.geomspace(max(range_min, EPS), max(range_max, range_min + EPS), range_grid_size)
    nugget_grid = np.linspace(0.0, max(nugget_max, 0.0), nugget_grid_size)

    sse_grid = np.zeros((len(nugget_grid), len(range_grid)), dtype=float)
    best_sse = np.inf
    best_r = float(range_grid[len(range_grid) // 2])
    best_n = float(min(nugget_max, 0.05))

    for i, nug in enumerate(nugget_grid):
        partial = max(sill_total - nug, 0.0)
        for j, rr in enumerate(range_grid):
            mod = variogram_gamma(model_type, lag2, rr, nug, partial)
            sse = float(np.sum(w2 * (emp2 - mod) ** 2))
            sse_grid[i, j] = sse
            if sse < best_sse:
                best_sse = sse
                best_r = float(rr)
                best_n = float(nug)

    art = VariogramFitArtifacts(
        lag=lag,
        gamma_emp=gamma_emp,
        bin_w=bin_w,
        sill_total=sill_total,
        range_grid=range_grid,
        nugget_grid=nugget_grid,
        sse_grid=sse_grid,
    )
    return best_r, best_n, art


# ---------------------------------------------------------------------
# Local OK prediction for binary J (CV scoring)
# ---------------------------------------------------------------------


class NeighborIndex:
    """Radius neighbor search using scipy.cKDTree or sklearn NearestNeighbors."""

    def __init__(self, x: np.ndarray, y: np.ndarray):
        self.xy = np.column_stack([x, y]).astype(float)
        self._backend = None
        self._tree = None

        try:
            from scipy.spatial import cKDTree  # type: ignore

            self._backend = "scipy"
            self._tree = cKDTree(self.xy)
        except Exception:
            try:
                from sklearn.neighbors import NearestNeighbors  # type: ignore

                self._backend = "sklearn"
                self._tree = NearestNeighbors(algorithm="ball_tree").fit(self.xy)
            except Exception as exc:
                raise ImportError("Need scipy or scikit-learn for neighbor search.") from exc

    def query_neighbors(self, targets: np.ndarray, radius: float, max_neighbors: int) -> List[np.ndarray]:
        if self._backend == "scipy":
            idx_list = self._tree.query_ball_point(targets, r=radius)  # type: ignore
            out: List[np.ndarray] = []
            for i, idx in enumerate(idx_list):
                if not idx:
                    out.append(np.array([], dtype=int))
                    continue
                pts = self.xy[np.array(idx, dtype=int)]
                d = np.sqrt(np.sum((pts - targets[i]) ** 2, axis=1))
                order = np.argsort(d)[:max_neighbors]
                out.append(np.array(idx, dtype=int)[order])
            return out

        n = self.xy.shape[0]
        k = min(n, max_neighbors)
        d, idx = self._tree.kneighbors(targets, n_neighbors=k, return_distance=True)  # type: ignore
        out = []
        for di, ii in zip(d, idx):
            out.append(ii[di <= radius])
        return out


def ok_predict_binary_local(
    train_x: np.ndarray,
    train_y: np.ndarray,
    train_z: np.ndarray,
    targets: np.ndarray,
    model: VariogramModel,
    search_radius_m: float,
    min_neighbors: int,
    max_neighbors: int,
) -> np.ndarray:
    """Local ordinary kriging prediction for binary variable train_z at target locations."""
    nugget = float(model.nugget)
    sill_total = float(model.sill_total)
    partial = max(sill_total - nugget, 0.0)

    idx = NeighborIndex(train_x, train_y)
    neigh_list = idx.query_neighbors(targets, radius=search_radius_m, max_neighbors=max_neighbors)

    p = np.full(targets.shape[0], np.nan, dtype=float)
    xy_train = np.column_stack([train_x, train_y]).astype(float)

    for i, nn in enumerate(neigh_list):
        if nn.size < min_neighbors:
            p[i] = float(np.mean(train_z))
            continue

        pts = xy_train[nn]
        z = train_z[nn].astype(float)

        dx = pts[:, 0][:, None] - pts[:, 0][None, :]
        dy = pts[:, 1][:, None] - pts[:, 1][None, :]
        h = np.sqrt(dx * dx + dy * dy)
        C = covariance_from_model(model.model_type, h, model.range_m, nugget, partial)

        m = len(nn)
        A = np.zeros((m + 1, m + 1), dtype=float)
        A[:m, :m] = C
        A[:m, m] = 1.0
        A[m, :m] = 1.0

        d0 = np.sqrt((pts[:, 0] - targets[i, 0]) ** 2 + (pts[:, 1] - targets[i, 1]) ** 2)
        c0 = covariance_from_model(model.model_type, d0, model.range_m, nugget, partial)

        b = np.zeros(m + 1, dtype=float)
        b[:m] = c0
        b[m] = 1.0

        try:
            sol = np.linalg.solve(A, b)
            w = sol[:m]
        except np.linalg.LinAlgError:
            jitter = 1e-8 * (np.trace(C) / max(m, 1))
            A[:m, :m] += np.eye(m) * jitter
            try:
                sol = np.linalg.solve(A, b)
                w = sol[:m]
            except np.linalg.LinAlgError:
                p[i] = float(np.mean(train_z))
                continue

        p[i] = float(np.clip(np.dot(w, z), 0.0, 1.0))

    return p


def log_loss_binary(y_true: np.ndarray, p_hat: np.ndarray, eps: float = 1e-6) -> float:
    """Binary log loss."""
    p = np.clip(p_hat, eps, 1.0 - eps)
    y = y_true.astype(float)
    return float(np.mean(-(y * np.log(p) + (1.0 - y) * np.log(1.0 - p))))


# ---------------------------------------------------------------------
# Diagnostics plotting (PNG)
# ---------------------------------------------------------------------


def plot_variogram_fit_png(
    out_png: Path,
    lag: np.ndarray,
    gamma_emp: np.ndarray,
    bin_w: np.ndarray,
    fits: Dict[str, Tuple[float, float, float]],
    max_dist: float,
    title: str,
) -> None:
    """Plot empirical variogram + fitted curves for exp/sph/gau."""
    fig, ax = plt.subplots(figsize=(9, 5))

    ok = np.isfinite(gamma_emp) & (bin_w > 0)
    ax.scatter(lag[ok], gamma_emp[ok], s=35, label="Empirical (weighted bins)")

    xs = np.linspace(0.0, max_dist, 300)
    for mt, (rr, nug, sill_total) in fits.items():
        partial = max(sill_total - nug, 0.0)
        ys = variogram_gamma(mt, xs, rr, nug, partial)
        ax.plot(xs, ys, label=f"{mt}: range={rr:.1f}, nug={nug:.3f}")

    ax.set_title(title)
    ax.set_xlabel("Lag distance (m)")
    ax.set_ylabel("Semivariance γ(h)")
    ax.set_xlim(0, max_dist)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="best", frameon=True)
    fig.tight_layout()

    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_png, dpi=200)
    plt.close(fig)


def plot_cv_scores_png(out_png: Path, cv_df: pd.DataFrame, title: str) -> None:
    """Boxplot of per-fold log loss by model."""
    fig, ax = plt.subplots(figsize=(7, 5))
    models = ["exponential", "spherical", "gaussian"]
    data = [cv_df.loc[cv_df["model_type"] == m, "log_loss"].values for m in models]
    ax.boxplot(data, tick_labels=models, showmeans=True)

    ax.set_title(title)
    ax.set_ylabel("Log loss (lower is better)")
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()

    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_png, dpi=200)
    plt.close(fig)


def plot_variogram_surface_png(out_png: Path, art: VariogramFitArtifacts, best_r: float, best_n: float, title: str) -> None:
    """Heatmap of SSE objective over (range, nugget)."""
    fig, ax = plt.subplots(figsize=(8, 5))

    sse_plot = np.log10(np.maximum(art.sse_grid, EPS))
    im = ax.imshow(
        sse_plot,
        origin="lower",
        aspect="auto",
        interpolation="nearest",
        extent=[
            float(art.range_grid[0]),
            float(art.range_grid[-1]),
            float(art.nugget_grid[0]),
            float(art.nugget_grid[-1]),
        ],
    )
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label("log10(weighted SSE)")

    ax.scatter([best_r], [best_n], marker="x", s=120)
    ax.set_xscale("log")
    ax.set_xlabel("Range (m) [log scale]")
    ax.set_ylabel("Nugget")
    ax.set_title(title)
    fig.tight_layout()

    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_png, dpi=200)
    plt.close(fig)


def plot_realizations_panel_png(out_png: Path, realizations: List[np.ndarray], inside: np.ndarray, title: str) -> None:
    """Plot a panel of categorical realizations (uint8 codes)."""
    n = len(realizations)
    if n == 0:
        return

    ncols = int(math.ceil(math.sqrt(n)))
    nrows = int(math.ceil(n / ncols))

    cmap = ListedColormap(PALETTE)
    norm = BoundaryNorm(boundaries=np.arange(-0.5, 6.5, 1.0), ncolors=cmap.N)

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4 * ncols, 4 * nrows))
    axes = np.atleast_1d(axes).ravel()

    for i, ax in enumerate(axes):
        ax.axis("off")
        if i >= n:
            continue
        arr = realizations[i].copy()
        arr[~inside] = 0
        ax.imshow(arr, cmap=cmap, norm=norm, interpolation="nearest")
        ax.set_title(f"Panel sample {i + 1}")

    legend_handles = [Patch(facecolor=PALETTE[c], edgecolor="none", label=CODE_TO_CLASS[c]) for c in range(1, 6)]
    legend_handles.insert(0, Patch(facecolor="#ffffff", edgecolor="k", label="Uncertain/Masked"))

    fig.suptitle(title, y=0.995)
    fig.legend(handles=legend_handles, loc="lower center", ncol=3, frameon=True)
    fig.tight_layout(rect=[0, 0.06, 1, 0.96])

    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_png, dpi=200)
    plt.close(fig)


# ---------------------------------------------------------------------
# CV selection + per-unit fits for diagnostics
# ---------------------------------------------------------------------


def select_shared_variogram_by_block_cv(unit_data: UnitData, seed: int) -> Tuple[VariogramModel, pd.DataFrame]:
    """
    Select variogram family + (range, nugget) for binary J via spatial block CV.

    Returns:
      best_model_fitted_on_all_data, cv_scores_df
    """
    dx = float(np.max(unit_data.x) - np.min(unit_data.x))
    dy = float(np.max(unit_data.y) - np.min(unit_data.y))
    span = float(max(math.hypot(dx, dy), 1.0))

    nn_med = estimate_median_nn_distance(unit_data.x, unit_data.y)
    decluster_cell = float(np.clip(DECLUSTER_MULT * nn_med, DECLUSTER_CELL_MIN_M, DECLUSTER_CELL_MAX_M))
    block_size = float(np.clip(BLOCK_SIZE_MULT * nn_med, BLOCK_SIZE_MIN_M, BLOCK_SIZE_MAX_M))

    block_ids = assign_blocks(unit_data.x, unit_data.y, block_size_m=block_size)
    folds = make_block_folds(block_ids, n_folds=CV_FOLDS, seed=seed)

    model_types = ["exponential", "spherical", "gaussian"]
    rows: List[Dict[str, object]] = []

    max_dist = float(max(MAX_DIST_FRACTION_OF_SPAN * span, 3.0 * GRID_CELL_SIZE_M))
    range_min = float(max(3.0 * GRID_CELL_SIZE_M, 10.0))
    range_max = float(max(min(0.5 * span, 250.0), range_min + EPS))

    for fold_i, test_mask in enumerate(folds, start=1):
        train_mask = ~test_mask
        if np.sum(train_mask) < max(MIN_NEIGHBORS + 2, 20) or np.sum(test_mask) < 5:
            continue

        w_train = compute_declustering_weights(
            unit_data.x[train_mask],
            unit_data.y[train_mask],
            cell_size_m=decluster_cell,
        )

        p_train = weighted_prevalence(unit_data.code[train_mask], w_train)
        dominant_code = int(np.argmax(p_train) + 1)

        j_train = (unit_data.code[train_mask] != dominant_code).astype(float)
        j_test = (unit_data.code[test_mask] != dominant_code).astype(float)

        p_j = float(np.sum(w_train * j_train) / max(np.sum(w_train), EPS))
        sill_total_j = float(max(p_j * (1.0 - p_j), EPS))

        lag, gamma_emp, bin_w = empirical_variogram_weighted(
            x=unit_data.x[train_mask],
            y=unit_data.y[train_mask],
            z=j_train,
            w=w_train,
            n_lags=N_LAGS,
            max_dist=max_dist,
            max_pairs=MAX_PAIRS_FOR_VARIOGRAM,
            seed=seed + 1000 * fold_i,
        )

        for mt in model_types:
            best_r, best_n, _ = fit_variogram_surface(
                model_type=mt,
                lag=lag,
                gamma_emp=gamma_emp,
                bin_w=bin_w,
                sill_total=sill_total_j,
                range_min=range_min,
                range_max=range_max,
                nugget_max=min(NUGGET_MAX, sill_total_j),
                range_grid_size=RANGE_GRID_SIZE,
                nugget_grid_size=NUGGET_GRID_SIZE,
            )

            model_fold = VariogramModel(mt, best_r, best_n, sill_total_j)
            search_radius = float(SEARCH_RADIUS_FACTOR * best_r)

            p_hat = ok_predict_binary_local(
                train_x=unit_data.x[train_mask],
                train_y=unit_data.y[train_mask],
                train_z=j_train,
                targets=np.column_stack([unit_data.x[test_mask], unit_data.y[test_mask]]).astype(float),
                model=model_fold,
                search_radius_m=search_radius,
                min_neighbors=MIN_NEIGHBORS,
                max_neighbors=MAX_NEIGHBORS,
            )

            score = log_loss_binary(j_test, p_hat, eps=1e-6)
            rows.append(
                {
                    "fold": fold_i,
                    "model_type": mt,
                    "log_loss": float(score),
                    "range_m": float(best_r),
                    "nugget": float(best_n),
                    "sill_total_J": float(sill_total_j),
                    "dominant_code": int(dominant_code),
                }
            )

    cv_df = pd.DataFrame(rows)
    if cv_df.empty:
        # fallback: heuristic exponential
        w_all = compute_declustering_weights(unit_data.x, unit_data.y, cell_size_m=decluster_cell)
        p_all = weighted_prevalence(unit_data.code, w_all)
        dominant_code = int(np.argmax(p_all) + 1)
        j_all = (unit_data.code != dominant_code).astype(float)
        p_j = float(np.sum(w_all * j_all) / max(np.sum(w_all), EPS))
        sill_total = float(max(p_j * (1.0 - p_j), EPS))
        rng_m = float(np.clip(8.0 * nn_med, 10.0, 200.0))
        nug = float(min(0.05, sill_total))
        return VariogramModel("exponential", rng_m, nug, sill_total), cv_df

    mean_scores = cv_df.groupby("model_type")["log_loss"].mean().to_dict()
    best_family = min(mean_scores, key=mean_scores.get)

    # Final fit on ALL data for chosen family
    w_all = compute_declustering_weights(unit_data.x, unit_data.y, cell_size_m=decluster_cell)
    p_all = weighted_prevalence(unit_data.code, w_all)
    dominant_code = int(np.argmax(p_all) + 1)
    j_all = (unit_data.code != dominant_code).astype(float)

    p_j = float(np.sum(w_all * j_all) / max(np.sum(w_all), EPS))
    sill_total = float(max(p_j * (1.0 - p_j), EPS))

    lag, gamma_emp, bin_w = empirical_variogram_weighted(
        x=unit_data.x,
        y=unit_data.y,
        z=j_all,
        w=w_all,
        n_lags=N_LAGS,
        max_dist=max_dist,
        max_pairs=MAX_PAIRS_FOR_VARIOGRAM,
        seed=seed + 9999,
    )

    best_r, best_n, _ = fit_variogram_surface(
        model_type=best_family,
        lag=lag,
        gamma_emp=gamma_emp,
        bin_w=bin_w,
        sill_total=sill_total,
        range_min=range_min,
        range_max=range_max,
        nugget_max=min(NUGGET_MAX, sill_total),
        range_grid_size=RANGE_GRID_SIZE,
        nugget_grid_size=NUGGET_GRID_SIZE,
    )

    return VariogramModel(best_family, best_r, best_n, sill_total), cv_df


def build_unit_J_all(unit_data: UnitData, seed: int) -> Tuple[np.ndarray, np.ndarray, float, float, float, float]:
    """
    Build binary J on ALL data and return:
      j_all, w_all, sill_total_J, max_dist, range_min, range_max
    """
    dx = float(np.max(unit_data.x) - np.min(unit_data.x))
    dy = float(np.max(unit_data.y) - np.min(unit_data.y))
    span = float(max(math.hypot(dx, dy), 1.0))

    nn_med = estimate_median_nn_distance(unit_data.x, unit_data.y)
    decluster_cell = float(np.clip(DECLUSTER_MULT * nn_med, DECLUSTER_CELL_MIN_M, DECLUSTER_CELL_MAX_M))
    w_all = compute_declustering_weights(unit_data.x, unit_data.y, cell_size_m=decluster_cell)

    p_all = weighted_prevalence(unit_data.code, w_all)
    dominant_code = int(np.argmax(p_all) + 1)
    j_all = (unit_data.code != dominant_code).astype(float)

    p_j = float(np.sum(w_all * j_all) / max(np.sum(w_all), EPS))
    sill_total = float(max(p_j * (1.0 - p_j), EPS))

    max_dist = float(max(MAX_DIST_FRACTION_OF_SPAN * span, 3.0 * GRID_CELL_SIZE_M))
    range_min = float(max(3.0 * GRID_CELL_SIZE_M, 10.0))
    range_max = float(max(min(0.5 * span, 250.0), range_min + EPS))

    return j_all, w_all, sill_total, max_dist, range_min, range_max


def write_unit_variogram_diagnostics(
    unit_data: UnitData,
    unit_out: Path,
    best_model: VariogramModel,
    cv_df: pd.DataFrame,
    seed: int,
) -> None:
    """
    Diagnostics bundle items:
    1) variogram_J_fit.png
    2) cv_scores.csv + cv_scores.png
    4) variogram_fit_surface_<best_model>.png
    """
    unit_out.mkdir(parents=True, exist_ok=True)

    # (2) CV scores
    (unit_out / "cv_scores.csv").write_text(cv_df.to_csv(index=False))
    if not cv_df.empty:
        plot_cv_scores_png(
            out_png=unit_out / "cv_scores.png",
            cv_df=cv_df,
            title=f"{unit_data.unit} - Block CV log loss by variogram family",
        )

    # Empirical variogram for J on ALL data
    j_all, w_all, sill_total, max_dist, range_min, range_max = build_unit_J_all(unit_data, seed=seed)
    lag, gamma_emp, bin_w = empirical_variogram_weighted(
        x=unit_data.x,
        y=unit_data.y,
        z=j_all,
        w=w_all,
        n_lags=N_LAGS,
        max_dist=max_dist,
        max_pairs=MAX_PAIRS_FOR_VARIOGRAM,
        seed=seed + 4242,
    )

    # Fit each family on ALL data for overlay curves (diagnostic #1)
    fits: Dict[str, Tuple[float, float, float]] = {}
    surfaces: Dict[str, VariogramFitArtifacts] = {}
    params: Dict[str, Tuple[float, float]] = {}

    for mt in ["exponential", "spherical", "gaussian"]:
        rr, nug, art = fit_variogram_surface(
            model_type=mt,
            lag=lag,
            gamma_emp=gamma_emp,
            bin_w=bin_w,
            sill_total=sill_total,
            range_min=range_min,
            range_max=range_max,
            nugget_max=min(NUGGET_MAX, sill_total),
            range_grid_size=RANGE_GRID_SIZE,
            nugget_grid_size=NUGGET_GRID_SIZE,
        )
        fits[mt] = (float(rr), float(nug), float(sill_total))
        surfaces[mt] = art
        params[mt] = (float(rr), float(nug))

    plot_variogram_fit_png(
        out_png=unit_out / "variogram_J_fit.png",
        lag=lag,
        gamma_emp=gamma_emp,
        bin_w=bin_w,
        fits=fits,
        max_dist=max_dist,
        title=f"{unit_data.unit} - Empirical variogram for J (dominant vs rest) with fitted curves",
    )

    # (4) Parameter surface for selected family
    best_mt = best_model.model_type
    best_r, best_n = params[best_mt]
    art_best = surfaces[best_mt]
    plot_variogram_surface_png(
        out_png=unit_out / f"variogram_fit_surface_{best_mt}.png",
        art=art_best,
        best_r=best_r,
        best_n=best_n,
        title=f"{unit_data.unit} - Variogram fit surface ({best_mt})",
    )


# ---------------------------------------------------------------------
# Class-specific sills and shrinkage
# ---------------------------------------------------------------------


def compute_class_sills(unit_data: UnitData) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute marginal prevalence p_k and total sills sill_k = p_k(1-p_k) with declustering.
    """
    nn_med = estimate_median_nn_distance(unit_data.x, unit_data.y)
    decluster_cell = float(np.clip(DECLUSTER_MULT * nn_med, DECLUSTER_CELL_MIN_M, DECLUSTER_CELL_MAX_M))
    w = compute_declustering_weights(unit_data.x, unit_data.y, cell_size_m=decluster_cell)
    p = weighted_prevalence(unit_data.code, w)

    total = p * (1.0 - p)
    if SILL_FLOOR and float(SILL_FLOOR) > 0:
        total = np.maximum(total, float(SILL_FLOOR))

    return p, total


def total_to_partial_sills(total_sills: np.ndarray, nugget_shared: float) -> np.ndarray:
    """partial_sill_k = max(total_sill_k - nugget, 0)."""
    return np.maximum(total_sills - float(nugget_shared), 0.0)


# ---------------------------------------------------------------------
# SIS conditioning set (spatial hash)
# ---------------------------------------------------------------------


class SpatialHash:
    """Simple spatial hash for incremental neighborhood queries in SIS."""

    def __init__(self, xmin: float, ymin: float, cell_size_m: float):
        self.xmin = float(xmin)
        self.ymin = float(ymin)
        self.cell = float(cell_size_m)
        self._cells: Dict[Tuple[int, int], List[CondPoint]] = {}

    def _cell_index(self, x: float, y: float) -> Tuple[int, int]:
        cx = int(math.floor((x - self.xmin) / self.cell))
        cy = int(math.floor((y - self.ymin) / self.cell))
        return cx, cy

    def add_point(self, p: CondPoint) -> None:
        idx = self._cell_index(p.x, p.y)
        self._cells.setdefault(idx, []).append(p)

    def add_many(self, x: np.ndarray, y: np.ndarray, code: np.ndarray) -> None:
        for xi, yi, ci in zip(x, y, code):
            self.add_point(CondPoint(float(xi), float(yi), int(ci)))

    def query_radius(self, x0: float, y0: float, radius_m: float, max_neighbors: int) -> List[CondPoint]:
        r = float(radius_m)
        cx, cy = self._cell_index(x0, y0)
        n = int(math.ceil(r / max(self.cell, EPS)))

        candidates: List[CondPoint] = []
        for dx in range(-n, n + 1):
            for dy in range(-n, n + 1):
                cell = (cx + dx, cy + dy)
                if cell in self._cells:
                    candidates.extend(self._cells[cell])

        if not candidates:
            return []

        pts = np.array([(p.x, p.y) for p in candidates], dtype=float)
        d = np.sqrt((pts[:, 0] - x0) ** 2 + (pts[:, 1] - y0) ** 2)
        m = d <= r
        if not np.any(m):
            return []

        d = d[m]
        candidates = [candidates[i] for i in np.where(m)[0]]
        order = np.argsort(d)[:max_neighbors]
        return [candidates[i] for i in order]


# ---------------------------------------------------------------------
# SIS probability mechanics (OK weights + shrinkage)
# ---------------------------------------------------------------------


def ok_weights(
    cond_pts: List[CondPoint],
    x0: float,
    y0: float,
    model_type: str,
    range_m: float,
    nugget: float,
    partial_sill_ref: float,
) -> Optional[np.ndarray]:
    """
    Ordinary kriging weights for a target location using a reference sill scale.
    """
    m = len(cond_pts)
    if m == 0:
        return None

    xy = np.array([(p.x, p.y) for p in cond_pts], dtype=float)
    dx = xy[:, 0][:, None] - xy[:, 0][None, :]
    dy = xy[:, 1][:, None] - xy[:, 1][None, :]
    h = np.sqrt(dx * dx + dy * dy)
    C = covariance_from_model(model_type, h, range_m, nugget, partial_sill_ref)

    A = np.zeros((m + 1, m + 1), dtype=float)
    A[:m, :m] = C
    A[:m, m] = 1.0
    A[m, :m] = 1.0

    d0 = np.sqrt((xy[:, 0] - x0) ** 2 + (xy[:, 1] - y0) ** 2)
    c0 = covariance_from_model(model_type, d0, range_m, nugget, partial_sill_ref)

    b = np.zeros(m + 1, dtype=float)
    b[:m] = c0
    b[m] = 1.0

    try:
        sol = np.linalg.solve(A, b)
        return sol[:m]
    except np.linalg.LinAlgError:
        jitter = 1e-8 * (np.trace(C) / max(m, 1))
        A[:m, :m] += np.eye(m) * jitter
        try:
            sol = np.linalg.solve(A, b)
            return sol[:m]
        except np.linalg.LinAlgError:
            return None


def conditional_probs_sis(
    cond_pts: List[CondPoint],
    weights: Optional[np.ndarray],
    present_codes: List[int],
    marginal_p: np.ndarray,
    partial_sills: np.ndarray,
    nugget: float,
) -> np.ndarray:
    """
    Convert OK indicator means into conditional probabilities using sill-based shrinkage.
    """
    probs = np.zeros(5, dtype=float)

    if weights is None or len(cond_pts) == 0:
        return marginal_p.copy()

    codes = np.array([p.code for p in cond_pts], dtype=int)
    w = weights.astype(float)

    for code in present_codes:
        ind = (codes == code).astype(float)
        m_k = float(np.dot(w, ind))
        m_k = float(np.clip(m_k, 0.0, 1.0))

        ps = float(partial_sills[code - 1])
        lam = ps / max(ps + float(nugget), EPS)
        probs[code - 1] = lam * m_k + (1.0 - lam) * float(marginal_p[code - 1])

    probs = np.clip(probs, 0.0, 1.0)
    s = float(np.sum(probs))
    if s <= 0.0 or not np.isfinite(s):
        return marginal_p.copy()

    return probs / s


# ---------------------------------------------------------------------
# Point-to-grid imprint + MMU cleanup
# ---------------------------------------------------------------------


def rasterize_points_to_grid(
    x: np.ndarray,
    y: np.ndarray,
    code: np.ndarray,
    xmin: float,
    ymax: float,
    cell_size: float,
    height: int,
    width: int,
) -> np.ndarray:
    """Imprint observed points onto the grid at cell resolution."""
    col = np.floor((x - xmin) / cell_size).astype(int)
    row = np.floor((ymax - y) / cell_size).astype(int)

    fixed = np.zeros((height, width), dtype=np.uint8)
    ok = (row >= 0) & (row < height) & (col >= 0) & (col < width)
    fixed[row[ok], col[ok]] = code[ok].astype(np.uint8)
    return fixed


def fill_uncertain_by_majority(cat: np.ndarray) -> np.ndarray:
    """Fill zeros using 3x3 neighborhood majority class (excluding zero)."""
    try:
        from scipy import ndimage  # type: ignore
    except Exception:
        return cat

    out = cat.copy()
    uncertain = out == 0
    if not np.any(uncertain):
        return out

    kernel = np.ones((3, 3), dtype=int)
    class_counts = []
    for c in range(1, 6):
        class_counts.append(ndimage.convolve((out == c).astype(int), kernel, mode="nearest"))

    counts = np.stack(class_counts, axis=0)
    winner = np.argmax(counts, axis=0) + 1
    max_count = np.max(counts, axis=0)

    fillable = uncertain & (max_count > 0)
    out[fillable] = winner[fillable].astype(out.dtype)
    return out


def mmu_cleanup(cat: np.ndarray, mmu_cells: int, n_iters: int = 5) -> np.ndarray:
    """Remove patches smaller than mmu_cells and fill by neighborhood majority."""
    try:
        from scipy import ndimage  # type: ignore
    except Exception:
        print("WARNING: scipy not available; skipping MMU cleanup.")
        return cat

    out = cat.copy()
    structure = np.ones((3, 3), dtype=int)

    for _ in range(n_iters):
        changed = False
        for c in range(1, 6):
            mask = out == c
            if not np.any(mask):
                continue

            labeled, nlab = ndimage.label(mask, structure=structure)
            if nlab == 0:
                continue

            sizes = np.bincount(labeled.ravel())
            small = np.where((sizes < mmu_cells) & (np.arange(len(sizes)) != 0))[0]
            if len(small) == 0:
                continue

            small_mask = np.isin(labeled, small)
            if np.any(small_mask):
                out[small_mask] = 0
                changed = True

        if not changed:
            break

        out = fill_uncertain_by_majority(out)

    return out


# ---------------------------------------------------------------------
# GeoTIFF writing
# ---------------------------------------------------------------------


def write_geotiff(
    path: Path,
    array: np.ndarray,
    transform: rasterio.Affine,
    crs_epsg: int,
    nodata: Optional[float],
    dtype: str,
) -> None:
    """Write a single-band GeoTIFF."""
    path.parent.mkdir(parents=True, exist_ok=True)

    height, width = array.shape
    profile = {
        "driver": "GTiff",
        "height": height,
        "width": width,
        "count": 1,
        "dtype": dtype,
        "crs": rasterio.crs.CRS.from_epsg(crs_epsg),
        "transform": transform,
        "compress": "deflate",
        "predictor": 2 if np.issubdtype(array.dtype, np.floating) else 1,
        "tiled": True,
        "blockxsize": 256,
        "blockysize": 256,
    }
    if nodata is not None:
        profile["nodata"] = nodata

    with rasterio.open(path, "w", **profile) as dst:
        dst.write(array, 1)


# ---------------------------------------------------------------------
# SIS per unit (includes SAFE p_max / winner computation)
# ---------------------------------------------------------------------


def run_sis_for_unit(unit_data: UnitData, out_dir: Path, variogram: VariogramModel, seed: int) -> Dict[str, object]:
    unit_name = safe_name(unit_data.unit)
    unit_out = out_dir / unit_name
    unit_out.mkdir(parents=True, exist_ok=True)

    xmin, ymin, xmax, ymax = unit_bbox(unit_data, buffer_m=EXTENT_BUFFER_M)
    x_centers, y_centers, transform = build_grid(xmin, ymin, xmax, ymax, cell_size=GRID_CELL_SIZE_M)
    width = len(x_centers)
    height = len(y_centers)

    xx, yy = np.meshgrid(x_centers, y_centers)
    grid_x = xx.astype(float)
    grid_y = yy.astype(float)

    if USE_POINT_ENVELOPE_MASK:
        geoms = build_point_envelope_geometry(unit_data.x, unit_data.y, buffer_m=ENVELOPE_BUFFER_M)
        inside = envelope_mask_for_grid(geoms, out_shape=(height, width), transform=transform)
    else:
        inside = np.ones((height, width), dtype=bool)

    inside_lin = np.flatnonzero(inside.ravel())
    if inside_lin.size == 0:
        return {"strManagementUnit": unit_data.unit, "status": "skipped_empty_envelope"}

    marginal_p, total_sills = compute_class_sills(unit_data)
    partial_sills = total_to_partial_sills(total_sills, nugget_shared=variogram.nugget)

    partial_ref = max(float(variogram.sill_total - variogram.nugget), float(np.nanmedian(partial_sills)), EPS)
    search_radius = float(SEARCH_RADIUS_FACTOR * variogram.range_m)

    fixed_code = rasterize_points_to_grid(
        x=unit_data.x,
        y=unit_data.y,
        code=unit_data.code,
        xmin=xmin,
        ymax=ymax,
        cell_size=GRID_CELL_SIZE_M,
        height=height,
        width=width,
    )
    fixed_code = np.where(inside, fixed_code, 0).astype(np.uint8)
    fixed_mask = fixed_code > 0

    rng = np.random.default_rng(seed)

    counts = np.zeros((height, width, 5), dtype=np.uint16)

    # Keep some realizations for the panel
    panel_n = max(1, int(PANEL_N))
    if N_REALIZATIONS <= panel_n:
        keep_ids = list(range(1, N_REALIZATIONS + 1))
    else:
        keep_ids = np.linspace(1, N_REALIZATIONS, panel_n, dtype=int).tolist()
        keep_ids = sorted(set(keep_ids))

    kept_realizations: List[np.ndarray] = []

    for r_idx in range(1, N_REALIZATIONS + 1):
        sh = SpatialHash(xmin=xmin, ymin=ymin, cell_size_m=max(search_radius, EPS))
        sh.add_many(unit_data.x, unit_data.y, unit_data.code)

        cat = np.zeros((height, width), dtype=np.uint8)
        cat[fixed_mask] = fixed_code[fixed_mask]

        sim_lin = inside_lin[~fixed_mask.ravel()[inside_lin]]
        path = sim_lin.copy()
        rng.shuffle(path)

        for lin in path:
            row = lin // width
            col = lin % width
            x0 = float(grid_x[row, col])
            y0 = float(grid_y[row, col])

            neigh = sh.query_radius(x0, y0, radius_m=search_radius, max_neighbors=MAX_NEIGHBORS)

            if len(neigh) < MIN_NEIGHBORS:
                probs = marginal_p.copy()
            else:
                w = ok_weights(
                    cond_pts=neigh,
                    x0=x0,
                    y0=y0,
                    model_type=variogram.model_type,
                    range_m=variogram.range_m,
                    nugget=variogram.nugget,
                    partial_sill_ref=partial_ref,
                )
                probs = conditional_probs_sis(
                    cond_pts=neigh,
                    weights=w,
                    present_codes=unit_data.present_codes,
                    marginal_p=marginal_p,
                    partial_sills=partial_sills,
                    nugget=variogram.nugget,
                )

            probs = np.clip(probs, 0.0, 1.0)
            s = float(np.sum(probs))
            probs = probs / s if (s > 0.0 and np.isfinite(s)) else marginal_p.copy()

            code = int(rng.choice(np.arange(1, 6), p=probs))
            cat[row, col] = np.uint8(code)
            sh.add_point(CondPoint(x=x0, y=y0, code=code))

        for c in range(1, 6):
            counts[:, :, c - 1] += ((cat == c) & inside).astype(np.uint16)

        if r_idx in keep_ids:
            kept_realizations.append(cat.copy())

        print(f"Unit {unit_data.unit}: realization {r_idx}/{N_REALIZATIONS} complete")

    prob_stack = counts.astype(np.float32) / float(N_REALIZATIONS)
    prob_stack[~inside, :] = np.nan

    # SAFE p_max and winner (no RuntimeWarning)
    p_max, winner, all_nan = safe_max_and_argmax(prob_stack)

    # Uncertainty threshold -> label uncertain as 0
    uncertain_mask = (p_max < float(UNCERTAIN_TOLERANCE)) | ~np.isfinite(p_max)
    winner[uncertain_mask] = 0

    # MMU cleanup (only affects categorical)
    if MMU_CELLS and int(MMU_CELLS) > 1:
        winner = mmu_cleanup(winner, mmu_cells=int(MMU_CELLS)).astype(np.uint8)

    # Write rasters
    for i, cls in enumerate(CLASSES):
        write_geotiff(
            unit_out / f"prob_{safe_name(cls)}.tif",
            prob_stack[:, :, i].astype(np.float32),
            transform=transform,
            crs_epsg=CRS_EPSG,
            nodata=np.nan,
            dtype="float32",
        )

    write_geotiff(
        unit_out / "p_max_confidence.tif",
        p_max.astype(np.float32),
        transform=transform,
        crs_epsg=CRS_EPSG,
        nodata=np.nan,
        dtype="float32",
    )

    write_geotiff(
        unit_out / "categorical_winner.tif",
        winner.astype(np.uint8),
        transform=transform,
        crs_epsg=CRS_EPSG,
        nodata=0,
        dtype="uint8",
    )

    # Realizations panel PNG (diagnostic #5)
    plot_realizations_panel_png(
        out_png=unit_out / "realizations_panel.png",
        realizations=kept_realizations,
        inside=inside,
        title=f"{unit_data.unit} - SIS realizations (panel n={len(kept_realizations)} of {N_REALIZATIONS})",
    )

    meta = {
        "strManagementUnit": unit_data.unit,
        "n_points": int(len(unit_data.x)),
        "present_classes": [CODE_TO_CLASS[c] for c in unit_data.present_codes],
        "grid_cell_size_m": float(GRID_CELL_SIZE_M),
        "extent_buffer_m": float(EXTENT_BUFFER_M),
        "use_point_envelope_mask": bool(USE_POINT_ENVELOPE_MASK),
        "envelope_buffer_m": float(ENVELOPE_BUFFER_M),
        "variogram_selected": {
            "model_type": variogram.model_type,
            "range_m": float(variogram.range_m),
            "nugget": float(variogram.nugget),
            "sill_total_J": float(variogram.sill_total),
        },
        "search_radius_m": float(search_radius),
        "min_neighbors": int(MIN_NEIGHBORS),
        "max_neighbors": int(MAX_NEIGHBORS),
        "n_realizations": int(N_REALIZATIONS),
        "random_seed": int(seed),
        "uncertain_tolerance": float(UNCERTAIN_TOLERANCE),
        "mmu_cells": int(MMU_CELLS),
        "class_to_code": CLASS_TO_CODE,
        "code_to_class": CODE_TO_CLASS,
    }
    (unit_out / "metadata.json").write_text(json.dumps(meta, indent=2))

    return {
        "strManagementUnit": unit_data.unit,
        "n_points": int(len(unit_data.x)),
        "variogram_type": variogram.model_type,
        "range_m": float(variogram.range_m),
        "nugget": float(variogram.nugget),
        "search_radius_m": float(search_radius),
        "n_realizations": int(N_REALIZATIONS),
        "status": "ok",
    }


# ---------------------------------------------------------------------
# Orchestrator: per unit, diagnostics + SIS
# ---------------------------------------------------------------------


def run_sis_workflow(input_csv: str, output_dir: str, seed: int = RANDOM_SEED) -> pd.DataFrame:
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    df = read_points(input_csv)
    units = split_by_unit(df)

    rows: List[Dict[str, object]] = []

    for u in units:
        unit_out = out_dir / safe_name(u.unit)
        print("\n" + "=" * 72)
        print_unit_class_summary(u)

        # Select shared variogram by block CV + CV table for diagnostics
        best_model, cv_df = select_shared_variogram_by_block_cv(u, seed=seed)
        print(
            f"Selected variogram: {best_model.model_type} "
            f"(range={best_model.range_m:.2f} m, nugget={best_model.nugget:.4f}, sill_J={best_model.sill_total:.4f})"
        )

        # Diagnostics 1,2,4
        write_unit_variogram_diagnostics(
            unit_data=u,
            unit_out=unit_out,
            best_model=best_model,
            cv_df=cv_df,
            seed=seed,
        )

        # SIS + outputs + diagnostic 5 (realizations panel)
        row = run_sis_for_unit(
            unit_data=u,
            out_dir=out_dir,
            variogram=best_model,
            seed=seed,
        )
        rows.append(row)

    summary = pd.DataFrame(rows).sort_values("strManagementUnit").reset_index(drop=True)
    summary_path = out_dir / "sis_summary_by_unit.csv"
    summary.to_csv(summary_path, index=False)
    print(f"\nWrote: {summary_path}")
    return summary


# ---------------------------------------------------------------------
# Run (execute in notebook)
# ---------------------------------------------------------------------

summary_df = run_sis_workflow(INPUT_CSV, OUTPUT_DIR, seed=RANDOM_SEED)
summary_df
