
# üåç Kenya Verification ‚Äî Curves Only (RMSE / MAE / SEEPS)

This notebook loads your **daily NetCDF** Kenya datasets (**tmax, tmin, precip**), compares them to **ERA5** (t2m/precip) and **IMERG** (precip), and plots **monthly curves** of the selected **metric**.

**Controls**
- **Load Data**: enter folders for `tmax`, `tmin`, `precip` (one file/day) then click **Load Data**. It prints a success message.
- **Metric**: RMSE / MAE / SEEPS (SEEPS is precipitation-only).
- **Date Window**: start (month/year) and end (month/year). All years in the span are plotted as separate lines (x=month, y=metric).
- **Precip reference**: ERA5 or IMERG.

> Notes:
> - Handles coordinate names like `Lat/Latitide` and `Lon/Longitude` automatically.
> - Reads date from filenames using `YYYYMMDD` (e.g., `tmin_mrg_20220101_ALL.nc`, `rr_mrg_20220101_CLM.nc`).

## üõ†Ô∏è Setup

In [None]:

# If needed, install deps (uncomment):
# %pip install -q xarray dask[complete] zarr fsspec gcsfs xesmf ipywidgets numpy pandas matplotlib weatherbenchx

import os, glob, re, warnings
import numpy as np, pandas as pd, xarray as xr, matplotlib.pyplot as plt, ipywidgets as widgets
from IPython.display import display, Markdown
warnings.filterwarnings("ignore")
plt.rcParams.update({"figure.dpi": 130})
xr.set_options(keep_attrs=True)


## ‚öôÔ∏è Controls

In [None]:

# Reference datasets on GCS
ERA5_T2M_ZARR = "gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr"  # daily 2m temp
ERA5_TP_ZARR  = "gs://aim4scale_training_25/ground_truth/era5_24hr.zarr"              # daily precip
IMERG_ZARR    = "gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr"   # daily precip

# Kenya bbox
KENYA = {"latitude": slice(5.0, -4.7), "longitude": slice(33.9, 41.9)}

# Folder inputs (leave blank to auto-search /mnt/data)
tmax_dir = widgets.Text(value="", description="tmax dir:", layout=widgets.Layout(width="95%"))
tmin_dir = widgets.Text(value="", description="tmin dir:", layout=widgets.Layout(width="95%"))
prcp_dir = widgets.Text(value="", description="precip dir:", layout=widgets.Layout(width="95%"))

# Metric & date controls
metric_dd  = widgets.Dropdown(options=["RMSE","MAE","SEEPS"], value="RMSE", description="Metric:")
prec_ref   = widgets.Dropdown(options=["ERA5","IMERG"], value="ERA5", description="Precip ref:")
start_m    = widgets.Dropdown(options=list(range(1,13)), value=1, description="Start M")
start_y    = widgets.BoundedIntText(value=2022, min=1980, max=2100, description="Start Y")
end_m      = widgets.Dropdown(options=list(range(1,13)), value=12, description="End M")
end_y      = widgets.BoundedIntText(value=2024, min=1980, max=2100, description="End Y")

load_btn   = widgets.Button(description="Load Data", button_style="info")
run_btn    = widgets.Button(description="Run Metrics", button_style="success")
out_info   = widgets.Output()
out_plot   = widgets.Output()

display(widgets.VBox([
    widgets.HTML("<b>Local daily NetCDF folders (one .nc per day). If blank, I will auto-search /mnt/data.</b>"),
    tmax_dir, tmin_dir, prcp_dir,
    widgets.HBox([metric_dd, prec_ref]),
    widgets.HBox([start_m, start_y, end_m, end_y]),
    widgets.HBox([load_btn, run_btn]),
    out_info, out_plot
]))


## üß∞ Helpers

In [None]:

def _autosearch(base="/mnt/data"):
    def find_dir(keyword):
        hit = ""
        for root, dirs, files in os.walk(base):
            if any(fn.lower().endswith((".nc",".nc4")) for fn in files) and keyword in os.path.basename(root).lower():
                hit = root; break
        return hit
    return find_dir("tmax"), find_dir("tmin"), find_dir("precip")

def _standardize_coords(ds):
    rn = {}
    if "Lat" in ds.coords: rn["Lat"] = "latitude"
    if "lat" in ds.coords: rn["lat"] = "latitude"
    if "Latitude" in ds.coords: rn["Latitude"] = "latitude"
    if "Lon" in ds.coords: rn["Lon"] = "longitude"
    if "lon" in ds.coords: rn["lon"] = "longitude"
    if "Longitude" in ds.coords: rn["Longitude"] = "longitude"
    if rn: ds = ds.rename(rn)
    return ds

def _open_daily_folder(folder, var_candidates):
    files = (sorted(glob.glob(os.path.join(folder, "*.nc"))) or
             sorted(glob.glob(os.path.join(folder, "*.NC"))) or
             sorted(glob.glob(os.path.join(folder, "*.nc*"))))
    if not files:
        raise FileNotFoundError(f"No NetCDF files (*.nc*) found in {folder}")
    def _open_one(fp):
        ds = xr.open_dataset(fp)
        ds = _standardize_coords(ds)
        m = re.search(r"(\d{8})", os.path.basename(fp))
        if not m: raise ValueError(f"Filename missing YYYYMMDD: {fp}")
        t = pd.to_datetime(m.group(1), format="%Y%m%d")
        ds = ds.expand_dims({"time":[t]})
        # choose the first available candidate
        for v in var_candidates:
            if v in ds.data_vars: return ds[[v]]
        raise KeyError(f"{fp} variables {list(ds.data_vars)}; expected one of {var_candidates}")
    out = xr.concat([_open_one(f) for f in files], dim="time").sortby("time")
    return out

def _subset_kenya(ds):
    if "latitude" in ds.coords and ds.latitude[0] > ds.latitude[-1]:
        ds = ds.sortby("latitude")
    return ds.sel(**KENYA)

def _open_zarr(path):
    ds = xr.open_zarr(path, consolidated=True)
    return _subset_kenya(_standardize_coords(ds))

def _pick_var(ds, candidates):
    for c in candidates:
        if c in ds.data_vars: return c
    raise KeyError(f"None of {candidates} found in {list(ds.data_vars)}")

def _align_to_grid(model_da, ref_da):
    # Regrid (bilinear) to reference grid if latitude/longitude differ
    s = model_da
    try:
        same = (np.array_equal(s.latitude.values, ref_da.latitude.values) and
                np.array_equal(s.longitude.values, ref_da.longitude.values))
    except Exception:
        same = False
    if not same:
        import xesmf as xe
        templ = xr.Dataset(coords={"latitude": ref_da.latitude, "longitude": ref_da.longitude})
        rg = xe.Regridder(model_da.to_dataset(name="var"), templ, "bilinear", periodic=False, reuse_weights=False)
        s = rg(model_da.to_dataset(name="var"))["var"]
        rg.clean_weight_file()
    # intersect times
    t0 = max(np.datetime64(s.time.min().values), np.datetime64(ref_da.time.min().values))
    t1 = min(np.datetime64(s.time.max().values), np.datetime64(ref_da.time.max().values))
    return s.sel(time=slice(t0,t1)), ref_da.sel(time=slice(t0,t1))

# metrics
from weatherbenchx.metrics.seeps import seeps_score as wbx_seeps
def daily_rmse(a,b): return np.sqrt(((a-b)**2).mean(dim=("latitude","longitude"), skipna=True))
def daily_mae(a,b):  return (a-b).abs().mean(dim=("latitude","longitude"), skipna=True))
def daily_seeps(a,b): return wbx_seeps(a, b, dim=("latitude","longitude"))

def monthly_curve(series, year):
    sel = series.sel(time=slice(f"{year}-01-01", f"{year}-12-31"))
    return sel.resample(time="1MS").mean()

def _year_span(m0,y0,m1,y1):
    start = pd.Timestamp(year=y0, month=m0, day=1)
    end   = pd.Timestamp(year=y1, month=m1, day=1) + pd.offsets.MonthEnd(1)
    years = list(range(start.year, end.year+1))
    return start, end, years

def plot_curves(curves, metric, title):
    plt.figure(figsize=(9,5))
    for y, s in curves.items():
        if s.size == 0: continue
        x = s["time"].dt.month
        plt.plot(x, s, marker="o", label=str(y))
    plt.title(title); plt.xlabel("Month"); plt.ylabel(metric); plt.xticks(range(1,13))
    plt.grid(True, alpha=0.3); plt.legend(); plt.tight_layout(); plt.show()


## üì• Load Data

In [None]:

state = {"tmax":None, "tmin":None, "prec":None, "era5_t2m":None, "era5_tp":None, "imerg":None}

@out_info.capture(clear_output=True)
def _on_load(_):
    # Guess folders if blank
    if not tmax_dir.value or not tmin_dir.value or not prcp_dir.value:
        guess_tmax, guess_tmin, guess_pr = _autosearch("/mnt/data")
        if not tmax_dir.value: tmax_dir.value = guess_tmax or ""
        if not tmin_dir.value: tmin_dir.value = guess_tmin or ""
        if not prcp_dir.value: prcp_dir.value = guess_pr or ""
    if not tmax_dir.value or not tmin_dir.value or not prcp_dir.value:
        print("‚ùå Please provide folders for tmax, tmin, precip (or place them under /mnt/data).")
        return

    # Open local daily files (variable candidates include your naming)
    state["tmax"] = _subset_kenya(_open_daily_folder(tmax_dir.value, ["tmax","tmax_mrg","tmax_mean","Tmax"])).rename({list(_open_daily_folder(tmax_dir.value, ["tmax","tmax_mrg","tmax_mean","Tmax"]).data_vars)[0]:"tmax"})
    state["tmin"] = _subset_kenya(_open_daily_folder(tmin_dir.value, ["tmin","tmin_mrg","tmin_mean","Tmin"])).rename({list(_open_daily_folder(tmin_dir.value, ["tmin","tmin_mrg","tmin_mean","Tmin"]).data_vars)[0]:"tmin"})
    state["prec"] = _subset_kenya(_open_daily_folder(prcp_dir.value, ["precip","pr","tp","rain","prec"])).rename({list(_open_daily_folder(prcp_dir.value, ["precip","pr","tp","rain","prec"]).data_vars)[0]:"precip"})

    # Open references, pick vars & convert units
    e5t = _open_zarr(ERA5_T2M_ZARR)
    e5p = _open_zarr(ERA5_TP_ZARR)
    im  = _open_zarr(IMERG_ZARR)

    v_t2m = _pick_var(e5t, ["t2m","2m_temperature","temperature_2m","t2m_daily"])
    v_tp  = _pick_var(e5p, ["tp","total_precipitation_24hr","total_precipitation","precip"])
    v_im  = _pick_var(im,  ["precip","pr","total_precipitation_24hr"])

    e5t = e5t[[v_t2m]].rename({v_t2m:"t2m"})
    if float(e5t["t2m"].max()) > 200: e5t["t2m"] = e5t["t2m"] - 273.15
    e5p = e5p[[v_tp]].rename({v_tp:"tp"})
    # convert m->mm if needed
    if float(e5p["tp"].max()) < 10: e5p["tp"] = e5p["tp"] * 1000.0
    im  = im[[v_im]].rename({v_im:"pr"})

    state["era5_t2m"] = e5t
    state["era5_tp"]  = e5p
    state["imerg"]    = im

    print("‚úÖ All data files loaded.")
    print(f"   tmax:  {tmax_dir.value}")
    print(f"   tmin:  {tmin_dir.value}")
    print(f"   precip:{prcp_dir.value}")
    print("   + ERA5/IMERG references from GCS")


## üìà Run Metrics & Plot

In [None]:

@out_plot.capture(clear_output=True)
def _on_run(_):
    if any(state[k] is None for k in ["tmax","tmin","prec","era5_t2m","era5_tp","imerg"]):
        print("‚ùå Please click **Load Data** first."); return

    m = metric_dd.value
    start, end, years = _year_span(start_m.value, start_y.value, end_m.value, end_y.value)

    # Prepare refs on same grid+time
    # Temperature
    tmax_on, e5t_on1 = _align_to_grid(state["tmax"]["tmax"].sel(time=slice(start, end)), state["era5_t2m"]["t2m"].sel(time=slice(start, end)))
    tmin_on, e5t_on2 = _align_to_grid(state["tmin"]["tmin"].sel(time=slice(start, end)), state["era5_t2m"]["t2m"].sel(time=slice(start, end)))
    # Precip with chosen reference
    ref_choice = state["era5_tp"]["tp"] if prec_ref.value == "ERA5" else state["imerg"]["pr"]
    pr_on, refp_on = _align_to_grid(state["precip"]["precip"].sel(time=slice(start, end)), ref_choice.sel(time=slice(start, end)))

    # Compute daily metric series (spatial mean)
    if m == "RMSE":
        tmax_daily = daily_rmse(tmax_on, e5t_on1)
        tmin_daily = daily_rmse(tmin_on, e5t_on2)
        pr_daily   = daily_rmse(pr_on, refp_on)
        ylabel = "RMSE"
    elif m == "MAE":
        tmax_daily = daily_mae(tmax_on, e5t_on1)
        tmin_daily = daily_mae(tmin_on, e5t_on2)
        pr_daily   = daily_mae(pr_on, refp_on)
        ylabel = "MAE"
    else:  # SEEPS (precip only)
        tmax_daily = tmin_daily = None
        pr_daily   = daily_seeps(pr_on, refp_on)
        ylabel = "SEEPS"

    # Aggregate to monthly curves per year
    if tmax_daily is not None:
        tmax_curves = {y: monthly_curve(tmax_daily, y) for y in years}
        tmin_curves = {y: monthly_curve(tmin_daily, y) for y in years}
    pr_curves = {y: monthly_curve(pr_daily, y) for y in years}

    # Plots
    plt.close("all")
    if tmax_daily is not None:
        plot_curves(tmax_curves, ylabel, f"TMAX vs ERA5 t2m ‚Äî {ylabel} (monthly)")
        plot_curves(tmin_curves, ylabel, f"TMIN vs ERA5 t2m ‚Äî {ylabel} (monthly)")
    title_p = f"PRECIP vs {'ERA5' if prec_ref.value=='ERA5' else 'IMERG'} ‚Äî {ylabel} (monthly)"
    plot_curves(pr_curves, ylabel, title_p)

load_btn.on_click(_on_load)
run_btn.on_click(_on_run)
print("Ready. Enter folders or leave blank for auto-search, then click **Load Data**.")
