# üåç Demo 4: Putting Your Forecast to the Test

## üß† Learning Objectives
- Use the `weatherbench2` Python library for forecast verification.
- Load custom forecast and ERA5 data for comparison.
- Compute standard metrics like RMSE and ACC.
- Focus evaluation on specific countries or regions.
- Visualize how forecast skill changes with lead time.
- Understand the power of localized hindcasting for tailored insights.


## üéØ Objective

In this demo, we scientifically **grade the forecast** you created in **Demo 2**, using professional tools like `weatherbench2`. Instead of relying on generic operational forecast products, we‚Äôll do a **custom local evaluation**.

**Theme:** *We are doing this locally to get custom answers that operational websites can't provide.*

You‚Äôll compute metrics like RMSE and ACC over custom regions (like Kenya or Chile), visualize model skill, and learn why localized hindcasts are powerful.


In [1]:
import xarray as xr
import pandas as pd
import numpy as np
from weatherbench2 import regions
from weatherbenchX.metrics.deterministic import RMSE, MAE
from weatherbenchX.metrics import base as metrics_base
import apache_beam as beam
import numpy as np
import xarray as xr
import weatherbenchX
from weatherbenchX.data_loaders import xarray_loaders
from weatherbenchX.metrics import deterministic
from weatherbenchX.metrics import base as metrics_base
from weatherbenchX import aggregation
from weatherbenchX import weighting

from weatherbenchX import binning
from weatherbenchX import time_chunks
from weatherbenchX import beam_pipeline
import ipywidgets as widgets
from IPython.display import display, Markdown

forecast_path = widgets.Text(
    description='Forecast file:',
    placeholder='Enter path to NetCDF forecast from Demo 2',
    layout=widgets.Layout(width='80%')
)

display(forecast_path)
def load_forecast(path):
    try:
        forecast_ds = xr.open_dataset(path)
        display(Markdown("‚úÖ Forecast data loaded successfully."))
        return forecast_ds
    except Exception as e:
        display(Markdown(f"‚ùå Error loading forecast: {e}"))
        return None

forecast_ds = None

# Button to trigger loading
load_button = widgets.Button(description="Load Forecast")
load_output = widgets.Output()

def on_load_clicked(b):
    global forecast_ds
    with load_output:
        load_output.clear_output()
        if not forecast_path.value.strip():
            display(Markdown("‚ùå Please enter a valid file path."))
        else:
            forecast_ds = load_forecast(forecast_path.value)

load_button.on_click(on_load_clicked)
display(load_button, load_output)
# /mnt/default/jad/agg/init_ERA5_20230630T00_lead_360.nc

Text(value='', description='Forecast file:', layout=Layout(width='80%'), placeholder='Enter path to NetCDF for‚Ä¶

Button(description='Load Forecast', style=ButtonStyle())

Output()

In [None]:
import xarray as xr


if forecast_ds is not None:
    forecast_vars = ['2t', 'tp', 'z_500']
    var_map = {
        "2t": "2m_temperature",
        "tp": "total_precipitation",
        "z_500": "geopotential",
    }

    era5_vars = [var_map.get(fv, fv) for fv in forecast_vars]

   # post-process forecast
    forecast_ds["prediction_timedelta"] = forecast_ds["step"].astype("timedelta64[h]")
    forecast_ds = forecast_ds.assign_coords(prediction_timedelta=forecast_ds["step"].astype("timedelta64[h]"))
    forecast_ds = forecast_ds.drop_vars("step")
    forecast_ds = forecast_ds.rename({"lat": "latitude", "lon": "longitude",})
    forecast_start_time = forecast_ds["time"].min().values
    forecast_end_time = (forecast_start_time + forecast_ds["prediction_timedelta"].max().values).astype("datetime64[ns]")
    forecast_ds = forecast_ds.rename({'time': 'init_time', 'prediction_timedelta': 'lead_time'})
    forecast_ds = forecast_ds.swap_dims({'step': 'lead_time'})
    
    # ERA5
    era5_path = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
    full_era5 = xr.open_zarr(
        era5_path,
        chunks=None,
        storage_options={"token": "anon"}
    )[era5_vars]
    era5_z500 = full_era5["geopotential"].sel(level=500).drop_vars("level")

    era5_ds = xr.Dataset({
        "2t": full_era5["2m_temperature"],
        "tp": full_era5["total_precipitation"],
        "z_500": era5_z500
    })

    era5_ds = era5_ds.sel(time=slice(forecast_start_time, forecast_end_time))
    # era5_ds = era5_ds.rename_vars({v: k for k, v in var_map.items()})
    
    display(Markdown(f"‚úÖ ERA5 ground truth variable `{era5_vars}` loaded and time-aligned."))

‚úÖ ERA5 ground truth variable `['2m_temperature', 'total_precipitation', 'geopotential']` loaded and time-aligned.

In [4]:
from weatherbench2 import regions
DOMAIN_DEFINITIONS = {
    "Global": {
        "latitude": slice(90, -90),
        "longitude": slice(0, 360),
    },
    "Northern Hemisphere": {
        "latitude": slice(90, 0),
        "longitude": slice(0, 360),
    },
    "Tropics": {
        "latitude": slice(23.5, -23.5),
        "longitude": slice(0, 360),
    },
    "Bangladesh": {
        "latitude": slice(26.7, 20.7),
        "longitude": slice(88.0, 92.7),
    },
    "Chile": {
        "latitude": slice(-17.5, -56.0),
        "longitude": slice(284.0, 294.0),
    },
    "Nigeria": {
        "latitude": slice(14.7, 4.0),
        "longitude": slice(2.7, 14.7),
    },
    "Ethiopia": {
        "latitude": slice(14.9, 3.4),
        "longitude": slice(33.0, 48.0),
    },
    "Kenya": {
        "latitude": slice(5.0, -4.7),
        "longitude": slice(33.9, 41.9),
    },
}

In [None]:
def load_clim():
    var_map = {
    "2t": "2m_temperature",
    "z_500": "geopotential",
    'tp': 'total_precipitation_6hr',
    }

    clim = xr.open_zarr("gs://weatherbench2/datasets/era5-hourly-climatology/1990-2019_6h_1440x721.zarr")
    clim_var_map = {v: k for k, v in var_map.items()}

    clim = clim.rename_vars(clim_var_map)
    clim = clim[forecast_vars]
    clim["z_500"] = clim["z_500"].sel(level=500).drop_vars("level")
    return clim


def compute_statistics_in_chunks(forecast_ds, era5_ds, metrics, chunk_size=50):
    n_chunks = (forecast_ds.sizes['init_time'] + chunk_size - 1) // chunk_size
    all_statistics = []

    for i in range(n_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, forecast_ds.sizes['init_time'])
        
        forecast_chunk = forecast_ds.isel(init_time=slice(start, end))
        era5_chunk = era5_ds.isel(time=slice(start, end))
        
        stats_chunk = metrics_base.compute_unique_statistics_for_all_metrics(
            metrics, forecast_chunk, era5_chunk
        )
        all_statistics.append(stats_chunk)

    final_statistics = {}
    for stats_chunk in all_statistics:
        for metric_name, metric_data in stats_chunk.items():
            if metric_name not in final_statistics:
                final_statistics[metric_name] = metric_data
            else:
                final_statistics[metric_name] = xr.concat(
                    [final_statistics[metric_name], metric_data], dim='time'
                )
    return final_statistics


def run_verification(forecast_ds, era_5, region, metrics, chunk_size=50):
    forecast = forecast_ds.sel(**region)
    final_statistics = compute_statistics_in_chunks(forecast, era_5, metrics, chunk_size=chunk_size)
    return final_statistics


In [11]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

clim = load_clim()

def verify_interactive(forecast_ds, era5_ds, domain_definitions, clim=clim):
    region_selector = widgets.Dropdown(
        options=list(domain_definitions.keys()),
        description="Region:",
        value="Global"
    )
    
    metric_selector = widgets.Dropdown(
        options=["RMSE", "ACC"],
        description="Metric:",
        value="RMSE"
    )

    variable_selector = widgets.Dropdown(
        options=list(forecast_ds.data_vars.keys()),
        description="Variable:",
        value="2t" if "2t" in forecast_ds.data_vars else list(forecast_ds.data_vars.keys())[0]
    )
    run_button = widgets.Button(description="Run Verification", button_style="success")
    output = widgets.Output()
    
    def on_run_clicked(b):
        output.clear_output(wait=True)
        # display(region_selector, metric_selector, run_button, output)
        
        with output:
            region = region_selector.value
            metric = metric_selector.value
            ftr = variable_selector.value
            print(f"Selected region: {region}")
            print(f"Selected metric: {metric}")
            print(f"Selected variable: {ftr}")
            print("Loading metrics...")
            
            if metric == "RMSE":
                metrics = {"rmse": deterministic.RMSE()}
            elif metric == "ACC":
                print("Loading climatology data...")
                clim = load_clim()
                print("Climatology data loaded ‚úÖ")
                metrics = {"acc": deterministic.ACC(climatology=clim)}
            
            print("Computing statistics...")
            final_statistics = run_verification(
                forecast_ds, era5_ds, domain_definitions[region], metrics
            )
            print("Done ‚úÖ")
            
            if metric == "RMSE":
                rmse_2t = np.sqrt(final_statistics['SquaredError'][ftr]).mean(dim=['latitude','longitude'])

                if 'lead_time' in rmse_2t.dims:
                    rmse_lead = rmse_2t
                elif 'step' in rmse_2t.dims:
                    rmse_lead = rmse_2t.rename({'step':'lead_time'})
                else:
                    rmse_lead = rmse_2t
                    
                rmse_mean = rmse_lead.assign_coords(
                    lead_time_days = rmse_lead['lead_time'] / np.timedelta64(1, 'D')
                )


                fig, ax = plt.subplots(figsize=(8, 4))
                rmse_mean.plot(ax=ax, x='lead_time_days')
                ax.set_title(f"RMSE of {ftr} vs Time")
                ax.set_xlabel("Forecast time [days]")
                ax.set_ylabel("RMSE")

                display(fig)
                plt.close(fig)
            elif metric == "ACC":
                pred_anom = final_statistics['SquaredPredictionAnomaly'][ftr]
                target_anom = final_statistics['SquaredTargetAnomaly'][ftr]
                cov = final_statistics['AnomalyCovariance'][ftr]

                pred_anom = pred_anom.squeeze()
                target_anom = target_anom.squeeze()
                cov = cov.squeeze()

                acc_2t = cov / np.sqrt(pred_anom * target_anom)
                acc_2t_mean = acc_2t.mean(dim=['latitude','longitude'])
                acc_2t_mean = acc_2t_mean.assign_coords(
                    lead_time_days = acc_2t_mean['lead_time'] / np.timedelta64(1, 'D')
                )


                fig, ax = plt.subplots(figsize=(8, 4))
                acc_2t_mean.plot(ax=ax, x='lead_time_days')
                ax.set_title(f"ACC of {ftr} Temperature vs Time")
                ax.set_xlabel("Forecast time [days]")
                ax.set_ylabel("ACC")
                
                display(fig)
                plt.close(fig)
            # display(final_statistics)
    
    run_button.on_click(on_run_clicked)
    
    display(region_selector, metric_selector, variable_selector, run_button, output)

verify_interactive(forecast_ds, era5_ds, DOMAIN_DEFINITIONS)


Dropdown(description='Region:', options=('Global', 'Northern Hemisphere', 'Tropics', 'Bangladesh', 'Chile', 'N‚Ä¶

Dropdown(description='Metric:', options=('RMSE', 'ACC'), value='RMSE')

Dropdown(description='Variable:', options=('2t', 'tp', 'z_500'), value='2t')

Button(button_style='success', description='Run Verification', style=ButtonStyle())

Output()

## üìä Interpret Your Results

- **Look at the RMSE plot**: How does the error change as the forecast lead time increases? Why is this expected?
- **Try changing the region** from "Global" to "Kenya" and re-run the analysis.
    - Does the model's **ACC score** degrade faster or slower in Kenya?
    - What does this imply about model performance for **East African** weather?

Play with different regions and metrics to gain insights!


## üîç Key Takeaways

- Large operational centers provide broad forecasts ‚Äî but can't offer **custom regional analysis**.
- By running a **local hindcast** and benchmarking it with tools like `weatherbench2`, you can answer **targeted, high-impact questions**:
    - *How accurate is this model over Bangladesh during monsoon?*
    - *How reliable is the 5-day forecast for heatwaves in Chile?*
- This is the true power of open science and AI: **empowering local experts** to run meaningful evaluations for their region.

üß™ Keep exploring. Try different models, dates, and regions!
