In [None]:
from __future__ import annotations

import re
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rasterio
from matplotlib.colors import BoundaryNorm, ListedColormap


# ---------------------------------------------------------------------
# Configure these paths
# ---------------------------------------------------------------------

CSV_PATH = r"/home/abhayadana/Downloads/SummerLowStatureSampling_pt_2025.csv"
OUTPUT_DIR = Path(r"/home/abhayadana/Documents/GitHub/SLSS_analysis")


# ---------------------------------------------------------------------
# Constants: same 5 classes and codes as your workflow
# ---------------------------------------------------------------------

CLASSES = [
    "Short Open",
    "Tall Open",
    "Mid Mod",
    "Short Dense",
    "Tall Dense",
]
CLASS_TO_CODE = {cls: i + 1 for i, cls in enumerate(CLASSES)}
CODE_TO_CLASS = {0: "Uncertain", **{i + 1: cls for i, cls in enumerate(CLASSES)}}


def safe_name(name: str) -> str:
    """Make a safe folder/file stem from a string (must match your workflow)."""
    name = str(name).strip()
    name = re.sub(r"[^\w\-]+", "_", name)
    name = re.sub(r"_+", "_", name).strip("_")
    return name or "unit"


def load_points(csv_path: str) -> pd.DataFrame:
    """Load points and keep needed fields."""
    df = pd.read_csv(csv_path, dtype=str)
    needed = {"strManagementUnit", "strAveHeight_cm_PctCover", "UTM_X", "UTM_Y"}
    missing = needed - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {sorted(missing)}")

    df["UTM_X"] = pd.to_numeric(df["UTM_X"], errors="coerce")
    df["UTM_Y"] = pd.to_numeric(df["UTM_Y"], errors="coerce")

    df["strAveHeight_cm_PctCover"] = df["strAveHeight_cm_PctCover"].astype(str).str.strip()
    df = df.dropna(subset=["UTM_X", "UTM_Y", "strManagementUnit", "strAveHeight_cm_PctCover"]).copy()
    df = df[df["strAveHeight_cm_PctCover"].isin(CLASSES)].copy()

    df["class_code"] = df["strAveHeight_cm_PctCover"].map(CLASS_TO_CODE).astype(int)
    return df


def raster_read_with_extent(
    tif_path: Path,
) -> Tuple[np.ndarray, Tuple[float, float, float, float], Optional[float]]:
    """
    Read a single-band GeoTIFF and return (array, extent, nodata).

    extent is (xmin, xmax, ymin, ymax) for matplotlib imshow.
    """
    with rasterio.open(tif_path) as src:
        arr = src.read(1)
        nodata = src.nodata
        bounds = src.bounds  # left, bottom, right, top
        extent = (bounds.left, bounds.right, bounds.bottom, bounds.top)
    return arr, extent, nodata


def class_colors() -> Dict[str, str]:
    """
    Fixed colors for BOTH categorical raster classes and point overlays.
    Use any matplotlib named colors you like; keep them distinct.
    """
    return {
        "Short Open": "gold",
        "Tall Open": "orange",
        "Mid Mod": "dodgerblue",
        "Short Dense": "limegreen",
        "Tall Dense": "crimson",
    }


def categorical_cmap_and_norm() -> Tuple[ListedColormap, BoundaryNorm, Dict[int, str]]:
    """
    Build a discrete colormap for categorical codes:
      0 = Uncertain
      1..5 = CLASSES in fixed order

    Returns:
      cmap, norm, tick_labels_map
    """
    colors = class_colors()

    # Color for code 0 (Uncertain)
    uncertain_color = "lightgray"

    # Order must match codes [0, 1, 2, 3, 4, 5]
    color_list = [
        uncertain_color,                 # 0
        colors["Short Open"],            # 1
        colors["Tall Open"],             # 2
        colors["Mid Mod"],               # 3
        colors["Short Dense"],           # 4
        colors["Tall Dense"],            # 5
    ]
    cmap = ListedColormap(color_list, name="veg_classes")

    # Boundaries so each integer code falls into its own color bin
    boundaries = np.arange(-0.5, 6.5, 1.0)  # [-0.5, 0.5, 1.5, ..., 5.5]
    norm = BoundaryNorm(boundaries, cmap.N)

    tick_labels = {i: CODE_TO_CLASS[i] for i in range(0, 6)}
    return cmap, norm, tick_labels


def plot_unit_raster(
    points_df: pd.DataFrame,
    unit: str,
    raster_path: Path,
    raster_type: str = "auto",
    alpha: float = 0.85,
    point_size: float = 10,
    title: Optional[str] = None,
    show_point_legend: bool = True,
    show_raster_legend: bool = True,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
) -> None:
    """
    Plot a raster with class-symbolized points overlaid for a single management unit.

    raster_type:
      - "categorical": codes 0..5 (uint8)
      - "probability" or "confidence": 0..1 float
      - "auto": infer from filename/dtype
    """
    raster_path = Path(raster_path)
    if not raster_path.exists():
        raise FileNotFoundError(str(raster_path))

    unit_df = points_df[points_df["strManagementUnit"].astype(str) == str(unit)].copy()
    if unit_df.empty:
        raise ValueError(f"No points found for unit={unit!r}")

    arr, extent, nodata = raster_read_with_extent(raster_path)

    rt = raster_type.lower().strip()
    if rt == "auto":
        name = raster_path.name.lower()
        if "categorical" in name:
            rt = "categorical"
        elif "p_max" in name or "confidence" in name:
            rt = "confidence"
        elif "prob_" in name:
            rt = "probability"
        else:
            rt = "categorical" if np.issubdtype(arr.dtype, np.integer) else "probability"

    fig, ax = plt.subplots(figsize=(9, 7))

    if rt == "categorical":
        # Treat nodata as 0 if present (common for categorical_winner.tif),
        # otherwise keep as-is.
        arr_plot = arr.copy()
        if nodata is not None and np.isfinite(nodata):
            # Replace explicit nodata with 0 so it maps to "Uncertain"
            arr_plot = np.where(arr_plot == nodata, 0, arr_plot)

        cmap, norm, tick_labels = categorical_cmap_and_norm()
        im = ax.imshow(
            arr_plot,
            extent=extent,
            origin="upper",
            interpolation="nearest",
            cmap=cmap,
            norm=norm,
            alpha=alpha,
        )

        if show_raster_legend:
            cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, ticks=np.arange(0, 6))
            cbar.ax.set_yticklabels([tick_labels[i] for i in range(0, 6)])
            cbar.set_label("Categorical class")

    else:
        # Continuous raster
        arr_plot = arr.astype(float)
        if nodata is not None and np.isfinite(nodata):
            arr_plot = np.where(arr_plot == nodata, np.nan, arr_plot)

        if vmin is None:
            vmin = 0.0
        if vmax is None:
            vmax = 1.0

        im = ax.imshow(
            arr_plot,
            extent=extent,
            origin="upper",
            interpolation="nearest",
            alpha=alpha,
            vmin=vmin,
            vmax=vmax,
        )

        if show_raster_legend:
            label = "Probability" if rt == "probability" else "p_max confidence"
            cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label(label)

    # Overlay points (symbolized by class)
    colors = class_colors()
    for cls in CLASSES:
        g = unit_df[unit_df["strAveHeight_cm_PctCover"] == cls]
        if g.empty:
            continue
        ax.scatter(
            g["UTM_X"].to_numpy(float),
            g["UTM_Y"].to_numpy(float),
            s=point_size,
            c=colors.get(cls, "white"),
            edgecolors="black",
            linewidths=0.25,
            label=f"{cls} (n={len(g)})",
            zorder=3,
        )

    ax.set_xlabel("UTM_X (m)")
    ax.set_ylabel("UTM_Y (m)")
    ax.set_aspect("equal", adjustable="box")

    if title is None:
        title = f"{unit} â€” {raster_path.name}"
    ax.set_title(title)

    if show_point_legend:
        ax.legend(
            loc="upper right",
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Point classes",
        )

    plt.tight_layout()
    plt.show()


def unit_output_folder(output_dir: Path, unit: str) -> Path:
    """Return the unit folder name as created by the workflow (safe_name)."""
    return Path(output_dir) / safe_name(unit)


def list_unit_rasters(output_dir: Path, unit: str) -> Dict[str, Path]:
    """Convenience: return common raster paths (if present) for a unit."""
    folder = unit_output_folder(output_dir, unit)
    rasters: Dict[str, Path] = {}

    cat = folder / "categorical_winner.tif"
    conf = folder / "p_max_confidence.tif"
    if cat.exists():
        rasters["categorical"] = cat
    if conf.exists():
        rasters["confidence"] = conf

    for cls in CLASSES:
        p = folder / f"prob_{safe_name(cls)}.tif"
        if p.exists():
            rasters[f"prob_{cls}"] = p

    return rasters



In [None]:
# ---------------------------------------------------------------------
# Example usage (run these lines in your notebook)
# ---------------------------------------------------------------------

points = load_points(CSV_PATH)
unit = "1NE"
rasters = list_unit_rasters(OUTPUT_DIR, unit)

# Plot categorical raster symbolized by class + points
plot_unit_raster(points, unit, rasters["categorical"], raster_type="categorical", alpha=0.85)

# Plot p_max confidence + points
plot_unit_raster(points, unit, rasters["confidence"], raster_type="confidence", alpha=0.9)

# Plot a class probability raster + points
plot_unit_raster(points, unit, rasters["prob_Tall Dense"], raster_type="probability", alpha=0.9)

# Plot a class probability raster + points
plot_unit_raster(points, unit, rasters["prob_Short Open"], raster_type="probability", alpha=0.9)