In [None]:
"""
COMPARE GEOS-LDAS SURFACE SOIL MOISTURE (SFMC) TO ERA5-LAND SWVL1
==================================================================
This notebook:
 1. Reads the GEOS-LDAS tilecoord binary and model output.
 2. Builds the exact M36 grid geometry from tile bounds.
 3. Regrids ERA5-Land monthly soil moisture to that grid using xESMF (conservative).
 4. Aggregates model tile data to the same grid (area weighted, land only).
 5. Computes anomaly correlation & ubRMSE between model and ERA5-Land.
"""

import numpy as np
import xarray as xr
import pandas as pd
import struct
import matplotlib.pyplot as plt
import xesmf as xe
import datetime as dt
import os
from collections import OrderedDict

# === File paths (edit for your system) ===
ftc = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/land_sweeper/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/rc_out/LS_OLv8_M36.ldas_tilecoord.bin'
era_nc = '/Users/amfox/Desktop/GEOSldas_diagnostics/Jupyter/era5l_monthly_nc/ERA5L_monthly_merged.nc'
model_nc = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/M21C_land_sweeper/LS_DAv8_M36_v2/LS_DAv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/DAv8_land_variables_2000_2024_compressed.nc'
# model_nc = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/land_sweeper/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/OLv8_land_variables_2000_2024_compressed.nc'

In [None]:
# --- SOURCE: 1° global grid (centers + bounds) ---
lon_b_src = np.linspace(0.0, 360.0, 361)          # 361 edges
lat_b_src = np.linspace(-90.0, 90.0, 181)         # 181 edges
lon_src   = 0.5*(lon_b_src[:-1] + lon_b_src[1:])  # 360 centers
lat_src   = 0.5*(lat_b_src[:-1] + lat_b_src[1:])  # 180 centers

src = xr.Dataset(
    {
        "lon":    (("lon",), lon_src),
        "lat":    (("lat",), lat_src),
        "lon_b":  (("lon_b",), lon_b_src),
        "lat_b":  (("lat_b",), lat_b_src),
    }
)
src["lon"].attrs.update(standard_name="longitude", units="degrees_east", bounds="lon_b")
src["lat"].attrs.update(standard_name="latitude",  units="degrees_north", bounds="lat_b")
src["lon_b"].attrs.update(standard_name="longitude", units="degrees_east")
src["lat_b"].attrs.update(standard_name="latitude",  units="degrees_north")

# --- TARGET: 6° global grid ---
lon_b_tgt = np.linspace(0.0, 360.0, 61)           # 61 edges
lat_b_tgt = np.linspace(-90.0, 90.0, 31)          # 31 edges
lon_tgt   = 0.5*(lon_b_tgt[:-1] + lon_b_tgt[1:])  # 60 centers
lat_tgt   = 0.5*(lat_b_tgt[:-1] + lat_b_tgt[1:])  # 30 centers

tgt = xr.Dataset(
    {
        "lon":    (("lon",), lon_tgt),
        "lat":    (("lat",), lat_tgt),
        "lon_b":  (("lon_b",), lon_b_tgt),
        "lat_b":  (("lat_b",), lat_b_tgt),
    }
)
tgt["lon"].attrs.update(standard_name="longitude", units="degrees_east", bounds="lon_b")
tgt["lat"].attrs.update(standard_name="latitude",  units="degrees_north", bounds="lat_b")
tgt["lon_b"].attrs.update(standard_name="longitude", units="degrees_east")
tgt["lat_b"].attrs.update(standard_name="latitude",  units="degrees_north")

# --- Build regridder (conservative works because bounds are present) ---
r = xe.Regridder(src, tgt, method="conservative", periodic=True, reuse_weights=False)
print(r)  # should print weights file info


In [None]:
def read_tilecoord(fname):
    """Read GEOS-LDAS tilecoord Fortran binary (little-endian)."""
    int_precision = 'i'
    float_precision = 'f'
    machfmt = '<'
    tile_coord = {}
    with open(fname, 'rb') as ifp:
        _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        tile_coord['N_tile'] = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        Nt = tile_coord['N_tile']
        fields = ['tile_id','typ','pfaf','com_lon','com_lat','min_lon','max_lon',
                  'min_lat','max_lat','i_indg','j_indg','frac_cell','frac_pfaf',
                  'area','elev']
        for field in fields:
            _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
            dtype = int_precision if field in ['tile_id','typ','pfaf','i_indg','j_indg'] else float_precision
            arr = np.frombuffer(ifp.read(Nt*4), dtype=f'{machfmt}{dtype}')
            arr = arr.astype(np.float64 if dtype=='f' else np.int32)
            tile_coord[field] = arr
            _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
    return tile_coord

def build_m36_grid_cf(tc):
    """Build exact M36 grid with centers and bounds (CF-compliant) from tilecoord."""
    i_indg = tc['i_indg'].astype(int)
    j_indg = tc['j_indg'].astype(int)
    min_lon, max_lon = tc['min_lon'], tc['max_lon']
    min_lat, max_lat = tc['min_lat'], tc['max_lat']
    nx = int(i_indg.max()) + 1
    ny = int(j_indg.max()) + 1

    lon_b = np.full((ny+1, nx+1), np.nan)
    lat_b = np.full((ny+1, nx+1), np.nan)
    for lon0, lon1, lat0, lat1, i, j in zip(min_lon, max_lon, min_lat, max_lat, i_indg, j_indg):
        lon_b[j, i]     = lon0; lon_b[j, i+1] = lon1
        lon_b[j+1, i]   = lon0; lon_b[j+1, i+1] = lon1
        lat_b[j, i]     = lat0; lat_b[j+1, i]   = lat1
        lat_b[j, i+1]   = lat0; lat_b[j+1, i+1] = lat1

    lon_c = 0.5*(lon_b[:-1,:-1] + lon_b[1:,1:])
    lat_c = 0.5*(lat_b[:-1,:-1] + lat_b[1:,1:])

    m36 = xr.Dataset(
        {"lon": (("y","x"), lon_c),
         "lat": (("y","x"), lat_c),
         "lon_b": (("y_b","x_b"), lon_b),
         "lat_b": (("y_b","x_b"), lat_b)}
    )
    m36["lon"].attrs.update(standard_name="longitude", units="degrees_east", bounds="lon_b")
    m36["lat"].attrs.update(standard_name="latitude",  units="degrees_north", bounds="lat_b")
    m36["lon_b"].attrs.update(standard_name="longitude", units="degrees_east")
    m36["lat_b"].attrs.update(standard_name="latitude",  units="degrees_north")
    return m36

def tiles_to_m36_grid_unweighted(tile_values, tc, m36_grid, land_only=True):
    """
    Map tile values -> M36 grid with a simple (unweighted) mean per cell.
    - Accepts (time, tile) or (tile,) array.
    - Ignores NaN tiles in the per-cell average (no weighting).
    - Returns DataArray (time, y, x). 'time' is 0..T-1 (assign real times after call).
    """
    # tile -> grid indices
    ix = tc["i_indg"].astype(int)
    iy = tc["j_indg"].astype(int)
    ny = int(iy.max()) + 1
    nx = int(ix.max()) + 1

    tv = np.asarray(tile_values)
    if tv.ndim == 1:
        tv = tv[np.newaxis, :]  # (time, tile)
    T, Nt = tv.shape
    assert Nt == ix.size == iy.size, "tile index arrays must match tile dimension"

    # accumulators
    sum_tyx = np.zeros((T, ny, nx), dtype=np.float64)
    cnt_tyx = np.zeros((T, ny, nx), dtype=np.float64)

    for t in range(T):
        v = tv[t]
        ok = np.isfinite(v)  # only finite tiles contribute
        if not np.any(ok):
            continue
        np.add.at(sum_tyx[t], (iy[ok], ix[ok]), v[ok])
        np.add.at(cnt_tyx[t], (iy[ok], ix[ok]), 1.0)

    out = sum_tyx / np.where(cnt_tyx > 0, cnt_tyx, np.nan)

    da = xr.DataArray(
        out,
        dims=("time", "y", "x"),
        coords={
            "time": np.arange(T),
            "lat": (("y", "x"), m36_grid["lat"].values),
            "lon": (("y", "x"), m36_grid["lon"].values),
        },
        attrs={"note": "unweighted mean of tiles per M36 cell"}
    )

    if land_only:
        # mask cells that never received any finite tile across all times
        land_mask = np.isfinite(out).any(axis=0)
        da = da.where(land_mask)

    return da

def layer_weights_era5l(target_top=0.0, target_bot=1.0):
    """Weights for ERA5-Land's 4 layers to represent target depth [m]."""
    # ERA5-Land layer bounds (m):
    bounds = np.array([[0.00,0.07],[0.07,0.28],[0.28,1.00]], dtype=float)
    ov = np.maximum(0.0, np.minimum(bounds[:,1], target_bot) - np.maximum(bounds[:,0], target_top))
    if target_bot - target_top <= 0:
        raise ValueError("target_bot must be > target_top")
    w = ov / (target_bot - target_top)
    return w

def anomalies_monthly(da, dim="time"):
    """
    Monthly anomalies: remove month-of-year climatology (NaN-safe).
    """
    clim = da.groupby(f"{dim}.month").mean(dim=dim, skipna=True)
    return da.groupby(f"{dim}.month") - clim

def anom_metrics_monthly(a, b, dim="time", min_pairs=24):
    """
    NaN-safe anomR and ubRMSE on valid pairs only.
    Enforces a minimum number of valid pairs (default 24 months).
    """
    valid = xr.where(xr.ufuncs.isfinite(a) & xr.ufuncs.isfinite(b), True, False)
    n = valid.sum(dim=dim)

    a_v = a.where(valid)
    b_v = b.where(valid)

    cov   = (a_v * b_v).mean(dim=dim, skipna=True)
    var_a = (a_v**2).mean(dim=dim, skipna=True)
    var_b = (b_v**2).mean(dim=dim, skipna=True)

    denom = xr.ufuncs.sqrt(var_a * var_b)
    anomR  = xr.where(denom > 0, cov / denom, np.nan)
    ubRMSE = xr.ufuncs.sqrt(((a_v - b_v)**2).mean(dim=dim, skipna=True))

    # require enough valid months
    anomR  = xr.where(n >= min_pairs, anomR, np.nan)
    ubRMSE = xr.where(n >= min_pairs, ubRMSE, np.nan)
    return anomR, ubRMSE


In [None]:
# Read tilecoord and build grid
tc = read_tilecoord(ftc)
print(f"N_tile = {tc['N_tile']}")
m36_grid = build_m36_grid_cf(tc)
ny, nx = m36_grid.dims["y"], m36_grid.dims["x"]
print("M36 grid:", ny, "x", nx)

# Load model monthly (tile-based)
ds_mod = xr.open_dataset(model_nc)
ds_mod = xr.decode_cf(ds_mod)  # ensure time decoding

# Surface SM (tiles → M36)
sfmc_tiles = ds_mod["SFMC"].values.astype(np.float64)  # (time, tile)
sfmc_grid  = tiles_to_m36_grid_unweighted(sfmc_tiles, tc, m36_grid, land_only=True)
sfmc_grid = sfmc_grid.assign_coords(time=ds_mod["time"].values)
print("SFMC grid shape:", sfmc_grid.shape)

# If you have RZMC in the same file, aggregate similarly:
rzmc_grid = None
if "RZMC" in ds_mod.data_vars:
    rzmc_tiles = ds_mod["RZMC"].values.astype(np.float64)
    rzmc_grid  = tiles_to_m36_grid_unweighted(rzmc_tiles, tc, m36_grid, land_only=True)
    rzmc_grid  = rzmc_grid.assign_coords(time=ds_mod["time"].values)
    print("RZMC grid shape:", rzmc_grid.shape)
else:
    print("RZMC not present in model file; skipping model RZ comparison for now.")

scf_grid = None
if "FRLANDSNO" in ds_mod.data_vars:
    scf_tiles = ds_mod["FRLANDSNO"].values.astype(np.float64)   # units: 1 (0–1)
    scf_grid  = tiles_to_m36_grid_unweighted(scf_tiles, tc, m36_grid, land_only=True)
    scf_grid  = scf_grid.assign_coords(time=ds_mod["time"].values)
    print("SCF (FRLANDSNO) grid shape:", scf_grid.shape)
else:
    print("FRLANDSNO not present in model file; skipping SCF.")

# NEW: Soil temperature layer 1 (tiles → M36)
tsoil1_grid = None
if "TSOIL1" in ds_mod.data_vars:
    tsoil1_tiles = ds_mod["TSOIL1"].values.astype(np.float64)   # units: K
    tsoil1_grid  = tiles_to_m36_grid_unweighted(tsoil1_tiles, tc, m36_grid, land_only=True)
    tsoil1_grid  = tsoil1_grid.assign_coords(time=ds_mod["time"].values)
    print("TSOIL1 grid shape:", tsoil1_grid.shape)
else:
    print("TSOIL1 not present in model file; skipping soil temperature.")    


In [None]:
# Load ERA5-Land monthly (swvl1..4 + stl1 + snowc)
ds_era = xr.open_dataset(era_nc)

# CHANGED: ensure we have a 'time' coord
if 'valid_time' in ds_era:
    ds_era = ds_era.rename({'valid_time': 'time'})

# Ensure lat ascending and lon in [0,360)
if ds_era.latitude[0] > ds_era.latitude[-1]:
    ds_era = ds_era.reindex(latitude=list(reversed(ds_era.latitude.values)))
ds_era = ds_era.assign_coords(longitude=np.mod(ds_era.longitude, 360.0)).sortby("longitude")

# CHANGED (optional): intersect time with model, keeps arrays aligned later
if 'time' in ds_mod and 'time' in ds_era:
    common_time = np.intersect1d(ds_mod['time'].values, ds_era['time'].values)
    ds_era = ds_era.sel(time=common_time)

# Make sure M36 grid has CF attrs (already done in builder), then build regridder
# NOTE: for intensive vars (SM, T, fraction) consider "bilinear" or "conservative_normed".
WEIGHTS = "weights_era_to_m36_consnormed.nc"

if os.path.exists(WEIGHTS):
    # fast path: load precomputed weights
    regridder = xe.Regridder(
        ds_era, m36_grid,
        method="conservative_normed",
        periodic=True,
        filename=WEIGHTS,
        reuse_weights=True
    )
else:
    # first run: build and save weights
    regridder = xe.Regridder(
        ds_era, m36_grid,
        method="conservative_normed",
        periodic=True,
        filename=WEIGHTS,
        reuse_weights=False
    )
print(regridder)

# Regrid ERA5-Land soil moisture layers to M36
era_swvl1 = regridder(ds_era["swvl1"]).rename("ERA5L_swvl1")
era_swvl2 = regridder(ds_era["swvl2"]).rename("ERA5L_swvl2")
era_swvl3 = regridder(ds_era["swvl3"]).rename("ERA5L_swvl3")

# NEW: regrid soil temperature (K) and snow cover (%)
era_stl1_K   = regridder(ds_era["stl1"]).rename("ERA5L_stl1")
era_snow_pct = regridder(ds_era["snowc"]).rename("ERA5L_snowc")

# NEW: convert snow to fraction [0–1] with safe bounds
era_snow_frac = (era_snow_pct / 100.0).clip(0.0, 1.0).rename("ERA5L_snow_frac")

print("ERA5-Land regridded shapes:",
      era_swvl1.shape, era_stl1_K.shape, era_snow_frac.shape)


In [None]:
# Inputs (already on M36 grid):
# Model: sfmc_grid, rzmc_grid (optional), scf_grid, tsoil1_grid
# ERA:   era_swvl1, era_swvl2, era_swvl3, era_stl1_K, era_snow_frac

def _to_month_period(dt64):
    dt = pd.to_datetime(dt64)
    return dt.to_period('M')

# Determine ERA time coord name from era_swvl1
tname = 'valid_time' if 'valid_time' in era_swvl1.dims or 'valid_time' in era_swvl1.coords else 'time'

# Sort by time
sfmc_grid   = sfmc_grid.sortby('time')
scf_grid    = scf_grid.sortby('time')
tsoil1_grid = tsoil1_grid.sortby('time')

era_swvl1   = era_swvl1.sortby(tname)
era_swvl2   = era_swvl2.sortby(tname)
era_swvl3   = era_swvl3.sortby(tname)
era_stl1_K  = era_stl1_K.sortby(tname)
era_snow_frac = era_snow_frac.sortby(tname)

# Drop duplicate months (keep first)
_, idx_mod_unique = np.unique(_to_month_period(sfmc_grid['time'].values), return_index=True)
sfmc_grid   = sfmc_grid.isel(time=np.sort(idx_mod_unique))
scf_grid    = scf_grid.isel(time=np.sort(idx_mod_unique))
tsoil1_grid = tsoil1_grid.isel(time=np.sort(idx_mod_unique))

_, idx_era_unique = np.unique(_to_month_period(era_swvl1[tname].values), return_index=True)
era_swvl1   = era_swvl1.isel({tname: np.sort(idx_era_unique)})
era_swvl2   = era_swvl2.isel({tname: np.sort(idx_era_unique)})
era_swvl3   = era_swvl3.isel({tname: np.sort(idx_era_unique)})
era_stl1_K  = era_stl1_K.isel({tname: np.sort(idx_era_unique)})
era_snow_frac = era_snow_frac.isel({tname: np.sort(idx_era_unique)})

# Intersect by monthly period strings
ym_mod = _to_month_period(sfmc_grid['time'].values)
ym_era = _to_month_period(era_swvl1[tname].values)
commonM = np.intersect1d(ym_mod.astype(str), ym_era.astype(str))
assert commonM.size > 0, "No overlapping months between model and ERA5-Land."

midx = np.isin(ym_mod.astype(str), commonM)
eidx = np.isin(ym_era.astype(str), commonM)

# Subset model
sfmc_aln    = sfmc_grid.isel(time=np.where(midx)[0])
scf_aln     = scf_grid.isel(time=np.where(midx)[0])
tsoil1_aln  = tsoil1_grid.isel(time=np.where(midx)[0])

# Subset ERA
era_sfc_aln = era_swvl1.isel({tname: np.where(eidx)[0]})
era_l2_aln  = era_swvl2.isel({tname: np.where(eidx)[0]})
era_l3_aln  = era_swvl3.isel({tname: np.where(eidx)[0]})
era_t1_aln  = era_stl1_K.isel({tname: np.where(eidx)[0]})
era_scf_aln = era_snow_frac.isel({tname: np.where(eidx)[0]})

# Assign a shared end-of-month timestamp and rename ERA time -> 'time'
common_ts = pd.PeriodIndex(commonM, freq='M').to_timestamp('M')

sfmc_aln    = sfmc_aln.assign_coords(time=common_ts)
scf_aln     = scf_aln.assign_coords(time=common_ts)
tsoil1_aln  = tsoil1_aln.assign_coords(time=common_ts)

era_sfc_aln = era_sfc_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})
era_l2_aln  = era_l2_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})
era_l3_aln  = era_l3_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})
era_t1_aln  = era_t1_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})
era_scf_aln = era_scf_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})

# ERA RZSM (0–100 cm) using your helper weights
w = layer_weights_era5l(0.0, 1.0)  # returns weights for swvl1,2,3 over 0–100 cm
era_rz_aln = (w[0]*era_sfc_aln + w[1]*era_l2_aln + w[2]*era_l3_aln).rename('ERA5L_RZSM')

# Align model RZMC if available
if 'rzmc_grid' in locals() and rzmc_grid is not None:
    rzmc_aln = rzmc_grid.sortby('time').isel(time=np.sort(idx_mod_unique))
    rzmc_aln = rzmc_aln.isel(time=np.where(midx)[0]).assign_coords(time=common_ts)
else:
    rzmc_aln = None

# Ensure float dtype
sfmc_aln    = sfmc_aln.astype('float64')
scf_aln     = scf_aln.astype('float64')
tsoil1_aln  = tsoil1_aln.astype('float64')
rzmc_aln    = rzmc_aln.astype('float64') if rzmc_aln is not None else None
era_sfc_aln = era_sfc_aln.astype('float64')
era_l2_aln  = era_l2_aln.astype('float64')
era_l3_aln  = era_l3_aln.astype('float64')
era_t1_aln  = era_t1_aln.astype('float64')
era_scf_aln = era_scf_aln.astype('float64')
era_rz_aln  = era_rz_aln.astype('float64')

print(f"Aligned months: {sfmc_aln.time.size} from {str(sfmc_aln.time.values[0])[:10]} to {str(sfmc_aln.time.values[-1])[:10]}")
print("RZ weights (0–100 cm):", w, "sum =", w.sum())


In [None]:

# --- Random 12-month visual checks for SM, RZSM, Tsoil, SCF (with per-month min/max) ---

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr

def plot_random_months(mod_da, era_da, title, units, N_SAMPLES=12, SEED=42):
    # sanity
    assert mod_da.dims == era_da.dims == ("time","y","x"), f"{title}: dims mismatch or unexpected"
    assert (mod_da["time"].values == era_da["time"].values).all(), f"{title}: time coords differ"
    T = mod_da.sizes["time"]

    # pick months
    rng  = np.random.default_rng(SEED)
    idx  = rng.choice(T, size=min(N_SAMPLES, T), replace=False)
    idx.sort()
    times = pd.to_datetime(mod_da["time"].values[idx])

    # shared color limits across sampled months
    mod_sel  = mod_da.isel(time=idx)
    era_sel  = era_da.isel(time=idx)
    both_val = xr.concat([mod_sel, era_sel], dim="stack").values
    vmin = np.nanpercentile(both_val, 2)
    vmax = np.nanpercentile(both_val, 98)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        vmin, vmax = np.nanmin(both_val), np.nanmax(both_val)

    # diff symmetric limits
    diff_vals = (mod_sel - era_sel).values
    dabs = np.nanpercentile(np.abs(diff_vals), 98)
    if not np.isfinite(dabs) or dabs == 0:
        dabs = np.nanmax(np.abs(diff_vals)) if np.isfinite(np.nanmax(np.abs(diff_vals))) else 1.0

    # track overall min/max
    overall = {"mod_min": np.inf, "mod_max": -np.inf, "era_min": np.inf, "era_max": -np.inf, "dif_min": np.inf, "dif_max": -np.inf}

    # figure
    fig, axes = plt.subplots(len(idx), 3, figsize=(14, 3.2*len(idx)), constrained_layout=True)
    if len(idx) == 1:
        axes = np.array([axes])

    for r, t_i in enumerate(idx):
        ts = pd.to_datetime(mod_da["time"].values[t_i])
        mod_t = mod_da.isel(time=t_i)
        era_t = era_da.isel(time=t_i)
        diff  = mod_t - era_t

        # per-month min/max
        mmin, mmax = np.nanmin(mod_t.values),  np.nanmax(mod_t.values)
        emin, emax = np.nanmin(era_t.values),  np.nanmax(era_t.values)
        dmin, dmax = np.nanmin(diff.values),   np.nanmax(diff.values)
        print(f"{title}  {ts:%Y-%m}  MOD[min,max]=[{mmin:.4f}, {mmax:.4f}]  "
              f"ERA[min,max]=[{emin:.4f}, {emax:.4f}]  DIFF[min,max]=[{dmin:.4f}, {dmax:.4f}]")

        overall["mod_min"] = min(overall["mod_min"], mmin)
        overall["mod_max"] = max(overall["mod_max"], mmax)
        overall["era_min"] = min(overall["era_min"], emin)
        overall["era_max"] = max(overall["era_max"], emax)
        overall["dif_min"] = min(overall["dif_min"], dmin)
        overall["dif_max"] = max(overall["dif_max"], dmax)

        # plots
        im0 = axes[r, 0].imshow(mod_t, vmin=vmin, vmax=vmax)
        axes[r, 0].set_title(f"{title} • Model @ {ts:%Y-%m}")
        axes[r, 0].set_xlabel("x"); axes[r, 0].set_ylabel("y")
        cb0 = plt.colorbar(im0, ax=axes[r, 0], fraction=0.046, pad=0.04); cb0.set_label(units)

        im1 = axes[r, 1].imshow(era_t, vmin=vmin, vmax=vmax)
        axes[r, 1].set_title(f"{title} • ERA @ {ts:%Y-%m}")
        axes[r, 1].set_xlabel("x"); axes[r, 1].set_ylabel("y")
        cb1 = plt.colorbar(im1, ax=axes[r, 1], fraction=0.046, pad=0.04); cb1.set_label(units)

        im2 = axes[r, 2].imshow(diff, vmin=-dabs, vmax=dabs, cmap="coolwarm")
        axes[r, 2].set_title("Model − ERA")
        axes[r, 2].set_xlabel("x"); axes[r, 2].set_ylabel("y")
        cb2 = plt.colorbar(im2, ax=axes[r, 2], fraction=0.046, pad=0.04); cb2.set_label(units)

    plt.show()
    print(f"\n{title} • Overall across sampled months:")
    print(f"  MOD  min/max: [{overall['mod_min']:.4f}, {overall['mod_max']:.4f}]")
    print(f"  ERA  min/max: [{overall['era_min']:.4f}, {overall['era_max']:.4f}]")
    print(f"  DIFF min/max: [{overall['dif_min']:.4f}, {overall['dif_max']:.4f}]\n")

# ---- Call for each variable pair ----
plot_random_months(sfmc_aln,   era_sfc_aln,  title="Surface SM (SFMC vs swvl1)",        units="m³ m⁻³")
if 'rzmc_aln' in locals() and rzmc_aln is not None:
    plot_random_months(rzmc_aln, era_rz_aln, title="Root-zone SM (RZMC vs ERA RZ)",      units="m³ m⁻³")
plot_random_months(tsoil1_aln, era_t1_aln,   title="Soil Temperature (TSOIL1 vs stl1)",  units="K")
plot_random_months(scf_aln,    era_scf_aln,  title="Snow Cover Fraction (SCF)",          units="1")


In [None]:
# Inputs (aligned by time, dims = ("time","y","x")):
#   Model: tsoil1_aln (K), scf_aln (0–1)
#   ERA:   era_t1_aln (K), era_scf_aln (0–1)

# T and snow thresholds
TEMP_THRESH_K = 275.15   # 2°C
SNOW_EPS     = 1e-2      # treat <1% as "snow-free" after regridding

# Sanity: same time axis
assert (tsoil1_aln["time"].values == era_t1_aln["time"].values).all(), "Time coords differ between model and ERA."

# Per-dataset masks
mask_model = (tsoil1_aln > TEMP_THRESH_K) & (scf_aln    < SNOW_EPS)
mask_era   = (era_t1_aln > TEMP_THRESH_K) & (era_scf_aln < SNOW_EPS)

# Intersection (use this to mask SM for comparison)
mask_both  = mask_model & mask_era

# Quick coverage summary
def cov_str(m):
    frac = float(m.mean().values)  # fraction over all time,y,x
    return f"{100.0*frac:.1f}% of (time,y,x)"

print("Mask coverage:")
print("  Model-only mask  :", cov_str(mask_model))
print("  ERA-only mask    :", cov_str(mask_era))
print("  BOTH (intersection):", cov_str(mask_both))

In [None]:
MIN_PAIRS = 24  # months required for stats

# 1) Surface SM (masked by intersection of criteria)
sm_mod_masked = sfmc_aln.where(mask_both).astype("float64")
sm_era_masked = era_sfc_aln.where(mask_both).astype("float64")

sm_mod_an = anomalies_monthly(sm_mod_masked)
sm_era_an = anomalies_monthly(sm_era_masked)
anomR_sfc, ubRMSE_sfc = anom_metrics_monthly(sm_mod_an, sm_era_an, min_pairs=MIN_PAIRS)

print("Surface SM — mean anomR:",
      float(anomR_sfc.mean(skipna=True).values),
      " mean ubRMSE:",
      float(ubRMSE_sfc.mean(skipna=True).values))

# 2) Root-zone SM (masked) — only if available
anomR_rz = ubRMSE_rz = None
if ('rzmc_aln' in locals() and rzmc_aln is not None and
    'era_rz_aln' in locals() and era_rz_aln is not None):

    rz_mod_masked = rzmc_aln.where(mask_both).astype("float64")
    rz_era_masked = era_rz_aln.where(mask_both).astype("float64")

    rz_mod_an = anomalies_monthly(rz_mod_masked)
    rz_era_an = anomalies_monthly(rz_era_masked)

    anomR_rz, ubRMSE_rz = anom_metrics_monthly(rz_mod_an, rz_era_an, min_pairs=MIN_PAIRS)

    print("Root-zone SM — mean anomR:",
          float(anomR_rz.mean(skipna=True).values),
          " mean ubRMSE:",
          float(ubRMSE_rz.mean(skipna=True).values))
else:
    print("Root-zone SM not available; skipping RZ metrics.")

# 3) Snow Cover Fraction (unmasked; masking by snow-free would trivialize SCF)
scf_mod_an = anomalies_monthly(scf_aln.astype("float64"))
scf_era_an = anomalies_monthly(era_scf_aln.astype("float64"))
anomR_scf, ubRMSE_scf = anom_metrics_monthly(scf_mod_an, scf_era_an, min_pairs=MIN_PAIRS)

print("SCF — mean anomR:",
      float(anomR_scf.mean(skipna=True).values),
      " mean ubRMSE:",
      float(ubRMSE_scf.mean(skipna=True).values))

# Optional: keep outputs together for downstream use
metrics = xr.Dataset(
    dict(
        anomR_sfc=anomR_sfc, ubRMSE_sfc=ubRMSE_sfc,
        anomR_rz=anomR_rz if anomR_rz is not None else xr.full_like(anomR_sfc, np.nan),
        ubRMSE_rz=ubRMSE_rz if ubRMSE_rz is not None else xr.full_like(ubRMSE_sfc, np.nan),
        anomR_scf=anomR_scf, ubRMSE_scf=ubRMSE_scf,
    )
)

In [None]:
# --- Quick global maps of anomaly metrics (Surface SM, RZ SM, SCF) ---

def robust_ub_range(m):
    """Robust upper bound for ubRMSE color scale."""
    v = np.asarray(m)
    v = v[np.isfinite(v)]
    if v.size == 0:
        return 1.0
    ub = np.nanpercentile(v, 98)
    return 1.0 if not np.isfinite(ub) or ub <= 0 else float(ub)

# Assemble rows to plot
rows = []
rows.append(("Surface SM", anomR_sfc, ubRMSE_sfc, "m³ m⁻³"))
if 'anomR_rz' in locals() and anomR_rz is not None and np.any(np.isfinite(anomR_rz)):
    rows.append(("Root-zone SM", anomR_rz, ubRMSE_rz, "m³ m⁻³"))
rows.append(("Snow Cover Fraction", anomR_scf, ubRMSE_scf, "1"))

nrows = len(rows)
fig, axes = plt.subplots(nrows, 2, figsize=(12, 4.2*nrows), constrained_layout=True)
if nrows == 1:
    axes = np.array([axes])

for r, (title, anomR_da, ub_da, units) in enumerate(rows):
    # Ensure plain numpy 2D arrays for imshow
    anomR = np.asarray(anomR_da)
    ub    = np.asarray(ub_da)

    # Row: left = corr, right = ubRMSE
    im0 = axes[r, 0].imshow(anomR, vmin=-1, vmax=1, cmap="RdBu_r")
    axes[r, 0].set_title(f"{title} — Anomaly correlation")
    axes[r, 0].set_xlabel("x"); axes[r, 0].set_ylabel("y")
    cb0 = plt.colorbar(im0, ax=axes[r, 0], fraction=0.046, pad=0.04)
    cb0.set_label("corr")

    vmax = robust_ub_range(ub)
    im1 = axes[r, 1].imshow(ub, vmin=0, vmax=vmax)
    axes[r, 1].set_title(f"{title} — ubRMSE")
    axes[r, 1].set_xlabel("x"); axes[r, 1].set_ylabel("y")
    cb1 = plt.colorbar(im1, ax=axes[r, 1], fraction=0.046, pad=0.04)
    cb1.set_label(units)

    # Print quick summaries
    mean_corr = float(np.nanmean(anomR))
    mean_ub   = float(np.nanmean(ub))
    print(f"{title}: mean corr = {mean_corr:.3f}, mean ubRMSE = {mean_ub:.4f} {units}")

plt.show()


In [None]:
# --- Save useful outputs to NetCDF (includes lat/lon) ---

# choose a filename
out_nc = "ERA5L_vs_DAv8_M36_summary.nc"

def add_var(ds, name, da, attrs=None):
    if (name not in ds) and (da is not None):
        ds[name] = da
        if attrs:
            ds[name].attrs.update(attrs)

# Build dataset
ds_out = xr.Dataset()

# Coordinates (time, y, x)
# grab from any aligned field you have; sfmc_aln is fine
ds_out = ds_out.assign_coords(
    time = sfmc_aln["time"].copy(),
    y    = sfmc_aln["y"].copy() if "y" in sfmc_aln.coords else xr.DataArray(np.arange(sfmc_aln.sizes["y"]), dims=("y",)),
    x    = sfmc_aln["x"].copy() if "x" in sfmc_aln.coords else xr.DataArray(np.arange(sfmc_aln.sizes["x"]), dims=("x",)),
    lat  = (("y","x"), m36_grid["lat"].values),
    lon  = (("y","x"), m36_grid["lon"].values),
)
ds_out["lat"].attrs.update(dict(standard_name="latitude",  units="degrees_north"))
ds_out["lon"].attrs.update(dict(standard_name="longitude", units="degrees_east"))

# Core aligned variables (monthly)
add_var(ds_out, "SM_model",   sfmc_aln.astype("float64"), dict(long_name="Model surface soil moisture", units="m3 m-3"))
add_var(ds_out, "SM_era",     era_sfc_aln.astype("float64"), dict(long_name="ERA5-Land swvl1", units="m3 m-3"))

if 'rzmc_aln' in locals() and rzmc_aln is not None:
    add_var(ds_out, "RZ_model", rzmc_aln.astype("float64"), dict(long_name="Model root-zone soil moisture", units="m3 m-3"))
if 'era_rz_aln' in locals() and era_rz_aln is not None:
    add_var(ds_out, "RZ_era",   era_rz_aln.astype("float64"), dict(long_name="ERA5-Land 0–100 cm RZ proxy", units="m3 m-3"))

add_var(ds_out, "SCF_model",  scf_aln.astype("float64"), dict(long_name="Model snow cover fraction", units="1"))
add_var(ds_out, "SCF_era",    era_scf_aln.astype("float64"), dict(long_name="ERA5-Land snow cover fraction", units="1"))

add_var(ds_out, "Tsoil_model", tsoil1_aln.astype("float64"), dict(long_name="Model soil temperature level 1", units="K"))
add_var(ds_out, "Tsoil_era",   era_t1_aln.astype("float64"), dict(long_name="ERA5-Land soil temperature level 1", units="K"))

# Masks
if 'mask_model' in locals():
    add_var(ds_out, "mask_model", mask_model.astype("bool"), dict(long_name="Model mask: T>2C & SCF≈0", units="1"))
if 'mask_era' in locals():
    add_var(ds_out, "mask_era",   mask_era.astype("bool"),   dict(long_name="ERA mask: T>2C & SCF≈0", units="1"))
if 'mask_both' in locals():
    add_var(ds_out, "mask_both",  mask_both.astype("bool"),  dict(long_name="Intersection mask (model & ERA)", units="1"))

# Metrics (spatial 2D)
if 'anomR_sfc' in locals():
    add_var(ds_out, "anomR_sfc",  anomR_sfc.astype("float64"),  dict(long_name="Anomaly correlation (surface SM)", units="1"))
if 'ubRMSE_sfc' in locals():
    add_var(ds_out, "ubRMSE_sfc", ubRMSE_sfc.astype("float64"), dict(long_name="ubRMSE (surface SM)", units="m3 m-3"))

if 'anomR_rz' in locals() and anomR_rz is not None:
    add_var(ds_out, "anomR_rz",   anomR_rz.astype("float64"),   dict(long_name="Anomaly correlation (root-zone SM)", units="1"))
if 'ubRMSE_rz' in locals() and ubRMSE_rz is not None:
    add_var(ds_out, "ubRMSE_rz",  ubRMSE_rz.astype("float64"),  dict(long_name="ubRMSE (root-zone SM)", units="m3 m-3"))

if 'anomR_scf' in locals():
    add_var(ds_out, "anomR_scf",  anomR_scf.astype("float64"),  dict(long_name="Anomaly correlation (SCF)", units="1"))
if 'ubRMSE_scf' in locals():
    add_var(ds_out, "ubRMSE_scf", ubRMSE_scf.astype("float64"), dict(long_name="ubRMSE (SCF)", units="1"))

# Global attributes
ds_out.attrs.update(OrderedDict(
    title      = "ERA5-Land vs Model (M36) aligned fields, masks, and anomaly metrics",
    conventions= "CF-1.8",
    note       = "SM fields masked with (T>2C & SCF≈0) when computing metrics; SCF unmasked.",
))

# Encodings for compression (float vars → float32 with zlib; masks → int8)
encoding = {}
for v in ds_out.data_vars:
    if str(ds_out[v].dtype).startswith(("bool",)):
        encoding[v] = dict(zlib=True, complevel=4, dtype="int8", _FillValue=0)
    else:
        encoding[v] = dict(zlib=True, complevel=4, dtype="float32", _FillValue=np.float32(np.nan))

# Coordinate encodings (optional: compress lat/lon too)
encoding["lat"] = dict(zlib=True, complevel=4, dtype="float32", _FillValue=np.float32(np.nan))
encoding["lon"] = dict(zlib=True, complevel=4, dtype="float32", _FillValue=np.float32(np.nan))

# Save
ds_out.to_netcdf(out_nc, format="NETCDF4", encoding=encoding)
print(f"Saved: {out_nc}")
print("Variables saved:", list(ds_out.data_vars))
