In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Plot Theil–Sen trend maps from LS_theilsen_mk_states_increments.nc
with flexible significance rendering (no stipple) and sensible precip units.

Significance rendering styles:
  - SIG_STYLE="fade"    -> non-significant tiles faded (low alpha), uses SAME cmap/norm as base
  - SIG_STYLE="outline" -> thin black halo around significant tiles
  - SIG_STYLE="contour" -> boundary line (requires quick grid)
  - SIG_STYLE="mask"    -> semi-transparent grey overlay where NOT significant
"""

import struct
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm

from geospatial_plotting import plot_region, REGION_BOUNDS

# -----------------------------
# Config
# -----------------------------
NC_TRENDS   = "LS_theilsen_mk_states_increments.nc"
FILE_TILECO = "/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/land_sweeper/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/rc_out/LS_OLv8_M36.ldas_tilecoord.bin"

VARS          = ["SFMC", "RZMC", "SNOW", "PREC"]

ALPHA         = 0.05
APPLY_SIG     = False
USE_STRICT_SIG_FOR_DELTA = True
PLOT_MASKS    = True

# How to draw significance on the main maps (no stipple):
# "fade" | "outline" | "contour" | "mask"
SIG_STYLE     = "mask"

REGION        = "global"

# Units / scaling
SEC_PER_YEAR  = 365.25 * 24 * 3600.0
VAR_SCALE     = {
    "SFMC": 1.0,
    "RZMC": 1.0,
    "SNOW": 1.0,
    "PREC": SEC_PER_YEAR,   # mm/s → mm/yr (per decade kept)
}
VAR_UNITS     = {
    "SFMC": "per decade",
    "RZMC": "per decade",
    "SNOW": "per decade",
    "PREC": "mm/yr per decade",
}

QUANT_CLIP_MAIN  = 0.995
QUANT_CLIP_DELTA = 0.995

FNAME_MAIN = "{var}_{mode}_trend_map.png"     # mode in {"CNTL","DA","DELTA"}
FNAME_MASK = "{var}_{mode}_sigmask.png"
NC_MASKS   = "trend_significance_masks.nc"

# -----------------------------
# tilecoord reader
# -----------------------------
def read_tilecoord(fname):
    """Read GEOS-LDAS tilecoord Fortran binary (little-endian)."""
    int_precision = 'i'
    float_precision = 'f'
    machfmt = '<'
    tile_coord = {}
    with open(fname, 'rb') as ifp:
        _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        tile_coord['N_tile'] = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        Nt = tile_coord['N_tile']
        fields = ['tile_id','typ','pfaf','com_lon','com_lat','min_lon','max_lon',
                  'min_lat','max_lat','i_indg','j_indg','frac_cell','frac_pfaf',
                  'area','elev']
        for field in fields:
            _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
            dtype = int_precision if field in ['tile_id','typ','pfaf','i_indg','j_indg'] else float_precision
            arr = np.frombuffer(ifp.read(Nt*4), dtype=f'{machfmt}{dtype}')
            arr = arr.astype(np.float64 if dtype=='f' else np.int32)
            tile_coord[field] = arr
            _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
    return tile_coord

# -----------------------------
# helpers
# -----------------------------
def signif_mask(p, alpha=ALPHA):
    p = np.asarray(p)
    return np.isfinite(p) & (p < alpha)

def robust_sym_limit(arrays, q=0.995):
    pool = []
    for a in arrays:
        if a is None:
            continue
        a = np.asarray(a)
        if a.size:
            a = a[np.isfinite(a)]
            if a.size:
                pool.append(np.abs(a))
    if not pool:
        return 1.0
    return float(np.nanquantile(np.concatenate(pool), q))

def make_map_array(values, lon, lat, mask=None):
    z = np.asarray(values, dtype=float)
    if mask is not None:
        z = np.where(mask, z, np.nan)
    m = np.empty((lon.shape[0], 3), dtype=float)
    m.fill(np.nan)
    m[:, 0] = z
    m[:, 1] = lon
    m[:, 2] = lat
    return m

def quick_grid(lon, lat, vals, res=1.0):
    """Nearest-neighbor grid for masks; res in degrees. For 'contour' style."""
    long = np.arange(-180, 180 + res, res)
    latg = np.arange(-90,   90 + res, res)
    LON, LAT = np.meshgrid(long, latg)
    li = np.searchsorted(long, lon) - 1
    lj = np.searchsorted(latg, lat) - 1
    li = np.clip(li, 0, long.size - 1)
    lj = np.clip(lj, 0, latg.size - 1)
    Z = np.full_like(LON, np.nan, dtype=float)
    Z[lj, li] = vals
    return long, latg, Z

def get_ax_cmap_norm(ax):
    """
    Try to recover the cmap & norm used by plot_region's base artist (pcolormesh/PathCollection/etc.).
    Returns (cmap, norm) or (None, None) if not found.
    """
    # look through collections first (scatter/pcolormesh end up here), then images
    for coll in getattr(ax, "collections", []):
        if hasattr(coll, "get_cmap") and hasattr(coll, "get_norm"):
            cmap = coll.get_cmap()
            norm = coll.get_norm()
            if cmap is not None and norm is not None:
                return cmap, norm
    for im in getattr(ax, "images", []):
        if hasattr(im, "get_cmap") and hasattr(im, "get_norm"):
            cmap = im.get_cmap()
            norm = im.get_norm()
            if cmap is not None and norm is not None:
                return cmap, norm
    return None, None

# -----------------------------
# significance rendering (no stipple)
# -----------------------------
def draw_sig_fade(ax, map_array, lon, lat, sig_mask, cmap, norm):
    """
    Re-draw the tile colors using the SAME cmap/norm as the base layer:
      - significant tiles at alpha=1
      - non-significant tiles at alpha=0.25
    """
    sig = np.asarray(sig_mask, dtype=bool)
    nonsig = ~sig
    ok_sig    = np.isfinite(lon) & np.isfinite(lat) & sig
    ok_nonsig = np.isfinite(lon) & np.isfinite(lat) & nonsig

    # cartopy transform (optional)
    xtra = {}
    try:
        import cartopy.crs as ccrs
        if hasattr(ax, "projection"):
            xtra["transform"] = ccrs.PlateCarree()
    except Exception:
        pass

    v = map_array[:, 0]

    # Compute *exact* RGBA using the base layer's cmap & norm
    if ok_nonsig.any():
        colors_ns = cmap(norm(v[ok_nonsig]))
        ax.scatter(lon[ok_nonsig], lat[ok_nonsig],
                   s=7, marker='s', facecolors=colors_ns, edgecolors='none',
                   alpha=0.25, zorder=6, **xtra)
    if ok_sig.any():
        colors_s = cmap(norm(v[ok_sig]))
        ax.scatter(lon[ok_sig], lat[ok_sig],
                   s=7, marker='s', facecolors=colors_s, edgecolors='none',
                   alpha=1.0, zorder=7, **xtra)

def draw_sig_outline(ax, lon, lat, sig_mask):
    sig = np.asarray(sig_mask, dtype=bool)
    ok = np.isfinite(lon) & np.isfinite(lat) & sig
    xtra = {}
    try:
        import cartopy.crs as ccrs
        if hasattr(ax, "projection"):
            xtra["transform"] = ccrs.PlateCarree()
    except Exception:
        pass
    if ok.any():
        ax.scatter(lon[ok], lat[ok],
                   s=12, facecolors='none', edgecolors='k',
                   linewidths=0.25, zorder=9, **xtra)

def draw_sig_contour(ax, lon, lat, sig_mask, res=1.0):
    ok = np.isfinite(lon) & np.isfinite(lat)
    if not ok.any():
        return
    xtra = {}
    try:
        import cartopy.crs as ccrs
        if hasattr(ax, "projection"):
            xtra["transform"] = ccrs.PlateCarree()
    except Exception:
        pass
    Zmask = np.asarray(sig_mask, dtype=bool).astype(float)
    lg, tg, Z = quick_grid(lon[ok], lat[ok], Zmask[ok], res=res)
    ax.contour(lg, tg, Z, levels=[0.5], linewidths=0.8, colors='k', zorder=9, **xtra)

def draw_sig_mask(ax, lon, lat, sig_mask):
    nonsig = ~np.asarray(sig_mask, dtype=bool)
    ok = np.isfinite(lon) & np.isfinite(lat) & nonsig
    xtra = {}
    try:
        import cartopy.crs as ccrs
        if hasattr(ax, "projection"):
            xtra["transform"] = ccrs.PlateCarree()
    except Exception:
        pass
    if ok.any():
        ax.scatter(lon[ok], lat[ok],
                   s=2, marker='s', c='#9e9e9e',
                   alpha=0.1, linewidths=0, zorder=8, **xtra)

# -----------------------------
# plotting
# -----------------------------
def plot_one(var, mode, ds, lon, lat, cmin=None, cmax=None, mask=None, title_extra="", sig_mask=None):
    units_label = VAR_UNITS.get(var, "per decade")
    scale = VAR_SCALE.get(var, 1.0)

    if mode in ("CNTL", "DA"):
        vals = ds[f"{var}_slope_{mode}"].values
    elif mode == "DELTA":
        vals = ds[f"{var}_slope_DA"].values - ds[f"{var}_slope_CNTL"].values
    else:
        raise ValueError(mode)

    vals_plot = vals * scale
    map_array = make_map_array(vals_plot, lon, lat, mask=mask)

    vmax = np.nanmax(map_array[:, 0])
    vmin = np.nanmin(map_array[:, 0])

    # Draw base layer (plot_region likely sets its own cmap & norm internally)
    fig, ax = plot_region(
        map_array,
        region_bounds=REGION_BOUNDS[REGION],
        meanflag=True,
        plot_title=(f"{var} {mode} trend ({units_label}){title_extra}\n"
                    f"(Max: {vmax:.3g}  Min: {vmin:.3g})"),
        units=units_label,
        cmin=cmin,
        cmax=cmax,
    )

    # Recover the EXACT cmap & norm used by the base layer
    cmap, norm = get_ax_cmap_norm(ax)
    # Fallback if we couldn't find them for any reason
    if norm is None:
        norm = mcolors.Normalize(vmin=cmin, vmax=cmax)
    if cmap is None:
        cmap = cm.get_cmap("viridis")

    # Significance overlay
    if sig_mask is not None:
        if SIG_STYLE == "fade":
            draw_sig_fade(ax, map_array, lon, lat, sig_mask, cmap=cmap, norm=norm)
        elif SIG_STYLE == "outline":
            draw_sig_outline(ax, lon, lat, sig_mask)
        elif SIG_STYLE == "contour":
            draw_sig_contour(ax, lon, lat, sig_mask, res=1.0)
        elif SIG_STYLE == "mask":
            draw_sig_mask(ax, lon, lat, sig_mask)

    fig.tight_layout()
    outfile = FNAME_MAIN.format(var=var, mode=mode)
    fig.savefig(outfile, dpi=180)
    # plt.close(fig)
    return outfile

def plot_mask(var, mode, mask_arr, lon, lat, title_note="p<0.05"):
    vals = mask_arr.astype(float)
    map_array = make_map_array(vals, lon, lat, mask=None)
    fig, ax = plot_region(
        map_array,
        region_bounds=REGION_BOUNDS[REGION],
        meanflag=True,
        plot_title=(f"{var} {mode} significance mask ({title_note})\n"
                    f"(1=significant, 0=not)"),
        units="Binary mask",
        cmin=0.0,
        cmax=1.0,
    )
    fig.tight_layout()
    outfile = FNAME_MASK.format(var=var, mode=mode)
    fig.savefig(outfile, dpi=180)
    # plt.close(fig)
    return outfile

# -----------------------------
# main
# -----------------------------
def main():
    ds = xr.open_dataset(NC_TRENDS)

    tc = read_tilecoord(FILE_TILECO)
    print(f"N_tile (tilecoord) = {tc['N_tile']}")
    n_tile = int(tc["N_tile"])
    lat = np.asarray(tc["com_lat"])
    lon = np.asarray(tc["com_lon"])

    assert lat.shape[0] == n_tile and lon.shape[0] == n_tile, "Tilecoord lon/lat length mismatch"
    for v in VARS:
        for sfx in ("CNTL","DA"):
            assert f"{v}_slope_{sfx}" in ds and f"{v}_p_{sfx}" in ds, f"Missing {v} {sfx} fields"

    mask_ds_vars = {}
    outputs = []

    for var in VARS:
        s_c = ds[f"{var}_slope_CNTL"].values
        s_d = ds[f"{var}_slope_DA"].values
        p_c = ds[f"{var}_p_CNTL"].values
        p_d = ds[f"{var}_p_DA"].values

        sig_c = signif_mask(p_c, alpha=ALPHA)
        sig_d = signif_mask(p_d, alpha=ALPHA)
        sig_delta_any  = sig_c | sig_d
        sig_delta_both = sig_c & sig_d
        sig_delta = sig_delta_both if USE_STRICT_SIG_FOR_DELTA else sig_delta_any

        mask_ds_vars[f"{var}_sig_CNTL"]       = (("tile",), sig_c.astype("i1"))
        mask_ds_vars[f"{var}_sig_DA"]         = (("tile",), sig_d.astype("i1"))
        mask_ds_vars[f"{var}_sig_DELTA_any"]  = (("tile",), sig_delta_any.astype("i1"))
        mask_ds_vars[f"{var}_sig_DELTA_both"] = (("tile",), sig_delta_both.astype("i1"))

        mask_cntl  = sig_c     if APPLY_SIG else None
        mask_da    = sig_d     if APPLY_SIG else None
        mask_d     = sig_delta if APPLY_SIG else None

        scale = VAR_SCALE.get(var, 1.0)
        lim_main  = robust_sym_limit([s_c * scale, s_d * scale], q=QUANT_CLIP_MAIN)
        delta     = (s_d - s_c) * scale
        lim_delta = robust_sym_limit([delta], q=QUANT_CLIP_DELTA)

        t_extra = " [sig mask]" if APPLY_SIG else ""

        outputs.append(
            plot_one(var, "CNTL", ds, lon, lat,
                     cmin=-lim_main, cmax=+lim_main,
                     mask=mask_cntl, title_extra=t_extra, sig_mask=sig_c)
        )
        outputs.append(
            plot_one(var, "DA", ds, lon, lat,
                     cmin=-lim_main, cmax=+lim_main,
                     mask=mask_da, title_extra=t_extra, sig_mask=sig_d)
        )
        outputs.append(
            plot_one(var, "DELTA", ds, lon, lat,
                     cmin=-lim_delta, cmax=+lim_delta,
                     mask=mask_d, title_extra=t_extra, sig_mask=sig_delta)
        )

        if PLOT_MASKS:
            outputs.append(plot_mask(var, "CNTL",        sig_c,          lon, lat, title_note=f"p<{ALPHA}"))
            outputs.append(plot_mask(var, "DA",          sig_d,          lon, lat, title_note=f"p<{ALPHA}"))
            outputs.append(plot_mask(var, "DELTA_any",   sig_delta_any,  lon, lat, title_note=f"p<{ALPHA} (DA or CNTL)"))
            outputs.append(plot_mask(var, "DELTA_both",  sig_delta_both, lon, lat, title_note=f"p<{ALPHA} (DA & CNTL)"))

    coords = {
        "tile": np.arange(n_tile, dtype=np.int64),
        "lat":  (("tile",), lat),
        "lon":  (("tile",), lon),
    }
    mask_ds = xr.Dataset(
        data_vars=mask_ds_vars,
        coords=coords,
        attrs={
            "description": "Significance masks derived from MK p-values (1=significant, 0=not).",
            "alpha": ALPHA,
            "note": "DELTA_any = (sig_CNTL OR sig_DA), DELTA_both = (sig_CNTL AND sig_DA).",
            "precip_units_note": "Precip slopes converted for plotting only: mm/s per decade → mm/yr per decade.",
        },
    )
    mask_ds.to_netcdf(NC_MASKS)
    print(f"Wrote masks NetCDF: {NC_MASKS}")

    print("Wrote figures:")
    for o in outputs:
        print("  -", o)

if __name__ == "__main__":
    main()

In [None]:
# -------- Trend map summary stats (area-weighted), save to CSV --------
import csv
from typing import Dict, Tuple

# Which named regions from REGION_BOUNDS to summarize (must exist in your module)
REGIONS_TO_SUM = ["global"]  # e.g., ["global", "NA", "EU", "AUS"] if you have them defined

def region_mask(lon, lat, bounds):
    """bounds = [min_lon, max_lon, min_lat, max_lat] in Plate Carree."""
    lo0, lo1, la0, la1 = bounds
    # handle dateline wrap simply: assume bounds are in [-180, 180]
    return (lon >= lo0) & (lon <= lo1) & (lat >= la0) & (lat <= la1)

def aw_mean(x, w):
    m = np.isfinite(x) & np.isfinite(w) & (w > 0)
    if not m.any(): return np.nan
    return np.nansum(x[m] * w[m]) / np.nansum(w[m])

def aw_median(x, w):
    # quantile with weights (approx via sorting)
    m = np.isfinite(x) & np.isfinite(w) & (w > 0)
    if not m.any(): return np.nan
    xs = x[m]; ws = w[m]
    order = np.argsort(xs)
    xs = xs[order]; ws = ws[order]
    c = np.cumsum(ws) / np.sum(ws)
    return xs[np.searchsorted(c, 0.5)]

def frac_area(mask, w):
    m = np.isfinite(w) & (w > 0)
    if not m.any(): return np.nan
    return float(np.nansum(w[m] * (mask[m].astype(float))) / np.nansum(w[m]))

def pearsonr_masked(a, b, w=None):
    m = np.isfinite(a) & np.isfinite(b)
    if not m.any(): return np.nan
    if w is None:
        a0 = a[m] - np.nanmean(a[m])
        b0 = b[m] - np.nanmean(b[m])
        denom = (np.sqrt(np.nansum(a0*a0)) * np.sqrt(np.nansum(b0*b0)))
        return float(np.nan if denom == 0 else np.nansum(a0*b0)/denom)
    else:
        ww = w[m]
        a1 = a[m]; b1 = b[m]
        mu_a = np.nansum(ww*a1)/np.nansum(ww)
        mu_b = np.nansum(ww*b1)/np.nansum(ww)
        da = a1 - mu_a; db = b1 - mu_b
        num = np.nansum(ww*da*db)
        den = np.sqrt(np.nansum(ww*da*da) * np.nansum(ww*db*db))
        return float(np.nan if den == 0 else num/den)

def summarize_for_region(var, ds, lon, lat, area_w, bounds, scale_prec=SEC_PER_YEAR):
    """
    Returns dict of stats for one var in one region.
    - Slopes are per decade; PREC converted to mm/yr per decade.
    """
    # read base arrays
    s_c = ds[f"{var}_slope_CNTL"].values
    s_d = ds[f"{var}_slope_DA"].values
    p_c = ds[f"{var}_p_CNTL"].values
    p_d = ds[f"{var}_p_DA"].values

    # unit conversion for plotting-consistent readout
    if var == "PREC":
        s_c = s_c * scale_prec
        s_d = s_d * scale_prec

    # region clip
    R = region_mask(lon, lat, bounds)
    if not R.any():
        return None
    w = area_w.copy()
    w[~R] = np.nan

    # significance masks
    sig_c = np.isfinite(p_c) & (p_c < ALPHA)
    sig_d = np.isfinite(p_d) & (p_d < ALPHA)

    # base stats
    out = {
        "area_frac_sig_CNTL": frac_area(sig_c & R, area_w),
        "area_frac_sig_DA":   frac_area(sig_d & R, area_w),
        "aw_mean_slope_CNTL": aw_mean(s_c, w),
        "aw_mean_slope_DA":   aw_mean(s_d, w),
        "aw_median_slope_CNTL": aw_median(s_c, w),
        "aw_median_slope_DA":   aw_median(s_d, w),
    }

    # sign agreement where both significant
    both = R & sig_c & sig_d
    agree = both & (np.sign(s_c) == np.sign(s_d))
    out["area_frac_both_sig"] = frac_area(both, area_w)
    out["area_frac_both_sig_and_agree"] = frac_area(agree, area_w)

    # delta stats (DA - CNTL)
    d = s_d - s_c
    out["aw_mean_delta"]   = aw_mean(d, w)
    out["aw_median_delta"] = aw_median(d, w)
    out["area_frac_delta_pos"] = frac_area((d > 0) & R, area_w)

    # relationship to PREC (only for moisture vars)
    if var in ("SFMC", "RZMC"):
        # compare Δ(SM) vs PREC_C or PREC_DA? Usually forcings similar; use CNTL by default
        p_ref = ds["PREC_slope_CNTL"].values * scale_prec
        out["corr_deltaSM_vs_PREC"] = pearsonr_masked(d, p_ref, w=w)

    return out

def write_summary_csv(summary: Dict[Tuple[str,str], Dict], path="trend_summary.csv"):
    # Collect all keys
    all_keys = sorted({k for d in summary.values() for k in d.keys()})
    with open(path, "w", newline="") as f:
        wtr = csv.writer(f)
        wtr.writerow(["var","region"] + all_keys)
        for (var, region), stats in summary.items():
            wtr.writerow([var, region] + [stats.get(k, np.nan) for k in all_keys])
    print(f"Wrote {path}")

def run_trend_summary():
    ds = xr.open_dataset(NC_TRENDS)
    tc = read_tilecoord(FILE_TILECO)
    lon = np.asarray(tc["com_lon"])
    lat = np.asarray(tc["com_lat"])
    # area from tilecoord (m^2); convert to km^2 for readability but it cancels in fractions
    area_w = np.asarray(tc.get("area", np.ones_like(lon)))
    if not np.isfinite(area_w).any():
        area_w = np.ones_like(lon)

    summary = {}
    for var in VARS:
        for region in REGIONS_TO_SUM:
            bounds = REGION_BOUNDS[region]
            stats = summarize_for_region(var, ds, lon, lat, area_w, bounds)
            if stats is not None:
                summary[(var, region)] = stats

    write_summary_csv(summary, "trend_summary.csv")
    return summary

# Run it
if __name__ == "__main__":
    s = run_trend_summary()
    # quick human-readable dump
    for (var, region), d in s.items():
        print(f"\n[{var} @ {region}]")
        for k, v in d.items():
            print(f"  {k}: {v:.4g}" if isinstance(v, float) else f"  {k}: {v}")
