In [33]:
# Cell 1 — Imports & basic config
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown

plt.rcParams.update({"figure.dpi": 140})
xr.set_options(keep_attrs=True)


<xarray.core.options.set_options at 0x7052ee9ef050>

In [34]:
# Cell 3 — Helpers (normalization, grid, time, masking)

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_asc(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _normalize_precip(da):
    units = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if units in ["m", "meter", "metre", "m of water equivalent"]:
        out = out * 1000.0
    out.attrs["units"] = "mm/day"
    return out

def _normalize_temp(da):
    units = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if units in ["k", "kelvin"]:
        out = out - 273.15
        out.attrs["units"] = "°C"
    elif units in ["degc", "c", "°c"]:
        out.attrs["units"] = "°C"
    return out

def _coerce_time(da):
    if "time" in da.dims:
        return da
    try:
        dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
        if "time" in dec.dims:
            return dec
    except Exception:
        pass
    return da

def _align_to_truth_grid(src: xr.DataArray, truth: xr.DataArray) -> xr.DataArray:
    s = _ensure_lat_asc(_to_0360(src))
    t = _ensure_lat_asc(_to_0360(truth))
    # identical labels?
    try:
        if (np.array_equal(s.latitude.values, t.latitude.values) and
            np.array_equal(s.longitude.values, t.longitude.values)):
            return s
    except Exception:
        pass
    # reindex with tolerance first
    try:
        return s.reindex_like(t, method="nearest", tolerance={"latitude":0.125, "longitude":0.125})
    except Exception:
        pass
    # fallback interp
    s2 = s.sortby(["latitude","longitude"])
    t2 = t.sortby(["latitude","longitude"])
    return s2.interp(latitude=t2["latitude"], longitude=t2["longitude"], method="nearest")

def _time_intersect(*arrs):
    t0 = max([np.datetime64(a.time.min().values) for a in arrs])
    t1 = min([np.datetime64(a.time.max().values) for a in arrs])
    return [a.sel(time=slice(t0, t1)) for a in arrs]

def _land_mask_from_truth(truth_da: xr.DataArray) -> xr.DataArray:
    """mask (True) where truth has *any* finite value over the window → removes ocean/out-of-country."""
    return xr.ufuncs.isfinite(truth_da).any("time")

def _rmse_series(pred, truth):
    se = (pred - truth)**2
    return np.sqrt(se.mean(dim=[d for d in ["latitude","longitude"] if d in se.dims], skipna=True)).squeeze()

def _mae_series(pred, truth):
    ae = (pred - truth).abs()
    return ae.mean(dim=[d for d in ["latitude","longitude"] if d in ae.dims], skipna=True).squeeze()

def _bias_map(pred, truth):
    # simple time mean of (pred - truth), expects (time, lat, lon)
    return (pred - truth).mean("time", skipna=True)


In [36]:
def process_step(output_state, runcount):
    data_vars = {}
    logging.info(f"Processing step {runcount}")
    for field in output_state['fields']:
        values = (TFM_N320_LATLON * output_state['fields'][field].reshape(-1,1)).reshape(721,1440)
        data_vars[field] = (["lat", "lon"], values.astype(np.float32))

    step_ds = xr.Dataset(
        data_vars,
        coords={"lat": LATITUDES, "lon": LONGITUDES},
    )
    step_ds = step_ds.expand_dims('step')
    step_ds['step'] = [int(runcount)]
    return step_ds

In [37]:
# ---- PATCHED run_inference: include save_vars in output filename ----
def run_inference(init_date=None, lead_time=360, save_vars=None):

    if lead_time < 6 or lead_time % 6 != 0:
        raise ValueError("Lead time must be a multiple of 6 hours and at least 6 hours.")

    if save_vars is None and lead_time > 120:
        logging.warning("Running this model for more than 120 steps and saving all variables is not recommended.")

    ic_src = "IFS" if init_date is None else "ERA5"

    # NEW: encode which vars are saved so different runs don't overwrite each other
    vars_tag = "ALL" if save_vars is None else "-".join(save_vars)
    save_path = f"{OUTPUT_STATE_PATH}/init_{ic_src}_{init_date.strftime('%Y%m%dT%H')}_lead_{lead_time}_vars_{vars_tag}.zarr"

    if os.path.exists(save_path):
        logging.info(f"Output file {save_path} already exists. Skipping inference.")
        logging.info("Loading existing dataset...")
        return xr.open_zarr(save_path)

    # build input state
    if ic_src == "IFS":
        input_state = get_latest_IFS_data()
    else:
        input_state = get_ERA5(init_date)

    runner = SimpleRunner(CHECKPOINT, device="cuda")
    current_step = 0
    start_time = time.perf_counter()
    print("Starting the inference session...")
    current_step_time = time.perf_counter()
    states = []
    for state in runner.run(input_state=input_state, lead_time=lead_time):
        print_state(state)
        current_step += 6
        if save_vars is None:
            processed_state = process_step(state, current_step)
        else:
            selected_data = {
                'date': state['date'],
                'fields': {var: state['fields'][var] for var in save_vars},
                'latitudes': state['latitudes'],
                'longitudes': state['longitudes'],
            }
            processed_state = process_step(selected_data, current_step)
        states.append(processed_state)
        logging.info(f"Step {current_step} completed.")
        step_time = time.perf_counter()
        logging.info(f"Time taken for step {current_step}: {step_time - current_step_time:.2f} s.")
        current_step_time = step_time

    logging.info("Inference session completed.")
    logging.info(f"Total time: {time.perf_counter() - start_time:.2f} s.")

    logging.info("Concatenating all steps into a single dataset.")
    ds = xr.concat(states, dim='step')
    del states
    ds = ds.expand_dims("time")
    ds["time"] = [pd.to_datetime(input_state['date'])]

    logging.info(f"Saving output dataset to {save_path}")
    ds.to_zarr(save_path, mode='w', zarr_format=2)

    return ds


In [38]:
# PATCH: replace _rmse_mae_series to avoid .abs() (use np.abs)
import numpy as np
import xarray as xr

SPATIAL_DIMS = ("latitude", "longitude")

def _rmse_mae_series(pred: xr.DataArray, truth: xr.DataArray):
    """
    Inputs: pred, truth with dims (time, latitude, longitude),
    already masked to the truth coverage.
    Returns (rmse_ts, mae_ts) as 1D time series.
    """
    # ensure numeric (some sources may be int)
    diff = (pred - truth).astype("float32")

    se = diff ** 2
    ae = np.abs(diff)                       # <-- key change

    rmse_ts = se.mean([d for d in SPATIAL_DIMS if d in se.dims], skipna=True) ** 0.5
    mae_ts  = ae.mean([d for d in SPATIAL_DIMS if d in ae.dims],  skipna=True)

    return rmse_ts.squeeze(), mae_ts.squeeze()

print("✅ Patched _rmse_mae_series (uses np.abs). Re-run your time-series now.")


✅ Patched _rmse_mae_series (uses np.abs). Re-run your time-series now.


In [None]:
# --- Cell 0: Load datasets via widgets (BMD optional) ---
import xarray as xr
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, Markdown

def _open_any(path: str):
    if path.endswith(".zarr"):
        return xr.open_zarr(path)
    return xr.open_dataset(path)

# Known defaults (edit if you have different ERA5/IMERG/IMD)
DEFAULT_ERA5  = "gs://aim4scale_training_25/ground_truth/era5_24hr.zarr"
DEFAULT_IMERG = "gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr"
DEFAULT_IMD   = "gs://aim4scale_training_25/ground_truth/IMD_rainfall_0p25.zarr"
DEFAULT_BMD = "gs://aim4scale_training_25/ground_truth/BMD_daily_combined_0p25.zarr"

bmd_txt   = widgets.Text(value=DEFAULT_BMD,      description="BMD:",        layout=widgets.Layout(width="95%"))
era5_txt  = widgets.Text(value=DEFAULT_ERA5,     description="ERA5:",       layout=widgets.Layout(width="95%"))
imerg_txt = widgets.Text(value=DEFAULT_IMERG,    description="IMERG:",      layout=widgets.Layout(width="95%"))
imd_txt   = widgets.Text(value=DEFAULT_IMD,      description="IMD:",        layout=widgets.Layout(width="95%"))
btn_load  = widgets.Button(description="Load datasets", button_style="primary")
out_load  = widgets.Output()
display(bmd_txt, era5_txt, imerg_txt, imd_txt, btn_load, out_load)

# Globals created by this cell:
ds_bmd = None
ds_era5 = None
ds_imerg = None
ds_imd = None

def _summ(ds, name):
    vars_ = list(ds.data_vars)[:6]
    tcoord = "valid_time" if "valid_time" in ds.coords else ("time" if "time" in ds.coords else None)
    tspan = "—"
    if tcoord:
        tspan = f"{pd.to_datetime(str(ds[tcoord].values.min())).date()} … {pd.to_datetime(str(ds[tcoord].values.max())).date()}"
    return f"**{name}**  \nvars → {vars_} | sizes → {dict(ds.sizes)}  \n" + (f"time → {tspan}" if tcoord else "")

def _try_open(label, path):
    p = path.strip()
    if p == "":
        return None, f"**{label}**: (skipped — empty path)"
    try:
        ds = _open_any(p)
        return ds, _summ(ds, label)
    except Exception as e:
        return None, f"**{label}**: ❌ load failed → `{e}`"

def _on_load(_):
    global ds_bmd, ds_era5, ds_imerg, ds_imd
    with out_load:
        out_load.clear_output()
        ds_bmd,   msg_bmd   = _try_open("BMD",   bmd_txt.value)
        ds_era5,  msg_era5  = _try_open("ERA5",  era5_txt.value)
        ds_imerg, msg_imerg = _try_open("IMERG", imerg_txt.value)
        ds_imd,   msg_imd   = _try_open("IMD",   imd_txt.value)

        display(Markdown("### Load summary"))
        display(Markdown(msg_bmd))
        display(Markdown(msg_era5))
        display(Markdown(msg_imerg))
        display(Markdown(msg_imd))

        ready = (ds_era5 is not None) and (ds_imerg is not None or ds_bmd is not None or ds_imd is not None)
        if ready:
            display(Markdown("✅ **Datasets ready.** Now run the next cells."))
        else:
            display(Markdown("⚠️ **Please load at least ERA5 plus one of BMD/IMERG/IMD.**"))

btn_load.on_click(_on_load)


Text(value='', description='BMD (opt):', layout=Layout(width='95%'))

Text(value='gs://aim4scale_training_25/ground_truth/era5_24hr.zarr', description='ERA5:', layout=Layout(width=…

Text(value='gs://aim4scale_training_25/ground_truth/IMERG_0p25_2000_2025.zarr', description='IMERG:', layout=L…

Text(value='gs://aim4scale_training_25/ground_truth/IMD_rainfall_0p25.zarr', description='IMD:', layout=Layout…

Button(button_style='primary', description='Load datasets', style=ButtonStyle())

Output()

In [29]:
# --- ERA5 loader/merger: combine precip + temperature daily into a single ds_era5 ---

import xarray as xr
import numpy as np
import pandas as pd
from IPython.display import Markdown, display

# Adjust paths if yours are different
PATH_ERA5_TP  = "gs://aim4scale_training_25/ground_truth/era5_24hr.zarr"            # precip (total_precipitation_24hr)
PATH_ERA5_T2M = "gs://aim4scale_training_25/ground_truth/era5_t2m_1D_1981_2024.zarr" # tavg, tmax, tmin

def _open_any(path: str):
    return xr.open_zarr(path) if path.endswith(".zarr") else xr.open_dataset(path)

def _to_0360(ds):
    if "longitude" in ds.coords:
        lon = ds["longitude"]
        try:
            if float(lon.min()) < 0:
                ds = ds.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        ds = ds.sortby("longitude")
    return ds

def _lat_asc(ds):
    if "latitude" in ds.coords:
        lat = ds["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            ds = ds.sortby("latitude")
    return ds

def _prep(ds):
    return _lat_asc(_to_0360(ds))

def _time_bounds(ds):
    if "time" not in ds.coords:
        return "—", "—"
    # robust formatting that works even with cftime/dtypes
    tmin = pd.to_datetime(str(ds.time.min().values)).date()
    tmax = pd.to_datetime(str(ds.time.max().values)).date()
    return str(tmin), str(tmax)

# Load & normalize
era5_tp  = _prep(_open_any(PATH_ERA5_TP))
era5_t2m = _prep(_open_any(PATH_ERA5_T2M))

# Merge (outer join along time so either side’s dates are included)
ds_era5 = xr.merge([era5_tp, era5_t2m], compat="override", join="outer")

# Keep only relevant variables if you want
keep_vars = [v for v in ["total_precipitation_24hr", "tavg", "tmax", "tmin"] if v in ds_era5.data_vars]
ds_era5 = ds_era5[keep_vars]

# Summary
vars_list = ", ".join(list(ds_era5.data_vars))
t0, t1 = _time_bounds(ds_era5)
summary = f"vars: {vars_list} | sizes: {dict(ds_era5.sizes)} | time: {t0} … {t1}"
display(Markdown("**ERA5 merged dataset**  \n" + summary))


**ERA5 merged dataset**  
vars: total_precipitation_24hr, tavg, tmax, tmin | sizes: {'time': 67208, 'latitude': 721, 'longitude': 1440} | time: 1979-01-01 … 2024-12-31

In [30]:
# === BMD truth vs ERA5 & IMERG — plot ALL variables (precip / tavg / tmax) ===
import numpy as np, xarray as xr, matplotlib.pyplot as plt, ipywidgets as widgets
from IPython.display import display, Markdown
plt.rcParams.update({"figure.dpi": 120})

# ---------------- helpers ----------------
def _norm_precip_mm(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["m","meter","metre","m of water equivalent"]:
        da = da * 1000.0
    # kg m-2 ≈ mm already
    da.attrs["units"] = "mm"
    return da

def _norm_temp_K(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["c","degc","celsius","°c"]:
        da = da + 273.15
    da.attrs["units"] = "K"
    return da

def _norm(var_key, da):
    return _norm_precip_mm(da) if var_key == "precip" else _norm_temp_K(da)

def _ensure_valid_time(da):
    if "valid_time" in da.dims:
        return da
    if "time" in da.dims:
        return da.rename({"time": "valid_time"})
    dec = xr.decode_cf(da.to_dataset(name="_tmp"), use_cftime=False).to_array("_tmp")
    if "time" in dec.dims:
        return dec.rename({"time": "valid_time"})
    raise ValueError("No time/valid_time dimension.")

def _align_to_truth(src, truth):
    s = src.sortby(["latitude", "longitude"])
    r = truth.sortby(["latitude", "longitude"])
    try:
        return s.reindex_like(
            r, method="nearest", tolerance={"latitude": 0.125, "longitude": 0.125}
        )
    except Exception:
        return s.interp(latitude=r.latitude, longitude=r.longitude, method="nearest")

def _rmse_mae_series(pred, truth):
    diff = (pred - truth).astype("float32")
    se = diff**2
    ae = np.abs(diff)
    rmse = se.mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True) ** 0.5
    mae  = ae.mean([d for d in ("latitude","longitude") if d in ae.dims],  skipna=True)
    return rmse.squeeze(), mae.squeeze()

def _mask_to_truth(truth, *others):
    # mask oceans by "ever finite" on truth over the window
    mask = xr.ufuncs.isfinite(truth).any("valid_time")
    outs = [truth.where(mask)] + [o.where(mask) for o in others]
    return outs

# ---------------- auto-detect concrete var names ----------------
ALIASES = {
    "precip": ["total_precipitation_24hr", "tp_24h", "tp_daily", "precip", "rain"],
    "tavg":   ["tavg", "t2m_mean", "t2m_daily_mean", "daily_mean_temperature", "tas_mean", "tmean"],
    "tmax":   ["tmax", "t2m_max", "t2m_daily_max", "daily_max_temperature", "tasmax", "tmax_mean"],
}
LABELS = {"precip": "Precip (mm/day)", "tavg": "Temp avg (K)", "tmax": "Temp max (K)"}

def _first_present(ds, names):
    if ds is None: 
        return None
    for n in names:
        if n in ds.data_vars:
            return n
    return None

def _autoguess_mapping(ds_bmd, ds_era5, ds_imerg=None):
    m = {}
    lines = []
    for key, aliases in ALIASES.items():
        b = _first_present(ds_bmd,  aliases)
        e = _first_present(ds_era5, aliases)
        i = _first_present(ds_imerg, aliases) if ds_imerg is not None else None
        m[key] = {"bmd": b, "era5": e, "imerg": i}
        lines.append(f"- **{key}** → BMD=`{b}` | ERA5=`{e}` | IMERG=`{i}`")
    return m, "\n".join(lines)

mapping, mapping_note = _autoguess_mapping(ds_bmd, ds_era5, ds_imerg if "ds_imerg" in globals() else None)
display(Markdown("**Auto-detected variables (first alias match):**\n" + mapping_note))

# Optional overrides
txt_bmd_precip = widgets.Text(value=mapping["precip"]["bmd"] or "", description="BMD precip:")
txt_era_precip = widgets.Text(value=mapping["precip"]["era5"] or "", description="ERA5 precip:")
txt_bmd_tavg   = widgets.Text(value=mapping["tavg"]["bmd"]   or "", description="BMD tavg:")
txt_era_tavg   = widgets.Text(value=mapping["tavg"]["era5"]  or "", description="ERA5 tavg:")
txt_bmd_tmax   = widgets.Text(value=mapping["tmax"]["bmd"]   or "", description="BMD tmax:")
txt_era_tmax   = widgets.Text(value=mapping["tmax"]["era5"]  or "", description="ERA5 tmax:")
display(Markdown("**(Optional) Override detected variable names** — leave blank to keep auto:"),
        txt_bmd_precip, txt_era_precip, txt_bmd_tavg, txt_era_tavg, txt_bmd_tmax, txt_era_tmax)

def _current_mapping():
    # IMERG is optional; we only need BMD & ERA5 for a variable to be plottable
    return {
        "precip": {"bmd": txt_bmd_precip.value or mapping["precip"]["bmd"],
                   "era5": txt_era_precip.value or mapping["precip"]["era5"],
                   "imerg": mapping["precip"]["imerg"]},
        "tavg":   {"bmd": txt_bmd_tavg.value   or mapping["tavg"]["bmd"],
                   "era5": txt_era_tavg.value  or mapping["tavg"]["era5"],
                   "imerg": mapping["tavg"]["imerg"]},
        "tmax":   {"bmd": txt_bmd_tmax.value   or mapping["tmax"]["bmd"],
                   "era5": txt_era_tmax.value  or mapping["tmax"]["era5"],
                   "imerg": mapping["tmax"]["imerg"]},
    }

def _plottable_vars(curmap):
    # Only keep variables present in BOTH BMD and ERA5
    out = []
    for key in ["precip", "tavg", "tmax"]:
        if curmap[key]["bmd"] and curmap[key]["era5"]:
            out.append(key)
    return out

# ---------------- UI ----------------
metric_dd = widgets.Dropdown(options=["RMSE","MAE"], description="Metric:", value="RMSE")
t0_txt    = widgets.Text(value="2018-06-01", description="Start (YYYY-MM-DD):")
t1_txt    = widgets.Text(value="2018-06-30", description="End   (YYYY-MM-DD):")
btn_run   = widgets.Button(description="Run ALL (BMD truth)", button_style="success")
out_all   = widgets.Output()
display(metric_dd, t0_txt, t1_txt, btn_run, out_all)

# ---------------- main action ----------------
def _one_variable_plot(key, cm, metric, t0, t1):
    vb, ve, vi = cm[key]["bmd"], cm[key]["era5"], cm[key]["imerg"]
    # sanity
    for src, ds, vn in [("BMD", ds_bmd, vb), ("ERA5", ds_era5, ve)]:
        if vn is None or vn not in ds.data_vars:
            display(Markdown(f"❌ **{src} variable not found**: `{vn}` for `{key}`")); 
            return

    has_imerg = ("ds_imerg" in globals()) and (ds_imerg is not None) and (vi is not None) and (vi in ds_imerg.data_vars)

    truth = _norm(key, ds_bmd[vb]).sel(time=slice(t0, t1))
    era5  = _norm(key, ds_era5[ve]).sel(time=slice(t0, t1))
    imerg = _norm(key, ds_imerg[vi]).sel(time=slice(t0, t1)) if has_imerg else None
    if truth.sizes.get("time",0) == 0:
        display(Markdown(f"⚠️ **No BMD data** for {LABELS[key]} in this window.")); 
        return

    truth = _ensure_valid_time(truth)
    era5  = _ensure_valid_time(era5)
    if imerg is not None:
        imerg = _ensure_valid_time(imerg)

    era5_on = _align_to_truth(era5, truth)
    imerg_on = _align_to_truth(imerg, truth) if imerg is not None else None

    # common overlap
    tmin = max(np.datetime64(truth.valid_time.min().values),
               np.datetime64(era5_on.valid_time.min().values))
    tmax = min(np.datetime64(truth.valid_time.max().values),
               np.datetime64(era5_on.valid_time.max().values))
    if imerg_on is not None:
        tmin = max(tmin, np.datetime64(imerg_on.valid_time.min().values))
        tmax = min(tmax, np.datetime64(imerg_on.valid_time.max().values))

    truth_c = truth.sel(valid_time=slice(tmin, tmax))
    era5_c  = era5_on.sel(valid_time=slice(tmin, tmax))
    imerg_c = imerg_on.sel(valid_time=slice(tmin, tmax)) if imerg_on is not None else None
    if truth_c.sizes.get("valid_time",0) == 0:
        display(Markdown(f"⚠️ **No overlapping days** for {LABELS[key]}.")); 
        return

    # ocean mask by BMD coverage
    if imerg_c is not None:
        truth_m, era5_m, imerg_m = _mask_to_truth(truth_c, era5_c, imerg_c)
    else:
        truth_m, era5_m = _mask_to_truth(truth_c, era5_c)

    # compute series
    if metric == "RMSE":
        s_era5, _ = _rmse_mae_series(era5_m, truth_m); ylabel = "RMSE"
    else:
        _, s_era5 = _rmse_mae_series(era5_m, truth_m); ylabel = "MAE"

    nrows = 2 if imerg_c is not None else 1
    fig, axs = plt.subplots(nrows, 1, figsize=(10, 6 if nrows==2 else 3), sharex=True)

    ax0 = axs[0] if nrows==2 else axs
    s_era5.rename({"valid_time":"date"}).plot(ax=ax0)
    ax0.set_title(f"{ylabel} — ERA5 vs BMD — { (vb if key!='precip' else 'total_precipitation_24hr') }")
    ax0.set_ylabel("mm/day" if key=="precip" else "K")
    ax0.grid(True, alpha=0.3)

    if nrows == 2:
        if metric == "RMSE":
            s_imerg, _ = _rmse_mae_series(imerg_m, truth_m)
        else:
            _, s_imerg = _rmse_mae_series(imerg_m, truth_m)
        axs[1].plot(s_imerg["valid_time"].values, s_imerg.values)
        axs[1].set_title(f"{ylabel} — IMERG vs BMD — { (vi if key!='precip' else 'total_precipitation_24hr') }")
        axs[1].set_ylabel("mm/day" if key=="precip" else "K")
        axs[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # print window means
    try:
        mean_era5 = float(s_era5.mean().values)
        unit = "mm/day" if key=="precip" else "K"
        if nrows == 2:
            mean_imerg = float(s_imerg.mean().values)
            display(Markdown(f"**Window mean {ylabel}** — ERA5: `{mean_era5:.3f} {unit}` | IMERG: `{mean_imerg:.3f} {unit}`"))
        else:
            display(Markdown(f"**Window mean {ylabel}** — ERA5: `{mean_era5:.3f} {unit}` (IMERG not available)"))
    except Exception:
        pass

def _run_all(_):
    with out_all:
        out_all.clear_output()
        cm = _current_mapping()
        keys = _plottable_vars(cm)
        if not keys:
            display(Markdown("❌ No variables available in **both** BMD and ERA5. Adjust overrides above.")); 
            return
        metric = metric_dd.value
        t0 = np.datetime64(t0_txt.value)
        t1 = np.datetime64(t1_txt.value)
        display(Markdown(f"**Running** {', '.join(LABELS[k] for k in keys)}  \n"
                         f"Metric: `{metric}` — Window: `{str(t0)}` → `{str(t1)}`"))
        for k in keys:
            display(Markdown(f"### {LABELS[k]}"))
            _one_variable_plot(k, cm, metric, t0, t1)

btn_run.on_click(_run_all)


**Auto-detected variables (first alias match):**
- **precip** → BMD=`total_precipitation_24hr` | ERA5=`total_precipitation_24hr` | IMERG=`total_precipitation_24hr`
- **tavg** → BMD=`tavg` | ERA5=`tavg` | IMERG=`None`
- **tmax** → BMD=`tmax` | ERA5=`tmax` | IMERG=`None`

**(Optional) Override detected variable names** — leave blank to keep auto:

Text(value='total_precipitation_24hr', description='BMD precip:')

Text(value='total_precipitation_24hr', description='ERA5 precip:')

Text(value='tavg', description='BMD tavg:')

Text(value='tavg', description='ERA5 tavg:')

Text(value='tmax', description='BMD tmax:')

Text(value='tmax', description='ERA5 tmax:')

Dropdown(description='Metric:', options=('RMSE', 'MAE'), value='RMSE')

Text(value='2018-06-01', description='Start (YYYY-MM-DD):')

Text(value='2018-06-30', description='End   (YYYY-MM-DD):')

Button(button_style='success', description='Run ALL (BMD truth)', style=ButtonStyle())

Output()

In [32]:
# === Cell 2: IMD truth — ERA5 & IMERG (precip only), multi-region ===
import numpy as np, xarray as xr, matplotlib.pyplot as plt, ipywidgets as widgets
from IPython.display import display
plt.rcParams.update({"figure.dpi": 120})

VAR = "total_precipitation_24hr"
REGIONS = {
    "Global":     {"latitude": slice(-90, 90),  "longitude": slice(0, 360)},
    "Ethiopia":   {"latitude": slice(3.4, 14.9),"longitude": slice(33.0, 48.0)},
    "Nigeria":    {"latitude": slice(4.0, 14.7),"longitude": slice(2.7, 14.7)},
    "Kenya":      {"latitude": slice(-4.7, 5.0),"longitude": slice(33.9, 41.9)},
    "Bangladesh": {"latitude": slice(20.7, 26.7),"longitude": slice(88.0, 92.7)},
    "Chile":      {"latitude": slice(-56.0, -17.5),"longitude": slice(284.0, 294.0)},
}

def _norm_mm(da):
    u = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    if u in ["m","meter","metre","m of water equivalent"]:
        da = da * 1000.0
    da.attrs["units"] = "mm"
    return da

def _ensure_valid_time(da):
    if "valid_time" in da.dims: return da
    if "time" in da.dims:       return da.rename({"time":"valid_time"})
    dec = xr.decode_cf(da.to_dataset(name="_tmp"), use_cftime=False).to_array("_tmp")
    if "time" in dec.dims:      return dec.rename({"time":"valid_time"})
    raise ValueError("No time/valid_time dimension.")

def _apply_region(ds, region):
    lat_lo = min(region["latitude"].start,  region["latitude"].stop)
    lat_hi = max(region["latitude"].start,  region["latitude"].stop)
    lon_lo = min(region["longitude"].start, region["longitude"].stop)
    lon_hi = max(region["longitude"].start, region["longitude"].stop)
    return ds.sel(latitude=slice(lat_lo, lat_hi), longitude=slice(lon_lo, lon_hi))

def _align_to_truth(src, truth):
    s = src.sortby(["latitude","longitude"]); r = truth.sortby(["latitude","longitude"])
    try:
        return s.reindex_like(r, method="nearest",
                              tolerance={"latitude":0.125,"longitude":0.125})
    except Exception:
        return s.interp(latitude=r.latitude, longitude=r.longitude, method="nearest")

def _rmse_mae_series(pred, truth):
    diff = (pred - truth).astype("float32")
    se, ae = diff**2, np.abs(diff)
    rmse = se.mean([d for d in ("latitude","longitude") if d in se.dims], skipna=True)**0.5
    mae  = ae.mean([d for d in ("latitude","longitude") if d in ae.dims],  skipna=True)
    return rmse.squeeze(), mae.squeeze()

def _mask_to_truth(truth, *others):
    m = xr.ufuncs.isfinite(truth).any("valid_time")
    outs = [truth.where(m)] + [o.where(m) for o in others]
    return outs

# ── UI ───────────────────────────────────────────────────────────────────────
region_dd = widgets.Dropdown(options=list(REGIONS.keys()), description="Region:", value="Global")
metric_dd = widgets.Dropdown(options=["RMSE","MAE"], description="Metric:", value="RMSE")
t0_txt    = widgets.Text(value="2018-06-01", description="Start (YYYY-MM-DD):")
t1_txt    = widgets.Text(value="2018-06-30", description="End   (YYYY-MM-DD):")
btn_run2  = widgets.Button(description="Run (IMD truth)", button_style="success")
out_imd   = widgets.Output()
display(region_dd, metric_dd, t0_txt, t1_txt, btn_run2, out_imd)

def run_imd(_):
    with out_imd:
        out_imd.clear_output()

        # availability
        if (VAR not in ds_imd) or (VAR not in ds_era5):
            print(f"❌ '{VAR}' missing in IMD or ERA5."); return
        has_imerg_candidate = ('ds_imerg' in globals()) and (ds_imerg is not None) and (VAR in ds_imerg)

        region = REGIONS[region_dd.value]
        t0 = np.datetime64(t0_txt.value); t1 = np.datetime64(t1_txt.value)

        truth = _norm_mm(_apply_region(ds_imd [[VAR]], region)[VAR]).sel(time=slice(t0,t1))
        era5  = _norm_mm(_apply_region(ds_era5[[VAR]], region)[VAR]).sel(time=slice(t0,t1))
        imerg = _norm_mm(_apply_region(ds_imerg[[VAR]], region)[VAR]).sel(time=slice(t0,t1)) if has_imerg_candidate else None
        if truth.sizes.get("time",0)==0:
            print("⚠️ IMD empty in this window/region."); return

        truth = _ensure_valid_time(truth)
        era5  = _ensure_valid_time(era5)
        if imerg is not None:
            imerg = _ensure_valid_time(imerg)

        era5_on  = _align_to_truth(era5, truth)
        imerg_on = _align_to_truth(imerg, truth) if imerg is not None else None

        # intersect time
        tmin = max(np.datetime64(truth.valid_time.min().values),
                   np.datetime64(era5_on.valid_time.min().values))
        tmax = min(np.datetime64(truth.valid_time.max().values),
                   np.datetime64(era5_on.valid_time.max().values))
        if imerg_on is not None:
            tmin = max(tmin, np.datetime64(imerg_on.valid_time.min().values))
            tmax = min(tmax, np.datetime64(imerg_on.valid_time.max().values))

        truth_c = truth.sel(valid_time=slice(tmin, tmax))
        era5_c  = era5_on.sel(valid_time=slice(tmin, tmax))
        imerg_c = imerg_on.sel(valid_time=slice(tmin, tmax)) if imerg_on is not None else None
        if truth_c.sizes.get("valid_time",0)==0:
            print("⚠️ No overlap among datasets."); return

        # mask ocean via IMD coverage
        if imerg_c is not None:
            truth_m, era5_m, imerg_m = _mask_to_truth(truth_c, era5_c, imerg_c)
        else:
            truth_m, era5_m = _mask_to_truth(truth_c, era5_c)

        # if IMERG ended up empty after masking/intersection, drop the bottom panel
        imerg_valid = (imerg_c is not None) and (imerg_m.sizes.get("valid_time",0) > 0)

        metric = metric_dd.value
        nrows = 2 if imerg_valid else 1
        fig, axs = plt.subplots(nrows, 1, figsize=(10, 6 if nrows==2 else 3), sharex=True)

        if metric == "RMSE":
            s_era5, _ = _rmse_mae_series(era5_m, truth_m); ylabel = "RMSE (mm/day)"
        else:
            _, s_era5 = _rmse_mae_series(era5_m, truth_m); ylabel = "MAE (mm/day)"

        ax0 = axs[0] if nrows==2 else axs
        s_era5.rename({"valid_time":"date"}).plot(ax=ax0)
        ax0.set_title(f"{ylabel} — ERA5 vs IMD — {VAR} — {region_dd.value}")
        ax0.set_ylabel("mm/day")
        ax0.grid(True, alpha=0.3)

        if imerg_valid:
            if metric == "RMSE":
                s_imerg, _ = _rmse_mae_series(imerg_m, truth_m)
            else:
                _, s_imerg = _rmse_mae_series(imerg_m, truth_m)
            axs[1].plot(s_imerg["valid_time"].values, s_imerg.values)
            axs[1].set_title(f"{ylabel} — IMERG vs IMD — {VAR} — {region_dd.value}")
            axs[1].set_ylabel("mm/day")
            axs[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

btn_run2.on_click(run_imd)


Dropdown(description='Region:', options=('Global', 'Ethiopia', 'Nigeria', 'Kenya', 'Bangladesh', 'Chile'), val…

Dropdown(description='Metric:', options=('RMSE', 'MAE'), value='RMSE')

Text(value='2018-06-01', description='Start (YYYY-MM-DD):')

Text(value='2018-06-30', description='End   (YYYY-MM-DD):')

Button(button_style='success', description='Run (IMD truth)', style=ButtonStyle())

Output()

In [15]:
# Cell — Bias maps over India (IMD truth) with safe widget creation & wiring
import numpy as np, pandas as pd, xarray as xr, matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown

# ---------- Use existing datasets if present ----------
try: ds_imd
except NameError: ds_imd = None
try: ds_imerg
except NameError: ds_imerg = None
try: ds_e5_p   # ERA5 daily precip
except NameError: ds_e5_p = None

VAR = "total_precipitation_24hr"   # precipitation variable in all three

# ---------- Minimal helpers (self-contained) ----------
def _normalize_mm(da):
    units = (da.attrs.get("units") or da.attrs.get("unit") or "").lower()
    out = da
    if units in ["m","meter","metre","m of water equivalent"]:
        out = out * 1000.0
    out.attrs["units"] = "mm"
    return out

def _to_0360(obj):
    if "longitude" in obj.coords:
        lon = obj["longitude"]
        try:
            if float(lon.min()) < 0:
                obj = obj.assign_coords(longitude=(lon % 360))
        except Exception:
            pass
        obj = obj.sortby("longitude")
    return obj

def _ensure_lat_asc(obj):
    if "latitude" in obj.coords:
        lat = obj["latitude"].values
        if len(lat) > 1 and lat[0] > lat[-1]:
            obj = obj.sortby("latitude")
    return obj

def _coerce_time(da):
    if "time" in da.dims: return da
    # try CF decode
    dec = xr.decode_cf(da.to_dataset(name="_tmp")).to_array("_tmp")
    if "time" in dec.dims: return dec
    raise ValueError("No time dimension.")

def _align_to_reference_grid(src: xr.DataArray, ref: xr.DataArray) -> xr.DataArray:
    """Nearest-label reindex then nearest interp; avoids empty-grid crashes."""
    # early guards
    if src.sizes.get("latitude",0)==0 or src.sizes.get("longitude",0)==0:
        return src
    if ref.sizes.get("latitude",0)==0 or ref.sizes.get("longitude",0)==0:
        raise RuntimeError("reference grid has zero points (lat or lon size = 0).")
    s = _ensure_lat_asc(_to_0360(src))
    r = _ensure_lat_asc(_to_0360(ref))
    # identical labels?
    try:
        if np.array_equal(s.latitude.values, r.latitude.values) and np.array_equal(s.longitude.values, r.longitude.values):
            return s
    except Exception:
        pass
    # safe nearest label reindex
    try:
        return s.reindex_like(r, method="nearest", tolerance={"latitude":0.125,"longitude":0.125})
    except Exception:
        pass
    # fallback nearest interp
    s2 = s.sortby(["latitude","longitude"])
    r2 = r.sortby(["latitude","longitude"])
    return s2.interp(latitude=r2["latitude"], longitude=r2["longitude"], method="nearest")

def _india_box_from_imd(imd_da: xr.DataArray):
    """Compute lat/lon bounds from where IMD is ever finite."""
    mask_any = xr.ufuncs.isfinite(imd_da).any("time")
    lat_any = mask_any.any("longitude")
    lon_any = mask_any.any("latitude")
    # dask-safe: compute boolean indexers
    lat_vals = imd_da.latitude.where(lat_any).dropna("latitude")
    lon_vals = imd_da.longitude.where(lon_any).dropna("longitude")
    if lat_vals.size == 0 or lon_vals.size == 0:
        # fallback box
        return 6.5, 37.0, 68.0, 98.0
    return float(lat_vals.min()), float(lat_vals.max()), float(lon_vals.min()), float(lon_vals.max())

def _slice_box(da, lat0, lat1, lon0, lon1):
    lo = min(lon0%360, lon1%360); hi = max(lon0%360, lon1%360)
    lat_lo = min(lat0, lat1);     lat_hi = max(lat0, lat1)
    da1 = _ensure_lat_asc(_to_0360(da))
    return da1.sel(latitude=slice(lat_lo, lat_hi), longitude=slice(lo, hi))

def _time_filter_mjj_and_years(da, years):
    da = da.sel(time=da.time.dt.month.isin([5,6,7]))
    if years:
        y0, y1 = years
        da = da.sel(time=slice(f"{y0}-01-01", f"{y1}-12-31"))
    return da

# ---------- Widgets (create if missing) ----------
try: btn_run_map
except NameError:
    btn_run_map = widgets.Button(description="Run bias maps", button_style="info")

try: years_txt
except NameError:
    years_txt = widgets.Text(value="2018-2021", description="Years (YYYY-YYYY):")

try: smooth_chk
except NameError:
    smooth_chk = widgets.Checkbox(value=True, description="Smooth (gaussian)")

try: vmax_txt
except NameError:
    vmax_txt = widgets.Text(value="50", description="Abs max (mm)")

try: show_era5
except NameError:
    show_era5 = widgets.Checkbox(value=True, description="ERA5 vs IMD")

try: show_imerg
except NameError:
    show_imerg = widgets.Checkbox(value=True, description="IMERG vs IMD")

try: out_map
except NameError:
    out_map = widgets.Output()

ui = widgets.HBox([years_txt, smooth_chk, vmax_txt, show_era5, show_imerg, btn_run_map])
display(Markdown("### Spatial bias maps — IMD truth (May–July, mean over years)"))
display(ui, out_map)

# ---------- Bias-map runner ----------
def run_bias_maps(_):
    with out_map:
        out_map.clear_output()
        # sanity
        if ds_imd is None or ds_e5_p is None or ds_imerg is None:
            print("❌ Please load IMD, ERA5 (precip), and IMERG datasets first.")
            return
        if VAR not in ds_imd or VAR not in ds_e5_p or VAR not in ds_imerg:
            print(f"❌ `{VAR}` must exist in all datasets.")
            return

        # parse years
        yrs = years_txt.value.strip()
        years = None
        if yrs:
            try:
                a, b = yrs.split("-")
                years = (int(a), int(b))
            except Exception:
                print("⚠️ Could not parse years; using full overlap.")

        print("Step 1/7: Prepping variables …")
        imd   = _normalize_mm(_coerce_time(ds_imd[VAR]))
        imerg = _normalize_mm(_coerce_time(ds_imerg[VAR]))
        era5  = _normalize_mm(_coerce_time(ds_e5_p[VAR]))

        print("Step 2/7: Building dynamic India box from IMD coverage …")
        lat0, lat1, lon0, lon1 = _india_box_from_imd(imd)
        print(f"  India box (lat, lon): [{lat0:.2f}, {lat1:.2f}] × [{lon0:.2f}, {lon1:.2f}]")

        print("Step 3/7: Slicing all products to the India box …")
        imd_ind   = _slice_box(imd,   lat0, lat1, lon0, lon1)
        imerg_ind = _slice_box(imerg, lat0, lat1, lon0, lon1)
        era5_ind  = _slice_box(era5,  lat0, lat1, lon0, lon1)
        def _ext(da):
            return (float(da.latitude.min()), float(da.latitude.max()),
                    float(da.longitude.min()), float(da.longitude.max()),
                    str(pd.to_datetime(da.time.min().values).date()),
                    str(pd.to_datetime(da.time.max().values).date()),
                    int(da.sizes.get("time",0)))
        la0, la1, lo0, lo1, tmin, tmax, tn = _ext(imd_ind)
        print(f"  IMD   lat: {la0:.2f} … {la1:.2f}  |  lon: {lo0:.2f} … {lo1:.2f}  |  time: {tmin} … {tmax} (n={tn})")

        la0, la1, lo0, lo1, tmin, tmax, tn = _ext(imerg_ind)
        print(f"  IMERG lat: {la0:.2f} … {la1:.2f}  |  lon: {lo0:.2f} … {lo1:.2f}  |  time: {tmin} … {tmax} (n={tn})")

        la0, la1, lo0, lo1, tmin, tmax, tn = _ext(era5_ind)
        print(f"  ERA5  lat: {la0:.2f} … {la1:.2f}  |  lon: {lo0:.2f} … {lo1:.2f}  |  time: {tmin} … {tmax} (n={tn})")

        print("Step 4/7: Selecting May–July and requested years …")
        imd_mjj   = _time_filter_mjj_and_years(imd_ind, years)
        imerg_mjj = _time_filter_mjj_and_years(imerg_ind, years)
        era5_mjj  = _time_filter_mjj_and_years(era5_ind, years)

        # intersect time range across all three
        t0 = max(np.datetime64(imd_mjj.time.min().values),
                 np.datetime64(imerg_mjj.time.min().values),
                 np.datetime64(era5_mjj.time.min().values))
        t1 = min(np.datetime64(imd_mjj.time.max().values),
                 np.datetime64(imerg_mjj.time.max().values),
                 np.datetime64(era5_mjj.time.max().values))
        imd_mjj   = imd_mjj.sel(time=slice(t0, t1))
        imerg_mjj = imerg_mjj.sel(time=slice(t0, t1))
        era5_mjj  = era5_mjj.sel(time=slice(t0, t1))

        print("Step 5/7: Aligning IMERG/ERA5 to IMD grid …")
        try:
            imerg_on_imd = _align_to_reference_grid(imerg_mjj, imd_mjj)
            era5_on_imd  = _align_to_reference_grid(era5_mjj,  imd_mjj)
        except Exception as e:
            print(f"⚠️ Align error: {e}"); return

        print("Step 6/7: Intersecting time …")
        # re-check common time after alignment (labels retained)
        t0 = max(np.datetime64(imd_mjj.time.min().values),
                 np.datetime64(imerg_on_imd.time.min().values),
                 np.datetime64(era5_on_imd.time.min().values))
        t1 = min(np.datetime64(imd_mjj.time.max().values),
                 np.datetime64(imerg_on_imd.time.max().values),
                 np.datetime64(era5_on_imd.time.max().values))
        imd_use   = imd_mjj.sel(time=slice(t0, t1))
        imerg_use = imerg_on_imd.sel(time=slice(t0, t1))
        era5_use  = era5_on_imd.sel(time=slice(t0, t1))

        print("Step 7/7: Computing time-mean bias (May–July, selected years) …")
        # mask: finite anywhere in window (land/coverage)
        m_imd = xr.ufuncs.isfinite(imd_use).any("time")
        # predictors masked to IMD coverage
        imerg_bias = (imerg_use - imd_use).where(m_imd).mean("time", skipna=True)
        era5_bias  = (era5_use  - imd_use).where(m_imd).mean("time", skipna=True)

        # optional smoothing
        if smooth_chk.value:
            try:
                from scipy.ndimage import gaussian_filter
                imerg_bias_vals = gaussian_filter(np.asarray(imerg_bias), sigma=0.8, mode="nearest")
                era5_bias_vals  = gaussian_filter(np.asarray(era5_bias),  sigma=0.8, mode="nearest")
            except Exception:
                imerg_bias_vals = np.asarray(imerg_bias)
                era5_bias_vals  = np.asarray(era5_bias)
        else:
            imerg_bias_vals = np.asarray(imerg_bias)
            era5_bias_vals  = np.asarray(era5_bias)

        # plotting
        vmax = float(vmax_txt.value or 50)
        vmin = -vmax
        ncols = int(show_era5.value) + int(show_imerg.value)
        if ncols == 0:
            print("⚠️ Select at least one predictor to plot."); return
        fig, axs = plt.subplots(1, ncols, figsize=(5.8*ncols, 5), constrained_layout=True)
        if ncols == 1: axs = [axs]

        i = 0
        if show_era5.value:
            im0 = axs[i].imshow(era5_bias_vals, origin="lower",
                                extent=[float(imd_use.longitude.min()), float(imd_use.longitude.max()),
                                        float(imd_use.latitude.min()),  float(imd_use.latitude.max())],
                                vmin=vmin, vmax=vmax, cmap="RdBu_r", aspect="auto")
            axs[i].set_title("Bias: ERA5 − IMD (mm/day)")
            axs[i].set_xlabel("Longitude"); axs[i].set_ylabel("Latitude")
            fig.colorbar(im0, ax=axs[i], fraction=0.046, pad=0.04, label="mm/day")
            i += 1

        if show_imerg.value:
            im1 = axs[i].imshow(imerg_bias_vals, origin="lower",
                                extent=[float(imd_use.longitude.min()), float(imd_use.longitude.max()),
                                        float(imd_use.latitude.min()),  float(imd_use.latitude.max())],
                                vmin=vmin, vmax=vmax, cmap="RdBu_r", aspect="auto")
            axs[i].set_title("Bias: IMERG − IMD (mm/day)")
            axs[i].set_xlabel("Longitude"); axs[i].set_ylabel("Latitude")
            fig.colorbar(im1, ax=axs[i], fraction=0.046, pad=0.04, label="mm/day")

        # caption
        y0 = pd.to_datetime(str(imd_use.time.min().values)).year
        y1 = pd.to_datetime(str(imd_use.time.max().values)).year
        display(Markdown(f"*Mean bias over **May–July**, years **{y0}–{y1}**; masked to IMD coverage (ocean excluded).*"))

        plt.show()

# ---------- Wire the button (idempotent) ----------
btn_run_map.on_click(run_bias_maps)
print("✅ Bias-map UI ready. Click **Run bias maps**.")


### Spatial bias maps — IMD truth (May–July, mean over years)

HBox(children=(Text(value='2018-2021', description='Years (YYYY–YYYY):'), Checkbox(value=True, description='Sm…

Output()

✅ Bias-map UI ready. Click **Run bias maps**.
