In [None]:
# --- Difference DA vs OL for selected metrics, flatten, and plot (one map per metric) ---

import numpy as np
import xarray as xr
from geospatial_plotting import plot_region, REGION_BOUNDS

FN_DA = "ERA5L_vs_DAv8_M36_summary.nc"
FN_OL = "ERA5L_vs_OLv8_M36_summary.nc"

metrics = ['anomR_sfc', 'ubRMSE_sfc', 'anomR_rz', 'ubRMSE_rz', 'anomR_scf', 'ubRMSE_scf']
titles  = {
    'anomR_sfc':'Anomaly Corr — Surface SM',
    'ubRMSE_sfc':'ubRMSE — Surface SM',
    'anomR_rz':'Anomaly Corr — Root-zone SM',
    'ubRMSE_rz':'ubRMSE — Root-zone SM',
    'anomR_scf':'Anomaly Corr — SCF',
    'ubRMSE_scf':'ubRMSE — SCF',
}
units   = {
    'anomR_sfc':'Δ corr', 'anomR_rz':'Δ corr', 'anomR_scf':'Δ corr',
    'ubRMSE_sfc':'Δ (m³ m⁻³)', 'ubRMSE_rz':'Δ (m³ m⁻³)', 'ubRMSE_scf':'Δ (1)',
}

# Load files
da = xr.open_dataset(FN_DA)
ol = xr.open_dataset(FN_OL)

# Check grid consistency
assert da['lat'].shape == ol['lat'].shape and da['lon'].shape == ol['lon'].shape
assert np.allclose(da['lat'], ol['lat'], equal_nan=True) and np.allclose(da['lon'], ol['lon'], equal_nan=True)

lat2d = da['lat'].values
lon2d = da['lon'].values

# Helper to flatten for your mapper
def to_map_array(vals, lon2d, lat2d):
    mask = np.isfinite(vals) & np.isfinite(lon2d) & np.isfinite(lat2d)
    out = np.empty((mask.sum(), 3)); out.fill(np.nan)
    out[:, 0] = vals[mask]
    out[:, 1] = lon2d[mask]
    out[:, 2] = lat2d[mask]
    return out

# Loop through metrics, compute DA-OL, flatten, and plot
for key in metrics:
    if key not in da or key not in ol:
        print(f"Skipping {key} (not in both files)")
        continue

    diff = (da[key] - ol[key])  # RAW difference: DA − OL
    arr  = to_map_array(diff.values, lon2d, lat2d)

    maxval = float(np.nanmax(arr[:, 0])) if np.isfinite(arr[:, 0]).any() else np.nan
    minval = float(np.nanmin(arr[:, 0])) if np.isfinite(arr[:, 0]).any() else np.nan

    # sensible color ranges: symmetric for correlations; auto for ubRMSE
    if key.startswith('anomR'):
        cmin, cmax = -0.5, 0.5
    else:
        cmin = np.nanpercentile(arr[:, 0], 2)
        cmax = np.nanpercentile(arr[:, 0], 98)
        camx_sym = max(abs(cmin), abs(cmax))
        cmin, cmax = -camx_sym, camx_sym
        if not np.isfinite(cmin) or not np.isfinite(cmax) or cmin == cmax:
            cmin, cmax = None, None  # fall back to plot defaults

    fig, ax = plot_region(
        arr,
        region_bounds=REGION_BOUNDS['global'],
        meanflag=True,
        plot_title=(f"DAv8_M36 − OLv8_M36 • {titles[key]}\n"
                    f"(Max: {maxval:.3g}  Min: {minval:.3g})"),
        units=units[key],
        cmin=cmin,
        cmax=cmax,
    )
    fig.tight_layout()


In [None]:
def anomalies_monthly(da, dim="time"):
    """
    Monthly anomalies: remove month-of-year climatology (NaN-safe).
    """
    clim = da.groupby(f"{dim}.month").mean(dim=dim, skipna=True)
    return da.groupby(f"{dim}.month") - clim

def anom_metrics_monthly(a, b, dim="time", min_pairs=24):
    """
    NaN-safe anomR and ubRMSE on valid pairs only.
    Enforces a minimum number of valid pairs (default 24 months).
    """
    valid = xr.where(xr.ufuncs.isfinite(a) & xr.ufuncs.isfinite(b), True, False)
    n = valid.sum(dim=dim)

    a_v = a.where(valid)
    b_v = b.where(valid)

    cov   = (a_v * b_v).mean(dim=dim, skipna=True)
    var_a = (a_v**2).mean(dim=dim, skipna=True)
    var_b = (b_v**2).mean(dim=dim, skipna=True)

    denom = xr.ufuncs.sqrt(var_a * var_b)
    anomR  = xr.where(denom > 0, cov / denom, np.nan)
    ubRMSE = xr.ufuncs.sqrt(((a_v - b_v)**2).mean(dim=dim, skipna=True))

    # require enough valid months
    anomR  = xr.where(n >= min_pairs, anomR, np.nan)
    ubRMSE = xr.where(n >= min_pairs, ubRMSE, np.nan)
    return anomR, ubRMSE

def corr_monthly(a, b, dim="time", min_pairs=24):
    """
    Pearson R on raw monthly series (no anomaly removal).
    NaN-safe; requires at least `min_pairs` valid months.
    """
    valid = xr.where(xr.ufuncs.isfinite(a) & xr.ufuncs.isfinite(b), True, False)
    n = valid.sum(dim=dim)

    a_v = a.where(valid)
    b_v = b.where(valid)

    # remove means over time (on valid pairs only)
    a_dm = a_v - a_v.mean(dim=dim, skipna=True)
    b_dm = b_v - b_v.mean(dim=dim, skipna=True)

    cov   = (a_dm * b_dm).mean(dim=dim, skipna=True)
    var_a = (a_dm**2).mean(dim=dim, skipna=True)
    var_b = (b_dm**2).mean(dim=dim, skipna=True)

    denom = xr.ufuncs.sqrt(var_a * var_b)
    R = xr.where(denom > 0, cov / denom, np.nan)

    # require enough valid months
    R = xr.where(n >= min_pairs, R, np.nan)
    return R

In [None]:
# --- Recompute monthly anomaly metrics over limited time periods ---
# Reuses: anomalies_monthly(), anom_metrics_monthly()

import numpy as np
import xarray as xr
import pandas as pd

# Files to process (do one or both)
FILES = [
    "ERA5L_vs_DAv8_M36_summary.nc",
    "ERA5L_vs_OLv8_M36_summary.nc",
]

# Sub-periods (inclusive)
PERIODS = [
    ("2000-06-01", "2007-05-31"),
    ("2007-06-01", "2015-03-31"),
    ("2015-04-01", "2024-05-31"),
    ("2000-06-01", "2024-05-31"),
]

MIN_PAIRS = 24  # minimum valid months for metrics

def compute_metrics_for_period(ds, t0, t1, min_pairs=24, snow_eps=0.01):
    import xarray as xr
    import numpy as np
    import pandas as pd

    sub = xr.decode_cf(ds).sel(time=slice(pd.to_datetime(t0), pd.to_datetime(t1)))

    for v in ["SM_model","SM_era","SCF_model","SCF_era","Tsoil_model","Tsoil_era"]:
        if v not in sub:
            raise KeyError(f"Required variable '{v}' missing from dataset.")

    mask_both = sub["mask_both"] if "mask_both" in sub else xr.full_like(sub["SM_model"], True, dtype=bool)

    # ----- Surface SM -----
    sm_mod_raw = sub["SM_model"].where(mask_both).astype("float64")
    sm_era_raw = sub["SM_era"  ].where(mask_both).astype("float64")

    # anomalies → anomR & ubRMSE (unchanged)
    sm_mod_an = anomalies_monthly(sm_mod_raw)
    sm_era_an = anomalies_monthly(sm_era_raw)
    anomR_sfc, ubRMSE_sfc = anom_metrics_monthly(sm_mod_an, sm_era_an, min_pairs=min_pairs)

    # raw correlation R
    R_sfc = corr_monthly(sm_mod_raw, sm_era_raw, min_pairs=min_pairs)

    # ----- Root-zone SM (if present) -----
    if ("RZ_model" in sub) and ("RZ_era" in sub):
        rz_mod_raw = sub["RZ_model"].where(mask_both).astype("float64")
        rz_era_raw = sub["RZ_era"  ].where(mask_both).astype("float64")

        rz_mod_an = anomalies_monthly(rz_mod_raw)
        rz_era_an = anomalies_monthly(rz_era_raw)
        anomR_rz,  ubRMSE_rz  = anom_metrics_monthly(rz_mod_an, rz_era_an, min_pairs=min_pairs)
        R_rz = corr_monthly(rz_mod_raw, rz_era_raw, min_pairs=min_pairs)
    else:
        anomR_rz  = xr.full_like(anomR_sfc, np.nan)
        ubRMSE_rz = xr.full_like(ubRMSE_sfc, np.nan)
        R_rz      = xr.full_like(anomR_sfc, np.nan)

    # ----- SCF (ever-snow mask over FULL record) -----
    snow_any  = ((ds["SCF_model"] > snow_eps) | (ds["SCF_era"] > snow_eps)).any(dim="time")
    snow_mask = xr.where(snow_any, True, False)

    scf_mod_raw = sub["SCF_model"].where(snow_mask).astype("float64")
    scf_era_raw = sub["SCF_era"  ].where(snow_mask).astype("float64")

    scf_mod_an = anomalies_monthly(scf_mod_raw)
    scf_era_an = anomalies_monthly(scf_era_raw)
    anomR_scf,  ubRMSE_scf = anom_metrics_monthly(scf_mod_an, scf_era_an, min_pairs=min_pairs)
    R_scf = corr_monthly(scf_mod_raw, scf_era_raw, min_pairs=min_pairs)

    return xr.Dataset(dict(
        # anomalies-based metrics (existing)
        anomR_sfc=anomR_sfc,   ubRMSE_sfc=ubRMSE_sfc,
        anomR_rz=anomR_rz,     ubRMSE_rz=ubRMSE_rz,
        anomR_scf=anomR_scf,   ubRMSE_scf=ubRMSE_scf,
        # new raw correlations
        R_sfc=R_sfc, R_rz=R_rz, R_scf=R_scf,
    ))


# Process each file → stack results along "period"
results = {}
for fn in FILES:
    ds = xr.open_dataset(fn)
    # ensure CF-time decoding just in case
    ds = xr.decode_cf(ds)

    per_list = []
    per_labels = []
    for (t0, t1) in PERIODS:
        ds_per = compute_metrics_for_period(ds, t0, t1, min_pairs=MIN_PAIRS)
        per_list.append(ds_per)
        per_labels.append(f"{pd.to_datetime(t0).date()}–{pd.to_datetime(t1).date()}")

        print(f"{fn}  [{t0} → {t1}]  "
            f"SM anomR mean={float(ds_per['anomR_sfc'].mean(skipna=True)):.3f}  "
            f"SM R mean={float(ds_per['R_sfc'].mean(skipna=True)):.3f}  "
            f"SCF anomR mean={float(ds_per['anomR_scf'].mean(skipna=True)):.3f}")


    ds_all = xr.concat(per_list, dim=xr.DataArray(per_labels, dims=("period",), name="period"))
    results[fn] = ds_all

# Example: access DA vs OL periodized metrics
# results["ERA5L_vs_DAv8_M36_summary.nc"]  -> Dataset with dim "period","y","x"
# results["ERA5L_vs_OLv8_M36_summary.nc"]  -> same structure

# If you want to difference DA − OL per period:
if len(FILES) == 2:
    da_key = FILES[0]; ol_key = FILES[1]
    assert set(results[da_key].data_vars) == set(results[ol_key].data_vars)
    diff_period = xr.Dataset(
        {v: (results[da_key][v] - results[ol_key][v]) for v in results[da_key].data_vars}
    )
    # Now diff_period has dims ("period","y","x") and the six metrics differenced per period.
    # You can flatten any [period slice] the same way you did before to map it.
    print("Built DA−OL periodized metrics in 'diff_period' (dims:", diff_period.dims, ")")


In [None]:
# --- Periodized DA−OL maps: corr fixed [-0.4,0.4] with RdBu_r; ubRMSE pooled 10/90 → symmetric, viridis ---

import numpy as np
import xarray as xr

# If diff_period isn't defined yet, build it from your 'results' dict (DA − OL per period)
if 'diff_period' not in globals():
    da_key = "ERA5L_vs_DAv8_M36_summary.nc"
    ol_key = "ERA5L_vs_OLv8_M36_summary.nc"
    diff_period = xr.Dataset({v: (results[da_key][v] - results[ol_key][v])
                              for v in results[da_key].data_vars})

lat2d = diff_period['lat'].values
lon2d = diff_period['lon'].values
period_labels = list(diff_period['period'].values)

metric_titles = {
    'anomR_sfc':'Anomaly Corr — Surface SM',
    'ubRMSE_sfc':'ubRMSE — Surface SM',
    'anomR_rz':'Anomaly Corr — Root-zone SM',
    'ubRMSE_rz':'ubRMSE — Root-zone SM',
    'anomR_scf':'Anomaly Corr — SCF',
    'ubRMSE_scf':'ubRMSE — SCF',
}
metric_units = {
    'anomR_sfc':'Δ', 'anomR_rz':'Δ', 'anomR_scf':'Δ',
    'ubRMSE_sfc':'Δ (m³ m⁻³)', 'ubRMSE_rz':'Δ (m³ m⁻³)', 'ubRMSE_scf':'Δ (1)',
}

# include R variables
metrics = [
    'anomR_sfc','R_sfc','ubRMSE_sfc',
    'anomR_rz','R_rz','ubRMSE_rz',
    'anomR_scf','R_scf','ubRMSE_scf'
]
metrics = [m for m in metrics if m in diff_period.data_vars]

# titles & units
metric_titles.update({
    'R_sfc':'Raw Corr — Surface SM',
    'R_rz':'Raw Corr — Root-zone SM',
    'R_scf':'Raw Corr — SCF',
})
metric_units.update({
    'R_sfc':'Δ', 'R_rz':'Δ', 'R_scf':'Δ',
})

def to_map_array(vals, lon2d, lat2d):
    mask = np.isfinite(vals) & np.isfinite(lon2d) & np.isfinite(lat2d)
    out = np.empty((mask.sum(), 3)); out.fill(np.nan)
    out[:, 0] = vals[mask]
    out[:, 1] = lon2d[mask]
    out[:, 2] = lat2d[mask]
    return out

# Precompute pooled symmetric limits for ubRMSE (10/90 pooled, then symmetric)
pooled_sym_limits = {}
for m in metrics:
    if m.startswith('ubRMSE'):
        vals = diff_period[m].values  # (period,y,x)
        finite = vals[np.isfinite(vals)]
        if finite.size == 0:
            pooled_sym_limits[m] = (-1.0, 1.0)
        else:
            p10, p90 = np.nanpercentile(finite, [2, 98])
            vmax = float(max(abs(p10), abs(p90)))
            if not np.isfinite(vmax) or vmax == 0:
                vmax = float(max(abs(np.nanmin(finite)), abs(np.nanmax(finite)))) or 1.0
            pooled_sym_limits[m] = (-vmax, +vmax)

for per in period_labels:
    print(f"\n=== Period: {per} ===")
    for m in metrics:
        diff_map = diff_period[m].sel(period=per)  # 2D (y,x)
        arr = to_map_array(diff_map.values, lon2d, lat2d)

        finite_vals = arr[:, 0][np.isfinite(arr[:, 0])]
        if finite_vals.size == 0:
            print(f"{m}: no finite values; skipping.")
            continue

        # Color limits + colormap
        if m.startswith(('anomR', 'R')):
            cmin, cmax = -0.4, 0.4
            cmap = 'RdBu'        # blue-red reversed
        else:  # ubRMSE*
            cmin, cmax = pooled_sym_limits[m]
            cmap = 'RdBu_r'       # sequential; change to your preference

        vmax = float(np.nanmax(finite_vals))
        vmin = float(np.nanmin(finite_vals))

        fig, ax = plot_region(
            arr,
            region_bounds=REGION_BOUNDS['global'],
            meanflag=True,
            plot_title=(f"DAv8_M36 − OLv8_M36 • {metric_titles[m]}\n"
                        f"{per}  (Max: {vmax:.3g}  Min: {vmin:.3g})"),
            units=metric_units[m],
            cmin=cmin, cmax=cmax,
            cmap=cmap,              # <-- pass through to your plotter
        )
        fig.tight_layout()


In [None]:
# --- Build a tidy summary table of periodized metrics: DA, OL, and Δ = DA−OL ---

# Map filenames → short labels used in the table
label_for = {}
for fn in FILES:
    if "DA" in fn or "DAv" in fn:
        label_for[fn] = "DA"
    elif "OL" in fn or "OLv" in fn:
        label_for[fn] = "OL"
    else:
        # fallback: use stem without extension
        label_for[fn] = fn.replace(".nc","")

metrics = [
    'anomR_sfc','R_sfc','ubRMSE_sfc',
    'anomR_rz','R_rz','ubRMSE_rz',
    'anomR_scf','R_scf','ubRMSE_scf'
]

pretty  = {
    'anomR_sfc':'AnomCorr (Surf SM)',
    'R_sfc'    :'RawCorr (Surf SM)',
    'ubRMSE_sfc':'ubRMSE (Surf SM)',

    'anomR_rz' :'AnomCorr (RZ SM)',
    'R_rz'     :'RawCorr (RZ SM)',
    'ubRMSE_rz':'ubRMSE (RZ SM)',

    'anomR_scf':'AnomCorr (SCF)',
    'R_scf'    :'RawCorr (SCF)',
    'ubRMSE_scf':'ubRMSE (SCF)',
}

# Collect rows
rows = []
periods = list(next(iter(results.values())).coords['period'].values)
for per in periods:
    row = {'Period': str(per)}
    # Per-file means and valid counts
    for fn, ds_per in results.items():
        tag = label_for[fn]
        for m in metrics:
            if m in ds_per:
                da2d = ds_per[m].sel(period=per)  # (y,x)
                mean_val = float(da2d.mean(skipna=True).values)
                n_valid  = int(np.isfinite(da2d.values).sum())
                row[(tag, pretty[m])] = mean_val
                row[(tag, f"N valid ({pretty[m]})")] = n_valid
    rows.append(row)

# Build DataFrame with MultiIndex columns
df = pd.DataFrame(rows).set_index("Period")
# Ensure both DA and OL exist before computing deltas
if len(FILES) >= 2:
    # Add Δ = DA − OL columns for each metric (means only)
    for m in metrics:
        col_da = ("DA",  pretty[m])
        col_ol = ("OL",  pretty[m])
        if col_da in df.columns and col_ol in df.columns:
            df[("Δ (DA−OL)", pretty[m])] = df[col_da] - df[col_ol]

# Sort columns: DA | OL | Δ
new_cols = []
for group in ["DA", "OL", "Δ (DA−OL)"]:
    sub = [c for c in df.columns if c[0] == group and "N valid" not in c[1]]
    sub_valid = [c for c in df.columns if c[0] == group and "N valid"     in c[1]]
    # keep metric order
    ordered = []
    for m in metrics:
        for c in sub:
            if c[1] == pretty[m]:
                ordered.append(c)
        for c in sub_valid:
            if pretty[m] in c[1]:
                ordered.append(c)
    new_cols.extend(ordered)
df = df.reindex(columns=new_cols)

# Formatting: 3 decimals for metric means, integers for counts
fmt = {}
for c in df.columns:
    if "N valid" in c[1]:
        fmt[c] = "{:d}"
    else:
        fmt[c] = "{:.3f}"

# Display (pandas)
try:
    from IPython.display import display
    display(df.style.format(fmt).set_caption(
    "ERA5-Land vs Model: Periodized Metrics (AnomCorr, RawCorr, ubRMSE; spatial medians and valid counts)"
))
except Exception:
    print(df.to_string())

# Also save to CSV for records
out_csv = "periodized_metrics_summary.csv"
df.to_csv(out_csv)
print(f"Saved summary table → {out_csv}")


In [None]:
# --- OL vs DA bar charts for Surface SM, Root-zone SM, and SCF ---
# Uses `results` (periodized datasets) and `FILES` from earlier cells.

import numpy as np
import matplotlib.pyplot as plt

# Map filenames → labels
label_for = {}
for fn in FILES:
    if "DA" in fn or "DAv" in fn:
        label_for[fn] = "DA"
    elif "OL" in fn or "OLv" in fn:
        label_for[fn] = "OL"
    else:
        label_for[fn] = fn.replace(".nc","")

period_labels = list(next(iter(results.values())).coords["period"].values)

# Colors
COL_OL = "#2878B5"   # blue
COL_DA = "#F28E2B"   # orange

def mean_se_2d(da2d):
    """NaN-safe spatial mean & standard error across y,x + N."""
    v = np.asarray(da2d.values)
    mask = np.isfinite(v)
    n = int(mask.sum())
    if n == 0:
        return np.nan, np.nan, 0
    vals = v[mask]
    mean = float(np.nanmean(vals))
    se   = float(np.nanstd(vals, ddof=1) / np.sqrt(n)) if n > 1 else np.nan
    return mean, se, n

def pooled_ylim(means_errs, pad=0.18, hard=(None, None)):
    """
    Compute common y-lims from all (mean, err) pairs across periods/tags,
    then add generous padding. `hard` can clamp (min,max) if desired.
    """
    lows  = []
    highs = []
    for (mean, err) in means_errs:
        m = mean
        e = 0.0 if not np.isfinite(err) else err
        if np.isfinite(m):
            lows.append(m - e)
            highs.append(m + e)
    if not lows:
        return (0.0, 1.0)
    lo = np.nanmin(lows)
    hi = np.nanmax(highs)
    if hard[0] is not None: lo = max(lo, hard[0])
    if hard[1] is not None: hi = min(hi, hard[1])
    if not np.isfinite(lo) or not np.isfinite(hi) or lo == hi:
        lo, hi = 0.0, 1.0
    span = hi - lo
    pad_abs = pad * (span if span > 0 else 1.0)
    return (lo - pad_abs, hi + pad_abs)

def collect_stats(ds_per_all, metric_key):
    """Return dict: per period → {'OL': (mean,se,n), 'DA': (mean,se,n)}."""
    stats = {str(per): {} for per in period_labels}
    for fn, ds_all in ds_per_all.items():
        tag = label_for[fn]  # "OL" or "DA"
        if metric_key not in ds_all.data_vars:
            for per in stats: stats[per][tag] = (np.nan, np.nan, 0)
            continue
        for per in period_labels:
            da2d = ds_all[metric_key].sel(period=per)
            stats[str(per)][tag] = mean_se_2d(da2d)
    return stats

def plot_group(group_name, metrics, ylabels, hard_corr_bounds=None):
    """
    group_name: title string
    metrics   : list of metric variable names (length 3)
    ylabels   : list of y-axis labels for rows (length 3)
    hard_corr_bounds: (min,max) to clamp corr axes, e.g. (0.5, 0.75) or None
    """
    nrows = 3
    ncols = len(period_labels)
    fig, axes = plt.subplots(nrows, ncols, figsize=(4.4*ncols, 4.1*nrows), constrained_layout=True)
    if nrows == 1:
        axes = np.array([axes])

    fig.suptitle(f"{group_name} statistics (spatial means ± SE) for OL vs DA",
                 fontsize=16, weight="bold", y=1.02)

    # Pre-compute common y-lims per row across all periods
    common_lims = []
    for r, mkey in enumerate(metrics):
        S = collect_stats(results, mkey)
        # gather all mean±err across periods & tags
        all_me = []
        for per in period_labels:
            for tag in ("OL", "DA"):
                mean, se, _ = S[str(per)][tag]
                all_me.append((mean, se))
        # clamp correlations between [0,1] or given bounds if applicable
        if mkey.startswith(("R_", "anomR_")):
            hard = hard_corr_bounds if hard_corr_bounds is not None else (0.0, 1.0)
        else:
            hard = (None, None)
        common_lims.append(pooled_ylim(all_me, pad=0.22, hard=hard))

    # Now plot using the pre-computed y-lims
    for r, (mkey, ylabel) in enumerate(zip(metrics, ylabels)):
        S = collect_stats(results, mkey)
        for c, per in enumerate(period_labels):
            ax = axes[r, c]
            ol_mean, ol_se, ol_n = S[str(per)]["OL"]
            da_mean, da_se, da_n = S[str(per)]["DA"]

            means = [ol_mean, da_mean]
            errs  = [ol_se, da_se]
            ns    = [ol_n, da_n]

            x = np.arange(2)
            ax.bar(x, means, width=0.72,
                   color=[COL_OL, COL_DA], edgecolor="black", linewidth=0.8,
                   yerr=errs, capsize=6)

            ax.set_xticks(x)
            ax.set_xticklabels(["LS_OL", "LS_DA"])
            ax.set_ylim(*common_lims[r])
            ax.grid(axis="y", alpha=0.3, linestyle="--")

            # Titles: show period + n=min(OL,DA)
            ax.set_title(f"{per} (n = {min(ns)})", fontsize=11)

            # y-label on first column only
            if c == 0:
                ax.set_ylabel(ylabel)

    # Optional save
    # fig.savefig(f"bars_{group_name.replace(' ','_').lower()}.png", dpi=200, bbox_inches="tight")
    plt.show()

# ---- Make the three figures ----
# Surface SM
plot_group(
    group_name="Surface SM",
    metrics = ["R_sfc", "anomR_sfc", "ubRMSE_sfc"],
    ylabels = ["R Mean", "anomR Mean", "ubRMSE Mean (m³ m⁻³)"],
    hard_corr_bounds=(0.0, 1.0)  # keep correlations within 0..1; change if you prefer
)

# Root-zone SM
plot_group(
    group_name="Root-zone SM",
    metrics = ["R_rz", "anomR_rz", "ubRMSE_rz"],
    ylabels = ["R Mean", "anomR Mean", "ubRMSE Mean (m³ m⁻³)"],
    hard_corr_bounds=(0.0, 1.0)
)

# Snow Cover Fraction
plot_group(
    group_name="SCF",
    metrics = ["R_scf", "anomR_scf", "ubRMSE_scf"],
    ylabels = ["R Mean", "anomR Mean", "ubRMSE Mean (1)"],
    hard_corr_bounds=(0.0, 1.0)
)
