
# üß™ Demo 5: Ground Truth Challenge ‚Äî Validate Your Forecast with Your Data (Demo 4‚Äìstyle)

This notebook mirrors **Demo 4** (widgets, regions, WeatherBench‚ÄëX), and adds the ability to plug in **your own ground truth**:
- **Default**: ERA5 from Google Cloud Storage (no downloads)
- **Optional**: Point **CSV** (`time,lat,lon,value`) with time tolerance
- **Optional**: **Gridded** NetCDF/Zarr file (local path or `gs://‚Ä¶`)

You can compare **Global vs Local** regions and compute **RMSE** (and **ACC** if GT supports climatology).

> Tip: start with ERA5 as truth; once it runs, try your CSV/NetCDF.


In [2]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output
import fsspec, gcsfs, os, warnings
warnings.filterwarnings("ignore")

# WeatherBench-X
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base
from weatherbenchX import aggregation
from weatherbenchX import time_chunks

ERA5_PATH = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
CLIM_PATH = "gs://weatherbench2/datasets/era5-hourly-climatology/1990-2019_6h_1440x721.zarr"
G = 9.80665  # m s^-2

# Regions (same spirit as Demo-4)
DOMAIN_DEFINITIONS = {
    "Global": {"latitude": slice(90, -90), "longitude": slice(0, 360)},
    "Northern Hemisphere": {"latitude": slice(90, 0), "longitude": slice(0, 360)},
    "Tropics": {"latitude": slice(23.5, -23.5), "longitude": slice(0, 360)},
    "Bangladesh": {"latitude": slice(26.7, 20.7), "longitude": slice(88.0, 92.7)},
    "Chile": {"latitude": slice(-17.5, -56.0), "longitude": slice(284.0, 294.0)},
    "Nigeria": {"latitude": slice(14.7, 4.0), "longitude": slice(2.7, 14.7)},
    "Ethiopia": {"latitude": slice(14.9, 3.4), "longitude": slice(33.0, 48.0)},
    "Kenya": {"latitude": slice(5.0, -4.7), "longitude": slice(33.9, 41.9)},
}


In [3]:
def ensure_lon360(ds: xr.Dataset) -> xr.Dataset:
    if float(ds.longitude.min()) < 0 or float(ds.longitude.max()) <= 180:
        ds = ds.assign_coords(longitude=(ds.longitude % 360))
    return ds.sortby(["latitude", "longitude"])

def ensure_ecmwf_names(ds: xr.Dataset) -> xr.Dataset:
    ren = {}
    if "lat" in ds.coords: ren["lat"] = "latitude"
    if "lon" in ds.coords: ren["lon"] = "longitude"
    if ren: ds = ds.rename(ren)
    return ensure_lon360(ds)

def ensure_lead_time(ds: xr.Dataset) -> xr.Dataset:
    """
    Guarantee dims ('init_time','lead_time','latitude','longitude') from ('time'|'init_time','step',...).
    Uses ndarray values to avoid the 'DataArray is ambiguous' error.
    """
    if "init_time" not in ds.coords and "time" in ds.coords:
        ds = ds.rename({"time": "init_time"})
    if "lead_time" not in ds.dims:
        if "step" not in ds.dims:
            raise ValueError("Dataset has neither 'lead_time' nor 'step' dimension.")
        step = ds["step"]
        if np.issubdtype(step.dtype, np.number):
            lt_vals = pd.to_timedelta(step.values, unit="h").astype("timedelta64[ns]")
        else:
            # already timedelta-like
            lt_vals = step.values.astype("timedelta64[ns]")
        ds = ds.assign_coords(lead_time=("step", lt_vals))   # <- ndarray, not DataArray
        ds = ds.swap_dims({"step": "lead_time"})
    return ds

def open_era5_subset(needed_vars):
    # Open ERA5 from GCS lazily; select only needed vars
    ds = xr.open_zarr(ERA5_PATH, chunks=None, storage_options={"token":"anon"})
    keep = [v for v in needed_vars if v in ds]
    if not keep:
        raise KeyError(f"None of {needed_vars} exist in ERA5 store.")
    return ds[keep]

def build_era5_truth(forecast_ds: xr.Dataset, fvars):
    # Map AIFS names -> ERA5 names
    var_map = {"2t":"2m_temperature", "tp":"total_precipitation", "z_500":"geopotential"}
    era5_vars = [var_map.get(v, v) for v in fvars]
    full = open_era5_subset(era5_vars)

    # Build truth dataset with AIFS-style var names
    gt = xr.Dataset()
    if "2m_temperature" in full: gt["2t"] = full["2m_temperature"]
    if "total_precipitation" in full: gt["tp"] = full["total_precipitation"]
    if "geopotential" in full: gt["z_500"] = full["geopotential"].sel(level=500).drop_vars("level") / G

    # Slice time window to forecast range
    init = forecast_ds["init_time"].min().values
    end  = (init + forecast_ds["lead_time"].max().astype("timedelta64[ns]")).astype("datetime64[ns]")
    return gt.sel(time=slice(init, end))

def load_climatology_for_acc(fvars):
    # WB-X climatology zarr; rename to AIFS var names
    raw = xr.open_zarr(CLIM_PATH, storage_options={"token":"anon"})
    ren = {}
    if "2m_temperature" in raw: ren["2m_temperature"] = "2t"
    if "geopotential" in raw:   ren["geopotential"]    = "z_500"
    if "total_precipitation_6hr" in raw: ren["total_precipitation_6hr"] = "tp"
    clim = raw.rename_vars(ren)
    keep = [v for v in fvars if v in clim]
    clim = clim[keep]
    if "z_500" in clim and "level" in clim["z_500"].dims:
        clim["z_500"] = clim["z_500"].sel(level=500).drop_vars("level")
    return clim


In [None]:
# Path to your Demo-2 forecast
forecast_vars = ['2t', 'z_500', 'tp']

# Mapping from forecast variable names ‚Üí ERA5 names
var_map = {
    "2t": "2m_temperature",
    "tp": "total_precipitation",
    "z_500": "geopotential",
}
forecast_path = widgets.Text(
    description='Forecast file:',
    value="init_ERA5_20230630T00_lead_360.nc",  # <‚Äî change if needed
    layout=widgets.Layout(width='80%')
)

load_button = widgets.Button(description="Load Forecast", button_style="primary")
load_out = widgets.Output()

display(forecast_path, load_button, load_out)

forecast_ds = None

def on_load(_):
    global forecast_ds
    load_out.clear_output()
    with load_out:
        try:
            ds = xr.open_dataset(forecast_path.value)
            ds = ensure_ecmwf_names(ds)
            ds = ensure_lead_time(ds)
            forecast_ds = ds
            display(Markdown("‚úÖ Forecast loaded & normalized."))
            display(forecast_ds)
        except Exception as e:
            display(Markdown(f"‚ùå Error loading forecast: `{e}`"))

load_button.on_click(on_load)


Text(value='init_ERA5_20230630T00_lead_360.nc', description='Forecast file:', layout=Layout(width='80%'))

Button(button_style='primary', description='Load Forecast', style=ButtonStyle())

Output()

In [6]:
# --- Cell 5 (REPLACE ME) ---
# Build ERA5 truth aligned to forecast, but only for the variables you will verify.
# This keeps memory low and avoids kernel crashes.

era5_out = widgets.Output()
display(era5_out)

# ‚¨áÔ∏è choose which variables to pull NOW (start with just ["2t"]; add "z_500" or "tp" later)
vars_for_this_run = ["2t"]   # <<< IMPORTANT: keep only "2t" first; then try ["z_500"], then ["tp"]

# Map AIFS ‚Üí ERA5
VAR_MAP = {"2t": "2m_temperature", "z_500": "geopotential", "tp": "total_precipitation"}

def open_era5_subset_min(forecast_ds: xr.Dataset, wanted: list[str]) -> xr.Dataset:
    # 1) compute time window from forecast (do not touch data yet)
    init = forecast_ds["init_time"].min().values
    end  = (init + forecast_ds["lead_time"].max().astype("timedelta64[ns]")).astype("datetime64[ns]")

    # 2) open ERA5 lazily, select only needed variables
    keep = []
    ds_all = xr.open_zarr(ERA5_PATH, chunks={"time": 24}, storage_options={"token": "anon"})
    for v in wanted:
        evar = VAR_MAP.get(v, v)
        if evar in ds_all:
            keep.append(evar)

    if not keep:
        raise KeyError(f"No matching ERA5 variables for {wanted}")

    # 3) light time slice first (still lazy)
    ds = ds_all[keep].sel(time=slice(init, end))

    # 4) build GT in AIFS names, with minimal compute
    gt = xr.Dataset()
    if "2m_temperature" in ds and "2t" in wanted:
        gt["2t"] = ds["2m_temperature"]

    if "geopotential" in ds and "z_500" in wanted:
        # pick level=500 lazily; divide by g (still lazy)
        gt["z_500"] = (ds["geopotential"].sel(level=500).drop_vars("level")) / G

    if "total_precipitation" in ds and "tp" in wanted:
        # ERA5 hourly accumulations ‚Üí 6-hour sums.
        # Do rolling AFTER the time slice to keep it small.
        tp1h = ds["total_precipitation"]
        gt["tp"] = tp1h.rolling(time=6, min_periods=6).sum()

    return gt

era5_ds = None
with era5_out:
    try:
        if forecast_ds is None:
            raise RuntimeError("Load the forecast first (previous cell).")

        # Safety: check the forecast actually contains the variable(s) we‚Äôll verify later
        missing = [v for v in vars_for_this_run if v not in forecast_ds.data_vars]
        if missing:
            display(Markdown(f"‚ö†Ô∏è Forecast does not have {missing}. I will pull only the overlap."))

        # open ERA5 for ONLY the vars we‚Äôll verify now
        era5_ds = open_era5_subset_min(forecast_ds, vars_for_this_run)
        display(Markdown(f"‚úÖ ERA5 truth opened for {list(era5_ds.data_vars)} and time-aligned."))
        display(era5_ds)
    except Exception as e:
        display(Markdown(f"‚ùå ERA5 load error: `{e}`"))


Output()

In [13]:
def compute_statistics_in_chunks(forecast_ds, era5_ds, metrics, fvars, chunk_size=48):
    """Chunked compute, but only for the selected variables (fvars)."""
    n = forecast_ds.sizes["init_time"]
    parts = []
    for i in range(0, n, chunk_size):
        sl = slice(i, min(i + chunk_size, n))
        # ‚¨áÔ∏è subset to the exact variables on BOTH sides
        f_chunk = forecast_ds[fvars].isel(init_time=sl)
        t_chunk = era5_ds[fvars].isel(time=sl)
        stats = metrics_base.compute_unique_statistics_for_all_metrics(metrics, f_chunk, t_chunk)
        parts.append(stats)
    final = {}
    for st in parts:
        for k, v in st.items():
            final[k] = xr.concat([final[k], v], dim="time") if k in final else v
    return final

def run_verification(forecast_ds, era5_ds, metric_name, fvars):
    """Compute metrics on the FULL grid (no region slicing here)."""
    if metric_name == "RMSE":
        mets = {"rmse": deterministic.RMSE()}
    else:
        clim = load_climatology_for_acc(fvars)
        mets = {"ACC": deterministic.ACC(climatology=clim)}
    return compute_statistics_in_chunks(forecast_ds, era5_ds, mets, fvars=fvars, chunk_size=48)


In [None]:
if forecast_ds is None or era5_ds is None:
    display(Markdown("‚ÑπÔ∏è Load forecast (Cell 4) and ERA5 truth (Cell 5) first."))
else:
    region_selector = widgets.Dropdown(options=list(DOMAIN_DEFINITIONS.keys()),
                                       value="Global", description="Region:")
    metric_selector = widgets.Dropdown(options=["RMSE","ACC"],
                                       value="RMSE", description="Metric:")
    available = [v for v in vars_for_this_run if v in forecast_ds.data_vars and v in era5_ds.data_vars]
    if not available:
        available = [v for v in forecast_ds.data_vars if v in ["2t","z_500","tp"] and v in era5_ds.data_vars]
    variable_selector = widgets.Dropdown(options=available, value=available[0], description="Variable:")
    run_btn = widgets.Button(description="Run Verification", button_style="success")
    out = widgets.Output()
    display(region_selector, metric_selector, variable_selector, run_btn, out)

    def on_run(_):
        out.clear_output(wait=True)
        with out:
            try:
                region = DOMAIN_DEFINITIONS[region_selector.value]
                metric = metric_selector.value
                var    = variable_selector.value

                # 1) Compute metrics on full grid
                stats = run_verification(forecast_ds, era5_ds, metric, [var])

                # 2) Region slicing done HERE (after metrics), then spatial mean
                box = region
                if metric == "RMSE":
                    se = stats["SquaredError"][var]
                    lat_slice = _lat_slice_for(se, box)
                    se_reg = se.sel(latitude=lat_slice, longitude=box["longitude"]).mean(dim=["latitude","longitude"])
                    rmse = np.sqrt(se_reg).assign_coords(lead_time_days = se_reg['lead_time'] / np.timedelta64(1,'D'))
                    ax = rmse.plot(x='lead_time_days', figsize=(8,4))
                    ax.set_title(f"RMSE ‚Äî {var} ‚Äî {region_selector.value}")
                    ax.set_xlabel("Forecast time (days)"); ax.set_ylabel("RMSE")
                    plt.show()
                else:
                    pa = stats['SquaredPredictionAnomaly'][var]
                    ta = stats['SquaredTargetAnomaly'][var]
                    co = stats['AnomalyCovariance'][var]

                    lat_slice = _lat_slice_for(pa, box)
                    pa = pa.sel(latitude=lat_slice, longitude=box["longitude"]).mean(dim=["latitude","longitude"])
                    ta = ta.sel(latitude=lat_slice, longitude=box["longitude"]).mean(dim=["latitude","longitude"])
                    co = co.sel(latitude=lat_slice, longitude=box["longitude"]).mean(dim=["latitude","longitude"])

                    acc = (co / np.sqrt(pa * ta)).assign_coords(lead_time_days = pa['lead_time'] / np.timedelta64(1,'D'))
                    ax = acc.plot(x='lead_time_days', figsize=(8,4))
                    ax.set_title(f"ACC ‚Äî {var} ‚Äî {region_selector.value}")
                    ax.set_xlabel("Forecast time (days)"); ax.set_ylabel("ACC")
                    plt.show()

                display(Markdown("‚úÖ Done"))
            except Exception as e:
                display(Markdown(f"‚ùå Error: `{e}`"))

    run_btn.on_click(on_run)


Dropdown(description='Region:', options=('Global', 'Northern Hemisphere', 'Tropics', 'Bangladesh', 'Chile', 'N‚Ä¶

Dropdown(description='Metric:', options=('RMSE', 'ACC'), value='RMSE')

Dropdown(description='Variable:', options=('2t',), value='2t')

Button(button_style='success', description='Run Verification', style=ButtonStyle())

Output()