
# üåç Demo 5: Ground Truth Challenge 


## üß† Learning Objectives
- Use **weatherbenchX** (metrics) + **xarray** to verify forecasts.
- Load your **Forecast** forecast (*or any forecast in NC/Zarr*) and compare to **ERA5** ground truth.
- Compute **RMSE** and **MAE** (with ERA5 climatology).
- Focus evaluation on specific regions: **Global, Ethiopia, Nigeria, Kenya, Bangladesh, Chile**.
- Visualize skill vs **lead time**.

> **Note:** If your local truth (e.g., BMD/IMD Zarr) is *forecast/prediction*, don't use it as truth in this notebook. Here, **ERA5** is the ground truth just like Demo 4.



## üì¶ Environment Requirements
If you hit backend errors like xarray not finding `netcdf4`/`h5netcdf`, install the deps:
```bash
# pip install xarray netCDF4 h5netcdf zarr fsspec gcsfs ipywidgets matplotlib numpy pandas
# pip install weatherbenchX xesmf
```
For public GCS ERA5, we use anonymous access. No credentials needed.


In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import xarray as xr
from IPython.display import display, Markdown
plt.rcParams.update({"figure.dpi": 140})
xr.set_options(keep_attrs=True)
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base

plt.rcParams.update({"figure.dpi": 120})

# Paths (edit if needed)


VAR = "total_precipitation_24hr"

REGIONS = {
    "Global":     {"latitude": slice( 90, -90), "longitude": slice(  0, 360)},
    "Ethiopia":   {"latitude": slice(14.9,  3.4), "longitude": slice(33.0, 48.0)},
    "Nigeria":    {"latitude": slice(14.7,  4.0), "longitude": slice( 2.7, 14.7)},
    "Kenya":      {"latitude": slice( 5.0, -4.7), "longitude": slice(33.9, 41.9)},
    "Bangladesh": {"latitude": slice(26.7, 20.7), "longitude": slice(88.0, 92.7)},
    "Chile":      {"latitude": slice(-17.5,-56.0), "longitude": slice(284.0, 294.0)},
}


IMD_PRECIP_DEFAULT   = "gs://aim4scale_training_25/ground_truth/IMD_rainfall_0p25.zarr"
IMD_MASK_DEFAULT     = "IMD_mask.nc"  # your ocean/coverage mask file
BMD_DAILY_DEFAULT    = "gs://aim4scale_training_25/ground_truth/BMD_daily_combined_0p25.zarr"
IMERG_PRECIP_DEFAULT = "gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr"
ERA5_PRECIP_DEFAULT  = "gs://aim4scale_training_25/ground_truth/era5_24hr.zarr"
ERA5_TEMP_DEFAULT    = "gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr"

# Widgets
p_imd   = widgets.Text(value=IMD_PRECIP_DEFAULT,   description="IMD precip:",   layout=widgets.Layout(width="95%"))
p_imdmsk= widgets.Text(value=IMD_MASK_DEFAULT,     description="IMD mask:",     layout=widgets.Layout(width="95%"))
p_bmd   = widgets.Text(value=BMD_DAILY_DEFAULT,    description="BMD daily:",    layout=widgets.Layout(width="95%"))
p_imerg = widgets.Text(value=IMERG_PRECIP_DEFAULT, description="IMERG precip:", layout=widgets.Layout(width="95%"))
p_e5_p  = widgets.Text(value=ERA5_PRECIP_DEFAULT,  description="ERA5 precip:",  layout=widgets.Layout(width="95%"))
p_e5_t  = widgets.Text(value=ERA5_TEMP_DEFAULT,    description="ERA5 temp:",    layout=widgets.Layout(width="95%"))

btn_load = widgets.Button(description="Load datasets", button_style="primary")
load_out = widgets.Output()

display(p_imd, p_imdmsk, p_bmd, p_imerg, p_e5_p, p_e5_t, btn_load, load_out)

def _open_any(path: str):
    if not path:
        raise ValueError("Empty path.")
    return xr.open_zarr(path) if path.endswith(".zarr") else xr.open_dataset(path)

ds_imd = ds_imd_mask = ds_bmd = ds_imerg = ds_e5_p = ds_e5_t = None

def _summ(ds):
    vs = list(ds.data_vars)[:6]
    tcoord = "time" if "time" in ds.coords else ("valid_time" if "valid_time" in ds.coords else None)
    if tcoord:
        tmin = pd.to_datetime(str(ds[tcoord].min().values)).date()
        tmax = pd.to_datetime(str(ds[tcoord].max().values)).date()
        ttxt = f"{tmin} ‚Ä¶ {tmax}"
    else:
        ttxt = "‚Äî"
    return f"vars: {vs} | sizes: {dict(ds.sizes)} | time: {ttxt}"

def _load_all(_):
    global ds_imd, ds_imd_mask, ds_bmd, ds_imerg, ds_e5_p, ds_e5_t
    with load_out:
        load_out.clear_output()
        try:
            ds_imd      = _open_any(p_imd.value.strip())
            ds_imd_mask = _open_any(p_imdmsk.value.strip())  # NetCDF with NaN over ocean
            ds_bmd      = _open_any(p_bmd.value.strip())
            ds_imerg    = _open_any(p_imerg.value.strip())
            ds_e5_p     = _open_any(p_e5_p.value.strip())
            ds_e5_t     = _open_any(p_e5_t.value.strip())
            display(Markdown("### ‚úÖ Loaded"))
            for name, ds in [("IMD", ds_imd), ("IMD mask", ds_imd_mask), ("BMD", ds_bmd),
                             ("IMERG", ds_imerg), ("ERA5 precip", ds_e5_p), ("ERA5 temp", ds_e5_t)]:
                display(Markdown(f"- **{name}** ‚Üí `{_summ(ds)}`"))
        except Exception as e:
            display(Markdown(f"‚ùå Load error: `{e}`"))

btn_load.on_click(_load_all)


Text(value='gs://aim4scale_training_25/ground_truth/IMD_rainfall_0p25.zarr', description='IMD precip:', layout‚Ä¶

Text(value='IMD_mask.nc', description='IMD mask:', layout=Layout(width='95%'))

Text(value='gs://aim4scale_training_25/ground_truth/BMD_daily_combined_0p25.zarr', description='BMD daily:', l‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr', description='IMERG precip:', l‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_24hr.zarr', description='ERA5 precip:', layout=Layout‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr', description='ERA5 temp:', lay‚Ä¶

Button(button_style='primary', description='Load datasets', style=ButtonStyle())

Output()

In [2]:
def _open_any(path: str):
    return xr.open_zarr(path) if path.endswith(".zarr") else xr.open_dataset(path)

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_ascending(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _region_to_0360(region):
    a = float(region["longitude"].start); b = float(region["longitude"].stop)
    return {
        "latitude":  slice(region["latitude"].start, region["latitude"].stop),
        "longitude": slice(a % 360, b % 360),
    }

def _apply_region_safe(ds, region):
    ds1 = _ensure_lat_ascending(_to_0360(ds))
    r0360 = _region_to_0360(region)
    lo = min(r0360["longitude"].start, r0360["longitude"].stop)
    hi = max(r0360["longitude"].start, r0360["longitude"].stop)
    lat_lo = min(region["latitude"].start, region["latitude"].stop)
    lat_hi = max(region["latitude"].start, region["latitude"].stop)
    return ds1.sel(latitude=slice(lat_lo, lat_hi), longitude=slice(lo, hi))

def _normalize_precip_units(da):
    """Unify precip to mm/day."""
    units = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if units in ["m", "meter", "metre", "m of water equivalent"]:
        out = out * 1000.0
        out.attrs["units"] = "mm"
    elif units in ["kg m-2", "kg/m^2", "kg m**-2", "mm"]:
        out.attrs["units"] = "mm"  # 1 kg m-2 ‚âà 1 mm
    else:
        out.attrs["units"] = "mm"  # assume mm if missing
    return out

def _coerce_valid_time(da):
    """Use valid_time for WBX."""
    if "valid_time" in da.dims:
        return da
    if "time" in da.dims:
        return da.rename({"time": "valid_time"})
    # try CF decode
    try:
        dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
        if "time" in dec.dims:
            return dec.rename({"time": "valid_time"})
    except Exception:
        pass
    raise ValueError("No recognizable time/valid_time dimension in array.")

def _time_intersect(a: xr.DataArray, b: xr.DataArray):
    tmin = max(np.datetime64(a.valid_time.min().values), np.datetime64(b.valid_time.min().values))
    tmax = min(np.datetime64(a.valid_time.max().values), np.datetime64(b.valid_time.max().values))
    if tmax < tmin:
        return a.isel(valid_time=slice(0,0)), b.isel(valid_time=slice(0,0))
    return a.sel(valid_time=slice(tmin, tmax)), b.sel(valid_time=slice(tmin, tmax))

def _prep_for_wbx(da: xr.DataArray):
    """WBX expects valid_time or init/lead_time; we give valid_time + lead_time=0."""
    if "time" in da.dims:
        da = da.rename({"time": "valid_time"})
    if "valid_time" not in da.dims:
        raise ValueError("Need 'time' or 'valid_time' dimension.")
    return da.expand_dims({"lead_time": [np.timedelta64(0, "h")]})

def _align_to_reference_grid(src: xr.DataArray, ref: xr.DataArray):
    """Put src on the label grid of ref. Try safe reindex before interp."""
    if "latitude" not in src.coords or "longitude" not in src.coords:
        return src
    s = _ensure_lat_ascending(_to_0360(src))
    r = _ensure_lat_ascending(_to_0360(ref))
    # fast path: identical
    try:
        if np.array_equal(s.latitude.values, r.latitude.values) and np.array_equal(s.longitude.values, r.longitude.values):
            return s
    except Exception:
        pass
    # reindex with small tolerance
    try:
        return s.reindex_like(r, method="nearest", tolerance={"latitude":0.125, "longitude":0.125})
    except Exception:
        pass
    # fallback: interp
    return s.interp(latitude=r["latitude"], longitude=r["longitude"], method="nearest")

def _crop_to_imd_coverage(imd_valid: xr.DataArray):
    """
    Keep only grid cells where IMD has any finite value in the selected window.
    Avoid boolean dask indexers by converting to integer indices.
    """
    spatial_mask = xr.ufuncs.isfinite(imd_valid).any("valid_time")
    lat_sel = spatial_mask.any("longitude").compute().values  # numpy bool
    lon_sel = spatial_mask.any("latitude").compute().values   # numpy bool
    lat_idx = np.where(lat_sel)[0]
    lon_idx = np.where(lon_sel)[0]
    if lat_idx.size == 0 or lon_idx.size == 0:
        return (imd_valid.isel(valid_time=slice(0,0)), 
                imd_valid.latitude.isel(latitude=[]), 
                imd_valid.longitude.isel(longitude=[]))
    imd_crop = imd_valid.isel(latitude=lat_idx, longitude=lon_idx)
    return imd_crop, imd_valid.latitude.isel(latitude=lat_idx), imd_valid.longitude.isel(longitude=lon_idx)

def _series_from_stats(stats, var, metric_name, spatial_dims=("latitude","longitude")):
    if metric_name == "RMSE":
        se = stats["SquaredError"][var]
        series = np.sqrt(se).mean(dim=[d for d in spatial_dims if d in se.dims], skipna=True).squeeze()
        ylabel = "RMSE (mm/day)"
    else:  # MAE
        ae = stats["AbsoluteError"][var]
        series = ae.mean(dim=[d for d in spatial_dims if d in ae.dims], skipna=True).squeeze()
        ylabel = "MAE (mm/day)"
    # drop the dummy lead_time
    if "lead_time" in series.dims and series.sizes.get("lead_time", 1) == 1:
        series = series.isel(lead_time=0)
    # x-dim
    xdim = "valid_time" if "valid_time" in series.dims else "time"
    series = series.dropna(dim=xdim)
    return series, ylabel, xdim

def _summ(ds):
    vs = ", ".join(list(ds.data_vars)[:6])
    return f"vars: {vs} | sizes: {dict(ds.sizes)}"


In [3]:

# ERA5 precip
if "ds_e5_p" not in globals():
    if "ds_era5_tp" in globals() and ds_era5_tp is not None:
        ds_e5_p = ds_era5_tp
    elif "ds_era5" in globals() and ds_era5 is not None and "total_precipitation_24hr" in ds_era5:
        ds_e5_p = ds_era5
    elif "ds_e5" in globals() and ds_e5 is not None and "total_precipitation_24hr" in ds_e5:
        ds_e5_p = ds_e5

# ERA5 temperature (daily tavg/tmax)
if "ds_e5_t" not in globals():
    if "ds_era5_t2m" in globals() and ds_era5_t2m is not None:
        ds_e5_t = ds_era5_t2m
    elif "ds_era5" in globals() and ds_era5 is not None and ("tavg" in ds_era5 or "tmax" in ds_era5):
        ds_e5_t = ds_era5

# IMERG / IMD / BMD: leave as-is, but create shims if your loader used other names
if "ds_imerg" not in globals() and "ds_imerG" in globals():
    ds_imerg = ds_imerG
if "ds_imd" not in globals() and "ds_IMD" in globals():
    ds_imd = ds_IMD
if "ds_bmd" not in globals() and "ds_BMD" in globals():
    ds_bmd = ds_BMD

# define the precip var name once (used in several cells)
if "VAR" not in globals():
    VAR = "total_precipitation_24hr"
def _apply_region_safe(ds: xr.Dataset, region: dict) -> xr.Dataset:
    """
    Slice safely even if lon systems differ. We convert DATA to 0..360, ensure
    latitude ascending, and then slice. If the requested longitude range spans
    the full globe (e.g., 0..360 or -180..180 modulo 360), we DO NOT slice lon.
    """
    # normalize dataset to 0..360 & sorted coords
    out = _ensure_lat_ascending(_to_0360(ds))

    # latitude slice (always safe)
    lat_a = float(region["latitude"].start)
    lat_b = float(region["latitude"].stop)
    lat_lo, lat_hi = (min(lat_a, lat_b), max(lat_a, lat_b))
    if "latitude" in out.coords:
        out = out.sel(latitude=slice(lat_lo, lat_hi))

    # longitude slice ‚Äî detect full-globe requests
    if "longitude" not in out.coords:
        return out

    lon_a = float(region["longitude"].start)
    lon_b = float(region["longitude"].stop)

    # convert requested bounds to 0..360
    lon_a0360 = lon_a % 360.0
    lon_b0360 = lon_b % 360.0
    span = (lon_b0360 - lon_a0360) % 360.0

    # If span == 0, user asked for full globe (e.g. 0..360 or -180..180) ‚Üí no lon slice
    if np.isclose(span, 0.0):
        return out

    # Otherwise slice between lo/hi in 0..360 (no wrap crossing in our region set)
    lon_lo, lon_hi = (min(lon_a0360, lon_b0360), max(lon_a0360, lon_b0360))
    return out.sel(longitude=slice(lon_lo, lon_hi))

def _coerce_valid_time(da: xr.DataArray) -> xr.DataArray:
    if "valid_time" in da.dims:
        return da
    if "time" in da.dims:
        return da.rename({"time": "valid_time"})
    # last-resort CF decode
    try:
        dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
        if "time" in dec.dims:
            return dec.rename({"time":"valid_time"})
    except Exception:
        pass
    return da  # don't crash; caller will fail if truly missing

if "_prep_for_wbx_validtime" not in globals():
    def _prep_for_wbx_validtime(da: xr.DataArray) -> xr.DataArray:
        da = _coerce_valid_time(da)
        want = ["valid_time"] + [d for d in ("latitude","longitude") if d in da.dims]
        da = da.transpose(*want)
        return da.expand_dims({"lead_time": [np.timedelta64(0, "h")]})

if "_prep_for_wbx_initlead" not in globals():
    def _prep_for_wbx_initlead(da: xr.DataArray) -> xr.DataArray:
        da = _coerce_valid_time(da).rename({"valid_time":"init_time"})
        want = ["init_time"] + [d for d in ("latitude","longitude") if d in da.dims]
        da = da.transpose(*want)
        return da.expand_dims({"lead_time": [np.timedelta64(0, "h")]})

# Optional: if your cells call _wbx_series(pred_ds, truth_ds, "RMSE"/"MAE") with Datasets,
# wrap them here too (uses the two helpers above).
try:
    ds_imd_mask
except NameError:
    try:
        IMD_MASK_PATH = "IMD_mask.nc"  # change if your file lives elsewhere
        ds_imd_mask = xr.open_dataset(IMD_MASK_PATH)
        display(Markdown("‚úÖ **Loaded IMD mask** (oceans/out-of-coverage are NaN)"))
    except Exception as e:
        ds_imd_mask = None
        display(Markdown(f"‚ö†Ô∏è Could not load IMD mask: `{e}` ‚Äî maps will fall back to truth-coverage mask."))
if "_wbx_series" not in globals():
    from weatherbenchX.metrics import deterministic
    from weatherbenchX.metrics import base as metrics_base

    def _wbx_series(pred: xr.DataArray, truth: xr.DataArray, metric: str):
        ds_p = xr.Dataset({"var": _prep_for_wbx_validtime(pred.astype("float32"))})
        ds_t = xr.Dataset({"var": _prep_for_wbx_validtime(truth.astype("float32"))})
        if metric.upper() == "RMSE":
            stats = metrics_base.compute_unique_statistics_for_all_metrics({"rmse": deterministic.RMSE()}, ds_p, ds_t)
            se = list(stats["SquaredError"].values())[0]
            series = (se ** 0.5).mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)
            ylabel = "RMSE"
        else:
            stats = metrics_base.compute_unique_statistics_for_all_metrics({"mae": deterministic.MAE()}, ds_p, ds_t)
            ae = list(stats["AbsoluteError"].values())[0]
            series = ae.mean([d for d in ("latitude","longitude") if d in ae.dims], skipna=True)
            ylabel = "MAE"
        if "lead_time" in series.dims and series.sizes.get("lead_time",1) == 1:
            series = series.isel(lead_time=0)
        return series.rename({"valid_time":"date"}).squeeze(), ylabel
# Cell 3 ‚Äî Helpers (units, grid, time, masking, WBX packers)

VAR_P = "total_precipitation_24hr"

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_asc(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _normalize_precip(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["m","meter","metre","m of water equivalent"]:
        da = da * 1000.0
    da.attrs["units"] = "mm/day"
    return da

def _normalize_temp(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["k","kelvin"]:
        da = da - 273.15
    da.attrs["units"] = "¬∞C"
    return da

def _coerce_time(da):
    if "time" in da.dims:
        return da
    if "valid_time" in da.dims:
        return da.rename({"valid_time":"time"})
    try:
        dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
        if "time" in dec.dims:
            return dec
    except Exception:
        pass
    return da

def _align_to_truth_grid(src: xr.DataArray, truth: xr.DataArray) -> xr.DataArray:
    s = _ensure_lat_asc(_to_0360(src))
    t = _ensure_lat_asc(_to_0360(truth))
    try:
        if np.array_equal(s.latitude.values, t.latitude.values) and np.array_equal(s.longitude.values, t.longitude.values):
            return s
    except Exception:
        pass
    try:
        return s.reindex_like(t, method="nearest", tolerance={"latitude":0.125, "longitude":0.125})
    except Exception:
        pass
    s2 = s.sortby(["latitude","longitude"]); t2 = t.sortby(["latitude","longitude"])
    return s2.interp(latitude=t2["latitude"], longitude=t2["longitude"], method="nearest")

def _truth_land_mask(truth_da: xr.DataArray) -> xr.DataArray:
    """Mask (True) where truth has any finite value in the window."""
    return xr.ufuncs.isfinite(truth_da).any("time")

def _apply_imd_mask(imd_da: xr.DataArray, mask_ds: xr.Dataset | None) -> xr.DataArray:
    """Apply your IMD ocean mask (NaNs). Falls back to finite-any if mask missing."""
    if mask_ds is None:
        return imd_da
    # pick a variable from the mask file
    if VAR_P in mask_ds.data_vars:
        m = mask_ds[VAR_P]
    else:
        m = next(iter(mask_ds.data_vars.values()))
    m = _ensure_lat_asc(_to_0360(_coerce_time(m)))
    imd_da = _ensure_lat_asc(_to_0360(_coerce_time(imd_da)))
    # align mask to IMD grid
    try:
        m_on = _align_to_truth_grid(m.isel(time=0) if "time" in m.dims else m, imd_da.isel(time=0))
    except Exception:
        m_on = m
    return imd_da.where(~xr.ufuncs.isnan(m_on))

# --- WeatherBenchX series packers (valid_time + lead_time=0) ---
def _wbx_series(pred: xr.DataArray, truth: xr.DataArray, metric: str):
    def _pack(da):
        da = _coerce_time(da).rename({"time":"valid_time"})
        want = ["valid_time"] + [d for d in ("latitude","longitude") if d in da.dims]
        da = da.transpose(*want)
        return da.expand_dims({"lead_time":[np.timedelta64(0,"h")]})
    ds_p = xr.Dataset({"var": _pack(pred.astype("float32"))})
    ds_t = xr.Dataset({"var": _pack(truth.astype("float32"))})

    if metric.upper() == "RMSE":
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"rmse": deterministic.RMSE()}, ds_p, ds_t)
        se = list(stats["SquaredError"].values())[0]
        s = (se**0.5).mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)
        if "lead_time" in s.dims and s.sizes.get("lead_time",1)==1:
            s = s.isel(lead_time=0)
        return s.rename({"valid_time":"date"}).squeeze(), "RMSE"
    else:
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"mae": deterministic.MAE()}, ds_p, ds_t)
        ae = list(stats["AbsoluteError"].values())[0]
        s = ae.mean([d for d in ("latitude","longitude") if d in ae.dims], skipna=True)
        if "lead_time" in s.dims and s.sizes.get("lead_time",1)==1:
            s = s.isel(lead_time=0)
        return s.rename({"valid_time":"date"}).squeeze(), "MAE"
    




 ## - BMD/ERA5: RMSE/MAE/SEEPS (SEEPS for precip only) Uses WeatherBenchX
 

In [4]:
# ===========================================================
# BMD truth: RMSE/MAE/SEEPS (SEEPS for precip only)
# - Uses WeatherBenchX
# ===========================================================
import numpy as np, pandas as pd, xarray as xr, matplotlib.pyplot as plt, ipywidgets as widgets
from IPython.display import display, Markdown
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base
from weatherbenchX.metrics.categorical import SEEPS as WBX_SEEPS

plt.rcParams.update({"figure.dpi": 120})

# ---------- helpers ----------
def _as_da(obj):
    if isinstance(obj, xr.DataArray):
        return obj
    if isinstance(obj, xr.Dataset):
        if not obj.data_vars:
            raise ValueError("Empty Dataset where DataArray expected.")
        return obj[next(iter(obj.data_vars))]
    raise TypeError(f"Expected DataArray/Dataset, got {type(obj)}")

ALIASES = {
    "precip": ["total_precipitation_24hr", "tp_24h", "tp_daily", "precip", "rain"],
    "tavg":   ["tavg", "t2m_mean", "t2m_daily_mean", "daily_mean_temperature", "tas_mean", "tmean"],
    "tmax":   ["tmax", "t2m_max",  "t2m_daily_max",  "daily_max_temperature", "tasmax",  "tmax_mean"],
}
LABEL  = {"precip": "Precip (mm/day)", "tavg": "Temp avg (¬∞C)", "tmax": "Temp max (¬∞C)"}

def _first_present(ds, names):
    if ds is None: return None
    for n in names:
        if n in ds.data_vars: return n
    return None

def _normalize_precip_units(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["m","meter","metre","m of water equivalent"]:
        da = da * 1000.0
    da.attrs["units"] = "mm/day"
    return da

def _normalize_temp_units_C(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if "k" in u and "pa" not in u:
        da = da - 273.15
    da.attrs["units"] = "¬∞C"
    return da

def _norm_by_key(key, da):
    return _normalize_precip_units(da) if key=="precip" else _normalize_temp_units_C(da)

def _coerce_valid_time(da):
    if "valid_time" in da.dims: return da
    if "time" in da.dims:       return da.rename({"time":"valid_time"})
    dec = xr.decode_cf(da.to_dataset(name="_tmp"), use_cftime=False).to_array("_tmp")
    if "time" in dec.dims:      return dec.rename({"time":"valid_time"})
    raise ValueError("No time/valid_time dimension found.")

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_asc(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _align_to_reference_grid(src: xr.DataArray, ref: xr.DataArray) -> xr.DataArray:
    s = _ensure_lat_asc(_to_0360(src)); r = _ensure_lat_asc(_to_0360(ref))
    try:
        if np.array_equal(s.latitude.values, r.latitude.values) and np.array_equal(s.longitude.values, r.longitude.values):
            return s
    except Exception:
        pass
    try:
        return s.reindex_like(r, method="nearest", tolerance={"latitude":0.125,"longitude":0.125})
    except Exception:
        pass
    s2 = s.sortby(["latitude","longitude"]); r2 = r.sortby(["latitude","longitude"])
    return s2.interp(latitude=r2["latitude"], longitude=r2["longitude"], method="nearest")

def _prep_for_wbx_validtime(da: xr.DataArray) -> xr.DataArray:
    da = _coerce_valid_time(da)
    want = ["valid_time"] + [d for d in ("latitude","longitude") if d in da.dims]
    da = da.transpose(*want)
    return da.expand_dims({"lead_time":[np.timedelta64(0,"h")]})

def _prep_for_wbx_initlead(da: xr.DataArray) -> xr.DataArray:
    da = _coerce_valid_time(da).rename({"valid_time":"init_time"})
    want = ["init_time"] + [d for d in ("latitude","longitude") if d in da.dims]
    da = da.transpose(*want).expand_dims({"lead_time":[np.timedelta64(0,"h")]})
    return da.transpose("lead_time","init_time",*(d for d in ("latitude","longitude") if d in da.dims))

def _wbx_series(pred_da: xr.DataArray, truth_da: xr.DataArray, metric_name: str):
    pred_da  = _as_da(pred_da)
    truth_da = _as_da(truth_da)
    pred_ds  = xr.Dataset({"var": _prep_for_wbx_validtime(pred_da)})
    truth_ds = xr.Dataset({"var": _prep_for_wbx_validtime(truth_da)})
    metrics  = {"rmse": deterministic.RMSE()} if metric_name=="RMSE" else {"mae": deterministic.MAE()}
    stats    = metrics_base.compute_unique_statistics_for_all_metrics(metrics, pred_ds, truth_ds)
    if metric_name == "RMSE":
        se = stats["SquaredError"]["var"]
        s  = (se**0.5).mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)
        ylab = "mm/day"
    else:
        ae = stats["AbsoluteError"]["var"]
        s  = ae.mean([d for d in ("latitude","longitude") if d in ae.dims], skipna=True)
        # unit label by variable type
        ylab = "mm/day" if "precip" in str(truth_da.name).lower() else "¬∞C"
    s = s.isel(lead_time=0).rename({"valid_time":"date"}).squeeze()
    return s, ylab

def _series_from_seeps(stats: dict):
    key = [k for k in stats.keys() if k.startswith("SEEPS") or k=="SEEPS"]
    if not key: key = [list(stats.keys())[0]]
    da = list(stats[key[0]].values())[0]
    if "lead_time" in da.dims and da.sizes.get("lead_time",1)==1:
        da = da.isel(lead_time=0)
    s = da.mean([d for d in ("latitude","longitude") if d in da.dims], skipna=True)
    xdim = "init_time" if "init_time" in s.dims else ("valid_time" if "valid_time" in s.dims else "time")
    return s.squeeze().dropna(dim=xdim), xdim

def _doy_mean(series: xr.DataArray, xdim_hint: str | None = None):
    xdim = xdim_hint or ("init_time" if "init_time" in series.dims else "valid_time" if "valid_time" in series.dims else None)
    if xdim is None or series.sizes.get(xdim, 0) == 0:
        return series, (xdim_hint or "time")
    s = series.groupby(f"{xdim}.dayofyear").mean(skipna=True)
    s = s.rename({"dayofyear":"DOY"}).assign_coords(DOY=s["DOY"].astype(int))
    return s, "DOY"

def _to_monthly(series: xr.DataArray, time_dim: str) -> xr.DataArray:
    """Aggregate a daily series to monthly means for plotting by month."""
    ts = pd.to_datetime(series[time_dim].values)
    s2 = series.assign_coords({time_dim: ts}).resample({time_dim: "MS"}).mean(skipna=True)
    return s2

# ---------- inputs UI (unchanged) ----------
BMD_PATH_DEFAULT     = "gs://aim4scale_training_25/ground_truth/BMD_daily_combined_0p25.zarr"
ERA5_TP_PATH_DEFAULT = "gs://aim4scale_training_25/ground_truth/era5_24hr.zarr"
ERA5_T2M_PATH_DEFAULT= "gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr"
IMERG_PATH_DEFAULT   = "gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr"

bmd_txt   = widgets.Text(value=BMD_PATH_DEFAULT,     description="BMD:",    layout=widgets.Layout(width="95%"))
era5p_txt = widgets.Text(value=ERA5_TP_PATH_DEFAULT, description="ERA5 tp:",layout=widgets.Layout(width="95%"))
era5t_txt = widgets.Text(value=ERA5_T2M_PATH_DEFAULT,description="ERA5 t2m:",layout=widgets.Layout(width="95%"))
imerg_txt = widgets.Text(value=IMERG_PATH_DEFAULT,   description="IMERG:",  layout=widgets.Layout(width="95%"))
metric_dd = widgets.Dropdown(options=["RMSE","MAE","SEEPS"], description="Metric:", value="RMSE")
doy_avg_cb= widgets.Checkbox(value=False, description="Average by DOY (multi-year mean line)")
t0_txt    = widgets.Text(value="2018-06-01", description="Start (YYYY-MM-DD):")
t1_txt    = widgets.Text(value="2019-06-30", description="End   (YYYY-MM-DD):")
btn_run   = widgets.Button(description="Run (BMD truth, land-only)", button_style="success")
out_all   = widgets.Output()
display(bmd_txt, era5p_txt, era5t_txt, imerg_txt, metric_dd, doy_avg_cb, t0_txt, t1_txt, btn_run, out_all)

# ---------- main ----------
ds_bmd = ds_era5_tp = ds_era5_t2m = ds_imerg = None

def _load(path):
    p = path.strip()
    if not p: return None
    return xr.open_zarr(p) if p.endswith(".zarr") else xr.open_dataset(p)

@out_all.capture(clear_output=True)
def _run_all(_):
    global ds_bmd, ds_era5_tp, ds_era5_t2m, ds_imerg
    if ds_bmd      is None: ds_bmd      = _load(bmd_txt.value)
    if ds_era5_tp  is None: ds_era5_tp  = _load(era5p_txt.value)
    if ds_era5_t2m is None: ds_era5_t2m = _load(era5t_txt.value)
    if ds_imerg    is None: ds_imerg    = _load(imerg_txt.value)

    if ds_bmd is None:
        print("‚ùå Need BMD."); return
    if ds_era5_tp is None and ds_era5_t2m is None and ds_imerg is None:
        print("‚ùå Need at least one predictor (ERA5 tp / ERA5 t2m / IMERG)."); return

    mapping = {}
    for key, aliases in ALIASES.items():
        mapping[key] = {
            "bmd":   _first_present(ds_bmd, aliases),
            "era5":  _first_present(ds_era5_tp if key=="precip" else ds_era5_t2m, aliases),
            "imerg": _first_present(ds_imerg, aliases) if key=="precip" else None,
            "_era5_src": "tp" if key=="precip" else "t2m",
        }

    keys = [k for k in ["precip","tavg","tmax"]
            if mapping[k]["bmd"] and (mapping[k]["era5"] is not None or mapping[k]["imerg"] is not None)]
    if not keys:
        print("‚ùå No overlapping variables. Check variable names."); return

    metric = metric_dd.value
    t0 = np.datetime64(pd.to_datetime(t0_txt.value).normalize())
    t1 = np.datetime64(pd.to_datetime(t1_txt.value).normalize())

    display(Markdown(f"**Running** {', '.join(LABEL[k] for k in keys)}  \n"
                     f"Metric: `{metric}` ‚Äî Window: `{str(t0)}` ‚Üí `{str(t1)}`  \n"
                     f"Mask: **BMD land/coverage (finite-any in window)**"))

    for key in keys:
        vb = mapping[key]["bmd"]
        ve = mapping[key]["era5"]
        vi = mapping[key]["imerg"]
        era5_src = mapping[key]["_era5_src"]
        display(Markdown(f"### {LABEL[key]}"))

        B = _norm_by_key(key, ds_bmd[vb]).sel(time=slice(t0, t1))
        if B.sizes.get("time",0)==0 and B.sizes.get("valid_time",0)==0:
            display(Markdown("‚ö†Ô∏è No BMD data in this window.")); 
            continue
        B = _coerce_valid_time(B)

        preds = []
        if ve is not None:
            E = (_norm_by_key(key, (ds_era5_tp if era5_src=="tp" else ds_era5_t2m)[ve])
                 .sel(time=slice(t0, t1)))
            E = _coerce_valid_time(E)
            preds.append(("ERA5", E))
        if vi is not None:
            I = _coerce_valid_time(_norm_by_key(key, ds_imerg[vi]).sel(time=slice(t0, t1)))
            preds.append(("IMERG", I))

        if not preds:
            display(Markdown("‚ö†Ô∏è No predictor for this variable.")); 
            continue

        for label, P in preds:
            P_on = _align_to_reference_grid(P, B)

            tmin = max(np.datetime64(B.valid_time.min().values), np.datetime64(P_on.valid_time.min().values))
            tmax = min(np.datetime64(B.valid_time.max().values), np.datetime64(P_on.valid_time.max().values))
            if tmax < tmin:
                print(f"‚ö†Ô∏è No time overlap with {label}."); 
                continue
            B_use = B.sel(valid_time=slice(tmin, tmax))
            P_use = P_on.sel(valid_time=slice(tmin, tmax))

            mask = xr.ufuncs.isfinite(B_use).any(dim="valid_time")
            B_mask = B_use.where(mask)
            P_mask = P_use.where(mask)

            if metric == "SEEPS" and key == "precip":
                P_wbx = _prep_for_wbx_initlead(P_mask)
                B_wbx = _prep_for_wbx_initlead(B_mask)

                # Build SEEPS climatology from FULL BMD on the evaluation grid
                full_bmd_all = _normalize_precip_units(_ensure_lat_asc(_to_0360(ds_bmd[vb])))
                full_bmd_region = full_bmd_all.sel(latitude=B_mask.latitude, longitude=B_mask.longitude)

                # build BOTH fields (threshold + dry_fraction)
                dry_thr = 0.25
                wet = full_bmd_region.where(full_bmd_region >= dry_thr)
                thr = wet.groupby("time.dayofyear").quantile(0.5, dim="time", skipna=True)
                if "quantile" in thr.dims: thr = thr.sel(quantile=0.5, drop=True)
                dry = (full_bmd_region < dry_thr).groupby("time.dayofyear").mean("time", skipna=True)
                all_doy = np.arange(1, 367, dtype=np.int16)
                thr = thr.rename({d:"dayofyear" for d in thr.dims if d.endswith("dayofyear")}).reindex(dayofyear=all_doy).fillna(dry_thr)
                dry = dry.rename({d:"dayofyear" for d in dry.dims if d.endswith("dayofyear")}).reindex(dayofyear=all_doy).fillna(0.0).clip(0,1)
                hours = np.arange(24, dtype=np.int16)
                thr = thr.expand_dims(hour=hours).astype("float32").rename("total_precipitation_24hr_seeps_threshold")
                dry = dry.expand_dims(hour=hours).astype("float32").rename("total_precipitation_24hr_seeps_dry_fraction")
                clim_ds = xr.Dataset({"total_precipitation_24hr_seeps_threshold": thr,
                                      "total_precipitation_24hr_seeps_dry_fraction": dry})

                seeps_metric = WBX_SEEPS(variables=["total_precipitation_24hr"], climatology=clim_ds,
                                         dry_threshold_mm=dry_thr, min_p1=0.10, max_p1=0.85)

                stats = metrics_base.compute_unique_statistics_for_all_metrics(
                    {"seeps": seeps_metric},
                    xr.Dataset({"total_precipitation_24hr": P_wbx}),
                    xr.Dataset({"total_precipitation_24hr": B_wbx})
                )

                series, xdim = _series_from_seeps(stats)
                series = series.clip(min=0.0, max=1.0)

                # monthly mean for plotting
                series_m = _to_monthly(series, xdim)

                fig, ax = plt.subplots(figsize=(9, 3.6))
                series_m.plot(ax=ax, x=series_m.dims[0])
                ax.set_title(f"SEEPS ‚Äî {label} vs BMD ‚Äî total_precipitation_24hr (monthly)")
                ax.set_ylabel("SEEPS (0 best)"); ax.grid(True, alpha=0.3); plt.show()

                try:
                    display(Markdown(f"**Mean SEEPS over window ‚Äî {label}:** `{float(series_m.mean().values):.3f}`"))
                except Exception:
                    pass

            else:
                use_metric = metric if metric in ("RMSE","MAE") else "RMSE"
                if metric == "SEEPS" and key != "precip":
                    print("‚ÑπÔ∏è SEEPS is precipitation-only; using RMSE for temperature.")

                series, ylab = _wbx_series(P_mask, B_mask, use_metric)
                # monthly mean
                series_m = _to_monthly(series, "date")

                fig, ax = plt.subplots(figsize=(9, 3.6))
                series_m.plot(ax=ax, x="date")
                ax.set_title(f"{use_metric} ‚Äî {label} vs BMD ‚Äî {(vb if key!='precip' else 'total_precipitation_24hr')} (monthly)")
                ax.set_ylabel(ylab); ax.set_xlabel("month"); ax.grid(True, alpha=0.3)
                plt.show()

                try:
                    mval = float(series_m.mean().values)
                    display(Markdown(f"**Mean {use_metric} over window ‚Äî {label}:** `{mval:.3f} {ylab}`"))
                except Exception:
                    pass

btn_run.on_click(_run_all)


Text(value='gs://aim4scale_training_25/ground_truth/BMD_daily_combined_0p25.zarr', description='BMD:', layout=‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_24hr.zarr', description='ERA5 tp:', layout=Layout(wid‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr', description='ERA5 t2m:', layo‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr', description='IMERG:', layout=L‚Ä¶

Dropdown(description='Metric:', options=('RMSE', 'MAE', 'SEEPS'), value='RMSE')

Checkbox(value=False, description='Average by DOY (multi-year mean line)')

Text(value='2018-06-01', description='Start (YYYY-MM-DD):')

Text(value='2019-06-30', description='End   (YYYY-MM-DD):')

Button(button_style='success', description='Run (BMD truth, land-only)', style=ButtonStyle())

Output()

In [5]:
# Cell 3 ‚Äî Helpers (units, grids, time, masking, WBX series)

VAR_P = "total_precipitation_24hr"

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_asc(obj):
    for latn in ("latitude","lat"):
        if latn in obj.coords:
            lat = obj[latn].values
            if len(lat) > 1 and lat[0] > lat[-1]:
                obj = obj.sortby(latn)
    return obj

def _normalize_precip(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if u in ["m","meter","metre","m of water equivalent"]:
        out = out * 1000.0
    out.attrs["units"] = "mm/day"
    return out

def _normalize_temp(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if u in ["k","kelvin"]:
        out = out - 273.15
    out.attrs["units"] = "¬∞C"
    return out

def _coerce_time(da):
    if "time" in da.dims: return da
    try:
        dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
        if "time" in dec.dims: return dec
    except Exception:
        pass
    if "valid_time" in da.dims:
        return da.rename({"valid_time":"time"})
    return da

def _align_to_truth_grid(src: xr.DataArray, truth: xr.DataArray) -> xr.DataArray:
    s = _ensure_lat_asc(_to_0360(src))
    t = _ensure_lat_asc(_to_0360(truth))
    try:
        if np.array_equal(s.latitude.values, t.latitude.values) and np.array_equal(s.longitude.values, t.longitude.values):
            return s
    except Exception:
        pass
    try:
        return s.reindex_like(t, method="nearest", tolerance={"latitude":0.125, "longitude":0.125})
    except Exception:
        pass
    s2 = s.sortby(["latitude","longitude"]); t2 = t.sortby(["latitude","longitude"])
    return s2.interp(latitude=t2["latitude"], longitude=t2["longitude"], method="nearest")

def _truth_land_mask(truth_da: xr.DataArray) -> xr.DataArray:
    return xr.ufuncs.isfinite(truth_da).any("time")

def _apply_imd_mask(imd_da: xr.DataArray, mask_ds: xr.Dataset) -> xr.DataArray:
    """Apply your ocean mask: keep land/coverage where mask is finite."""
    # If mask is a Dataset, pick the first variable or the precip var name.
    m = None
    if VAR_P in mask_ds.data_vars:
        m = mask_ds[VAR_P]
    else:
        m = next(iter(mask_ds.data_vars.values()))
    m = _ensure_lat_asc(_to_0360(_coerce_time(m)))
    imd_da = _ensure_lat_asc(_to_0360(_coerce_time(imd_da)))
    m_on = _align_to_truth_grid(m.isel(time=0) if "time" in m.dims else m, imd_da.isel(time=0))
    return imd_da.where(~xr.ufuncs.isnan(m_on))

# --- WeatherBenchX: robust RMSE/MAE spatial-mean series (time on x-axis) ---
def _wbx_series(pred: xr.DataArray, truth: xr.DataArray, metric: str):
    # WBX wants Datasets with a var name; we'll use "var" and valid_time + lead_time=0
    def _wbx_pack(da):
        da = _coerce_time(da).rename({"time":"valid_time"})
        want = ["valid_time"] + [d for d in ("latitude","longitude") if d in da.dims]
        da = da.transpose(*want)
        return da.expand_dims({"lead_time":[np.timedelta64(0,"h")]})
    pred_ds  = xr.Dataset({"var": _wbx_pack(pred.astype("float32"))})
    truth_ds = xr.Dataset({"var": _wbx_pack(truth.astype("float32"))})

    if metric.upper() == "RMSE":
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"rmse": deterministic.RMSE()}, pred_ds, truth_ds)
        # Use SquaredError block (stable across WBX versions)
        se = list(stats["SquaredError"].values())[0]
        s  = (se ** 0.5).mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)
        if "lead_time" in s.dims and s.sizes.get("lead_time",1)==1: s = s.isel(lead_time=0)
        return s.rename({"valid_time":"date"}).squeeze(), "RMSE"
    else:
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"mae": deterministic.MAE()}, pred_ds, truth_ds)
        ae = list(stats["AbsoluteError"].values())[0]
        s  = ae.mean([d for d in ("latitude","longitude") if d in ae.dims], skipna=True)
        if "lead_time" in s.dims and s.sizes.get("lead_time",1)==1: s = s.isel(lead_time=0)
        return s.rename({"valid_time":"date"}).squeeze(), "MAE"



 ## - IMD/ERA5: RMSE/MAE/SEEPS Uses WeatherBenchX
 

In [6]:
# IMD panel 

VAR = VAR_P  # precipitation variable key

# widgets (added: month/year splits)
region_dd = widgets.Dropdown(options=["IMD-Coverage"], value="IMD-Coverage", description="Region:")
metric_dd = widgets.Dropdown(options=["RMSE","MAE","SEEPS"], value="RMSE", description="Metric:")
start_month = widgets.Dropdown(options=list(range(1,13)), value=6, description="Start month")
end_month   = widgets.Dropdown(options=list(range(1,13)), value=6, description="End month")
start_year  = widgets.BoundedIntText(value=2018, min=1900, max=2100, description="Start year")
end_year    = widgets.BoundedIntText(value=2018, min=1900, max=2100, description="End year")
run_btn = widgets.Button(description="Run IMD verification (per-year)", button_style="success")
out = widgets.Output()

display(region_dd, metric_dd, widgets.HBox([start_month, end_month, start_year, end_year]), run_btn, out)

def _one_year_range(y, m0, m1):
    t0 = np.datetime64(pd.Timestamp(year=y, month=m0, day=1))
    t1 = np.datetime64(pd.Timestamp(year=y, month=m1, day=1) + pd.offsets.MonthEnd(1))
    return t0, t1

def _safe_draw(fn):
    with out:
        out.clear_output(wait=True); plt.close('all'); fn()

def on_run(_=None):
    def _do():
        if ds_imerg is None or ds_imd is None or ds_e5_p is None:
            print("‚ùå Please load datasets first."); return
        if VAR not in ds_imerg or VAR not in ds_imd or VAR not in ds_e5_p:
            print(f"‚ùå `{VAR}` must exist in IMERG, IMD, and ERA5."); return

        m0, m1 = int(start_month.value), int(end_month.value)
        y0, y1 = int(start_year.value),  int(end_year.value)
        metric = metric_dd.value

        # Build IMD truth with your mask applied
        imd_full = _normalize_precip(_coerce_time(ds_imd[VAR]))
        imd_full = _apply_imd_mask(imd_full, ds_imd_mask)

        for y in range(y0, y1+1):
            t0, t1 = _one_year_range(y, m0, m1)
            imd = imd_full.sel(time=slice(t0, t1))
            if imd.sizes.get("time",0) == 0:
                print(f"‚ö†Ô∏è IMD empty for {y}-{m0}..{m1}"); 
                continue

            # region = IMD coverage (already trimmed by mask)
            # align predictors to IMD
            imerg = _normalize_precip(_coerce_time(ds_imerg[VAR])).sel(time=slice(t0, t1))
            era5  = _normalize_precip(_coerce_time(ds_e5_p[VAR])).sel(time=slice(t0, t1))
            imerg_on = _align_to_truth_grid(imerg, imd)
            era5_on  = _align_to_truth_grid(era5,  imd)

            # time overlap
            def _inter(a,b):
                T0 = max(np.datetime64(a.time.min().values), np.datetime64(b.time.min().values))
                T1 = min(np.datetime64(a.time.max().values), np.datetime64(b.time.max().values))
                return a.sel(time=slice(T0,T1)), b.sel(time=slice(T0,T1))
            imerg_use, imd_i = _inter(imerg_on, imd)
            era5_use,  imd_e = _inter(era5_on,  imd)

            # mask oceans/out-of-coverage by IMD finite-any
            land_mask_i = xr.ufuncs.isfinite(imd_i).any("time")
            land_mask_e = xr.ufuncs.isfinite(imd_e).any("time")
            imerg_use = imerg_use.where(land_mask_i); imd_i = imd_i.where(land_mask_i)
            era5_use  = era5_use .where(land_mask_e); imd_e = imd_e.where(land_mask_e)

            if metric in ("RMSE","MAE"):
                # WBX deterministic series
                s_imerg, _ = _wbx_series(imerg_use, imd_i, metric)
                s_era5,  _ = _wbx_series(era5_use,  imd_e, metric)

                # two separate plots
                fig1, ax1 = plt.subplots(figsize=(9, 3.8))
                ax1.plot(pd.to_datetime(s_imerg["date"].values), s_imerg.values)
                ax1.set_title(f"{metric} ‚Äî IMERG vs IMD ‚Äî {y}-{m0}..{m1}")
                ax1.set_ylabel("mm/day"); ax1.grid(True, alpha=0.3); plt.show()

                fig2, ax2 = plt.subplots(figsize=(9, 3.8))
                ax2.plot(pd.to_datetime(s_era5["date"].values), s_era5.values)
                ax2.set_title(f"{metric} ‚Äî ERA5 vs IMD ‚Äî {y}-{m0}..{m1}")
                ax2.set_ylabel("mm/day"); ax2.grid(True, alpha=0.3); plt.show()

                print(f"Mean {metric} ‚Äî IMERG({y}): {float(s_imerg.mean()):.3f} mm/day")
                print(f"Mean {metric} ‚Äî ERA5 ({y}): {float(s_era5.mean()):.3f} mm/day")

            else:  # SEEPS (precip only)
                # WBX SEEPS needs climatology; build from full masked IMD on same grid
                # convert to init_time + lead_time=0
                def _to_initlead(da):
                    da = da.rename({"time":"init_time"})
                    return da.expand_dims({"lead_time":[np.timedelta64(0,"h")]})
                imerg_w = _to_initlead(imerg_use); era5_w = _to_initlead(era5_use); imd_w = _to_initlead(imd_i)

                # build climatology (dayofyear √ó hour) from entire masked IMD record on this grid
                fullB = imd_full.sel(latitude=imd_i.latitude, longitude=imd_i.longitude)
                # wet threshold (q=0.5) + dry fraction at 0.25 mm/day
                dry_thr = 0.25
                wet = fullB.where(fullB >= dry_thr)
                thr = wet.groupby("time.dayofyear").quantile(0.5, dim="time", skipna=True)
                if "quantile" in thr.dims: thr = thr.sel(quantile=0.5, drop=True)
                dry = (fullB < dry_thr).groupby("time.dayofyear").mean("time", skipna=True)
                hours = np.arange(24, dtype=np.int16)
                thr = thr.rename({k:"dayofyear" for k in thr.dims if k.endswith("dayofyear") or k=="time_dayofyear"}).reindex(dayofyear=np.arange(1,367)).fillna(dry_thr).expand_dims(hour=hours).astype("float32")
                dry = dry.rename({k:"dayofyear" for k in dry.dims if k.endswith("dayofyear") or k=="time_dayofyear"}).reindex(dayofyear=np.arange(1,367)).fillna(0.0).expand_dims(hour=hours).clip(0,1).astype("float32")
                clim = xr.Dataset({
                    f"{VAR}_seeps_threshold": thr.rename(f"{VAR}_seeps_threshold"),
                    f"{VAR}_seeps_dry_fraction": dry.rename(f"{VAR}_seeps_dry_fraction"),
                })

                seeps = WBX_SEEPS(variables=[VAR], climatology=clim, dry_threshold_mm=dry_thr, min_p1=0.10, max_p1=0.85)

                stats_i = metrics_base.compute_unique_statistics_for_all_metrics({"seeps": seeps},
                          xr.Dataset({VAR: imerg_w}), xr.Dataset({VAR: imd_w}))
                stats_e = metrics_base.compute_unique_statistics_for_all_metrics({"seeps": seeps},
                          xr.Dataset({VAR: era5_w}),  xr.Dataset({VAR: imd_w}))

                s_imerg = list(stats_i["SEEPS"].values())[0].isel(lead_time=0).mean(["latitude","longitude"], skipna=True)
                s_era5  = list(stats_e["SEEPS"].values())[0].isel(lead_time=0).mean(["latitude","longitude"], skipna=True)

                fig1, ax1 = plt.subplots(figsize=(9, 3.8))
                ax1.plot(pd.to_datetime(s_imerg["init_time"].values), s_imerg.values)
                ax1.set_title(f"SEEPS ‚Äî IMERG vs IMD ‚Äî {y}-{m0}..{m1}"); ax1.set_ylabel("SEEPS (0 best)"); ax1.grid(True, alpha=0.3); plt.show()

                fig2, ax2 = plt.subplots(figsize=(9, 3.8))
                ax2.plot(pd.to_datetime(s_era5["init_time"].values), s_era5.values)
                ax2.set_title(f"SEEPS ‚Äî ERA5 vs IMD ‚Äî {y}-{m0}..{m1}"); ax2.set_ylabel("SEEPS (0 best)"); ax2.grid(True, alpha=0.3); plt.show()

                print(f"Mean SEEPS ‚Äî IMERG({y}): {float(s_imerg.mean()):.3f}")
                print(f"Mean SEEPS ‚Äî ERA5 ({y}): {float(s_era5.mean()):.3f}")

    _safe_draw(_do)

run_btn.on_click(on_run)


Dropdown(description='Region:', options=('IMD-Coverage',), value='IMD-Coverage')

Dropdown(description='Metric:', options=('RMSE', 'MAE', 'SEEPS'), value='RMSE')

HBox(children=(Dropdown(description='Start month', index=5, options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), v‚Ä¶

Button(button_style='success', description='Run IMD verification (per-year)', style=ButtonStyle())

Output()

## - Spatial bias maps ‚Äî IMD truth (May‚ÄìJuly mean, masked to IMD coverage)

In [7]:
# IMD bias maps ‚Äî same cell, but: IMD mask, no smoothing, correct orientation

years_txt   = widgets.Text(value="2018-2021", description="Years (YYYY-YYYY):")
vmax_txt    = widgets.Text(value="50",        description="Abs max (mm)")
show_era5   = widgets.Checkbox(value=True,    description="ERA5 vs IMD")
show_imerg  = widgets.Checkbox(value=True,    description="IMERG vs IMD")
btn_run_map = widgets.Button(description="Run bias maps", button_style="info")
out_map     = widgets.Output()


display(widgets.HBox([years_txt, vmax_txt, show_era5, show_imerg, btn_run_map]), out_map)

def _mjj_years(da, y0, y1):
    da = da.sel(time=da.time.dt.month.isin([5,6,7]))
    return da.sel(time=slice(f"{y0}-01-01", f"{y1}-12-31"))

@out_map.capture(clear_output=True)
def _run_bias_maps(_):
    if any(v is None for v in [ds_imd, ds_imerg, ds_e5_p]):
        print("‚ùå Please load IMD, IMERG, ERA5 first."); return
    if VAR not in ds_imd or VAR not in ds_imerg or VAR not in ds_e5_p:
        print(f"‚ùå `{VAR}` must exist in IMD, IMERG, ERA5."); return

    try:
        y0, y1 = [int(x) for x in years_txt.value.strip().split("-")]
    except Exception:
        print("‚ö†Ô∏è Could not parse years, using full overlap."); y0, y1 = (1900, 2100)

    # IMD truth with mask (excludes oceans)
    imd   = _apply_imd_mask(_normalize_precip(_coerce_time(ds_imd[VAR])), ds_imd_mask)
    imerg = _normalize_precip(_coerce_time(ds_imerg[VAR]))
    era5  = _normalize_precip(_coerce_time(ds_e5_p[VAR]))

    # Select MJJ + years (then auto-overlap)
    imd_mjj   = _mjj_years(imd,   y0, y1)
    imerg_mjj = _mjj_years(imerg, y0, y1)
    era5_mjj  = _mjj_years(era5,  y0, y1)

    # Align to IMD grid + intersection in time
    imerg_on = _align_to_truth_grid(imerg_mjj, imd_mjj)
    era5_on  = _align_to_truth_grid(era5_mjj,  imd_mjj)

    t0 = max(np.datetime64(imd_mjj.time.min().values),
             np.datetime64(imerg_on.time.min().values),
             np.datetime64(era5_on.time.min().values))
    t1 = min(np.datetime64(imd_mjj.time.max().values),
             np.datetime64(imerg_on.time.max().values),
             np.datetime64(era5_on.time.max().values))
    imd_use   = imd_mjj.sel(time=slice(t0, t1))
    imerg_use = imerg_on.sel(time=slice(t0, t1))
    era5_use  = era5_on.sel(time=slice(t0, t1))

    # land/coverage mask (finite-any over window)
    land_mask = xr.ufuncs.isfinite(imd_use).any("time")
    imerg_bias = (imerg_use - imd_use).where(land_mask).mean("time", skipna=True)
    era5_bias  = (era5_use  - imd_use).where(land_mask).mean("time",  skipna=True)

    vmax = float(vmax_txt.value or 50.0); vmin = -vmax
    ncols = int(show_era5.value) + int(show_imerg.value)
    if ncols == 0:
        print("‚ö†Ô∏è Select at least one predictor."); return

    fig, axs = plt.subplots(1, ncols, figsize=(6.0*ncols, 4.8), constrained_layout=True)
    if ncols == 1: axs = [axs]
    i = 0
    if show_era5.value:
        era5_bias.plot(ax=axs[i], x="longitude", y="latitude", vmin=vmin, vmax=vmax,
                       cmap="RdBu_r", cbar_kwargs={"label":"mm/day"})
        axs[i].set_title("Bias: ERA5 ‚àí IMD (mm/day)"); axs[i].grid(True, alpha=0.2); i += 1
    if show_imerg.value:
        imerg_bias.plot(ax=axs[i], x="longitude", y="latitude", vmin=vmin, vmax=vmax,
                        cmap="RdBu_r", cbar_kwargs={"label":"mm/day"})
        axs[i].set_title("Bias: IMERG ‚àí IMD (mm/day)"); axs[i].grid(True, alpha=0.2)

    y0p = pd.to_datetime(str(imd_use.time.min().values)).year
    y1p = pd.to_datetime(str(imd_use.time.max().values)).year
    plt.show()

btn_run_map.on_click(_run_bias_maps)



HBox(children=(Text(value='2018-2021', description='Years (YYYY-YYYY):'), Text(value='50', description='Abs ma‚Ä¶

Output()