# üåç Demo 4: Putting the Forecast to the Test

## üß† Learning Objectives
- Use the `weatherbench2` Python library for forecast verification.
- Load forecast and ERA5 data for comparison.
- Compute standard metrics including RMSE and ACC.
- Focus evaluation on specific countries or regions.
- Visualize how forecast skill changes with lead time.

## üéØ Objective

In this demo, we scientifically **grade the forecast**, using professional tools like `weatherbenchX`. 

You‚Äôll compute metrics like RMSE and ACC over custom regions (like Kenya or Chile), visualize model skill, and learn why localized forecasts can be powerful.


### Import Necessary Libraries

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import apache_beam as beam
import weatherbenchX
from weatherbenchX.data_loaders import xarray_loaders
from weatherbenchX.metrics import deterministic
from weatherbenchX import aggregation
from weatherbenchX import weighting
from weatherbenchX import binning
from weatherbenchX import time_chunks
from weatherbenchX import beam_pipeline
import ipywidgets as widgets
from IPython.display import display
import warnings
import os
import matplotlib.pyplot as plt
import logging
import glob

warnings.filterwarnings("ignore")
logging.getLogger("apache_beam").setLevel(logging.ERROR)

### Define global variables

These will be used throughout the notebook

In [None]:
# Define domains and variables
DOMAIN_DEFINITIONS = {
    "Global": {
        "latitude": (90, -90),
        "longitude": (0, 360),
    },
    "Northern Hemisphere": {
        "latitude": (90, 0),
        "longitude": (0, 360),
    },
    "Tropics": {
        "latitude": (23.5, -23.5),
        "longitude": (0, 360),
    },
    "Bangladesh": {
        "latitude": (26.7, 20.7),
        "longitude": (88.0, 92.7),
    },
    "Chile": {
        "latitude": (-17.5, -56.0),
        "longitude": (284.0, 294.0),
    },
    "Nigeria": {
        "latitude": (14.7, 4.0),
        "longitude": (2.7, 14.7),
    },
    "Ethiopia": {
        "latitude": (14.9, 3.4),
        "longitude": (33.0, 48.0),
    },
    "Kenya": {
        "latitude": (5.0, -4.7),
        "longitude": (33.9, 41.9),
    },
}
VARIABLES = ["2m_temperature", "total_precipitation_6hr", "z_500"]

# Load climatology
def load_clim():
    var_map = {
    "z_500": "geopotential",
    }
    clim_path     = "gs://weatherbench2/datasets/era5-hourly-climatology/1990-2019_6h_64x32_equiangular_conservative.zarr"
    clim = xr.open_zarr(clim_path, decode_timedelta=True)
    clim_var_map = {v: k for k, v in var_map.items()}

    clim = clim.rename_vars(clim_var_map)
    clim = clim[VARIABLES]
    clim["z_500"] = clim["z_500"].sel(level=500).drop_vars("level")
    return clim

CLIM = load_clim().compute()

### Define core benchmarking functions

In [None]:
# Main WBX Benchmarking function
def run_benchmark(model_name):

    # Model paths
    if model_name.lower() == "graphcast":
        forecast_path = "gs://weatherbench2/datasets/graphcast/2020/date_range_2019-11-16_2021-02-01_12_hours-64x32_equiangular_conservative.zarr"
    elif model_name.lower() == "fuxi":
        forecast_path = "gs://weatherbench2/datasets/fuxi/2020-64x32_equiangular_conservative.zarr"
    elif model_name.lower() == "ifs":
        forecast_path = "gs://weatherbench2/datasets/hres/2016-2022-0012-64x32_equiangular_conservative.zarr"
    target_path   = "gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr"

    # Preprocessing function to extract 500hPa geopotential height
    def preprocess_forecast(ds):
        if "geopotential" in ds.data_vars and "level" in ds["geopotential"].dims:
            ds["z_500"] = ds["geopotential"].sel(level=500).drop_vars("level")
        return ds
    ds = xr.open_zarr(forecast_path, decode_timedelta=True)

    # Define init time date range
    start = pd.Timestamp("2020-01-01T00:00:00")
    end   = pd.Timestamp("2020-12-31T00:00:00")

    # Get matching init times (Mondays and Thursdays in 2020)
    available_inits = pd.to_datetime(ds["time"].values)
    mask_in_range = (available_inits >= start) & (available_inits < end)
    mask_weekday = available_inits.weekday.isin([0, 3])
    mask = mask_in_range & mask_weekday

    selected_inits_pd = available_inits[mask]
    init_times = selected_inits_pd.values.astype("datetime64[ns]")

    # Define lead times (0 to 240 hours in 24-hour increments)
    desired_lead_hours = np.arange(0, 241, 24)
    lead_times = desired_lead_hours.astype("timedelta64[h]").astype("timedelta64[ns]")

    available_leads = ds["prediction_timedelta"].values
    lead_times = np.array([lt for lt in lead_times if lt in available_leads], dtype="timedelta64[ns]")
    lead_times_hr = lead_times.astype("timedelta64[h]")

    print(f"Benchmarking model: {model_name}")
    print(f"Selected {len(init_times)} init times (Mondays & Thursdays between {start} and {end}):")
    print(init_times[:6], "..." if len(init_times) > 6 else "")
    print("Lead times (hours):", lead_times.astype("timedelta64[h]"))

    # Begin WBX Benchmarking Pipeline
    
    # Define data loaders for forecasts and targets
    forecast_loader = xarray_loaders.PredictionsFromXarray(
        path=forecast_path,
        variables=VARIABLES,
        rename_dimensions='ecmwf',
        preprocessing_fn=preprocess_forecast
    )

    target_loader = xarray_loaders.TargetsFromXarray(
        path=target_path,
        variables=VARIABLES,
        rename_dimensions='ecmwf',
        preprocessing_fn=preprocess_forecast
    )

    # Define evaluation times
    times = time_chunks.TimeChunks(
        init_times=init_times,
        lead_times=lead_times,
        init_time_chunk_size=len(init_times),
        lead_time_chunk_size=len(lead_times)
    )

    # Define metrics of interest
    metrics = {
    'rmse': deterministic.RMSE(),
    'acc': deterministic.ACC(climatology=CLIM),
    }

    # Use the regions defined above for spatial binning
    regions = {name: ((bounds['latitude'][1], bounds['latitude'][0]), (bounds['longitude'][0], bounds['longitude'][1])) for name, bounds in DOMAIN_DEFINITIONS.items()}
    bin_by = [binning.Regions(regions)]

    # Weight by grid cell area to avoid bias towards poles
    weigh_by = [weighting.GridAreaWeighting()]

    # Define aggregator to reduce over init_time, latitude, and longitude
    # Aggregator will incorporate binning and weighting
    aggregator = aggregation.Aggregator(
    reduce_dims=['init_time', 'latitude', 'longitude'],
    bin_by=bin_by,
    weigh_by=weigh_by,
    )

    # Define output path and filename
    if not os.path.exists('benchmark_results'):
        os.makedirs('benchmark_results')
    metrics_fname = "_".join(sorted(metrics.keys()))
    # Output filename format: wbx_{model}_{metrics}_{num_inits}_{min_lead}hr_{max_lead}hr.nc
    out_fname = f"wbx_{model_name.lower()}_{metrics_fname}_{len(init_times)}_{min(lead_times_hr).astype(int)}hr_{max(lead_times_hr).astype(int)}hr.nc"
    out_path = f"benchmark_results/{out_fname}"
    if os.path.exists(out_path):
        print("Pipeline finished. Check:", out_path)
        return out_path

    # Beam is a tool for parallel processing, useful for large datasets
    # Here we use the DirectRunner for simplicity; for larger datasets consider DataflowRunner or others
    root = beam.Pipeline(runner="DirectRunner")
    beam_pipeline.define_pipeline(
        root=root,
        times=times,
        predictions_loader=forecast_loader,
        targets_loader=target_loader,
        metrics=metrics,
        aggregator=aggregator,
        out_path=out_path,
    )

    # Run the pipeline
    result = root.run()
    result.wait_until_finish()

    print("Pipeline finished. Check:", out_path)
    return out_path

In [None]:
out_paths = {}
models = ['GraphCast', 'FuXi', 'IFS']

for model in models:
    out_path = run_benchmark(model)
    out_paths[model] = out_path

## View the benchmark results

In [None]:
wbx_results = {}

try:
    print("Loading results...")
    for model, out_path in out_paths.items():
        wbx_results[model] = xr.open_dataset(out_path).compute()
        display(wbx_results[model])
except Exception as _:
    print("Loading results...")
    out_paths = glob.glob('benchmark_results/wbx_*.nc')
    for out_path in out_paths:
        model = out_path.split('/')[1].split('_')[1]
        wbx_results[model] = xr.open_dataset(out_path).compute()
        print(f"Loaded {model} results from {out_path}")
        display(wbx_results[model])



## Plot the results

### Plotting function

In [None]:
def plot_metric_over_region(wbx_results, region, metric, variable):
    plt.figure(figsize=(8, 4), dpi=100)
    for model, wbx_results in wbx_results.items():
        plt.plot(
            wbx_results['lead_time'].values.astype('timedelta64[h]').astype(int)//24,
            wbx_results[f"{metric}.{variable}"].sel(region=region),
            label=model
        )
    if (metric == 'acc' or metric == 'ACC'):
        plt.ylim(0, 1)
        plt.axhline(0.6, color='gray', linestyle='--')
    else:
        plt.ylim(bottom=0)

    plt.title(f"{variable} {metric.upper()} over {region} in 2020 (ERA5 ground truth)")
    plt.xlabel("Lead time (days)")
    plt.ylabel(metric.upper())
    plt.grid()
    plt.legend()

    plt.show()

### Select region, metrics, variable to include within the plot. 

In [None]:
# Define items to plot 
region = 'Global'   # options:
                        # 'Global', 'Northern Hemisphere', 'Tropics', 
                        # 'Bangladesh', 'Chile', 'Nigeria', 'Ethiopia', 'Kenya'

metric =  'acc' # options: 'rmse', 'acc'

variable = '2m_temperature' # options: '2m_temperature', 'total_precipitation_6hr', 'z_500'

plot_metric_over_region(wbx_results, region, metric, variable)

### Discussion Questions

Examine the plot:  
What do you notice? Which model is the best at a 5-day lead time?  
Which model is the best at the final lead time?  
Is the relative skill between models consistent with lead time? 

## Plot the results, but aggregated by lead time instead of time

Here we will use GraphCast as a test case. Other models will produce similar patterns

### Setup code

In [None]:
# Map dropdown keys -> possible variable names in the Zarrs
VAR_MAP = {
    "2t":   ["2m_temperature", "2t", "t2m"],
    "tp6h": ["total_precipitation_6hr", "total_precipitation", "tp"],
    "z500": ["z_500", "z500", "geopotential"],
}

DOMAIN_SLICES = {
    "Global": {
        "latitude": slice(90, -90),
        "longitude": slice(0, 360),
    },
    "Northern Hemisphere": {
        "latitude": slice(90, 0),
        "longitude": slice(0, 360),
    },
    "Tropics": {
        "latitude": slice(23.5, -23.5),
        "longitude": slice(0, 360),
    },
    "Bangladesh": {
        "latitude": slice(26.7, 20.7),
        "longitude": slice(88.0, 92.7),
    },
    "Chile": {
        "latitude": slice(-17.5, -56.0),
        "longitude": slice(284.0, 294.0),
    },
    "Nigeria": {
        "latitude": slice(14.7, 4.0),
        "longitude": slice(2.7, 14.7),
    },
    "Ethiopia": {
        "latitude": slice(14.9, 3.4),
        "longitude": slice(33.0, 48.0),
    },
    "Kenya": {
        "latitude": slice(5.0, -4.7),
        "longitude": slice(33.9, 41.9),
    },
}

def _pick_var(ds, candidates):
    for v in candidates:
        if v in ds.data_vars:
            return v
    raise KeyError(f"None of {candidates} found in dataset. Have: {list(ds.data_vars)}")

def _add_z500(ds):
    # Create z_500 from 'geopotential' if present with a 'level' dim
    if "geopotential" in ds.data_vars and "level" in ds["geopotential"].dims:
        ds["z_500"] = ds["geopotential"].sel(level=500).drop_vars("level")
    return ds

def _slice_region(ds, region_name):
    r = DOMAIN_SLICES[region_name]
    lat_name = "latitude" if "latitude" in ds.coords else "lat"
    lon_name = "longitude" if "longitude" in ds.coords else "lon"

    lat = ds[lat_name]
    # Handle both ascending and descending latitude coordinates
    if lat[0] > lat[-1]:
        lat_slice = slice(r["latitude"].start, r["latitude"].stop)  # descending coord
    else:
        lat_slice = slice(r["latitude"].stop, r["latitude"].start) if r["latitude"].start > r["latitude"].stop else r["latitude"]

    return ds.sel({lat_name: lat_slice, lon_name: r["longitude"]})

def load_aligned(forecast_path, target_path, var_key, region_name):
    """
    Returns forecast and target DataArrays over the chosen region with matching
    (time, prediction_timedelta, lat, lon). Targets are looked up at valid time = init + lead.
    """
    f = xr.open_zarr(forecast_path, decode_timedelta=True)
    t = xr.open_zarr(target_path,   decode_timedelta=True)

    f = _add_z500(f)
    t = _add_z500(t)

    fv = _pick_var(f, VAR_MAP[var_key])
    tv = _pick_var(t, VAR_MAP[var_key])

    f = _slice_region(f[[fv]], region_name)
    t = _slice_region(t[[tv]], region_name)

    # Build valid-time matrix (init + lead) and select targets there
    # Xarray allows DataArray indexers with matching shape.
    valid_time = f["time"] + f["prediction_timedelta"]
    try:
        t_valid = t[tv].sel(time=valid_time)  # exact match
    except Exception:
        t_valid = t[tv].sel(time=valid_time, method="nearest", tolerance=np.timedelta64(3, "h"))

    return f[fv], t_valid.rename(tv), "prediction_timedelta"

def metric_rmse(f_da, t_da):
    """
    RMSE collapsed over space and lead, returning one value per init time.
    """
    err2 = (f_da - t_da) ** 2
    # Average over everything except 'time'
    dims_to_mean = [d for d in err2.dims if d not in ["time", "prediction_timedelta"]]
    return np.sqrt(err2.mean(dim=dims_to_mean, skipna=True))
# -------------------------------------------------------------------------------


### Interactive Plotting code

In [None]:
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ Widgets + time-series plot (uses objects/functions defined in earlier cells) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ

# Variable keys expected by your helpers (the same ones you used elsewhere)
VAR_KEYS = [("2t", "2t"), ("tp6h", "tp6h"), ("z500", "z500")]

# ‚îÄ‚îÄ UI
w_region = widgets.Dropdown(options=list(DOMAIN_DEFINITIONS.keys()),
                            value="Ethiopia", description="Region:")
w_var    = widgets.Dropdown(options=VAR_KEYS, value="2t", description="Variable:")
w_run    = widgets.Button(description="Run Verification", button_style="success")
w_log    = widgets.Output()

ui_box = widgets.VBox([w_region, w_var, w_run, w_log])
display(ui_box)

def on_run(_):
    with w_log:
        w_log.clear_output()
        print(f"Selected region: {w_region.value}")
        print(f"Selected variable: {w_var.value}")
        print("Loading data...")

    try:
        # Use your previously defined helper to get forecast/target aligned on the same grid/times
        forecast_path = "gs://weatherbench2/datasets/graphcast/2020/date_range_2019-11-16_2021-02-01_12_hours-64x32_equiangular_conservative.zarr"
        target_path   = "gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr"
        f, t, lead_name = load_aligned(forecast_path, target_path, w_var.value, w_region.value)

        with w_log:
            print("Computing metrics...")

        # Use your metric helpers (already defined in earlier cells)
        series = metric_rmse(f, t)            # returns DataArray with dim 'time' (init times)
        ylabel = "RMSE"
        title_metric = "RMSE"

        # Plot
        times = pd.to_datetime(series["time"].values)
        fig, ax = plt.subplots(figsize=(9, 4))
        for lt in [240, 120]:
            ax.plot(times, series.sel({lead_name: np.timedelta64(lt, 'h')}), label=f'Lead {lt}h')
        # ax.plot(times, series.values)
        ax.set_title(f"{title_metric} of {w_var.value} vs Time")
        ax.set_ylabel(ylabel)
        ax.set_xlabel("Init time (UTC)")
        ax.legend()
        ax.grid(True, alpha=0.3)

        with w_log:
            print("Done ‚úÖ")
        display(fig)
        plt.close(fig)

    except Exception as e:
        with w_log:
            print("‚ùå Error:", repr(e))

w_run.on_click(on_run)


### Discussion Questions

What patterns do you observe as they relate to skill?   
Are there any months that the model shows better skill over other months?  
Is there any relationship between the lead time and the months that have increased skill? Or is consistent between lead times?   