
# üß™ Demo 5: Ground Truth Challenge ‚Äî Validate Your Forecast with Your Data (Demo 4‚Äìstyle)

This notebook mirrors **Demo 4** (widgets, regions, WeatherBench‚ÄëX), and adds the ability to plug in **your own ground truth**:
- **Default**: ERA5 from Google Cloud Storage (no downloads)
- **Optional**: Point **CSV** (`time,lat,lon,value`) with time tolerance
- **Optional**: **Gridded** NetCDF/Zarr file (local path or `gs://‚Ä¶`)

You can compare **Global vs Local** regions and compute **RMSE** (and **ACC** if GT supports climatology).

> Tip: start with ERA5 as truth; once it runs, try your CSV/NetCDF.


In [1]:
# --- Imports & config ---
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown

import weatherbenchX
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base

plt.rcParams.update({"figure.dpi": 120})

# --- YOUR DEFAULT PATHS (edit if needed) ---
FORECAST_DEFAULT = "gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr"   # or IMD_rainfall_0p25.zarr
ERA5_DAILY_DEFAULT = "gs://aim4scale_training_25/ground_truth/era5_24hr.zarr"            # your ERA5-daily Zarr

FORECAST_VAR = "total_precipitation_24hr"   # both IMERG/IMD expose this name in your bucket

# --- Regions you asked for ---
REGIONS = {
    "Global":     {"latitude": slice( 90, -90), "longitude": slice(  0, 360)},
    "Ethiopia":   {"latitude": slice(14.9,  3.4), "longitude": slice(33.0, 48.0)},
    "Nigeria":    {"latitude": slice(14.7,  4.0), "longitude": slice( 2.7, 14.7)},
    "Kenya":      {"latitude": slice( 5.0, -4.7), "longitude": slice(33.9, 41.9)},
    "Bangladesh": {"latitude": slice(26.7, 20.7), "longitude": slice(88.0, 92.7)},
    "Chile":      {"latitude": slice(-17.5,-56.0), "longitude": slice(284.0, 294.0)},
}

# ---------- small utilities ----------
def _open_any(path: str):
    return xr.open_zarr(path) if path.endswith(".zarr") else xr.open_dataset(path)

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_ascending(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _region_to_0360(region):
    a = float(region["longitude"].start); b = float(region["longitude"].stop)
    return {
        "latitude": slice(region["latitude"].start, region["latitude"].stop),
        "longitude": slice(a % 360, b % 360),
    }

def _apply_region_safe(ds, region):
    """Slice safely even if lon systems differ. We convert DATA to 0..360, then slice with 0..360 bounds."""
    ds1 = _ensure_lat_ascending(_to_0360(ds))
    r0360 = _region_to_0360(region)
    lo = min(r0360["longitude"].start, r0360["longitude"].stop)
    hi = max(r0360["longitude"].start, r0360["longitude"].stop)
    lat_lo = min(region["latitude"].start, region["latitude"].stop)
    lat_hi = max(region["latitude"].start, region["latitude"].stop)
    return ds1.sel(latitude=slice(lat_lo, lat_hi), longitude=slice(lo, hi))

def _normalize_precip_units(da):
    """unify precip to mm/day."""
    units = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if units in ["m", "meter", "metre", "m of water equivalent"]:
        out = out * 1000.0
        out.attrs["units"] = "mm"
    elif units in ["kg m-2", "kg/m^2", "kg m**-2"]:
        # 1 kg m-2 ‚âà 1 mm
        out.attrs["units"] = "mm"
    # if units empty, assume mm (your datasets are mm already)
    return out

def _coerce_valid_time(da):
    """WeatherBenchX accepts either valid_time or (init_time, lead_time). We'll use valid_time."""
    if "valid_time" in da.dims:
        return da
    if "time" in da.dims:
        return da.rename({"time": "valid_time"})
    # try to decode CF times if needed
    try:
        dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
        if "time" in dec.dims:
            return dec.rename({"time": "valid_time"})
    except Exception:
        pass
    raise ValueError("No recognizable time/valid_time dimension in array.")


In [None]:
# -------- loader widgets --------
fc_path = widgets.Text(
    value=FORECAST_DEFAULT,
    description="Forecast (IMERG/IMD):",
    layout=widgets.Layout(width="95%")
)
era5_path = widgets.Text(
    value=ERA5_DAILY_DEFAULT,
    description="ERA5 daily:",
    layout=widgets.Layout(width="95%")
)
load_btn = widgets.Button(description="Load", button_style="primary")
load_out = widgets.Output()
display(fc_path, era5_path, load_btn, load_out)

forecast_ds = None
era5_daily_ds = None

def _summ(ds):
    vs = ", ".join(list(ds.data_vars)[:6])
    return f"vars: {vs} | sizes: {dict(ds.sizes)}"

# --- Patch: re-launch verification UI after datasets load ---

def _on_load(_):
    global forecast_ds, era5_daily_ds
    with load_out:
        load_out.clear_output()
        try:
            forecast_ds    = _open_any(fc_path.value.strip())
            era5_daily_ds  = _open_any(era5_path.value.strip())
            display(Markdown("‚úÖ **Loaded datasets**"))
            display(Markdown(f"- Forecast ‚Üí `{_summ(forecast_ds)}`"))
            display(Markdown(f"- ERA5 daily ‚Üí `{_summ(era5_daily_ds)}`"))
        except Exception as e:
            display(Markdown(f"‚ùå Load error: `{e}`"))
            return

    # IMPORTANT: (re)launch the verification UI so the callbacks
    # capture the newly loaded datasets instead of the old Nones.
    verify_interactive()

# re-bind the click handler (overwrites the old one)
load_btn.on_click(_on_load)

print("Patch applied. Click 'Load' again, then use 'Run verification'.")

# -------- verification UI (precip only) --------
region_dd = widgets.Dropdown(options=list(REGIONS.keys()), description="Region:", value="Global")
metric_dd = widgets.Dropdown(options=["RMSE", "MAE", "ACC"], description="Metric:", value="RMSE")
var_dd    = widgets.Dropdown(options=[FORECAST_VAR], description="Variable:", value=FORECAST_VAR)

# simple time pickers
t0 = widgets.Text(value="2023-01-01", description="Start (YYYY-MM-DD):")
t1 = widgets.Text(value="2023-01-31", description="End   (YYYY-MM-DD):")

run_btn = widgets.Button(description="Run verification", button_style="success")
out = widgets.Output()
display(region_dd, metric_dd, var_dd, t0, t1, run_btn, out)

def _nearest_regrid_truth_to_forecast(truth_da, forecast_da):
    """Nearest-neighbor ERA5‚Üíforecast grid (both 0.25¬∞, but we still guard)."""
    t = _ensure_lat_ascending(_to_0360(truth_da))
    f = _ensure_lat_ascending(_to_0360(forecast_da))
    return t.interp(latitude=f["latitude"], longitude=f["longitude"], method="nearest")

def _acc_from_stats(stats_ds, var):
    """Compute ACC from WBX outputs (already on time/valid_time)."""
    # WBX returns components for ACC when we ask for ACC metric
    pa  = stats_ds['SquaredPredictionAnomaly'][var].squeeze()
    ta  = stats_ds['SquaredTargetAnomaly'][var].squeeze()
    cov = stats_ds['AnomalyCovariance'][var].squeeze()
    return (cov / np.sqrt(pa * ta))

def verify_interactive():
    if forecast_ds is None or era5_daily_ds is None:
        display(Markdown("")); 
        return

    def on_run(_):
        with out:
            out.clear_output()

            region = REGIONS[region_dd.value]
            var = var_dd.value
            metric = metric_dd.value
            start = np.datetime64(pd.to_datetime(t0.value).date())
            end   = np.datetime64(pd.to_datetime(t1.value).date())

            print(f"Selected region: {region_dd.value}")
            print(f"Selected metric: {metric}")
            print(f"Selected variable: {var}")
            print(f"Time window: {str(start)} ‚Üí {str(end)}")

            # --- slice to region and time; normalize lon/lat conventions ---
            # forecast: IMERG / IMD (already daily)
            if var not in forecast_ds:
                print(f"‚ùå '{var}' not found in forecast dataset vars: {list(forecast_ds.data_vars)[:6]} ‚Ä¶")
                return
            f = _apply_region_safe(forecast_ds[[var]], region)[var].sel(time=slice(start, end))
            if f.sizes.get("time", 0) == 0:
                print("‚ö†Ô∏è Region/time window has no points on the forecast grid."); return
            f = _normalize_precip_units(f)
            f = _coerce_valid_time(f)

            # ERA5 daily truth (same variable name)
            if var not in era5_daily_ds:
                print(f"‚ùå '{var}' not found in ERA5 dataset vars: {list(era5_daily_ds.data_vars)[:6]} ‚Ä¶")
                return
            print("Preparing ERA5 daily truth...")
            t = _apply_region_safe(era5_daily_ds[[var]], region)[var].sel(time=slice(start, end))
            if t.sizes.get("time", 0) == 0:
                print("‚ö†Ô∏è Region/time window has no points in ERA5."); return
            t = _normalize_precip_units(t)
            t = _coerce_valid_time(t)

            # regrid ERA5 ‚Üí forecast grid (guard if lat/lon differ)
            t_on_f = _nearest_regrid_truth_to_forecast(t, f)

            # intersect time just in case
            tmin = max(np.datetime64(f["valid_time"].min().values), np.datetime64(t_on_f["valid_time"].min().values))
            tmax = min(np.datetime64(f["valid_time"].max().values), np.datetime64(t_on_f["valid_time"].max().values))
            f_use = f.sel(valid_time=slice(tmin, tmax))
            t_use = t_on_f.sel(valid_time=slice(tmin, tmax))

            if f_use.sizes.get("valid_time", 0) == 0:
                print("‚ö†Ô∏è No overlapping times between forecast and ERA5."); return

            # WeatherBenchX expects datasets; give it valid_time dimension
            preds = xr.Dataset({var: f_use})
            targs = xr.Dataset({var: t_use})

            # Metrics (ACC needs a DOY climatology on the same grid)
            if metric == "ACC":
                print("Building daily day-of-year climatology for ACC‚Ä¶")
                clim_doy = (
                    t_use.groupby("valid_time.dayofyear")
                         .mean("valid_time", keep_attrs=True)
                         .rename("clim")
                )
                # WBX requires a dataset where variable name matches:
                clim = xr.Dataset({var: clim_doy})
                metrics = {"acc": deterministic.ACC(climatology=clim)}
            elif metric == "RMSE":
                metrics = {"rmse": deterministic.RMSE()}
            else:  # MAE
                metrics = {"mae": deterministic.MAE()}

            print("Computing statistics‚Ä¶")
            stats = metrics_base.compute_unique_statistics_for_all_metrics(metrics, preds, targs)
            print("Done ‚úÖ")

            # reduce to a time series (area-mean over region)
            if metric == "RMSE":
                se = stats["SquaredError"][var]
                series = np.sqrt(se).mean(dim=[d for d in ["latitude","longitude"] if d in se.dims])
                ylabel = "RMSE (mm/day)"
            elif metric == "MAE":
                ae = stats["AbsoluteError"][var]
                series = ae.mean(dim=[d for d in ["latitude","longitude"] if d in ae.dims])
                ylabel = "MAE (mm/day)"
            else:
                acc = _acc_from_stats(stats, var)
                series = acc.mean(dim=[d for d in ["latitude","longitude"] if d in acc.dims])
                ylabel = "ACC"
            
            # report number of valid days
            print(f"{int(series.sizes.get('valid_time', 0))} valid days in window.")

            # plot
            fig, ax = plt.subplots(figsize=(8,4))
            series.rename({"valid_time": "time"}).plot(ax=ax)
            ax.set_title(f"{metric} ‚Äî {('tp_24h' if var=='total_precipitation_24hr' else var)} ‚Äî {region_dd.value}")
            ax.set_ylabel(ylabel)
            ax.grid(True, alpha=0.3)
            plt.show()

    run_btn.on_click(on_run)

# launch interactive
verify_interactive()


Text(value='gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr', description='Forecast (IMERG/I‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_24hr.zarr', description='ERA5 daily:', layout=Layout(‚Ä¶

Button(button_style='primary', description='Load', style=ButtonStyle())

Output()

Patch applied. Click 'Load' again, then use 'Run verification'.


Dropdown(description='Region:', options=('Global', 'Ethiopia', 'Nigeria', 'Kenya', 'Bangladesh', 'Chile'), val‚Ä¶

Dropdown(description='Metric:', options=('RMSE', 'MAE', 'ACC'), value='RMSE')

Dropdown(description='Variable:', options=('total_precipitation_24hr',), value='total_precipitation_24hr')

Text(value='2023-01-01', description='Start (YYYY-MM-DD):')

Text(value='2023-01-31', description='End   (YYYY-MM-DD):')

Button(button_style='success', description='Run verification', style=ButtonStyle())

Output()

‚û°Ô∏è Load the datasets above first.

In [None]:
fc_path = widgets.Text(
    value=FORECAST_DEFAULT,
    description="Forecast Zarr:",
    layout=widgets.Layout(width="95%")
)
era5_t_path = widgets.Text(
    value=ERA5_TEMP_DAILY_DEFAULT,
    description="ERA5 temp:",
    layout=widgets.Layout(width="95%")
)
era5_p_path = widgets.Text(
    value=ERA5_PRECIP_DAILY_DEFAULT,
    description="ERA5 precip:",
    layout=widgets.Layout(width="95%")
)
load_btn = widgets.Button(description="Load datasets", button_style="primary")
load_out = widgets.Output()

display(fc_path, era5_t_path, era5_p_path, load_btn, load_out)

forecast_ds = None
era5_temp_ds = None
era5_precip_ds = None

def _open_any(path: str):
    if path.endswith(".zarr"):
        return xr.open_zarr(path)
    return xr.open_dataset(path)

def _summ(ds):
    vs = ", ".join([v for v in ds.data_vars][:6])
    return f"vars: {vs} ... | dims: {dict(ds.dims)}"

def _load_all(_):
    global forecast_ds, era5_temp_ds, era5_precip_ds
    with load_out:
        load_out.clear_output()
        try:
            forecast_ds   = _open_any(fc_path.value.strip())
            era5_temp_ds  = _open_any(era5_t_path.value.strip())
            era5_precip_ds= _open_any(era5_p_path.value.strip())
            display(Markdown("‚úÖ Loaded:\n"
                             f"- Forecast ‚Üí `{_summ(forecast_ds)}`\n"
                             f"- ERA5 temp daily ‚Üí `{_summ(era5_temp_ds)}`\n"
                             f"- ERA5 precip daily ‚Üí `{_summ(era5_precip_ds)}`"))
        except Exception as e:
            display(Markdown(f"‚ùå Load error: `{e}`"))

load_btn.on_click(_load_all)


In [None]:
# --- VERIFY INTERACTIVE (daily ERA5 <-> your daily forecast) ---
import numpy as np
import pandas as pd
import xarray as xr
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base

# Regions (same as before)
DOMAIN_DEFINITIONS = {
    "Global": {"latitude": slice(90, -90), "longitude": slice(0, 360)},
    "Ethiopia": {"latitude": slice(14.9, 3.4), "longitude": slice(33.0, 48.0)},
    "Nigeria": {"latitude": slice(14.7, 4.0), "longitude": slice(2.7, 14.7)},
    "Kenya": {"latitude": slice(5.0, -4.7), "longitude": slice(33.9, 41.9)},
    "Bangladesh": {"latitude": slice(26.7, 20.7), "longitude": slice(88.0, 92.7)},
    "Chile": {"latitude": slice(-17.5, -56.0), "longitude": slice(284.0, 294.0)},
}

DAILY_ALLOWED = ["tmin", "tmax", "tavg", "total_precipitation_24hr"]

# ---- helpers: coords/units/time ----
def _has_latlon(da): return "latitude" in da.coords and "longitude" in da.coords

def _to0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        if float(lon.min()) < 0:
            obj = obj.assign_coords(longitude=(lon % 360)).sortby("longitude")
    return obj

def _lat_ascending(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _norm_grid(obj):
    return _lat_ascending(_to0360(obj))

def _slice_region_safe(da_or_ds, region):
    """Slice even if dataset uses different lon convention; we convert to 0..360 and slice."""
    if not _has_latlon(da_or_ds):
        return da_or_ds
    ds = _norm_grid(da_or_ds)
    # 0..360 region bounds
    a = float(region["longitude"].start); b = float(region["longitude"].stop)
    lo_lon, hi_lon = (a % 360, b % 360)
    lo_lon, hi_lon = (min(lo_lon, hi_lon), max(lo_lon, hi_lon))
    # latitude bounds (ascending)
    la = float(region["latitude"].start); lb = float(region["latitude"].stop)
    lo_lat, hi_lat = (min(la, lb), max(la, lb))
    return ds.sel(latitude=slice(lo_lat, hi_lat), longitude=slice(lo_lon, hi_lon))

def _coerce_time(da: xr.DataArray) -> xr.DataArray:
    for k in ["time", "date", "day", "valid_time"]:
        if k in da.dims:
            return da.rename({k: "time"})
    # last resort: try decode_cf
    dec = xr.decode_cf(da.to_dataset(name="_tmp"), use_cftime=False).to_array("_tmp")
    for k in ["time", "date", "day", "valid_time"]:
        if k in dec.dims:
            return dec.rename({k: "time"})
    raise ValueError("Daily variable is missing a recognizable time dimension.")

def _normalize_units(da: xr.DataArray, name: str) -> xr.DataArray:
    units = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if name in {"tmin", "tmax", "tavg"}:
        if units in ["c", "degc", "celsius", "¬∞c"]:
            out = out + 273.15
        out.attrs["units"] = "K"
    elif name == "total_precipitation_24hr":
        # Expect mm. If meters, convert.
        if units in ["m", "meter", "metre", "m of water equivalent"]:
            out = out * 1000.0
        out.attrs["units"] = "mm"
    return out

def _match_lon_system(src: xr.DataArray, tgt: xr.DataArray) -> xr.DataArray:
    """Return 'src' with longitudes converted to match 'tgt' convention."""
    if "longitude" not in src.coords or "longitude" not in tgt.coords:
        return src
    lon_s = src.longitude.values
    lon_t = tgt.longitude.values
    if np.nanmin(lon_s) < 0 and np.nanmin(lon_t) >= 0:
        src = src.assign_coords(longitude=(src.longitude % 360))
    elif np.nanmin(lon_s) >= 0 and np.nanmin(lon_t) < 0:
        src = src.assign_coords(longitude=((src.longitude + 180) % 360) - 180)
    return src.sortby("longitude")

def _align_truth_to_forecast_grid(truth_da: xr.DataArray, forecast_da: xr.DataArray) -> xr.DataArray:
    t = _lat_ascending(_to0360(truth_da))
    f = _lat_ascending(_to0360(forecast_da))
    # ensure same lon system
    t = _match_lon_system(t, f)
    # nearest-neighbour regrid (fast, avoids xesmf dependency)
    return t.interp(latitude=f.latitude, longitude=f.longitude, method="nearest")

def _forecast_extent_region(f_ds) -> dict:
    """Use the forecast native extent as a 'region'."""
    lat = np.array(f_ds.latitude.values)
    lon = np.array((_to0360(f_ds).longitude.values))
    return {
        "latitude": slice(float(lat.min()), float(lat.max())),
        "longitude": slice(float(lon.min()), float(lon.max()))
    }

# ---- UI + engine ----
def verify_interactive(forecast_ds, era5_temp_ds, era5_precip_ds):
    # only show vars that actually exist in forecast
    options = [v for v in DAILY_ALLOWED if v in forecast_ds.data_vars]
    if not options:
        display(Markdown("‚ùå No daily variables found in your forecast (expected any of "
                         "`tmin`, `tmax`, `tavg`, `total_precipitation_24hr`)."))
        return

    region_dd  = widgets.Dropdown(options=list(DOMAIN_DEFINITIONS.keys()),
                                  description="Region:", value="Global")
    metric_dd  = widgets.Dropdown(options=["RMSE","MAE","ACC"],
                                  description="Metric:", value="RMSE")
    var_dd     = widgets.Dropdown(options=options, description="Variable:", value=options[0])
    days_box   = widgets.IntSlider(description="Days (latest):", value=30, min=7, max=365, step=1)
    run_btn    = widgets.Button(description="Run verification", button_style="success")
    out        = widgets.Output()

    display(region_dd, metric_dd, var_dd, days_box, run_btn, out)

    def on_run(_):
        with out:
            out.clear_output()
            var    = var_dd.value
            metric = metric_dd.value
            region_name = region_dd.value
            print(f"Selected region: {region_name}")
            print(f"Selected metric: {metric}")
            print(f"Selected variable: {var}")
            print("Preparing ERA5 daily truth...")

            # Forecast slice (region + last N days)
            f_da = forecast_ds[var]
            if not _has_latlon(f_da):
                display(Markdown("‚ùå Forecast variable has no `latitude/longitude` coordinates.")); return
            f_da = _coerce_time(f_da)

            # Region logic
            if region_name == "Global":
                # Use forecast native extent to avoid huge loads
                region = _forecast_extent_region(forecast_ds)
            else:
                region = DOMAIN_DEFINITIONS[region_name]

            f_da = _slice_region_safe(f_da, region)
            if f_da.sizes.get("latitude",0) == 0 or f_da.sizes.get("longitude",0) == 0:
                display(Markdown("‚ö†Ô∏è Region has no points on the forecast grid.")); return

            # restrict to the latest N days present in the forecast
            tmax_f = pd.to_datetime(f_da.time.max().item())
            tmin_f = tmax_f - pd.Timedelta(days=int(days_box.value)-1)
            f_da = f_da.sel(time=slice(tmin_f, tmax_f))

            # Units
            f_da = _normalize_units(f_da, var)

            # ERA5 "truth" (already DAILY in your stores)
            if var in ("tmin","tmax","tavg"):
                t_da = era5_temp_ds[var]
            else:
                t_da = era5_precip_ds["total_precipitation_24hr"]

            t_da = _slice_region_safe(t_da, region).sel(time=slice(tmin_f, tmax_f))
            if t_da.sizes.get("latitude",0) == 0 or t_da.sizes.get("longitude",0) == 0:
                display(Markdown("‚ö†Ô∏è Region has no points in the ERA5 dataset.")); return

            # grid normalisation & regrid ERA5 -> forecast grid
            f_da = _norm_grid(f_da)
            t_da = _norm_grid(t_da)
            t_da_on_f = _align_truth_to_forecast_grid(t_da, f_da)

            # Intersect times safely
            t0 = max(np.datetime64(f_da.time.min().values), np.datetime64(t_da_on_f.time.min().values))
            t1 = min(np.datetime64(f_da.time.max().values), np.datetime64(t_da_on_f.time.max().values))
            f_da = f_da.sel(time=slice(t0, t1))
            t_da_on_f = t_da_on_f.sel(time=slice(t0, t1))
            if f_da.sizes.get("time",0) == 0:
                display(Markdown("‚ö†Ô∏è No overlapping time between forecast and ERA5 in that window.")); return

            # Build predictions/targets for weatherbenchX
            # (ACC expects 'valid_time' or init/lead; we use valid_time)
            pred = xr.Dataset({var: f_da.rename({"time":"valid_time"})})
            targ = xr.Dataset({var: t_da_on_f.rename({"time":"valid_time"})})

            # Metrics
            if metric == "ACC":
                # Day-of-year climatology from ERA5 daily (same region, same grid), across ALL years
                # Slice ERA5 full time for this region, then regrid to forecast grid and compute DOY mean.
                if var in ("tmin","tmax","tavg"):
                    t_full = _slice_region_safe(era5_temp_ds[var], region)
                else:
                    t_full = _slice_region_safe(era5_precip_ds["total_precipitation_24hr"], region)
                t_full = _norm_grid(t_full)
                t_full_on_f = _align_truth_to_forecast_grid(t_full, f_da)
                clim_doy = t_full_on_f.groupby("time.dayofyear").mean("time")  # (dayofyear, lat, lon)
                clim = xr.Dataset({var: clim_doy})
                metrics = {"acc": deterministic.ACC(climatology=clim)}
            elif metric == "RMSE":
                metrics = {"rmse": deterministic.RMSE()}
            else:  # MAE
                metrics = {"mae": deterministic.MAE()}

            # Compute
            stats = metrics_base.compute_unique_statistics_for_all_metrics(metrics, pred, targ)

            # Plot
            if "rmse" in metrics:
                se = stats["SquaredError"][var]
                series = np.sqrt(se).mean(dim=[d for d in ["latitude","longitude"] if d in se.dims])
                ylabel = "RMSE"
            elif "mae" in metrics:
                ae = stats["AbsoluteError"][var]
                series = ae.mean(dim=[d for d in ["latitude","longitude"] if d in ae.dims])
                ylabel = "MAE"
            else:  # ACC
                pa  = stats["SquaredPredictionAnomaly"][var]
                ta  = stats["SquaredTargetAnomaly"][var]
                cov = stats["AnomalyCovariance"][var]
                series = (cov/np.sqrt(pa*ta)).mean(dim=[d for d in ["latitude","longitude"] if d in cov.dims])
                ylabel = "ACC"

            fig, ax = plt.subplots(figsize=(8,4))
            series.to_series().plot(ax=ax)
            ax.set_title(f"{metric} ‚Äî {var} ({region_name})  [{pd.to_datetime(str(t0)).date()} ‚Ä¶ {pd.to_datetime(str(t1)).date()}]")
            ax.set_ylabel(ylabel)
            ax.grid(True, alpha=0.3)
            plt.show()

    run_btn.on_click(on_run)


In [None]:
def load_clim():
    var_map = {
        "2t": "2m_temperature",
        "z_500": "geopotential",
        "tp": "total_precipitation_6hr",
    }
    clim = xr.open_zarr(
        "gs://weatherbench2/datasets/era5-hourly-climatology/1990-2019_6h_1440x721.zarr",
        consolidated=True
    )
    clim_var_map = {v: k for k, v in var_map.items()}
    clim = clim.rename_vars(clim_var_map)
    clim = clim[list(var_map.keys())]
    clim["z_500"] = clim["z_500"].sel(level=500).drop_vars("level")
    return clim


In [None]:
def compute_statistics_in_chunks(forecast_ds, era5_ds, metrics, fvars, chunk_size=48):
    """Chunked compute, but only for the selected variables (fvars)."""
    n = forecast_ds.sizes["init_time"]
    parts = []
    for i in range(0, n, chunk_size):
        sl = slice(i, min(i + chunk_size, n))
        # ‚¨áÔ∏è subset to the exact variables on BOTH sides
        f_chunk = forecast_ds[fvars].isel(init_time=sl)
        t_chunk = era5_ds[fvars].isel(time=sl)
        stats = metrics_base.compute_unique_statistics_for_all_metrics(metrics, f_chunk, t_chunk)
        parts.append(stats)
    final = {}
    for st in parts:
        for k, v in st.items():
            final[k] = xr.concat([final[k], v], dim="time") if k in final else v
    return final

def run_verification(forecast_ds, era5_ds, metric_name, fvars):
    """Compute metrics on the FULL grid (no region slicing here)."""
    if metric_name == "RMSE":
        mets = {"rmse": deterministic.RMSE()}
    else:
        clim = load_climatology_for_acc(fvars)
        mets = {"ACC": deterministic.ACC(climatology=clim)}
    return compute_statistics_in_chunks(forecast_ds, era5_ds, mets, fvars=fvars, chunk_size=48)


In [None]:
# --- Cell 7: WB-X RMSE/ACC with safe reduction & plotting ---

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown

def _lat_slice_for(da: xr.DataArray, box):
    lat = da["latitude"]
    a, b = box["latitude"].start, box["latitude"].stop
    return slice(min(a,b), max(a,b)) if lat[0] < lat[-1] else slice(max(a,b), min(a,b))

def get_lead_coord(da: xr.DataArray):
    if "lead_time" in da.dims:
        return "lead_time"
    if "step" in da.dims:
        return "step"
    raise ValueError(f"No lead dimension in {da.dims}")

def to_1d_over_lead(da: xr.DataArray) -> xr.DataArray:
    # average away time/init_time & any lat/lon left, then squeeze to 1D lead
    for d in ("time", "init_time"):
        if d in da.dims:
            da = da.mean(dim=d, keep_attrs=True, skipna=True)
    for d in ("latitude", "longitude"):
        if d in da.dims:
            da = da.mean(dim=d, keep_attrs=True, skipna=True)
    lead = get_lead_coord(da)
    return da.transpose(lead).squeeze()

def lead_days_from(da_1d: xr.DataArray) -> np.ndarray:
    lead = get_lead_coord(da_1d)
    lt = np.asarray(da_1d[lead].values)
    # support timedelta64 or ints (hours)
    if np.issubdtype(lt.dtype, np.timedelta64):
        x = (lt / np.timedelta64(1, "D")).astype(float)
    else:
        # assume hours
        x = (lt.astype(float) / 24.0)
    return x.reshape(-1)

if forecast_ds is None or era5_ds is None:
    display(Markdown("‚ÑπÔ∏è Load forecast (Cell 4) and ERA5 truth (Cell 5) first."))
else:
    region_selector = widgets.Dropdown(options=list(DOMAIN_DEFINITIONS.keys()),
                                       value="Global", description="Region:")
    metric_selector  = widgets.Dropdown(options=["RMSE","ACC"],
                                        value="ACC", description="Metric:")
    available = [v for v in vars_for_this_run if v in forecast_ds.data_vars and v in era5_ds.data_vars]
    if not available:
        available = [v for v in ["2t","z_500","tp"] if v in forecast_ds.data_vars and v in era5_ds.data_vars]
    variable_selector = widgets.Dropdown(options=available, value=available[0], description="Variable:")
    run_btn = widgets.Button(description="Run Verification", button_style="success")
    out = widgets.Output()
    display(region_selector, metric_selector, variable_selector, run_btn, out)

    def on_run(_):
        out.clear_output(wait=True)
        with out:
            try:
                region_name = region_selector.value
                box         = DOMAIN_DEFINITIONS[region_name]
                metric      = metric_selector.value
                var         = variable_selector.value

                # 1) WB-X metrics on full grid (no pre-slicing)
                stats = run_verification(forecast_ds, era5_ds, metric, [var])

                # 2) Region aggregation AFTER metric computation
                if metric == "RMSE":
                    se = stats["SquaredError"][var]  # (time|init_time, lead, lat, lon)
                    se_reg = se.sel(
                        latitude=_lat_slice_for(se, box),
                        longitude=box["longitude"]
                    ).mean(dim=["latitude","longitude"], skipna=True)
                    rmse = np.sqrt(se_reg)                 # (time|init_time, lead)
                    line = to_1d_over_lead(rmse)           # -> 1D over lead
                    ylab, title = "RMSE", f"RMSE ‚Äî {var} ‚Äî {region_name}"
                else:
                    pa = stats["SquaredPredictionAnomaly"][var]
                    ta = stats["SquaredTargetAnomaly"][var]
                    co = stats["AnomalyCovariance"][var]
                    lat_sl = _lat_slice_for(pa, box)
                    pa = pa.sel(latitude=lat_sl, longitude=box["longitude"]).mean(dim=["latitude","longitude"], skipna=True)
                    ta = ta.sel(latitude=lat_sl, longitude=box["longitude"]).mean(dim=["latitude","longitude"], skipna=True)
                    co = co.sel(latitude=lat_sl, longitude=box["longitude"]).mean(dim=["latitude","longitude"], skipna=True)
                    acc  = co / np.sqrt(pa * ta)           # (time|init_time, lead)
                    line = to_1d_over_lead(acc)            # -> 1D over lead
                    ylab, title = "ACC",  f"ACC ‚Äî {var} ‚Äî {region_name}"

                # 3) Build x, y and mask non-finite values
                x = lead_days_from(line)                              # (n_leads,)
                y = np.asarray(line.values, dtype=float).reshape(-1)  # (n_leads,)
                m = np.isfinite(x) & np.isfinite(y)
                x_plot, y_plot = x[m], y[m]

                fig, ax = plt.subplots(figsize=(8,4))
                if x_plot.size > 0:
                    ax.plot(x_plot, y_plot)
                else:
                    ax.text(0.5, 0.5, "No finite values to plot", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(title)
                ax.set_xlabel("Forecast time (days)")
                ax.set_ylabel(ylab)
                ax.grid(True, alpha=0.3)
                plt.show()

                # Optional quick debug readout
                print(f"lead len: {x.size}, finite points: {x_plot.size}, y range: {np.nanmin(y):.3g}..{np.nanmax(y):.3g}")

                display(Markdown("‚úÖ Done"))
            except Exception as e:
                display(Markdown(f"‚ùå Error: `{e}`"))

    run_btn.on_click(on_run)
