
# üåç Demo 5: Ground Truth Challenge ‚Äî TAHMO DATA*

## üß† Learning Objectives
- Use **weatherbenchX** (metrics) + **xarray** to verify forecasts.

- Compute **RMSE** and **MAE** (with WeatherbenchX).
- Focus evaluation 
- Visualize skill vs **lead time**.





## üì¶ Environment Requirements
If you hit backend errors like xarray not finding `netcdf4`/`h5netcdf`, install the deps:
```bash
# pip install xarray netCDF4 h5netcdf zarr fsspec gcsfs ipywidgets matplotlib numpy pandas
# pip install weatherbenchX xesmf
```
For public GCS ERA5, we use anonymous access. No credentials needed.


In [None]:

import os, glob, io, contextlib, logging
import numpy as np
import pandas as pd
import xarray as xr
import ipywidgets as widgets
from IPython.display import display, Markdown
import matplotlib.pyplot as plt

from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base

plt.rcParams.update({"figure.dpi": 130})
xr.set_options(keep_attrs=True)


In [None]:
# ------------------------- quiet logging helper -------------------------
@contextlib.contextmanager
def _silence_lib_logs():
    names = ["gcsfs", "fsspec", "zarr", "google.auth", "urllib3"]
    old = {n: logging.getLogger(n).level for n in names}
    for n in names: logging.getLogger(n).setLevel(logging.CRITICAL)
    buf = io.StringIO()
    with contextlib.redirect_stderr(buf), contextlib.redirect_stdout(buf):
        yield
    for n,lvl in old.items(): logging.getLogger(n).setLevel(lvl)

# ------------------------- helpers -------------------------
def _open_any(path: str):
    """Open Zarr (local/GCS) or NetCDF quietly; try ADC‚Üíanon‚Üíunconsolidated."""
    if path.endswith(".zarr"):
        with _silence_lib_logs():
            for token in ("cloud", "anon", None):
                try:
                    if token:
                        return xr.open_zarr(path, consolidated=True, chunks={}, storage_options={"token": token})
                    return xr.open_zarr(path, consolidated=False, chunks={})
                except Exception:
                    continue
        raise RuntimeError(f"Could not open Zarr: {path}")
    return xr.open_dataset(path)

def _open_station_nc(path: str) -> xr.Dataset:
    ds = xr.open_dataset(path)
    # standardize coord names if present
    if "Lat" in ds.coords: ds = ds.rename({"Lat":"latitude"})
    if "Lon" in ds.coords: ds = ds.rename({"Lon":"longitude"})
    if "Lat" in ds.dims:   ds = ds.rename({"Lat":"latitude"})
    if "Lon" in ds.dims:   ds = ds.rename({"Lon":"longitude"})
    return ds

def _find_precip_name(ds: xr.Dataset) -> str|None:
    for cand in ["total_precipitation_24hr","tp_24h","precip","total_precipitation","rr","rain"]:
        if cand in ds.data_vars: return cand
    return None

def _find_temp_name(ds: xr.Dataset) -> str|None:
    for cand in ["tavg","t2m_mean","daily_mean_temperature","tas_mean","tmean"]:
        if cand in ds.data_vars: return cand
    # sometimes tmax/tmin exist; we‚Äôll average later in Cell 2
    return None


In [None]:


# ------------------------- widgets (loader) -------------------------
station_dir_txt = widgets.Text(value="/workspace", description="Station folder:", layout=widgets.Layout(width="95%"))
scan_btn   = widgets.Button(description="Scan stations", button_style="info")
station_dd = widgets.Dropdown(options=[], description="Station file:", layout=widgets.Layout(width="95%"))

era5_t2m_txt = widgets.Text(
    value="gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr",
    description="ERA5 t2m:", layout=widgets.Layout(width="95%"))
era5_tp_txt  = widgets.Text(
    value="gs://aim4scale_training_25/ground_truth/era5_24hr.zarr",
    description="ERA5 tp:", layout=widgets.Layout(width="95%"))
imerg_txt    = widgets.Text(
    value="gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr",
    description="IMERG:", layout=widgets.Layout(width="95%"))

load_btn = widgets.Button(description="Load all", button_style="primary")
out_load = widgets.Output()

display(Markdown("### üìÑ Pick station (NetCDF) & paths"),
        station_dir_txt, scan_btn, station_dd,
        era5_t2m_txt, era5_tp_txt, imerg_txt, load_btn, out_load)

# ------------------------- state (shared with later cells) -------------------------
ds_station = ds_era5_t2m = ds_era5_tp = ds_imerg = None
station_name = ""
precip_key_station = None
temp_key_station   = None

# ------------------------- actions -------------------------
def _scan(_):
    folder = station_dir_txt.value.strip()
    files = sorted(glob.glob(os.path.join(folder, "*.nc")))
    station_dd.options = files
    if files: station_dd.value = files[0]
scan_btn.on_click(_scan)

@out_load.capture(clear_output=True)
def _load_all(_):
    global ds_station, ds_era5_t2m, ds_era5_tp, ds_imerg, station_name, precip_key_station, temp_key_station
    if not station_dd.value:
        print("‚ùå Pick a station file first (Scan stations)."); return
    try:
        with _silence_lib_logs():
            ds_station = _open_station_nc(station_dd.value)
            ds_era5_t2m = _open_any(era5_t2m_txt.value.strip())
            ds_era5_tp  = _open_any(era5_tp_txt.value.strip())
            ds_imerg    = _open_any(imerg_txt.value.strip())
        station_name = os.path.basename(station_dd.value)
        precip_key_station = _find_precip_name(ds_station)
        temp_key_station   = _find_temp_name(ds_station)
        # summary (quiet ‚Äì no GCS noise)
        def _span(ds):
            t = "time" if "time" in ds.coords else "valid_time"
            return str(pd.to_datetime(ds[t].min().values).date()), str(pd.to_datetime(ds[t].max().values).date())
        print("‚úÖ All datasets loaded")
        try:
            a,b = _span(ds_station); print(f"‚Ä¢ Station: {station_name} ‚Äî vars: {list(ds_station.data_vars)[:4]} ‚Äî {a} ‚Üí {b}")
            a,b = _span(ds_era5_t2m); print(f"‚Ä¢ ERA5 t2m: {a} ‚Üí {b}")
            a,b = _span(ds_era5_tp);  print(f"‚Ä¢ ERA5 tp : {a} ‚Üí {b}")
            a,b = _span(ds_imerg);    print(f"‚Ä¢ IMERG   : {a} ‚Üí {b}")
        except Exception:
            pass
    except Exception as e:
        print(f"‚ùå Load error: {e}")
load_btn.on_click(_load_all)

# ------------------------- utilities used by later cells -------------------------
def _normalize_precip_mmday(da: xr.DataArray) -> xr.DataArray:
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["m","meter","metre","m of water equivalent"]:
        da = da * 1000.0
    da.attrs["units"] = "mm/day"; return da

def _normalize_temp_C(da: xr.DataArray) -> xr.DataArray:
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if "k" in u and "pa" not in u: da = da - 273.15
    da.attrs["units"] = "¬∞C"; return da

def _coerce_valid_time(da: xr.DataArray) -> xr.DataArray:
    if "valid_time" in da.dims: return da
    if "time" in da.dims: return da.rename({"time":"valid_time"})
    dec = xr.decode_cf(da.to_dataset(name="_tmp"), use_cftime=False).to_array("_tmp")
    if "time" in dec.dims: return dec.rename({"time":"valid_time"})
    raise ValueError("No time/valid_time dimension found.")

def _prep_for_wbx_validtime(da: xr.DataArray) -> xr.DataArray:
    da = _coerce_valid_time(da)
    want = ["valid_time"] + [d for d in ("latitude","longitude") if d in da.dims]
    da = da.transpose(*want)
    return da.expand_dims({"lead_time":[np.timedelta64(0,"h")]})

def _wbx_series(pred: xr.DataArray, truth: xr.DataArray, metric: str):
    ds_p = xr.Dataset({"var": _prep_for_wbx_validtime(pred.astype("float32"))})
    ds_t = xr.Dataset({"var": _prep_for_wbx_validtime(truth.astype("float32"))})
    if metric.upper() == "RMSE":
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"rmse": deterministic.RMSE()}, ds_p, ds_t)
        se = stats["SquaredError"]["var"]
        s  = (se**0.5).mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)
        ylabel = "RMSE"
    else:
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"mae": deterministic.MAE()}, ds_p, ds_t)
        ae = stats["AbsoluteError"]["var"]
        s  = ae.mean([d for d in ("latitude","longitude") if d in ae.dims], skipna=True)
        ylabel = "MAE"
    if "lead_time" in s.dims and s.sizes.get("lead_time",1)==1: s = s.isel(lead_time=0)
    return s.rename({"valid_time":"time"}).squeeze(), ylabel

def _nearest_on_grid(grid_da: xr.DataArray, lat: float, lon: float) -> xr.DataArray:
    if "longitude" in grid_da.coords:
        g_lon = grid_da["longitude"]
        try:
            if float(g_lon.min()) < 0: lon = lon % 360.0
        except Exception: pass
    return grid_da.sel(latitude=lat, longitude=lon, method="nearest")

def _to_monthly(series: xr.DataArray, time_dim="time"):
    ts = pd.to_datetime(series[time_dim].values)
    return series.assign_coords({time_dim: ts}).resample({time_dim:"MS"}).mean(skipna=True)

def _to_doy_mean(series: xr.DataArray) -> xr.DataArray:
    s = series.groupby("time.dayofyear").mean(skipna=True)   # dim = 'dayofyear'
    s = s.rename({"dayofyear":"DOY"})
    # ensure coordinate exists & is int (prevents KeyError)
    idx = s.indexes["DOY"] if "DOY" in s.indexes else np.arange(1, s.sizes.get("DOY",0)+1)
    s = s.assign_coords(DOY=("DOY", np.asarray(idx, dtype=int)))
    return s

def _month_ticks_for_doy(ax):
    month_names = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
    first_doy = [1,32,60,91,121,152,182,213,244,274,305,335]
    ax.set_xticks(first_doy); ax.set_xticklabels(month_names)

def _time_slice(da: xr.DataArray, y0, m0, y1, m1) -> xr.DataArray:
    t0 = np.datetime64(pd.Timestamp(year=y0, month=m0, day=1))
    t1 = np.datetime64(pd.Timestamp(year=y1, month=m1, day=1) + pd.offsets.MonthEnd(1))
    return da.sel(time=slice(t0, t1))


In [None]:
# === Cell 2: Station (truth) vs ERA5 (predictor) ===
metric_dd_e5 = widgets.Dropdown(options=["RMSE","MAE"], value="RMSE", description="Metric:")
# options populated from station vars (fallbacks handled below)
var_opts = []
if 'temp_key_station' in globals() and temp_key_station: var_opts.append(temp_key_station)
if 'precip_key_station' in globals() and precip_key_station: var_opts.append(precip_key_station)
var_dd_e5  = widgets.Dropdown(options=(var_opts or ["tavg","total_precipitation_24hr"]),
                              value=(var_opts[0] if var_opts else "tavg"),
                              description="Variable:")
start_month_e5 = widgets.Dropdown(options=list(range(1,13)), value=1,  description="Start M")
end_month_e5   = widgets.Dropdown(options=list(range(1,13)), value=12, description="End M")
start_year_e5  = widgets.BoundedIntText(value=2022, min=1900, max=2100, description="Start Y")
end_year_e5    = widgets.BoundedIntText(value=2024, min=1900, max=2100, description="End Y")
doy_cb_e5      = widgets.Checkbox(value=False, description="Group by DOY (month ticks)")
run_btn_e5     = widgets.Button(description="Run (Station vs ERA5)", button_style="success")
out_plot_e5    = widgets.Output()

display(widgets.HBox([var_dd_e5, metric_dd_e5, doy_cb_e5]),
        widgets.HBox([start_month_e5, end_month_e5, start_year_e5, end_year_e5]),
        run_btn_e5, out_plot_e5)

@out_plot_e5.capture(clear_output=True)
def _run_e5(_):
    if any(x is None for x in [ds_station, ds_era5_t2m, ds_era5_tp]):
        print("‚ùå Load data first (Cell 1)."); return

    var, metric = var_dd_e5.value, metric_dd_e5.value
    m0, m1 = int(start_month_e5.value), int(end_month_e5.value)
    y0, y1 = int(start_year_e5.value),  int(end_year_e5.value)

    # station variable + lat/lon
    if var in ds_station:
        st = ds_station[var]
    else:
        # tolerant fallbacks
        if var.lower().startswith("t"): st = ds_station.get("tavg", None)
        else:                           st = ds_station.get("total_precipitation_24hr", None)
        if st is None:
            print("‚ùå Selected variable not found in station file."); return
    lat = float(st.latitude.values) if "latitude" in st.coords and st.latitude.size else float(ds_station.latitude)
    lon = float(st.longitude.values) if "longitude" in st.coords and st.longitude.size else float(ds_station.longitude)

    # normalize & choose ERA5 field
    is_precip = (var == precip_key_station) or ("precip" in var) or ("tp" in var) or ("rain" in var)
    if is_precip:
        st = _normalize_precip_mmday(st)
        e5 = _normalize_precip_mmday(ds_era5_tp["total_precipitation_24hr"])
    else:
        st = _normalize_temp_C(st)
        # common ERA5 daily-mean-temperature names
        for cand in ["tavg","t2m_mean","daily_mean_temperature","tas_mean","tmean"]:
            if cand in ds_era5_t2m: e5 = _normalize_temp_C(ds_era5_t2m[cand]); break
        else:
            if "tmax" in ds_era5_t2m and "tmin" in ds_era5_t2m:
                e5 = _normalize_temp_C((ds_era5_t2m["tmax"] + ds_era5_t2m["tmin"])/2.0)
            else:
                print("‚ùå Could not find ERA5 daily tavg field."); return

    e5_pt = _nearest_on_grid(e5, lat, lon)

    st_s = _time_slice(st,  y0, m0, y1, m1)
    e5_s = _time_slice(e5_pt, y0, m0, y1, m1)
    if st_s.sizes.get("time",0)==0 or e5_s.sizes.get("time",0)==0:
        print("‚ö†Ô∏è No data in requested window."); return

    series, ylab = _wbx_series(e5_s, st_s, metric)

    if doy_cb_e5.value:
        s = _to_doy_mean(series)
        fig, ax = plt.subplots(figsize=(10,4))
        s.plot(ax=ax, x="DOY", label="ERA5 vs Station")
        _month_ticks_for_doy(ax); ax.set_xlabel("month")
    else:
        s = _to_monthly(series, "time")
        fig, ax = plt.subplots(figsize=(10,4))
        s.plot(ax=ax, x="time", label="ERA5 vs Station")

    units = "(mm/day)" if is_precip else "(¬∞C)"
    ax.set_ylabel(f"{ylab} {units}")
    title_var = "PRECIP" if is_precip else var.upper()
    period_txt = f"{y0}-{y1} ({m0:02d}‚Üí{m1:02d})"
    ax.set_title(f"{metric} ‚Äî {title_var} ‚Äî Station vs ERA5 ‚Äî {period_txt}")
    ax.grid(True, alpha=0.3); ax.legend(); plt.show()

run_btn_e5.on_click(_run_e5)


In [None]:
# === Cell 3: Station (truth) vs IMERG (predictor) ‚Äî precip only ===
metric_dd_img = widgets.Dropdown(options=["RMSE","MAE"], value="RMSE", description="Metric:")
start_month_img = widgets.Dropdown(options=list(range(1,13)), value=1,  description="Start M")
end_month_img   = widgets.Dropdown(options=list(range(1,13)), value=12, description="End M")
start_year_img  = widgets.BoundedIntText(value=2022, min=1900, max=2100, description="Start Y")
end_year_img    = widgets.BoundedIntText(value=2024, min=1900, max=2100, description="End Y")
doy_cb_img      = widgets.Checkbox(value=False, description="Group by DOY (month ticks)")
run_btn_img     = widgets.Button(description="Run (Station vs IMERG)", button_style="warning")
out_plot_img    = widgets.Output()

display(widgets.HBox([metric_dd_img, doy_cb_img]),
        widgets.HBox([start_month_img, end_month_img, start_year_img, end_year_img]),
        run_btn_img, out_plot_img)

@out_plot_img.capture(clear_output=True)
def _run_img(_):
    if any(x is None for x in [ds_station, ds_imerg]):
        print("‚ùå Load data first (Cell 1)."); return
    if precip_key_station is None and "total_precipitation_24hr" not in ds_station:
        print("‚ùå Station has no precipitation variable."); return

    metric = metric_dd_img.value
    m0, m1 = int(start_month_img.value), int(end_month_img.value)
    y0, y1 = int(start_year_img.value),  int(end_year_img.value)

    var = precip_key_station or "total_precipitation_24hr"
    st = _normalize_precip_mmday(ds_station[var])

    lat = float(st.latitude.values) if "latitude" in st.coords and st.latitude.size else float(ds_station.latitude)
    lon = float(st.longitude.values) if "longitude" in st.coords and st.longitude.size else float(ds_station.longitude)

    img = _normalize_precip_mmday(ds_imerg["total_precipitation_24hr"])
    img_pt = _nearest_on_grid(img, lat, lon)

    st_s  = _time_slice(st,     y0, m0, y1, m1)
    img_s = _time_slice(img_pt, y0, m0, y1, m1)
    if st_s.sizes.get("time",0)==0 or img_s.sizes.get("time",0)==0:
        print("‚ö†Ô∏è No data in requested window."); return

    series, ylab = _wbx_series(img_s, st_s, metric)

    if doy_cb_img.value:
        s = _to_doy_mean(series)
        fig, ax = plt.subplots(figsize=(10,4))
        s.plot(ax=ax, x="DOY", label="IMERG vs Station")
        _month_ticks_for_doy(ax); ax.set_xlabel("month")
    else:
        s = _to_monthly(series, "time")
        fig, ax = plt.subplots(figsize=(10,4))
        s.plot(ax=ax, x="time", label="IMERG vs Station")

    ax.set_ylabel(f"{ylab} (mm/day)")
    period_txt = f"{y0}-{y1} ({m0:02d}‚Üí{m1:02d})"
    ax.set_title(f"{metric} ‚Äî PRECIP ‚Äî Station vs IMERG ‚Äî {period_txt}")
    ax.grid(True, alpha=0.3); ax.legend(); plt.show()

run_btn_img.on_click(_run_img)
