In [None]:
"""
COMPARE GEOS-LDAS SURFACE SOIL MOISTURE (SFMC) TO ERA5-LAND SWVL1
==================================================================
This notebook:
 1. Reads the GEOS-LDAS tilecoord binary and model output.
 2. Builds the exact M36 grid geometry from tile bounds.
 3. Regrids ERA5-Land monthly soil moisture to that grid using xESMF (conservative).
 4. Aggregates model tile data to the same grid (area weighted, land only).
 5. Computes anomaly correlation & ubRMSE between model and ERA5-Land.
"""

import numpy as np
import xarray as xr
import pandas as pd
import struct
import matplotlib.pyplot as plt
import xesmf as xe
import datetime as dt

# === File paths (edit for your system) ===
ftc = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/land_sweeper/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/rc_out/LS_OLv8_M36.ldas_tilecoord.bin'
# era_nc = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/M21C_land_sweeper/Evaluation/era_land_sm_2017_monhtly_test.nc'
era_nc = '/Users/amfox/Desktop/GEOSldas_diagnostics/Jupyter/era5l_monthly_nc/ERA5L_monthly_merged.nc'
model_nc = '/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/M21C_land_sweeper/LS_DAv8_M36_v2/LS_DAv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/DAv8_land_variables_2000_2024_compressed.nc'


In [None]:
# --- SOURCE: 1° global grid (centers + bounds) ---
lon_b_src = np.linspace(0.0, 360.0, 361)          # 361 edges
lat_b_src = np.linspace(-90.0, 90.0, 181)         # 181 edges
lon_src   = 0.5*(lon_b_src[:-1] + lon_b_src[1:])  # 360 centers
lat_src   = 0.5*(lat_b_src[:-1] + lat_b_src[1:])  # 180 centers

src = xr.Dataset(
    {
        "lon":    (("lon",), lon_src),
        "lat":    (("lat",), lat_src),
        "lon_b":  (("lon_b",), lon_b_src),
        "lat_b":  (("lat_b",), lat_b_src),
    }
)
src["lon"].attrs.update(standard_name="longitude", units="degrees_east", bounds="lon_b")
src["lat"].attrs.update(standard_name="latitude",  units="degrees_north", bounds="lat_b")
src["lon_b"].attrs.update(standard_name="longitude", units="degrees_east")
src["lat_b"].attrs.update(standard_name="latitude",  units="degrees_north")

# --- TARGET: 6° global grid ---
lon_b_tgt = np.linspace(0.0, 360.0, 61)           # 61 edges
lat_b_tgt = np.linspace(-90.0, 90.0, 31)          # 31 edges
lon_tgt   = 0.5*(lon_b_tgt[:-1] + lon_b_tgt[1:])  # 60 centers
lat_tgt   = 0.5*(lat_b_tgt[:-1] + lat_b_tgt[1:])  # 30 centers

tgt = xr.Dataset(
    {
        "lon":    (("lon",), lon_tgt),
        "lat":    (("lat",), lat_tgt),
        "lon_b":  (("lon_b",), lon_b_tgt),
        "lat_b":  (("lat_b",), lat_b_tgt),
    }
)
tgt["lon"].attrs.update(standard_name="longitude", units="degrees_east", bounds="lon_b")
tgt["lat"].attrs.update(standard_name="latitude",  units="degrees_north", bounds="lat_b")
tgt["lon_b"].attrs.update(standard_name="longitude", units="degrees_east")
tgt["lat_b"].attrs.update(standard_name="latitude",  units="degrees_north")

# --- Build regridder (conservative works because bounds are present) ---
r = xe.Regridder(src, tgt, method="conservative", periodic=True, reuse_weights=False)
print(r)  # should print weights file info


In [None]:
def read_tilecoord(fname):
    """Read GEOS-LDAS tilecoord Fortran binary (little-endian)."""
    int_precision = 'i'
    float_precision = 'f'
    machfmt = '<'
    tile_coord = {}
    with open(fname, 'rb') as ifp:
        _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        tile_coord['N_tile'] = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
        Nt = tile_coord['N_tile']
        fields = ['tile_id','typ','pfaf','com_lon','com_lat','min_lon','max_lon',
                  'min_lat','max_lat','i_indg','j_indg','frac_cell','frac_pfaf',
                  'area','elev']
        for field in fields:
            _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
            dtype = int_precision if field in ['tile_id','typ','pfaf','i_indg','j_indg'] else float_precision
            arr = np.frombuffer(ifp.read(Nt*4), dtype=f'{machfmt}{dtype}')
            arr = arr.astype(np.float64 if dtype=='f' else np.int32)
            tile_coord[field] = arr
            _ = struct.unpack(f'{machfmt}i', ifp.read(4))[0]
    return tile_coord

def build_m36_grid_cf(tc):
    """Build exact M36 grid with centers and bounds (CF-compliant) from tilecoord."""
    i_indg = tc['i_indg'].astype(int)
    j_indg = tc['j_indg'].astype(int)
    min_lon, max_lon = tc['min_lon'], tc['max_lon']
    min_lat, max_lat = tc['min_lat'], tc['max_lat']
    nx = int(i_indg.max()) + 1
    ny = int(j_indg.max()) + 1

    lon_b = np.full((ny+1, nx+1), np.nan)
    lat_b = np.full((ny+1, nx+1), np.nan)
    for lon0, lon1, lat0, lat1, i, j in zip(min_lon, max_lon, min_lat, max_lat, i_indg, j_indg):
        lon_b[j, i]     = lon0; lon_b[j, i+1] = lon1
        lon_b[j+1, i]   = lon0; lon_b[j+1, i+1] = lon1
        lat_b[j, i]     = lat0; lat_b[j+1, i]   = lat1
        lat_b[j, i+1]   = lat0; lat_b[j+1, i+1] = lat1

    lon_c = 0.5*(lon_b[:-1,:-1] + lon_b[1:,1:])
    lat_c = 0.5*(lat_b[:-1,:-1] + lat_b[1:,1:])

    m36 = xr.Dataset(
        {"lon": (("y","x"), lon_c),
         "lat": (("y","x"), lat_c),
         "lon_b": (("y_b","x_b"), lon_b),
         "lat_b": (("y_b","x_b"), lat_b)}
    )
    m36["lon"].attrs.update(standard_name="longitude", units="degrees_east", bounds="lon_b")
    m36["lat"].attrs.update(standard_name="latitude",  units="degrees_north", bounds="lat_b")
    m36["lon_b"].attrs.update(standard_name="longitude", units="degrees_east")
    m36["lat_b"].attrs.update(standard_name="latitude",  units="degrees_north")
    return m36

def tiles_to_m36_grid(tile_values, tc, m36_grid, land_only=True):
    """Aggregate tile (T,Ntile) to (T, y, x) using frac_cell weights."""
    i_indg = tc['i_indg'].astype(int)
    j_indg = tc['j_indg'].astype(int)
    frac   = tc['frac_cell']
    nx = int(i_indg.max()) + 1
    ny = int(j_indg.max()) + 1

    if tile_values.ndim == 1:
        tile_values = tile_values[np.newaxis,:]
    T = tile_values.shape[0]

    grid_sum = np.zeros((T, ny, nx), dtype=np.float64)
    grid_wt  = np.zeros((ny, nx), dtype=np.float64)

    for t in range(T):
        np.add.at(grid_sum[t], (j_indg, i_indg), np.nan_to_num(tile_values[t]) * frac)
        np.add.at(grid_wt,     (j_indg, i_indg), frac)

    out = np.where(grid_wt>0, grid_sum/np.where(grid_wt==0, 1, grid_wt), np.nan)
    da = xr.DataArray(out, dims=("time","y","x"),
                      coords={"time": np.arange(T),
                              "lat": (("y","x"), m36_grid["lat"].values),
                              "lon": (("y","x"), m36_grid["lon"].values)})
    if land_only:
        land_mask = grid_wt>0
        da = da.where(land_mask)
    return da

def layer_weights_era5l(target_top=0.0, target_bot=1.0):
    """Weights for ERA5-Land's 4 layers to represent target depth [m]."""
    # ERA5-Land layer bounds (m):
    bounds = np.array([[0.00,0.07],[0.07,0.28],[0.28,1.00]], dtype=float)
    ov = np.maximum(0.0, np.minimum(bounds[:,1], target_bot) - np.maximum(bounds[:,0], target_top))
    if target_bot - target_top <= 0:
        raise ValueError("target_bot must be > target_top")
    w = ov / (target_bot - target_top)
    return w

def anomalies(da, dim='time'):
    return da - da.mean(dim=dim)

def anom_metrics(a, b, dim='time'):
    cov   = (a * b).mean(dim=dim)
    var_a = (a**2).mean(dim=dim)
    var_b = (b**2).mean(dim=dim)
    anomR = cov / np.sqrt(var_a * var_b)
    ubRMSE = np.sqrt(((a - b)**2).mean(dim=dim))
    return anomR, ubRMSE


In [None]:
# Read tilecoord and build grid
tc = read_tilecoord(ftc)
print(f"N_tile = {tc['N_tile']}")
m36_grid = build_m36_grid_cf(tc)
ny, nx = m36_grid.dims["y"], m36_grid.dims["x"]
print("M36 grid:", ny, "x", nx)

# Load model monthly (tile-based)
ds_mod = xr.open_dataset(model_nc)
ds_mod = xr.decode_cf(ds_mod)  # ensure time decoding

# Surface SM (tiles → M36)
sfmc_tiles = ds_mod["SFMC"].values.astype(np.float64)  # (time, tile)
sfmc_grid  = tiles_to_m36_grid(sfmc_tiles, tc, m36_grid, land_only=True)
sfmc_grid = sfmc_grid.assign_coords(time=ds_mod["time"].values)
print("SFMC grid shape:", sfmc_grid.shape)

# If you have RZMC in the same file, aggregate similarly:
rzmc_grid = None
if "RZMC" in ds_mod.data_vars:
    rzmc_tiles = ds_mod["RZMC"].values.astype(np.float64)
    rzmc_grid  = tiles_to_m36_grid(rzmc_tiles, tc, m36_grid, land_only=True)
    rzmc_grid  = rzmc_grid.assign_coords(time=ds_mod["time"].values)
    print("RZMC grid shape:", rzmc_grid.shape)
else:
    print("RZMC not present in model file; skipping model RZ comparison for now.")


In [None]:
# Load ERA5-Land monthly (swvl1..4)
ds_era = xr.open_dataset(era_nc)

# Ensure lat ascending and lon in [0,360)
if ds_era.latitude[0] > ds_era.latitude[-1]:
    ds_era = ds_era.reindex(latitude=list(reversed(ds_era.latitude.values)))
ds_era = ds_era.assign_coords(longitude=np.mod(ds_era.longitude, 360.0)).sortby("longitude")

# Make sure M36 grid has CF attrs (already done in builder), then build conservative regridder
regridder = xe.Regridder(ds_era, m36_grid, method="conservative", periodic=True, reuse_weights=False)
print(regridder)

# Regrid ERA5-Land layers to M36
era_swvl1 = regridder(ds_era["swvl1"]).rename("ERA5L_swvl1")
era_swvl2 = regridder(ds_era["swvl2"]).rename("ERA5L_swvl2")
era_swvl3 = regridder(ds_era["swvl3"]).rename("ERA5L_swvl3")

print("ERA5-Land regridded shapes:", era_swvl1.shape)


In [None]:
# --- robust month alignment for ERA with 'valid_time' ---

def _to_month_period(dt64):
    dt = pd.to_datetime(dt64)
    return dt.to_period('M')

# Sort by time/valid_time
sfmc_grid = sfmc_grid.sortby('time')
tname = 'valid_time'
era_swvl1 = era_swvl1.sortby(tname)
era_swvl2 = era_swvl2.sortby(tname)
era_swvl3 = era_swvl3.sortby(tname)

# (optional) drop duplicate months (keep first) to avoid mismatch
_, idx_mod_unique = np.unique(_to_month_period(sfmc_grid['time'].values), return_index=True)
_, idx_era_unique = np.unique(_to_month_period(era_swvl1[tname].values), return_index=True)
sfmc_grid  = sfmc_grid.isel(time=np.sort(idx_mod_unique))
era_swvl1  = era_swvl1.isel({tname: np.sort(idx_era_unique)})
era_swvl2  = era_swvl2.isel({tname: np.sort(idx_era_unique)})
era_swvl3  = era_swvl3.isel({tname: np.sort(idx_era_unique)})

# Intersect by monthly periods
ym_mod  = _to_month_period(sfmc_grid['time'].values)
ym_era  = _to_month_period(era_swvl1[tname].values)
commonM = np.intersect1d(ym_mod.astype(str), ym_era.astype(str))
assert commonM.size > 0, "No overlapping months between model and ERA5-Land."

midx = np.isin(ym_mod.astype(str), commonM)
eidx = np.isin(ym_era.astype(str), commonM)

# Subset
sfmc_aln   = sfmc_grid.isel(time=np.where(midx)[0])
era_sfc_aln = era_swvl1.isel({tname: np.where(eidx)[0]})
era_l2_aln  = era_swvl2.isel({tname: np.where(eidx)[0]})
era_l3_aln  = era_swvl3.isel({tname: np.where(eidx)[0]})


# Assign a shared end-of-month timestamp and rename ERA time to 'time'
common_ts = pd.PeriodIndex(commonM, freq='M').to_timestamp('M')
sfmc_aln = sfmc_aln.assign_coords(time=common_ts)

era_sfc_aln = era_sfc_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})
era_l2_aln  = era_l2_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})
era_l3_aln  = era_l3_aln.assign_coords({tname: common_ts}).rename({tname: 'time'})

# Optional snow/freeze mask (uncomment when available)
# ds_snow   = xr.open_dataset(model_snow_nc)  # tiles, monthly
# snow_grid = tiles_to_m36_grid(ds_snow['FRLANDSNO'].values, tc, m36_grid, land_only=True)
# snow_aln  = snow_grid.sel(time=common_ts)
# mask      = snow_aln > 0.1
# sfmc_aln     = sfmc_aln.where(~mask)
# era_sfc_aln  = era_sfc_aln.where(~mask)
# era_l2_aln   = era_l2_aln.where(~mask)
# era_l3_aln   = era_l3_aln.where(~mask)
# era_l4_aln   = era_l4_aln.where(~mask)

# RZSM (0–100 cm) from ERA layers with overlap weights
w = layer_weights_era5l(0.0, 1.0)
era_rz_aln = (w[0]*era_sfc_aln + w[1]*era_l2_aln + w[2]*era_l3_aln).rename('ERA5L_RZSM')

# Align model RZMC if available (use same indexer + timestamps as SFMC)
if 'rzmc_grid' in locals() and rzmc_grid is not None:
    rzmc_aln = rzmc_grid.isel(time=np.where(midx)[0])
    rzmc_aln = rzmc_aln.assign_coords(time=common_ts)
else:
    rzmc_aln = None

# Ensure float dtype
sfmc_aln     = sfmc_aln.astype('float64')
era_sfc_aln  = era_sfc_aln.astype('float64')
era_l2_aln   = era_l2_aln.astype('float64')
era_l3_aln   = era_l3_aln.astype('float64')
era_rz_aln   = era_rz_aln.astype('float64')

print("Aligned months:", sfmc_aln.time.size, "from", str(sfmc_aln.time.values[0])[:10], "to", str(sfmc_aln.time.values[-1])[:10])
print("RZ weights (0–100 cm):", w, "sum =", w.sum())


In [None]:
# Monthly anomalies over common period (set a fixed climatology period if preferred)
sfmc_anom  = anomalies(sfmc_aln)
era_sfc_an = anomalies(era_sfc_aln)

anomR_sfc, ubRMSE_sfc = anom_metrics(sfmc_anom, era_sfc_an, dim='time')

print("Surface SM — mean anomR:", float(anomR_sfc.mean().values),
      " mean ubRMSE:", float(ubRMSE_sfc.mean().values))

# Root-zone, if model RZMC available
anomR_rz = ubRMSE_rz = None
if rzmc_aln is not None:
    rzmc_anom  = anomalies(rzmc_aln)
    era_rz_an  = anomalies(era_rz_aln)
    anomR_rz, ubRMSE_rz = anom_metrics(rzmc_anom, era_rz_an, dim='time')
    print("RZSM — mean anomR:", float(anomR_rz.mean().values),
          " mean ubRMSE:", float(ubRMSE_rz.mean().values))
else:
    print("RZMC not provided; skipping RZ metrics.")


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
anomR_sfc.plot(ax=ax[0], vmin=0, vmax=1, cmap='viridis')
ax[0].set_title("Anomaly correlation — Surface SM (Model vs ERA5-Land)")

ubRMSE_sfc.plot(ax=ax[1], vmin=0, vmax=0.1, cmap='magma')
ax[1].set_title("ubRMSE (anomalies) — Surface SM")
plt.tight_layout()
plt.show()

if anomR_rz is not None:
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    anomR_rz.plot(ax=ax[0], vmin=0, vmax=1, cmap='viridis')
    ax[0].set_title("Anomaly correlation — RZSM (Model vs ERA5-Land)")
    ubRMSE_rz.plot(ax=ax[1], vmin=0, vmax=0.1, cmap='magma')
    ax[1].set_title("ubRMSE (anomalies) — RZSM")
    plt.tight_layout()
    plt.show()


In [None]:
# ensure coords are attached + CF attrs
lat2d = m36_grid["lat"].assign_attrs(standard_name="latitude", units="degrees_north")
lon2d = m36_grid["lon"].assign_attrs(standard_name="longitude", units="degrees_east")

# build dataset
ds_out = xr.Dataset(
    data_vars={
        "anomR_sfc":  (("y","x"), anomR_sfc.values,
                       {"long_name": "Anomaly correlation (surface SM: model vs ERA5-Land swvl1)",
                        "units": "1"}),
        "ubRMSE_sfc": (("y","x"), ubRMSE_sfc.values,
                       {"long_name": "Unbiased RMSE of anomalies (surface SM)",
                        "units": "m3 m-3"}),
    },
    coords={
        "lat": (("y","x"), lat2d.values, lat2d.attrs),
        "lon": (("y","x"), lon2d.values, lon2d.attrs),
    },
    attrs={
        "title": "GEOS-LDAS vs ERA5-Land diagnostics on EASEv2 M36",
        "source": "GEOS-LDAS monthly SFMC/RZMC; ERA5-Land swvl1–4 regridded with xESMF",
        "history": f"created {dt.datetime.utcnow():%Y-%m-%dT%H:%MZ}",
        "Conventions": "CF-1.8",
        "grid": "EASEv2 M36 (exact, from tilecoord bounds)",
        "notes": "Anomalies computed over common months; use mask/freeze filtering if applied upstream.",
    },
)

# add RZ maps if you computed them
if (('anomR_rz' in locals()) and (anomR_rz is not None)) and (('ubRMSE_rz' in locals()) and (ubRMSE_rz is not None)):
    ds_out["anomR_rz"]  = (("y","x"), anomR_rz.values)
    ds_out["anomR_rz"].attrs.update(
        long_name="Anomaly correlation (root-zone SM: model RZMC vs ERA5-Land RZ)",
        units="1"
    )
    ds_out["ubRMSE_rz"] = (("y","x"), ubRMSE_rz.values)
    ds_out["ubRMSE_rz"].attrs.update(
        long_name="Unbiased RMSE of anomalies (root-zone SM)",
        units="m3 m-3"
    )

# compression (smaller files)
comp = dict(zlib=True, complevel=4)
enc  = {v: {**comp, "dtype":"float32", "_FillValue": np.float32(np.nan)} for v in ds_out.data_vars}
enc.update({"lat": {"dtype":"float32", **comp}, "lon": {"dtype":"float32", **comp}})

ds_out.to_netcdf("GEOS_v_ERA5L_SM_diagnostics_M36.nc", format="NETCDF4", encoding=enc)
print("Wrote GEOS_v_ERA5L_SM_diagnostics_M36.nc")


In [None]:

# === ERA5-Land masking & SCF preparation (inserted by assistant) ===
# This cell expects **regridded ERA5-Land to model tiles** to be available as either:
#   - `era_on_tiles`  (preferred name used above), or
#   - `era_regrid`    (fallback name used in some notebooks)
# and the model Dataset as `mod` (or `model`).
#
# It constructs a comparison dataset `cmp_ds` with masked soil moisture and aligned SCF:
#   - SM_model:  model surface soil moisture (SFMC), masked by (ERA stl1 > 275.15K) & (ERA snow == 0)
#   - SM_era:    ERA5-Land swvl1, masked by the same mask
#   - SCF_model: model snow fraction (FRLANDSNO, 0–1)
#   - SCF_era:   ERA5-Land snow fraction (0–1), converted from % if needed
# It also exposes `soil_mask_bool` (True where stl1>2°C & SCF==0) to reuse in other analyses.

import xarray as xr
import numpy as np

# ---- Resolve dataset handles robustly ----
era_rg = None
for name in ["era_on_tiles", "era_regrid", "era_rg", "era_tiles"]:
    if name in globals():
        era_rg = globals()[name]
        break
if era_rg is None:
    raise RuntimeError("Could not find regridded ERA5-Land object. Expected one of: era_on_tiles, era_regrid, era_rg, era_tiles")

model_ds = None
for name in ["mod", "model", "model_ds", "ds_mod"]:
    if name in globals():
        model_ds = globals()[name]
        break
if model_ds is None:
    raise RuntimeError("Could not find model dataset. Expected one of: mod, model, model_ds, ds_mod")

# ---- Identify variable names (ERA) ----
era_sm_candidates = ["swvl1", "SFMC_era", "sm_era", "sfmc_era"]
era_st_candidates = ["stl1", "TSOIL1_era", "soil_temp_era"]
era_snow_candidates = ["snow_frac", "snowc", "SCF_era"]

def pick_var(ds, candidates):
    for v in candidates:
        if v in ds:
            return v
    return None

era_sm_var   = pick_var(era_rg, era_sm_candidates)
era_st_var   = pick_var(era_rg, era_st_candidates)
era_snow_var = pick_var(era_rg, era_snow_candidates)

if era_sm_var is None:
    raise RuntimeError(f"ERA5-Land soil moisture variable not found. Tried: {era_sm_candidates}")
if era_st_var is None:
    # Some pipelines carry only model soil temp; if so we allow skipping the temp part of the mask.
    print("Warning: ERA5-Land soil temperature not found; mask will ignore the temperature criterion.")
if era_snow_var is None:
    raise RuntimeError(f"ERA5-Land snow variable not found. Tried: {era_snow_candidates}")

# Convert ERA snow to fraction if it's in percent
era_snow = era_rg[era_snow_var]
if era_snow_var == "snowc" or (("units" in era_snow.attrs) and era_snow.attrs.get("units","").strip() in ["%", "percent"]):
    era_snow_frac = era_snow / 100.0
else:
    era_snow_frac = era_snow

# ---- Identify variable names (Model) ----
mod_sm_var   = "SFMC" if "SFMC" in model_ds else None
mod_scf_var  = "FRLANDSNO" if "FRLANDSNO" in model_ds else None
if mod_sm_var is None:
    raise RuntimeError("Model surface soil moisture 'SFMC' not found.")
if mod_scf_var is None:
    raise RuntimeError("Model snow fraction 'FRLANDSNO' not found.")

# ---- Build mask: (ERA soil temp > 275.15 K) & (ERA snow frac == 0) ----
snow_clear = xr.where(np.isfinite(era_snow_frac), era_snow_frac == 0.0, False)

if era_st_var is not None and era_st_var in era_rg:
    warm_enough = era_rg[era_st_var] > 275.15  # 2°C in Kelvin
else:
    warm_enough = xr.full_like(snow_clear, True, dtype=bool)  # No temp available → ignore temp criterion

soil_mask_bool = (snow_clear & warm_enough)

# ---- Apply mask to soil moisture comparisons ----
sm_era_masked = era_rg[era_sm_var].where(soil_mask_bool)
sm_mod_masked = model_ds[mod_sm_var].where(soil_mask_bool)

# ---- Package comparison dataset ----
cmp_ds = xr.Dataset(
    data_vars=dict(
        SM_era   = sm_era_masked,
        SM_model = sm_mod_masked,
        SCF_era  = era_snow_frac,
        SCF_model= model_ds[mod_scf_var],
        soil_mask_bool = soil_mask_bool
    )
)

# Add helpful attributes
cmp_ds["SM_era"].attrs.update(long_name="ERA5-Land surface soil moisture (masked: T>2C & SCF==0)", units="m3 m-3")
cmp_ds["SM_model"].attrs.update(long_name="Model surface soil moisture (masked: T>2C & SCF==0)", units="m3 m-3")
cmp_ds["SCF_era"].attrs.update(long_name="ERA5-Land snow cover fraction", units="1")
cmp_ds["SCF_model"].attrs.update(long_name="Model snow cover fraction", units="1")
cmp_ds["soil_mask_bool"].attrs.update(long_name="Mask (True where ERA stl1>2C and ERA snow frac==0)", units="1")

# ---- (Optional) Align times exactly if minor mismatches exist ----
if "time" in cmp_ds.dims and "time" in model_ds.dims and "time" in era_rg.dims:
    common_time = np.intersect1d(model_ds["time"].values, era_rg["time"].values)
    cmp_ds = cmp_ds.sel(time=common_time)

# Quick sanity prints
print("Mask coverage (mean over all points):", float(cmp_ds["soil_mask_bool"].mean().values))
print("Sample stats SM (masked): ERA mean =", float(cmp_ds['SM_era'].mean(skipna=True).values),
      ", Model mean =", float(cmp_ds['SM_model'].mean(skipna=True).values))

# You can now use `cmp_ds['SM_model']` vs `cmp_ds['SM_era']` exactly as before,
# and likewise compare `cmp_ds['SCF_model']` vs `cmp_ds['SCF_era']` with the same routines.
# Persist to disk if desired:
# cmp_ds.to_netcdf("ERA5L_model_cmp_masked.nc")
# === end inserted block ===
