# 2 GenCast autoregression with interpolation

This notebook presents a explored solution to get to a hourlt resolution by launching 2 (or more) gencast models with a time offset and merging the results

# Installation and Initialization

In [None]:
# !pip install ml_dtypes --prefer-binary
# # Install conda packages
# !conda install -y python=3.10 numpy scipy pandas xarray=2023.12.0 netCDF4 dask gcsfs cdsapi boto3 xarray-datatree -c conda-forge -c defaults
# # Install pip packages
# !pip install --upgrade "numexpr>=2.8.4"
# !pip install dm-haiku git+https://github.com/deepmind/graphcast.git@main "jax[cuda12]"

In [1]:
# Standard libraries
import dataclasses
import datetime
import glob
import math
import os
import zipfile
from typing import Optional



# IPython and widgets
from IPython.display import HTML, display
from IPython import display
import ipywidgets as widgets


# Visualization
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cartopy.crs as ccrs

# Numerical and array operations
import numpy as np
import xarray as xr
import xarray
from datatree import DataTree
xarray.DataTree = DataTree  # Patch xarray to include DataTree

# JAX-related
import jax
import haiku as hk


# GraphCast modules
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning

# CDS API
import cdsapi

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")


In [2]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())


# Load the Data and initialize the model

In [3]:
# @title Set paths

MODEL_PATH = "./model/GenCast 1p0deg _2019.npz"  # E.g. "GenCast 1p0deg _2019.npz"
DATA_PATH = "./evaluation_data/source-era5_date-2019-03-29_res-1.0_levels-13_steps-12.nc"  # E.g. "source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc"
STATS_DIR = "./stats/"  # E.g. "stats/"

In [4]:
import os
print("Current working directory:", os.getcwd())

Current working directory: /home/ec2-user/SageMaker/demo_cloud


In [5]:
# @title Load the model

with open(MODEL_PATH, "rb") as f:
    ckpt = checkpoint.load(f, gencast.CheckPoint)
    denoiser_architecture_config = ckpt.denoiser_architecture_config
    denoiser_architecture_config.sparse_transformer_config.attention_type = "triblockdiag_mha"
    denoiser_architecture_config.sparse_transformer_config.mask_type = "full"
    
params = ckpt.params
state = {}

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")

Model description:
 
        GenCast model at lower, 1deg, resolution, with 13 pressure levels and a
        5 times refined icosahedral mesh. This model is trained on ERA5 data
        from 1979 to 2018, and can be causally evaluated on 2019 and later years.
        This model has a smaller memory footprint than the 0.25deg models.
         

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 



## Load the example data

Example ERA5 datasets are available at 0.25 degree and 1 degree resolution.

Example HRES-fc0 datasets are available at 0.25 degree resolution.

Some transformations were done from the base datasets:
- We accumulated precipitation over 12 hours instead of the default 1 hour.
- For HRES-fc0 sea surface temperature, we assigned NaNs to grid cells in which sea surface temperature was NaN in the ERA5 dataset (this remains fixed at all times).

The data resolution must match the model that is loaded.



In [6]:
# @title Check example dataset matches model

def parse_file_parts(file_name):
    parts = {}
    for part in file_name.split("_"):
        if "-" in part:  # Only process parts containing a hyphen
            key, value = part.split("-", 1)
            # Normalize the key by removing './' if present
            key = key.lstrip("./")
            parts[key] = value
    return parts

def data_valid_for_model(file_name: str, params_file_name: str):
  """Check data type and resolution matches."""
  data_file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  res_matches = data_file_parts["res"].replace(".", "p") in params_file_name.lower()
  source_matches = "Operational" in params_file_name
  if data_file_parts["source"] == "era5":
    source_matches = not source_matches
  return res_matches and source_matches

assert data_valid_for_model(DATA_PATH, MODEL_PATH)


In [7]:
# @title Load weather data

with open(DATA_PATH, "rb") as f:
  example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(DATA_PATH.removesuffix(".nc")).items()]))

example_batch

source: era5, date: 2019-03-29, res: 1.0, levels: 13, steps: 12


In [None]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", "12h"), # Only 1AR training.
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time']-2)*12}h"), # All but 2 input frames.
    **dataclasses.asdict(task_config))


print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


In [20]:
# @title Load normalization data

with open(STATS_DIR +"diffs_stddev_by_level.nc", "rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"mean_by_level.nc", "rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"stddev_by_level.nc", "rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"min_by_level.nc", "rb") as f:
  min_by_level = xarray.load_dataset(f).compute()

In [21]:
# @title Build jitted functions, and possibly initialize random weights


def construct_wrapped_gencast():
  """Constructs and wraps the GenCast Predictor."""
  predictor = gencast.GenCast(
      sampler_config=sampler_config,
      task_config=task_config,
      denoiser_architecture_config=denoiser_architecture_config,
      noise_config=noise_config,
      noise_encoder_config=noise_encoder_config,
  )

  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level,
  )

  predictor = nan_cleaning.NaNCleaner(
      predictor=predictor,
      reintroduce_nans=True,
      fill_value=min_by_level,
      var_to_clean='sea_surface_temperature',
  )

  return predictor


@hk.transform_with_state
def run_forward(inputs, targets_template, forcings):
  predictor = construct_wrapped_gencast()
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(inputs, targets, forcings):
  predictor = construct_wrapped_gencast()
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics),
  )


def grads_fn(params, state, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), i, t, f
    )
    return loss, (diagnostics, next_state)

  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True
  )(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads


if params is None:
  init_jitted = jax.jit(loss_fn.init)
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets=train_targets,
      forcings=train_forcings,
  )


loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
)
grads_fn_jitted = jax.jit(grads_fn)
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
)
# We also produce a pmapped version for running in parallel.
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

# Run the model (2 autoregression)

The `chunked_prediction_generator_multiple_runs` iterates over forecast steps, where the 1 step forecast is jitted and samples are pmapped across the chips.
This allows us to make efficient use of all devices and parallelise generating an ensemble across them. We then combine the chunks at the end to form our final forecast.

Note that the `Autoregressive rollout` cell will take longer than the standard inference time to run when executed for the first time, as this will include code compilation time. This cost does not increase with the number of devices, it is a fixed-cost one time operation whose result can be reused across any number of devices.

In [22]:
# The number of ensemble members should be a multiple of the number of devices.
print(f"Number of local devices {len(jax.local_devices())}")

Number of local devices 4


In [23]:
# @title Autoregressive rollout (loop in python)

def run_autoregression(pred_input):
    print("Inputs:  ", eval_inputs.dims.mapping)
    print("Targets: ", eval_targets.dims.mapping)
    print("Forcings:", eval_forcings.dims.mapping)

    num_ensemble_members = 8 # @param ints
    rng = jax.random.PRNGKey(0)
    # We fold-in the ensemble member, this way the first N members should always
    # match across different runs which use take the same inputs
    # regardless of total ensemble size.
    rngs = np.stack(
        [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

    chunks = []
    for chunk in rollout.chunked_prediction_generator_multiple_runs(
        # Use pmapped version to parallelise across devices.
        predictor_fn=run_forward_pmap,
        rngs=rngs,
        inputs=pred_input, #eval_inputs,
        targets_template=eval_targets * np.nan,
        forcings=eval_forcings,
        num_steps_per_chunk = 1,
        num_samples = num_ensemble_members,
        pmap_devices=jax.local_devices()
        ):
        chunks.append(chunk)
    predictions = xarray.combine_by_coords(chunks)
    return predictions


# Run autoregression with pred_input_1
print("-------autoregression 1----------------")
predictions_1 = run_autoregression(pred_input_1)

# Run autoregression with pred_input_2
print("-------autoregression 2----------------")
predictions_2 = run_autoregression(pred_input_2)


-------autoregression 1----------------
Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Targets:  {'batch': 1, 'time': 12, 'lat': 181, 'lon': 360, 'level': 13}
Forcings: {'batch': 1, 'time': 12, 'lon': 360}
-------autoregression 2----------------
Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Targets:  {'batch': 1, 'time': 12, 'lat': 181, 'lon': 360, 'level': 13}
Forcings: {'batch': 1, 'time': 12, 'lon': 360}


In [56]:
import xarray as xr
import numpy as np

def combine_autoregressive_predictions(predictions_1: xr.Dataset, predictions_2: xr.Dataset) -> xr.Dataset:
    """
    Combines two autoregressive prediction datasets by averaging across samples,
    adjusting time coordinates, and concatenating them along the time dimension.

    Parameters:
    - predictions_1 (xr.Dataset): First set of predictions (earlier time steps).
    - predictions_2 (xr.Dataset): Second set of predictions (later time steps).

    Returns:
    - mean_predictions (xr.Dataset): Combined and time-sorted prediction dataset.
    """

    # Step 1: Average predictions across ensemble samples
    mean_predictions_1 = predictions_1.mean(dim="sample")
    mean_predictions_2 = predictions_2.mean(dim="sample")

    # Step 2: Shift time coordinates of the second prediction by 6 hours
    six_hours_ns = np.timedelta64(6, 'h')
    new_time_values = mean_predictions_1.coords['time'] + six_hours_ns
    mean_predictions_2 = mean_predictions_2.assign_coords(time=new_time_values)

    # Step 3: Concatenate both datasets along the time dimension
    mean_predictions = xr.concat([mean_predictions_1, mean_predictions_2], dim="time")

    # Step 4: Sort by time to ensure chronological order
    mean_predictions = mean_predictions.sortby("time")

    return mean_predictions
    
combined_predictions = combine_autoregressive_predictions(predictions_1, predictions_2)


In [None]:
import xarray as xr
import numpy as np

def interpolate_and_save_forecast(mean_predictions: xr.Dataset, pred_input_1: xr.Dataset, output_path: str = None) -> xr.Dataset:
    """
    Processes a combined autoregressive forecast by assigning datetime coordinates,
    interpolating the target day to 1-hour intervals, and optionally saving the result.

    Parameters:
    - mean_predictions (xr.Dataset): Combined forecast dataset with 12-hour intervals.
    - pred_input_1 (xr.Dataset): Input dataset used to derive datetime references.
    - output_path (str, optional): Path to save the interpolated forecast as NetCDF.

    Returns:
    - mean_prediction (xr.Dataset): Interpolated forecast dataset with hourly resolution.
    """

    # Step 1: Assign datetime coordinates to the forecast (starting 6 hours after input)
    start_datetime = pred_input_1["datetime"].values[0, 0] + np.timedelta64(6, 'h')
    new_times = mean_predictions.coords['time'].values
    datetime_values = [
        start_datetime + np.timedelta64(int(t / np.timedelta64(1, 'h')), 'h')
        for t in new_times
    ]
    mean_predictions = mean_predictions.assign_coords(datetime=("time", datetime_values))

    # Step 2: Select the first 3 time steps (target day = 24 hours = 3 x 12h)
    mean_prediction = mean_predictions.isel(time=slice(0, 12))

    # Step 3: Perform linear interpolation to get 1-hour intervals
    mean_prediction = mean_prediction.resample(time='1H').interpolate('linear')

    # Step 4: Recalculate datetime coordinates for interpolated forecast (starting 12h after input)
    start_datetime = pred_input_1["datetime"].values[0, 0] + np.timedelta64(12, 'h')
    new_times = mean_prediction.coords['time'].values
    datetime_values = [
        start_datetime + np.timedelta64(int(t / np.timedelta64(1, 'h')), 'h')
        for t in new_times
    ]
    mean_prediction = mean_prediction.assign_coords(datetime=("time", datetime_values))

    # Step 5: Optionally save to NetCDF
    if output_path:
        mean_prediction.to_netcdf(output_path, format="NETCDF4", engine="netcdf4")

    return mean_prediction


hourly_forecast = interpolate_and_save_forecast(combined_predictions, pred_input_1, output_path="./results/predictions_3_days_hourly_2G_linear.nc")
hourly_forecast