
# üåç Demo 5 (Kenya): Ground Truth Challenge ‚Äî * local Kenya data vs **ERA5/IMERG**

## üß† Learning Objectives
- Use **weatherbenchX** metrics with **xarray** to verify local daily NC data.
- Load your **Kenya** daily NetCDF (one file per day) for **tmax / tmin / precip**.
- Compare against **ERA5** (temp & precip) and **IMERG** (precip).
- Compute **RMSE**, **MAE**
- Evaluate over **Kenya** and visualize **monthly curves**.
- Widgets for **start/end month & year**, **metric**, and a **Load Data** button that prints status.



## üì¶ Environment Requirements
If you hit backend errors like xarray not finding `netcdf4`/`h5netcdf`, install the deps:
```bash
pip install xarray netCDF4 h5netcdf zarr fsspec gcsfs ipywidgets matplotlib numpy pandas weatherbenchX
```
Public GCS ERA5/IMERG paths use anonymous access (no creds needed).


In [1]:
import os, glob, warnings, logging, contextlib
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base

plt.rcParams.update({"figure.dpi": 130})
xr.set_options(keep_attrs=True)


<xarray.core.options.set_options at 0x7c71e88ba290>

In [2]:
# ---------- helpers ----------
def _open_daily_nc_folder(folder: str, var_name: str) -> xr.Dataset:
    files = sorted(glob.glob(os.path.join(folder, "*.nc")))
    if not files:
        raise FileNotFoundError(f"No .nc files found in {folder}")
    def _open_one(fp):
        ds = xr.open_dataset(fp)
        if var_name not in ds:
            v = list(ds.data_vars)[0]
            ds = ds.rename({v: var_name})
        if "time" not in ds.coords:
            base = os.path.basename(fp)
            digits = "".join(ch for ch in base if ch.isdigit())
            datestr = digits[:8]
            t = pd.to_datetime(datestr, format="%Y%m%d", errors="coerce")
            if pd.isna(t):
                t = pd.to_datetime(os.path.getmtime(fp), unit="s").normalize()
            ds = ds.expand_dims(time=[np.datetime64(t)])
        return ds[[var_name]]
    ds = xr.concat([_open_one(f) for f in files], dim="time").sortby("time")
    for a,b in [("Lat","latitude"),("Lon","longitude")]:
        if a in ds.coords: ds = ds.rename({a:b})
        if a in ds.dims:   ds = ds.rename({a:b})
    return ds

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_asc(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _normalize_precip(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["m","meter","metre","m of water equivalent"]:
        da = da * 1000.0
    da.attrs["units"] = "mm/day"
    return da

def _normalize_temp_C(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if "k" in u and "pa" not in u:
        da = da - 273.15
    da.attrs["units"] = "¬∞C"
    return da

def _coerce_valid_time(da):
    if "valid_time" in da.dims: return da
    if "time" in da.dims:       return da.rename({"time":"valid_time"})
    dec = xr.decode_cf(da.to_dataset(name="_tmp"), use_cftime=False).to_array("_tmp")
    if "time" in dec.dims:      return dec.rename({"time":"valid_time"})
    raise ValueError("No time/valid_time dimension found.")

def _align_to_ref(src: xr.DataArray, ref: xr.DataArray) -> xr.DataArray:
    s = _ensure_lat_asc(_to_0360(src))
    r = _ensure_lat_asc(_to_0360(ref))
    try:
        if np.array_equal(s.latitude.values, r.latitude.values) and np.array_equal(s.longitude.values, r.longitude.values):
            return s
    except Exception:
        pass
    try:
        return s.reindex_like(r, method="nearest", tolerance={"latitude":0.125,"longitude":0.125})
    except Exception:
        pass
    s2 = s.sortby(["latitude","longitude"]); r2 = r.sortby(["latitude","longitude"])
    return s2.interp(latitude=r2["latitude"], longitude=r2["longitude"], method="nearest")

def _prep_for_wbx_validtime(da: xr.DataArray) -> xr.DataArray:
    da = _coerce_valid_time(da)
    want = ["valid_time"] + [d for d in ("latitude","longitude") if d in da.dims]
    da = da.transpose(*want)
    return da.expand_dims({"lead_time":[np.timedelta64(0,"h")]})

def _wbx_series(pred: xr.DataArray, truth: xr.DataArray, metric: str):
    ds_p = xr.Dataset({"var": _prep_for_wbx_validtime(pred.astype("float32"))})
    ds_t = xr.Dataset({"var": _prep_for_wbx_validtime(truth.astype("float32"))})
    if metric.upper() == "RMSE":
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"rmse": deterministic.RMSE()}, ds_p, ds_t)
        se = stats["SquaredError"]["var"]
        s  = (se**0.5).mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)
        ylabel = "RMSE"
    else:
        stats = metrics_base.compute_unique_statistics_for_all_metrics({"mae": deterministic.MAE()}, ds_p, ds_t)
        ae = stats["AbsoluteError"]["var"]
        s  = ae.mean([d for d in ("latitude","longitude") if d in ae.dims], skipna=True)
        ylabel = "MAE"
    if "lead_time" in s.dims and s.sizes.get("lead_time",1)==1:
        s = s.isel(lead_time=0)
    return s.rename({"valid_time":"date"}).squeeze(), ylabel

def _to_monthly(series: xr.DataArray, time_dim="date"):
    ts = pd.to_datetime(series[time_dim].values)
    return series.assign_coords({time_dim: ts}).resample({time_dim:"MS"}).mean(skipna=True)

def _kenya_box(ds):
    ds1 = _ensure_lat_asc(_to_0360(ds))
    return ds1.sel(latitude=slice(-4.7, 5.0), longitude=slice(33.9, 41.9))
# ---- QUIET Zarr opener: try google_default ‚Üí cloud ‚Üí anon, without printing errors
@contextlib.contextmanager
def _silence_gcs_logs():
    lvl_retry = logging.getLogger("gcsfs.retry").level
    lvl_gcs   = logging.getLogger("gcsfs").level
    lvl_fss   = logging.getLogger("fsspec").level
    logging.getLogger("gcsfs.retry").setLevel(logging.CRITICAL)
    logging.getLogger("gcsfs").setLevel(logging.CRITICAL)
    logging.getLogger("fsspec").setLevel(logging.CRITICAL)
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", module="google.auth._default", category=UserWarning)
        yield
    logging.getLogger("gcsfs.retry").setLevel(lvl_retry)
    logging.getLogger("gcsfs").setLevel(lvl_gcs)
    logging.getLogger("fsspec").setLevel(lvl_fss)

def _open_zarr_quiet(path: str):
    tokens = ["google_default", "cloud", "anon", None]
    last_err = None
    with _silence_gcs_logs():
        for tok in tokens:
            try:
                so = {"token": tok} if tok is not None else {}
                return xr.open_zarr(path, consolidated=True, chunks={}, storage_options=so)
            except Exception as e:
                last_err = e
                continue
    # one more try without consolidated (still quiet)
    with _silence_gcs_logs():
        try:
            return xr.open_zarr(path, consolidated=False)
        except Exception:
            raise last_err if last_err else RuntimeError(f"Could not open {path}")

def _open_any(path: str):
    if path.endswith(".zarr"):
        return _open_zarr_quiet(path)
    return xr.open_dataset(path)

In [3]:

# ---- QUIET Zarr opener: try google_default ‚Üí cloud ‚Üí anon, without printing errors
@contextlib.contextmanager
def _silence_gcs_logs():
    lvl_retry = logging.getLogger("gcsfs.retry").level
    lvl_gcs   = logging.getLogger("gcsfs").level
    lvl_fss   = logging.getLogger("fsspec").level
    logging.getLogger("gcsfs.retry").setLevel(logging.CRITICAL)
    logging.getLogger("gcsfs").setLevel(logging.CRITICAL)
    logging.getLogger("fsspec").setLevel(logging.CRITICAL)
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", module="google.auth._default", category=UserWarning)
        yield
    logging.getLogger("gcsfs.retry").setLevel(lvl_retry)
    logging.getLogger("gcsfs").setLevel(lvl_gcs)
    logging.getLogger("fsspec").setLevel(lvl_fss)

def _open_zarr_quiet(path: str):
    tokens = ["google_default", "cloud", "anon", None]
    last_err = None
    with _silence_gcs_logs():
        for tok in tokens:
            try:
                so = {"token": tok} if tok is not None else {}
                return xr.open_zarr(path, consolidated=True, chunks={}, storage_options=so)
            except Exception as e:
                last_err = e
                continue
    # one more try without consolidated (still quiet)
    with _silence_gcs_logs():
        try:
            return xr.open_zarr(path, consolidated=False)
        except Exception:
            raise last_err if last_err else RuntimeError(f"Could not open {path}")

def _open_any(path: str):
    if path.endswith(".zarr"):
        return _open_zarr_quiet(path)
    return xr.open_dataset(path)

# ---------- widgets ----------
tmax_dir_txt = widgets.Text(value="kenya_data/tmax-2022-2025",   description="Kenya tmax dir:", layout=widgets.Layout(width="95%"))
tmin_dir_txt = widgets.Text(value="kenya_data/tmin-2022-2025",   description="Kenya tmin dir:", layout=widgets.Layout(width="95%"))
prcp_dir_txt = widgets.Text(value="kenya_data/precip-2022-2025", description="Kenya precip dir:", layout=widgets.Layout(width="95%"))
tmax_var_txt = widgets.Text(value="tmax",   description="tmax var:")
tmin_var_txt = widgets.Text(value="tmin",   description="tmin var:")
prcp_var_txt = widgets.Text(value="precip", description="precip var:")

era5_tp_txt  = widgets.Text(value="gs://aim4scale_training_25/ground_truth/era5_24hr.zarr",
                             description="ERA5 tp:",  layout=widgets.Layout(width="95%"))
era5_t2_txt  = widgets.Text(value="gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr",
                             description="ERA5 t2m:", layout=widgets.Layout(width="95%"))
imerg_txt    = widgets.Text(value="gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr",
                             description="IMERG:",    layout=widgets.Layout(width="95%"))

load_btn = widgets.Button(description="Load data", button_style="primary")
out_load = widgets.Output()

display(
    Markdown("### üì• Load Kenya + ERA5/IMERG"),
    tmax_dir_txt, tmin_dir_txt, prcp_dir_txt,
    widgets.HBox([tmax_var_txt, tmin_var_txt, prcp_var_txt]),
    era5_tp_txt, era5_t2_txt, imerg_txt, load_btn, out_load
)

# ---------- state ----------
ds_k_tmax = ds_k_tmin = ds_k_prcp = None
ds_e5_t2m = ds_e5_tp = ds_imerg   = None

@out_load.capture(clear_output=True)
def _on_load(_):
    global ds_k_tmax, ds_k_tmin, ds_k_prcp, ds_e5_t2m, ds_e5_tp, ds_imerg
    try:
        ds_k_tmax = _open_daily_nc_folder(tmax_dir_txt.value.strip(), tmax_var_txt.value.strip())
        ds_k_tmin = _open_daily_nc_folder(tmin_dir_txt.value.strip(), tmin_var_txt.value.strip())
        ds_k_prcp = _open_daily_nc_folder(prcp_dir_txt.value.strip(), prcp_var_txt.value.strip())
        ds_e5_t2m = _open_any(era5_t2_txt.value.strip())
        ds_e5_tp  = _open_any(era5_tp_txt.value.strip())
        ds_imerg  = _open_any(imerg_txt.value.strip())

        display(Markdown("‚úÖ **All datasets loaded**"))

        def _span(ds):
            t = "time" if "time" in ds.coords else "valid_time"
            return str(pd.to_datetime(ds[t].min().values).date()), str(pd.to_datetime(ds[t].max().values).date())

        for name, ds in [
            ("Kenya tmax", ds_k_tmax), ("Kenya tmin", ds_k_tmin), ("Kenya precip", ds_k_prcp),
            ("ERA5 t2m", ds_e5_t2m), ("ERA5 tp", ds_e5_tp), ("IMERG", ds_imerg)
        ]:
            try:
                a,b = _span(ds); display(Markdown(f"- **{name}**: `{a}` ‚Üí `{b}`"))
            except Exception:
                pass

    except Exception as e:
        # only show a short one-line failure if *anything* truly fails to open
        display(Markdown(f"‚ùå Load error: `{e}`"))

load_btn.on_click(_on_load)


### üì• Load Kenya + ERA5/IMERG

Text(value='kenya_data/tmax-2022-2025', description='Kenya tmax dir:', layout=Layout(width='95%'))

Text(value='kenya_data/tmin-2022-2025', description='Kenya tmin dir:', layout=Layout(width='95%'))

Text(value='kenya_data/precip-2022-2025', description='Kenya precip dir:', layout=Layout(width='95%'))

HBox(children=(Text(value='tmax', description='tmax var:'), Text(value='tmin', description='tmin var:'), Text(‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_24hr.zarr', description='ERA5 tp:', layout=Layout(wid‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr', description='ERA5 t2m:', layo‚Ä¶

Text(value='gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr', description='IMERG:', layout=L‚Ä¶

Button(button_style='primary', description='Load data', style=ButtonStyle())

Output()

In [6]:
ds = xr.open_dataset("/workspace/kenya_data/precip-2022-2025/rr_mrg_20220102_CLM.nc")
ds

In [4]:
# === Cell 2: Kenya (truth) vs ERA5 ‚Äî DOY curves with month ticks ===
metric_dd_e5    = widgets.Dropdown(options=["RMSE","MAE"], value="RMSE", description="Metric:")
var_dd_e5       = widgets.Dropdown(options=["tmax","tmin","precip"], value="tmax", description="Variable:")
start_month_e5  = widgets.Dropdown(options=list(range(1,13)), value=1,  description="Start M")
end_month_e5    = widgets.Dropdown(options=list(range(1,13)), value=12, description="End M")
start_year_e5   = widgets.BoundedIntText(value=2022, min=1900, max=2100, description="Start Y")
end_year_e5     = widgets.BoundedIntText(value=2023, min=1900, max=2100, description="End Y")
run_btn_e5      = widgets.Button(description="Run (Kenya vs ERA5)", button_style="success")
out_plot_e5     = widgets.Output()

display(
    widgets.HBox([var_dd_e5, metric_dd_e5]),
    widgets.HBox([start_month_e5, end_month_e5, start_year_e5, end_year_e5]),
    run_btn_e5, out_plot_e5
)

# Month tick helpers (non-leap)
_MONTH_START_DOY = [1,32,60,91,121,152,182,213,244,274,305,335]
_MONTH_LABELS    = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]

@out_plot_e5.capture(clear_output=True)
def _run_e5(_):
    if any(x is None for x in [ds_k_tmax, ds_k_tmin, ds_k_prcp, ds_e5_t2m, ds_e5_tp]):
        print("‚ùå Load data first (Cell 1).")
        return

    var, metric = var_dd_e5.value, metric_dd_e5.value
    m0, m1 = int(start_month_e5.value), int(end_month_e5.value)
    y0, y1 = int(start_year_e5.value),  int(end_year_e5.value)

    # ----- pick truth & ERA5 -----
    if var in ("tmax","tmin"):
        kenya = _kenya_box((ds_k_tmax if var=="tmax" else ds_k_tmin))[var]
        # prefer exact names in your ERA5 t2m Zarr; then common fallbacks
        if var in ds_e5_t2m:
            era5 = _kenya_box(ds_e5_t2m)[var]
        else:
            alts = ["t2m_max","maximum_2m_air_temperature","mx2t"] if var=="tmax" \
                   else ["t2m_min","minimum_2m_air_temperature","mn2t"]
            era5 = None
            for cand in alts:
                if cand in ds_e5_t2m:
                    era5 = _kenya_box(ds_e5_t2m)[cand]; break
            if era5 is None:
                print("‚ùå ERA5 t2m variable not found.")
                return
        kenya = _normalize_temp_C(kenya); era5 = _normalize_temp_C(era5)

    else:  # precip
        kenya = _normalize_precip(_kenya_box(ds_k_prcp)[prcp_var_txt.value.strip()])
        era5  = _normalize_precip(_kenya_box(ds_e5_tp)["total_precipitation_24hr"])

    # ----- time window (inclusive months/years) -----
    t0 = np.datetime64(pd.Timestamp(year=y0, month=m0, day=1))
    t1 = np.datetime64(pd.Timestamp(year=y1, month=m1, day=1) + pd.offsets.MonthEnd(1))
    kenya = kenya.sel(time=slice(t0, t1)); era5 = era5.sel(time=slice(t0, t1))

    # ----- align + WBX metric (daily) -----
    era5_on = _align_to_ref(era5, kenya)
    s_e5, ylab = _wbx_series(_coerce_valid_time(era5_on), _coerce_valid_time(kenya), metric)
    s = s_e5.dropna(dim="date")

    # keep only selected months/years (safety in case bounds were wider)
    keep = (s["date"].dt.year >= y0) & (s["date"].dt.year <= y1) & \
           (s["date"].dt.month >= m0) & (s["date"].dt.month <= m1)
    s = s.sel(date=keep)

    # ----- DOY aggregation (mean across years) -----
    s_doy = s.groupby(s["date"].dt.dayofyear).mean(skipna=True).rename({"dayofyear":"DOY"})
    # drop Feb 29 if present to keep 365 ticks
    if int(s_doy["DOY"].max()) == 366:
        s_doy = s_doy.sel(DOY=slice(1,365))

    # ----- plot with month ticks -----
    fig, ax = plt.subplots(figsize=(10,4))
    s_doy.plot(ax=ax, x="DOY", label="ERA5 vs Kenya")
    ax.set_xticks(_MONTH_START_DOY)
    ax.set_xticklabels(_MONTH_LABELS)
    ax.set_xlim(1, 365)
    ax.set_ylabel(f"{ylab} " + ("(mm/day)" if var=="precip" else "(¬∞C)"))
    ax.set_title(f"{metric} ‚Äî {var.upper()} ‚Äî Kenya vs ERA5 ‚Äî DOY mean over {y0}-{y1} "
                 f"(months {m0:02d}‚Üí{m1:02d})")
    ax.grid(True, alpha=0.3); ax.legend()
    plt.show()

run_btn_e5.on_click(_run_e5)


HBox(children=(Dropdown(description='Variable:', options=('tmax', 'tmin', 'precip'), value='tmax'), Dropdown(d‚Ä¶

HBox(children=(Dropdown(description='Start M', options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), value=1), Drop‚Ä¶

Button(button_style='success', description='Run (Kenya vs ERA5)', style=ButtonStyle())

Output()

In [None]:
# === Kenya (truth) vs IMERG (precip only) ‚Äî DOY curves with month ticks ===
metric_dd_img    = widgets.Dropdown(options=["RMSE","MAE"], value="RMSE", description="Metric:")
start_month_img  = widgets.Dropdown(options=list(range(1,13)), value=1,  description="Start M")
end_month_img    = widgets.Dropdown(options=list(range(1,13)), value=12, description="End M")
start_year_img   = widgets.BoundedIntText(value=2022, min=1900, max=2100, description="Start Y")
end_year_img     = widgets.BoundedIntText(value=2023, min=1900, max=2100, description="End Y")
run_btn_img      = widgets.Button(description="Run (Kenya vs IMERG)", button_style="success")
out_plot_img     = widgets.Output()

display(
    widgets.HBox([metric_dd_img]),
    widgets.HBox([start_month_img, end_month_img, start_year_img, end_year_img]),
    run_btn_img, out_plot_img
)

@out_plot_img.capture(clear_output=True)
def _run_img(_):
    if any(x is None for x in [ds_k_prcp, ds_imerg]):
        print("‚ùå Load data first (Cell 1)."); return

    metric = metric_dd_img.value
    m0, m1 = int(start_month_img.value), int(end_month_img.value)
    y0, y1 = int(start_year_img.value),  int(end_year_img.value)

    kenya = _normalize_precip(_kenya_box(ds_k_prcp)[prcp_var_txt.value.strip()])
    imerg = _normalize_precip(_kenya_box(ds_imerg)["total_precipitation_24hr"])

    t0 = np.datetime64(pd.Timestamp(year=y0, month=m0, day=1))
    t1 = np.datetime64(pd.Timestamp(year=y1, month=m1, day=1) + pd.offsets.MonthEnd(1))
    kenya = kenya.sel(time=slice(t0, t1)); imerg = imerg.sel(time=slice(t0, t1))

    imerg_on = _align_to_ref(imerg, kenya)
    s_img, ylab = _wbx_series(_coerce_valid_time(imerg_on), _coerce_valid_time(kenya), metric)
    s = s_img.dropna(dim="date")

    keep = (s["date"].dt.year >= y0) & (s["date"].dt.year <= y1) & \
           (s["date"].dt.month >= m0) & (s["date"].dt.month <= m1)
    s = s.sel(date=keep)

    s_doy = s.groupby(s["date"].dt.dayofyear).mean(skipna=True).rename({"dayofyear":"DOY"})
    if int(s_doy["DOY"].max()) == 366:
        s_doy = s_doy.sel(DOY=slice(1,365))

    fig, ax = plt.subplots(figsize=(10,4))
    s_doy.plot(ax=ax, x="DOY", label="IMERG vs Kenya")
    ax.set_xticks(_MONTH_START_DOY)
    ax.set_xticklabels(_MONTH_LABELS)
    ax.set_xlim(1, 365)
    ax.set_ylabel(f"{ylab} (mm/day)")
    ax.set_title(f"{metric} ‚Äî PRECIP ‚Äî Kenya vs IMERG ‚Äî DOY mean over {y0}-{y1} "
                 f"(months {m0:02d}‚Üí{m1:02d})")
    ax.grid(True, alpha=0.3); ax.legend()
    plt.show()

run_btn_img.on_click(_run_img)


HBox(children=(Dropdown(description='Metric:', options=('RMSE', 'MAE'), value='RMSE'),))

HBox(children=(Dropdown(description='Start M', options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), value=1), Drop‚Ä¶

Button(button_style='success', description='Run (Kenya vs IMERG)', style=ButtonStyle())

Output()