
# üåç Demo 5: Ground Truth Challenge ‚Äî NIGERIA DATA vs **ERA5**

## üß† Learning Objectives
- Use **weatherbenchX** (metrics) + **xarray** to verify ground-truth datasets.
- Load NIGERIA data (*or any dataset in NC/Zarr*) and compare to **ERA5** ground truth.
- Compute **RMSE** and **MAE** (with ERA5 climatology).
- Focus evaluation on NIGERIA region.
- Visualize difference in different ground-truth datasets.


In [29]:

import os, glob, io, contextlib, logging
import numpy as np
import pandas as pd
import xarray as xr
import ipywidgets as widgets
from IPython.display import display, Markdown

plt = None
xr.set_options(keep_attrs=True)

<xarray.core.options.set_options at 0x7bb19a7c5050>

In [26]:


@contextlib.contextmanager
def _silence_lib_logs():
    names = ["gcsfs", "fsspec", "zarr", "google.auth", "urllib3"]
    old = {n: logging.getLogger(n).level for n in names}
    for n in names: logging.getLogger(n).setLevel(logging.CRITICAL)
    buf = io.StringIO()
    with contextlib.redirect_stderr(buf), contextlib.redirect_stdout(buf):
        yield
    for n,lvl in old.items(): logging.getLogger(n).setLevel(lvl)

def _open_any(path: str):
    p = (path or "").strip()
    if not p: return None
    if p.endswith(".zarr"):
        with _silence_lib_logs():
            for token in ("cloud","anon",None):
                try:
                    if token:
                        return xr.open_zarr(p, consolidated=True, chunks={}, storage_options={"token": token})
                    return xr.open_zarr(p, consolidated=False, chunks={})
                except Exception:
                    continue
        raise RuntimeError(f"Could not open Zarr: {p}")
    return xr.open_dataset(p)

def _standardize_latlon(ds: xr.Dataset) -> xr.Dataset:
    """Rename lat/lon to latitude/longitude, set as coords, sort ascending."""
    ren = {}
    if "lat" in ds.dims or "lat" in ds.coords: ren["lat"] = "latitude"
    if "lon" in ds.dims or "lon" in ds.coords: ren["lon"] = "longitude"
    if ren: ds = ds.rename(ren)
    for c in ("latitude","longitude"):
        if c in ds and c not in ds.coords:
            ds = ds.set_coords(c)
    for c in ("latitude","longitude"):
        if c in ds.coords:
            ds = ds.sortby(c)
    return ds

def _drop_duplicate_coords_1d(ds: xr.Dataset) -> xr.Dataset:
    """If latitude/longitude have dup values, drop duplicates (keep first)."""
    for c in ("latitude","longitude"):
        if c in ds.dims and c in ds.coords and ds[c].ndim == 1:
            v = np.asarray(ds[c].values)
            uniq, idx = np.unique(v, return_index=True)
            if len(uniq) < len(v):
                ds = ds.isel({c: np.sort(idx)})
    return ds

def _span(ds):
    t = "time" if "time" in ds.coords else "valid_time"
    return str(pd.to_datetime(ds[t].min().values).date()), str(pd.to_datetime(ds[t].max().values).date())

def _bounds_str(ds):
    la = ds["latitude"].values; lo = ds["longitude"].values
    return f"lat {float(la.min()):.2f}..{float(la.max()):.2f}, lon {float(lo.min()):.2f}..{float(lo.max()):.2f}"



In [None]:
# ---------- widgets ----------
root_txt = widgets.Text(value="/workspace", description="Folder:", layout=widgets.Layout(width="95%"))
scan_btn = widgets.Button(description="Scan .nc", button_style="info")
ng_dd    = widgets.Dropdown(options=[], description="Nigeria .nc:", layout=widgets.Layout(width="95%"))

era5_t2m_txt = widgets.Text(
    value="gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr",
    description="ERA5 t2m:", layout=widgets.Layout(width="95%"))

load_btn = widgets.Button(description="Load", button_style="success")
out_load = widgets.Output()

display(Markdown("### üì¶ Pick files"))
display(root_txt, scan_btn, ng_dd, era5_t2m_txt, load_btn, out_load)

# ---------- shared state ----------
ds_ng = ds_e5_t2m = None

def _scan(_=None):
    root = root_txt.value.strip() or "."
    files = sorted(glob.glob(os.path.join(root, "*.nc")))
    ng_dd.options = files
    if files: ng_dd.value = files[0]

@out_load.capture(clear_output=True)
def _load(_=None):
    global ds_ng, ds_e5_t2m
    if not ng_dd.value:
        print("‚ùå Select a Nigeria .nc file."); return
    try:
        with _silence_lib_logs():
            ds_ng     = xr.open_dataset(ng_dd.value)
            ds_ng     = _standardize_latlon(ds_ng)
            ds_ng     = _drop_duplicate_coords_1d(ds_ng)
            ds_e5_t2m = _open_any(era5_t2m_txt.value)

        a,b = _span(ds_ng)
        print("‚úÖ All datasets loaded")
        print(f"‚Ä¢ Nigeria grid: {os.path.basename(ng_dd.value)} ‚Äî {a} ‚Üí {b} ‚Äî {_bounds_str(ds_ng)}")
        if ds_e5_t2m is not None:
            a,b = _span(ds_e5_t2m); print(f"‚Ä¢ ERA5 t2m: {a} ‚Üí {b}")
        print("Region: **Data over Nigeria** (bounds from the file).")
    except Exception as e:
        print(f"‚ùå Load error: {e}")

scan_btn.on_click(_scan)
load_btn.on_click(_load)
_scan()


### üì¶ Pick files

Text(value='/workspace', description='Folder:', layout=Layout(width='95%'))

Button(button_style='info', description='Scan .nc', style=ButtonStyle())

Dropdown(description='Nigeria .nc:', layout=Layout(width='95%'), options=(), value=None)

Text(value='gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr', description='ERA5 t2m:', layo‚Ä¶

Button(button_style='success', description='Load', style=ButtonStyle())

Output()

In [31]:
# === Cell 2: ERA5 vs Nigeria (temperature only) ‚Äî with mean error printout ===
import numpy as np, pandas as pd, xarray as xr
import matplotlib.pyplot as plt, ipywidgets as widgets
from IPython.display import display, Markdown

plt.rcParams.update({"figure.dpi": 130})

# Nigeria var names (no rainfall)
NIGERIA_VARS = {"tmax":"max_temperature", "tmin":"min_temperature"}

def _f32(da):
    try: return da.astype("float32", copy=False)
    except Exception: return da

def _normalize_temp_C(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if "k" in u and "pa" not in u: da = da - 273.15
    da.attrs["units"] = "¬∞C"; return da

def _window(da, y0,m0,y1,m1):
    if "time" not in da.coords: return da
    t0 = np.datetime64(pd.Timestamp(year=y0, month=m0, day=1))
    t1 = np.datetime64(pd.Timestamp(year=y1, month=m1, day=1) + pd.offsets.MonthEnd(1))
    return da.sel(time=slice(t0, t1))

def _month_agg(s): return s.resample(time="MS").mean(skipna=True)

def _doy_agg(s):
    g = s.groupby("time.dayofyear").mean(skipna=True).rename({"dayofyear":"DOY"})
    try: g = g.assign_coords(DOY=("DOY", g["DOY"].values.astype(int)))
    except Exception: g = g.assign_coords(DOY=("DOY", np.arange(1, g.sizes["DOY"]+1, dtype=int)))
    return g

def _month_ticks_for_doy(ax):
    ax.set_xticks([1,32,60,91,121,152,182,213,244,274,305,335])
    ax.set_xticklabels(["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"])

def _put_pred_on_truth_grid(pred, truth):
    for d in ("latitude","longitude"):
        if d in pred.dims:  pred  = pred.sortby(d)
        if d in truth.dims: truth = truth.sortby(d)
    try:
        return pred.sel(latitude=truth.latitude, longitude=truth.longitude, method="nearest")
    except Exception:
        return pred.interp(latitude=truth.latitude, longitude=truth.longitude, method="nearest")

def _common_time(a, b):
    ta = np.asarray(a.time.values); tb = np.asarray(b.time.values)
    t  = np.intersect1d(ta, tb)
    return a.sel(time=t), b.sel(time=t)

def _positional_view(da: xr.DataArray) -> xr.DataArray:
    dims = [d for d in ["time","latitude","longitude"] if d in da.dims]
    da2  = da.transpose(*dims)
    coords = {"time": da2["time"]}
    if "latitude" in da2.dims:  coords["latitude"]  = np.arange(da2.sizes["latitude"])
    if "longitude" in da2.dims: coords["longitude"] = np.arange(da2.sizes["longitude"])
    return xr.DataArray(da2.data, dims=da2.dims, coords=coords, attrs=da2.attrs)

def _daily_error_positional(a, b, metric):
    aP = _positional_view(a); bP = _positional_view(b)
    if metric == "RMSE":
        err = (aP - bP) ** 2
        s   = np.sqrt(err.mean(dim=[d for d in ("latitude","longitude") if d in err.dims], skipna=True))
    else:
        s   = (abs(aP - bP)).mean(dim=[d for d in ("latitude","longitude") if d in aP.dims], skipna=True)
    return s  # shape: time

# -------- UI --------
var_dd    = widgets.Dropdown(options=["tavg","tmax","tmin"], value="tavg", description="Variable:")
metric_dd = widgets.Dropdown(options=["RMSE","MAE"], value="RMSE", description="Metric:")
doy_cb    = widgets.Checkbox(value=False, description="Group by DOY (month ticks)")
start_m   = widgets.Dropdown(options=list(range(1,13)), value=1, description="Start M")
end_m     = widgets.Dropdown(options=list(range(1,13)), value=12, description="End M")
start_y   = widgets.BoundedIntText(value=2022, min=1900, max=2100, description="Start Y")
end_y     = widgets.BoundedIntText(value=2022, min=1900, max=2100, description="End Y")
run_btn   = widgets.Button(description="Run (ERA5 vs Nigeria)", button_style="success")
out       = widgets.Output()

display(Markdown("**Region:** Data over Nigeria (auto-detected)."))
display(widgets.HBox([var_dd, metric_dd, doy_cb]))
display(widgets.HBox([start_m, end_m, start_y, end_y]))
display(run_btn, out)

@out.capture(clear_output=True)
def _run(_):
    if any(x is None for x in [ds_ng, ds_e5_t2m]):
        print("‚ùå Load Nigeria grid and ERA5 in Cell 1."); return

    var, metric = var_dd.value, metric_dd.value
    y0,y1 = int(start_y.value), int(end_y.value)
    m0,m1 = int(start_m.value), int(end_m.value)

    # --- truth (Nigeria)
    if var == "tavg":
        kmax, kmin = NIGERIA_VARS["tmax"], NIGERIA_VARS["tmin"]
        if kmax not in ds_ng or kmin not in ds_ng:
            print("‚ùå Need Nigeria 'max_temperature' and 'min_temperature' to build TAVG."); return
        truth = (_normalize_temp_C(ds_ng[kmax]) + _normalize_temp_C(ds_ng[kmin])) / 2.0
    else:
        key = NIGERIA_VARS[var]
        if key not in ds_ng:
            print(f"‚ùå Nigeria file missing '{key}'."); return
        truth = _normalize_temp_C(ds_ng[key])

    # --- predictor (ERA5 daily mean temperature)
    era = None
    for cand in ["tavg","t2m_mean","daily_mean_temperature","tas_mean","tmean"]:
        if (ds_e5_t2m is not None) and (cand in ds_e5_t2m):
            era = _normalize_temp_C(ds_e5_t2m[cand]); break
    if era is None:
        if (ds_e5_t2m is not None) and ("tmax" in ds_e5_t2m and "tmin" in ds_e5_t2m):
            era = _normalize_temp_C((ds_e5_t2m["tmax"] + ds_e5_t2m["tmin"]) / 2.0)
        elif (ds_e5_t2m is not None) and ("t2m" in ds_e5_t2m or "t2m_daily_mean" in ds_e5_t2m):
            name = "t2m_daily_mean" if "t2m_daily_mean" in ds_e5_t2m else "t2m"
            era  = _normalize_temp_C(ds_e5_t2m[name])
        else:
            print("‚ùå Could not find a daily mean temperature in ERA5 dataset."); return

    # --- window
    truth = _window(truth, y0,m0,y1,m1).sortby("time")
    era   = _window(era,   y0,m0,y1,m1).sortby("time")
    if truth.sizes.get("time",0)==0 or era.sizes.get("time",0)==0:
        print("‚ö†Ô∏è No data in requested window."); return

    # --- space align (ERA5 ‚Üí Nigeria grid), then time intersect
    era_on = _put_pred_on_truth_grid(era, truth)
    truth2, era2 = _common_time(truth, era_on)
    if truth2.sizes.get("time",0)==0:
        print("‚ö†Ô∏è Empty overlap after time intersection."); return

    # --- daily error (positional), then aggregate
    truth2 = _f32(truth2); era2 = _f32(era2)
    daily  = _daily_error_positional(truth2, era2, metric)  # one value per day
    mean_daily = float(daily.mean(skipna=True).values)      # mean daily error over the whole window

    if doy_cb.value:
        s = _doy_agg(daily); xdim = "DOY"
    else:
        s = _month_agg(daily); xdim = "time"
    mean_curve = float(s.mean(skipna=True).values)          # mean of the plotted series

    # --- plot
    fig, ax = plt.subplots(figsize=(10,4))
    s.plot(ax=ax, x=xdim, label="ERA5 vs Nigeria")
    ax.set_ylabel(f"{metric} (¬∞C)")
    ax.set_title(f"{metric} ‚Äî {var.upper()} ‚Äî ERA5 vs Nigeria ‚Äî {y0}-{y1} ({m0:02d}‚Üí{m1:02d})")
    ax.grid(True, alpha=0.3); ax.legend()
    if doy_cb.value: _month_ticks_for_doy(ax)
    plt.show()

    # --- print summary
    print(f"Mean daily {metric} over window: {mean_daily:.2f} ¬∞C")
    if doy_cb.value:
        print(f"Mean of DOY curve (mean of daily-of-year means): {mean_curve:.2f} ¬∞C")
    else:
        print(f"Mean of monthly curve (mean of monthly means): {mean_curve:.2f} ¬∞C")

run_btn.on_click(_run)


**Region:** Data over Nigeria (auto-detected).

HBox(children=(Dropdown(description='Variable:', options=('tavg', 'tmax', 'tmin'), value='tavg'), Dropdown(des‚Ä¶

HBox(children=(Dropdown(description='Start M', options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), value=1), Drop‚Ä¶

Button(button_style='success', description='Run (ERA5 vs Nigeria)', style=ButtonStyle())

Output()

In [33]:
data=xr.load_dataset("/workspace/regridded_weather_data_kriging.nc")
data