# GraphCast

This colab lets you run several versions of GraphCast.

The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast).

A Colab runtime with TPU/GPU acceleration will substantially speed up generating predictions and computing the loss/gradients. If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type".

> <p><small><small>Copyright 2023 DeepMind Technologies Limited.</small></p>
> <p><small><small>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href="http://www.apache.org/licenses/LICENSE-2.0">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>
> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>

# Installation and Initialization


In [1]:
# @title Imports

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray as xr
import cfgrib

import glob
import os

def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))

In [2]:
# @title Plotting functions

def select(
    data: xr.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xr.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xr.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xr.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())

# Load the Data and initialize the model

## Load the model params

Choose one of the two ways of getting model params:
- **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.
- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device. In particular generating gradients uses a lot of memory, so you'll need at least 25GB of ram (TPUv4 or A100).

Checkpoints vary across a few axes:
- The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model.
- The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder.
- All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as "ERA5" take precipitation as input and expect ERA5 data as input, while model marked "ERA5-HRES" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below).

We provide three pre-trained models.
1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,

2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 1979 to 2015, useful to run a model with lower memory and compute constraints,

3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on HRES data from 2016 to 2021. This model can be initialized from HRES data (does not require precipitation inputs).


In [3]:
# @title Choose the model
file_paths_1 = glob.glob('./weights/**/*.npz', recursive=True)
# Extract just the file names
file_names = [os.path.basename(path) for path in file_paths_1]

random_mesh_size = widgets.IntSlider(
    value=4, min=4, max=6, description="Mesh size:")
random_gnn_msg_steps = widgets.IntSlider(
    value=4, min=1, max=32, description="GNN message steps:")
random_latent_size = widgets.Dropdown(
    options=[int(2**i) for i in range(4, 10)], value=32,description="Latent size:")
random_levels = widgets.Dropdown(
    options=[13, 37], value=13, description="Pressure levels:")

params_file = widgets.Dropdown(
    options=file_names,
    description="Params file:",
    layout={"width": "max-content"})

source_tab = widgets.Tab([
    widgets.VBox([
        random_mesh_size,
        random_gnn_msg_steps,
        random_latent_size,
        random_levels,
    ]),
    params_file,
])
source_tab.set_title(0, "Random")
source_tab.set_title(1, "Checkpoint")
widgets.VBox([
    source_tab,
    widgets.Label(value="Run the next cell to load the model. Rerunning this cell clears your selection.")
])

VBox(children=(Tab(children=(VBox(children=(IntSlider(value=4, description='Mesh size:', max=6, min=4), IntSli…

In [4]:
local_ckpt_path = file_paths_1[0]

# Open and load the checkpoint locally
with open(local_ckpt_path, "rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

params = ckpt.params
model_config = ckpt.model_config
task_config = ckpt.task_config
state = {}

## Load the example data

Several example datasets are available, varying across a few axes:
- **Source**: fake, era5, hres
- **Resolution**: 0.25deg, 1deg, 6deg
- **Levels**: 13, 37
- **Steps**: How many timesteps are included

Not all combinations are available.
- Higher resolution is only available for fewer steps due to the memory requirements of loading them.
- HRES is only available in 0.25 deg, with 13 pressure levels.

The data resolution must match the model that is loaded.

Some transformations were done from the base datasets:
- We accumulated precipitation over 6 hours instead of the default 1 hour.
- For HRES data, each time step corresponds to the HRES forecast at leadtime 0, essentially providing an "initialisation" from HRES. See HRES-fc0 in the GraphCast paper for further description. Note that a 6h accumulation of precipitation is not available from HRES, so our model taking HRES inputs does not depend on precipitation. However, because our models predict precipitation, we include the ERA5 precipitation in the example data so it can serve as an illustrative example of ground truth.
- We include ERA5 `toa_incident_solar_radiation` in the data. Our model uses the radiation at -6h, 0h and +6h as a forcing term for each 1-step prediction. If the radiation is missing from the data (e.g. in an operational setting), it will be computed using a custom implementation that produces values similar to those in ERA5.

In [5]:
# @title Choose the model
file_paths = glob.glob('./dataset/**/*.nc', recursive=True)

# Extract just the file names
file_names = [os.path.basename(path) for path in file_paths]

# @title Get and filter the list of available example datasets
dataset_file_options = file_names

def data_valid_for_model(
    file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):

  file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  return (
      model_config.resolution in (0, float(file_parts["res"])) and
      len(task_config.pressure_levels) == int(file_parts["levels"]) and
      (
          ("total_precipitation_6hr" in task_config.input_variables and
           file_parts["source"] in ("era5", "fake")) or
          ("total_precipitation_6hr" not in task_config.input_variables and
           file_parts["source"] in ("hres", "fake"))
      )
  )

dataset_file = widgets.Dropdown(
    options=[
        (", ".join([f"{k}: {v}" for k, v in parse_file_parts(option.removesuffix(".nc")).items()]), option)
        for option in dataset_file_options
        if data_valid_for_model(option, model_config, task_config)
    ],
    description="Dataset file:",
    layout={"width": "max-content"})
widgets.VBox([
    dataset_file,
    widgets.Label(value="Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.")
])

VBox(children=(Dropdown(description='Dataset file:', layout=Layout(width='max-content'), options=(('source: er…

In [6]:
import xarray as xr
import os

# Validate dataset
if not data_valid_for_model(dataset_file.value, model_config, task_config):
    raise ValueError("Invalid dataset file, rerun the cell above and choose a valid dataset file.")

# Load from local path
dataset_path = os.path.join("dataset", dataset_file.value)
example_batch = xr.load_dataset(dataset_path).compute()

# Check time dimension
assert example_batch.dims["time"] >= 3

# Print info from file name
file_info = parse_file_parts(dataset_file.value.removesuffix(".nc"))
print(", ".join([f"{k}: {v}" for k, v in file_info.items()]))

example_batch

source: era5, date: 2022-01-01, res: 1.0, levels: 13, steps: 01


  assert example_batch.dims["time"] >= 3


# GraphCast FineTuning

## Load and process local data

### Load surface and atmospherique ERA5 dataset

### Missing data

geopotential_at_surface; land_sea_mask; toa_incident_solar_radiation; 

In [7]:
ds_surface = xr.open_dataset("graphcast_ft_data/data_era5_surface.grib", engine="cfgrib")
ds_surface = ds_surface.rename({
    't2m': '2m_temperature',
    'u10': '10m_u_component_of_wind',
    'v10': '10m_v_component_of_wind',
    'msl': 'mean_sea_level_pressure',
})

ds_tp = xr.open_dataset("graphcast_ft_data/data_era5_surface.grib", engine="cfgrib", backend_kwargs={'filter_by_keys': {'shortName': 'tp'}})
ds_tp = ds_tp.rename({'tp': 'total_precipitation_6hr'})

ds_tisr = xr.open_dataset("graphcast_ft_data/data_era5_surface.grib", engine="cfgrib", backend_kwargs={'filter_by_keys': {'shortName': 'tisr'}})
ds_tisr = ds_tisr.rename({'tisr': 'toa_incident_solar_radiation'})

ds_atmo = xr.open_dataset("graphcast_ft_data/atmo_2days_jan.grib", engine="cfgrib")
ds_atmo = ds_atmo.rename({
    't': 'temperature',
    'z': 'geopotential',
    'u': 'u_component_of_wind',
    'v': 'v_component_of_wind',
    'w': 'vertical_velocity',
    'q': 'specific_humidity',
})

ds_atmo = ds_atmo.drop_vars(['step'], errors='ignore')
ds_atmo = ds_atmo.drop_vars(['number'], errors='ignore')
ds_atmo = ds_atmo.drop_vars(['valid_time'], errors='ignore')
ds_atmo = ds_atmo.rename({'isobaricInhPa': 'level'})

ds_surface = ds_surface.drop_vars(['step'], errors='ignore')
ds_surface = ds_surface.drop_vars(['surface'], errors='ignore')
ds_surface = ds_surface.drop_vars(['number'], errors='ignore')
ds_surface = ds_surface.drop_vars(['valid_time'], errors='ignore')

ds_tp = ds_tp.drop_vars(['step'], errors='ignore')
ds_tp = ds_tp.drop_vars(['surface'], errors='ignore')
ds_tp = ds_tp.drop_vars(['number'], errors='ignore')
ds_tp = ds_tp.drop_vars(['valid_time'], errors='ignore')


skipping variable: paramId==228 shortName='tp'
Traceback (most recent call last):
  File "/opt/homebrew/anaconda3/lib/python3.12/site-packages/cfgrib/dataset.py", line 725, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/opt/homebrew/anaconda3/lib/python3.12/site-packages/cfgrib/dataset.py", line 641, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='time' value=Variable(dimensions=('time',), data=array([1704067200, 1704088800, 1704110400, 1704132000, 1704153600,
       1704175200, 1704196800, 1704218400])) new_value=Variable(dimensions=('time',), data=array([1704045600, 1704088800, 1704132000, 1704175200]))
skipping variable: paramId==212 shortName='tisr'
Traceback (most recent call last):
  File "/opt/homebrew/anaconda3/lib/python3.12/site-packages/cfgrib/dataset.py", line 725, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/opt/homebrew/anaconda3/lib/python3.

In [8]:
start = np.datetime64("2024-01-01T00:00:00.000000000")
end = np.datetime64("2024-01-02T18:00:00.000000000")

ds_tp = ds_tp.sel(time=slice(start, end))
ds_tisr = ds_tisr.sel(time=slice(start, end))

ds_atmo = ds_atmo.sel(time=slice(start, end))

ds_surface = ds_surface.sel(time=slice(start, end))
ds_surface = ds_surface.sel(time=ds_atmo.time)

In [9]:
from utils import quick_extend_times

target_times = np.array([
    '2024-01-01T00:00:00.000000000', 
    '2024-01-01T06:00:00.000000000',
    '2024-01-01T12:00:00.000000000', 
    '2024-01-01T18:00:00.000000000',
    '2024-01-02T00:00:00.000000000', 
    '2024-01-02T06:00:00.000000000',
    '2024-01-02T12:00:00.000000000', 
    '2024-01-02T18:00:00.000000000'
], dtype='datetime64[ns]')

# Extend your existing dataset
extended_ds = quick_extend_times(ds_tp, target_times)
extended_ds = extended_ds.drop_vars('time').rename({'new_time': 'time'})

extended_ds_toa = quick_extend_times(ds_tisr, target_times)
extended_ds_toa = extended_ds_toa.drop_vars('time').rename({'new_time': 'time'})

Current times: ['2024-01-01T06:00:00.000000000' '2024-01-01T18:00:00.000000000'
 '2024-01-02T06:00:00.000000000']
Target times: ['2024-01-01T00:00:00.000000000' '2024-01-01T06:00:00.000000000'
 '2024-01-01T12:00:00.000000000' '2024-01-01T18:00:00.000000000'
 '2024-01-02T00:00:00.000000000' '2024-01-02T06:00:00.000000000'
 '2024-01-02T12:00:00.000000000' '2024-01-02T18:00:00.000000000']
For 2024-01-01 00:00:00, using forecast 2024-01-01 06:00:00 (step 0)
For 2024-01-01 06:00:00, using forecast 2024-01-01 06:00:00 (step 0)
For 2024-01-01 12:00:00, using forecast 2024-01-01 06:00:00 (step 0)
For 2024-01-01 18:00:00, using forecast 2024-01-01 18:00:00 (step 0)
For 2024-01-02 00:00:00, using forecast 2024-01-01 18:00:00 (step 0)
For 2024-01-02 06:00:00, using forecast 2024-01-02 06:00:00 (step 0)
For 2024-01-02 12:00:00, using forecast 2024-01-02 06:00:00 (step 0)
For 2024-01-02 18:00:00, using forecast 2024-01-02 06:00:00 (step 0)
Current times: ['2024-01-01T06:00:00.000000000' '2024-01-01

In [10]:
from utils import clean_and_restructure_for_graphcast

ds_atmo_clean = clean_and_restructure_for_graphcast(ds_atmo)
ds_precip_clean = clean_and_restructure_for_graphcast(extended_ds)
ds_toa_clean = clean_and_restructure_for_graphcast(extended_ds_toa)
ds_surface_clean = clean_and_restructure_for_graphcast(ds_surface)

=== CLEANING AND RESTRUCTURING FOR GRAPHCAST ===
Original coordinates: ['time', 'level', 'latitude', 'longitude']
Original time values: ['2024-01-01T00:00:00.000000000' '2024-01-01T06:00:00.000000000'
 '2024-01-01T12:00:00.000000000' '2024-01-01T18:00:00.000000000'
 '2024-01-02T00:00:00.000000000' '2024-01-02T06:00:00.000000000'
 '2024-01-02T12:00:00.000000000' '2024-01-02T18:00:00.000000000']
Created time deltas: [              0  21600000000000  43200000000000  64800000000000
  86400000000000 108000000000000 129600000000000 151200000000000]
New coordinates: ['datetime', 'level', 'latitude', 'longitude', 'time']
New time values: [              0  21600000000000  43200000000000  64800000000000
  86400000000000 108000000000000 129600000000000 151200000000000]
Datetime values: ['2024-01-01T00:00:00.000000000' '2024-01-01T06:00:00.000000000'
 '2024-01-01T12:00:00.000000000' '2024-01-01T18:00:00.000000000'
 '2024-01-02T00:00:00.000000000' '2024-01-02T06:00:00.000000000'
 '2024-01-02T12:00:

In [11]:
combined = xr.merge([ds_atmo_clean, ds_precip_clean, ds_surface_clean, ds_toa_clean])
combined = combined.expand_dims('batch')
combined = combined.rename({'latitude': 'lat'})
combined = combined.rename({'longitude': 'lon'})
combined = combined.swap_dims({'datetime': 'time'})
combined = combined.sortby('lat')
combined = combined.transpose('lon', 'lat', 'level', 'time', 'batch')
combined = combined.drop_vars(['valid_time'], errors='ignore')
combined['datetime'] = combined['datetime'].expand_dims('batch')

In [12]:
combined['land_sea_mask'] = example_batch['land_sea_mask']
combined['geopotential_at_surface'] = example_batch['geopotential_at_surface']

In [13]:
combined

In [14]:
# Variables that should be (batch, time, lat, lon)
surface_vars = ['2m_temperature', 'mean_sea_level_pressure', '10m_v_component_of_wind', 
                '10m_u_component_of_wind', 'total_precipitation_6hr', 'toa_incident_solar_radiation']

for var in surface_vars:
    if var in combined:
        combined[var] = combined[var].transpose('batch', 'time', 'lat', 'lon')

# Variables that should be (batch, time, level, lat, lon)  
level_vars = ['temperature', 'geopotential', 'u_component_of_wind', 'v_component_of_wind', 
              'vertical_velocity', 'specific_humidity']

for var in level_vars:
    if var in combined:
        combined[var] = combined[var].transpose('batch', 'time', 'level', 'lat', 'lon')

# Static variables (lat, lon) are already correct

### Flip latitudes if necessary

########################################################################################################################

In [15]:
rec = example_batch.copy()

In [16]:
example_batch = combined

In [17]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=example_batch.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=example_batch.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

  min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],


VBox(children=(Dropdown(description='Variable', index=9, options=('geopotential', 'specific_humidity', 'temper…

In [18]:
# @title Plot example data

plot_size = 7

data = {
    " ": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value),
}
fig_title = plot_example_variable.value
if "level" in example_batch[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value)

In [19]:
# @title Choose training and eval data to extract
train_steps = widgets.IntSlider(
    value=1, min=1, max=example_batch.sizes["time"]-2, description="Train steps")
eval_steps = widgets.IntSlider(
    value=example_batch.sizes["time"]-2, min=1, max=example_batch.sizes["time"]-2, description="Eval steps")

widgets.VBox([
    train_steps,
    eval_steps,
    widgets.Label(value="Run the next cell to extract the data. Rerunning this cell clears your selection.")
])

VBox(children=(IntSlider(value=1, description='Train steps', max=6, min=1), IntSlider(value=6, description='Ev…

In [20]:
# @title Extract training and eval data
train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps.value*6}h"),
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)

All Examples:   {'level': 13, 'lat': 721, 'lon': 1440, 'time': 8, 'batch': 1}
Train Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 13}
Train Forcings: {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440}
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Targets:   {'batch': 1, 'time': 6, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Forcings:  {'batch': 1, 'time': 6, 'lat': 721, 'lon': 1440}


array([[1640995200, 1641016800, 1641038400]])
<xarray.DataArray 'datetime' (batch: 1, time: 3)> Size: 24B
array([['2022-01-01T00:00:00.000000000', '2022-01-01T06:00:00.000000000',
        '2022-01-01T12:00:00.000000000']], dtype='datetime64[ns]')
Coordinates:
  * time      (time) timedelta64[ns] 24B 00:00:00 06:00:00 12:00:00
    datetime  (batch, time) datetime64[ns] 24B 2022-01-01 ... 2022-01-01T12:0...
Dimensions without coordinates: batch

In [21]:
# Define local paths
stats_dir = "stats"

diffs_stddev_by_level = xr.load_dataset(os.path.join(stats_dir, "diffs_stddev_by_level.nc")).compute()
mean_by_level         = xr.load_dataset(os.path.join(stats_dir, "mean_by_level.nc")).compute()
stddev_by_level       = xr.load_dataset(os.path.join(stats_dir, "stddev_by_level.nc")).compute()

In [22]:
# @title Build jitted functions, and possibly initialize random weights

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)
  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets_template=train_targets,
      forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

# Run the model

Note that the cell below may take a while (possibly minutes) to run the first time you execute them, because this will include the time it takes for the code to compile. The second time running will be significantly faster.

This use the python loop to iterate over prediction steps, where the 1-step prediction is jitted. This has lower memory requirements than the training steps below, and should enable making prediction with the small GraphCast model on 1 deg resolution data for 4 steps.

In [23]:
local_ckpt_path = file_paths_1[1]

# Open and load the checkpoint locally
with open(local_ckpt_path, "rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

params = ckpt.params
model_config = ckpt.model_config
task_config = ckpt.task_config
state = {}

In [24]:
# @title Autoregressive rollout (loop in python)

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)  
print("Targets: ", eval_targets.dims.mapping) 
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings
    )
predictions

Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Targets:  {'batch': 1, 'time': 6, 'lat': 721, 'lon': 1440, 'level': 13}
Forcings: {'batch': 1, 'time': 6, 'lat': 721, 'lon': 1440}


  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']


KeyboardInterrupt: 

In [None]:
# @title Choose predictions to plot
plot_pred_variable = widgets.Dropdown(
    options=predictions.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=predictions.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=predictions.dims["time"],
    value=predictions.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

NameError: name 'predictions' is not defined

In [None]:
1/0

ZeroDivisionError: division by zero

In [None]:
# @title Plot predictions
plot_size = 5
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in predictions[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)


  plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)


In [None]:
import haiku as hk
from graphcast import graphcast
import jax

# Create a function to inspect GraphCast's actual internal structure
@hk.transform_with_state
def inspect_graphcast_structure(model_config, task_config, inputs, targets_template, forcings):
    # Instantiate the core GraphCast model
    model = graphcast.GraphCast(model_config, task_config)
    
    # GraphCast has three main GNN components:
    print("GraphCast components:")
    print(f"1. Grid2Mesh GNN: {model._grid2mesh_gnn}")
    print(f"2. Mesh GNN: {model._mesh_gnn}")  
    print(f"3. Mesh2Grid GNN: {model._mesh2grid_gnn}")
    
    # Initialize the model (this creates the graph structures)
    model._maybe_init(inputs)
    
    # The actual flow is:
    # 1. Convert inputs to grid node features
    grid_node_features = model._inputs_to_grid_node_features(inputs, forcings)
    print(f"\nGrid node features shape: {grid_node_features.shape}")
    
    # 2. Grid2Mesh: Transfer data from grid to mesh
    latent_mesh_nodes, latent_grid_nodes = model._run_grid2mesh_gnn(grid_node_features)
    print(f"Latent mesh nodes shape: {latent_mesh_nodes.shape}")
    print(f"Latent grid nodes shape: {latent_grid_nodes.shape}")
    
    # 3. Mesh GNN: Process on the mesh
    updated_latent_mesh_nodes = model._run_mesh_gnn(latent_mesh_nodes)
    print(f"Updated mesh nodes shape: {updated_latent_mesh_nodes.shape}")
    
    # 4. Mesh2Grid: Transfer back to grid
    output_grid_nodes = model._run_mesh2grid_gnn(updated_latent_mesh_nodes, latent_grid_nodes)
    print(f"Output grid nodes shape: {output_grid_nodes.shape}")
    
    # 5. Convert back to xarray format
    outputs = model._grid_node_outputs_to_prediction(output_grid_nodes, targets_template)
    
    return outputs

# To use it with your existing code:
# Initialize and run
params, state = inspect_graphcast_structure.init(
    jax.random.PRNGKey(0),
    model_config,
    task_config, 
    train_inputs,
    train_targets,
    train_forcings
)

# This will print the model structure and component shapes
outputs, state = inspect_graphcast_structure.apply(
    params,
    state,
    jax.random.PRNGKey(0),
    model_config,
    task_config,
    train_inputs,
    train_targets,
    train_forcings
)

# To access just the GNN components without running the full model:
@hk.transform
def get_gnn_components(model_config, task_config):
    model = graphcast.GraphCast(model_config, task_config)
    return {
        'grid2mesh_gnn': model._grid2mesh_gnn,
        'mesh_gnn': model._mesh_gnn,
        'mesh2grid_gnn': model._mesh2grid_gnn
    }

GraphCast components:
1. Grid2Mesh GNN: DeepTypedGraphNet(
    embed_nodes=True,
    embed_edges=True,
    edge_latent_size={'grid2mesh': 512},
    node_latent_size={'grid_nodes': 512, 'mesh_nodes': 512},
    mlp_hidden_size=512,
    mlp_num_hidden_layers=1,
    num_message_passing_steps=1,
    use_layer_norm=True,
    include_sent_messages_in_node_update=False,
    activation='swish',
    f32_aggregation=True,
    aggregate_normalization=None,
    name='grid2mesh_gnn',
)
2. Mesh GNN: DeepTypedGraphNet(
    embed_nodes=False,
    embed_edges=True,
    node_latent_size={'mesh_nodes': 512},
    edge_latent_size={'mesh': 512},
    mlp_hidden_size=512,
    mlp_num_hidden_layers=1,
    num_message_passing_steps=16,
    use_layer_norm=True,
    include_sent_messages_in_node_update=False,
    activation='swish',
    f32_aggregation=False,
    name='mesh_gnn',
)
3. Mesh2Grid GNN: DeepTypedGraphNet(
    node_output_size={'grid_nodes': 83},
    embed_nodes=False,
    embed_edges=True,
    edge_l

KeyboardInterrupt: 

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
from graphcast import graphcast
from graphcast import xarray_jax
import numpy as np
import xarray

@hk.transform_with_state
def step_by_step_inference(model_config, task_config, inputs, targets_template, forcings):
    """Run GraphCast inference step by step through each component."""

    # 1. Create the model
    model = graphcast.GraphCast(model_config, task_config)

    # 2. Initialize the model (creates graph structures)
    model._maybe_init(inputs)

    print("=== Step 1: Convert inputs to grid node features ===")
    # This stacks all variables and reshapes to [num_grid_nodes, batch, channels]
    grid_node_features = model._inputs_to_grid_node_features(inputs, forcings)
    print(f"Grid node features shape: {grid_node_features.shape}")
    print(f"  - num_grid_nodes: {grid_node_features.shape[0]}")
    print(f"  - batch_size: {grid_node_features.shape[1]}")
    print(f"  - num_channels: {grid_node_features.shape[2]}")

    print("\n=== Step 2: Grid2Mesh GNN ===")
    # Transfer data from grid to mesh nodes
    latent_mesh_nodes, latent_grid_nodes = model._run_grid2mesh_gnn(grid_node_features)
    print(f"Latent mesh nodes shape: {latent_mesh_nodes.shape}")
    print(f"Latent grid nodes shape: {latent_grid_nodes.shape}")
    print(f"  - Mesh has {model._num_mesh_nodes} nodes")
    print(f"  - Grid has {model._num_grid_nodes} nodes")
    print(f"  - Latent size: {model_config.latent_size}")

    print("\n=== Step 3: Mesh GNN ===")
    # Process on the multi-scale mesh
    updated_latent_mesh_nodes = model._run_mesh_gnn(latent_mesh_nodes)
    print(f"Updated mesh nodes shape: {updated_latent_mesh_nodes.shape}")
    print(f"  - {model_config.gnn_msg_steps} message passing steps")

    print("\n=== Step 4: Mesh2Grid GNN ===")
    # Transfer back from mesh to grid
    output_grid_nodes = model._run_mesh2grid_gnn(
        updated_latent_mesh_nodes, latent_grid_nodes)
    print(f"Output grid nodes shape: {output_grid_nodes.shape}")
    print(f"  - Output channels: {output_grid_nodes.shape[2]}")
    
    print("\n=== Step 5: Convert outputs back to xarray ===")
    # Reshape back to lat/lon grid format
    predictions = model._grid_node_outputs_to_prediction(
        output_grid_nodes, targets_template)
    print(f"Predictions shape: {predictions.sizes}")

    return predictions

# Create a wrapper for autoregressive rollout that shows each step
def manual_autoregressive_rollout(
    run_forward_jitted=run_forward_jitted,  # Use the pre-compiled function from notebook
    eval_inputs=eval_inputs,
    eval_targets=eval_targets,
    eval_forcings=eval_forcings,
    num_steps=None):
    """Manually perform autoregressive rollout to see each step."""

    if num_steps is None:
        num_steps = eval_targets.dims['time']

    # Initialize with the input
    current_inputs = eval_inputs
    predictions_list = []

    print(f"Starting autoregressive rollout for {num_steps} steps\n")

    for step in range(num_steps):
        print(f"Step {step + 1}/{num_steps}")

        # Get forcings for this timestep
        step_forcings = eval_forcings.isel(time=slice(step, step+1))

        # Create targets template for this step
        step_targets_template = eval_targets.isel(time=slice(step, step+1)) * np.nan

        # Run one forward pass using the pre-compiled function
        step_predictions = run_forward_jitted(
            rng=jax.random.PRNGKey(0),
            inputs=current_inputs, 
            targets_template=step_targets_template, 
            forcings=step_forcings
        )

        predictions_list.append(step_predictions)
        
        # For autoregressive rollout, use predictions as next input
        # (keeping the first time from original input for context)
        if step < num_steps - 1:
            # Concatenate original first frame with new prediction
            current_inputs = xarray.concat([
                eval_inputs.isel(time=slice(0, 1)),
                step_predictions
            ], dim='time')
        
        print(f"  Prediction shape: {step_predictions.sizes}")
    
    # Concatenate all predictions
    all_predictions = xarray.concat(predictions_list, dim='time')
    print(f"\nFinal predictions shape: {all_predictions.sizes}")
    
    return all_predictions


# Example usage with your notebook:
# """
# IMPORTANT: Use the pre-trained params from your checkpoint, not new ones!
# The error occurs because you're trying to initialize new parameters
# instead of using the loaded checkpoint parameters.

# First, let's see what happens in a single forward pass using loaded params:
single_prediction, _ = step_by_step_inference.apply(
    params,  # Use the loaded params from checkpoint
    state,   # Use the loaded state
    model_config,
    task_config,
    eval_inputs,
    eval_targets.isel(time=slice(0, 1)) * np.nan,
    eval_forcings.isel(time=slice(0, 1)),
    rng=jax.random.PRNGKey(0),
)

# # Now run the full autoregressive rollout manually:
# predictions_manual = manual_autoregressive_rollout(
#     # params, # Use loaded params
#     # state , # Use loaded state
#     # model_config,
#     # task_config,
#     eval_inputs=eval_inputs, 
#     eval_targets=eval_targets, 
#     eval_forcings=eval_forcings,
#     num_steps=3, # Just do 3 steps for demonstration
#     run_forward_jitted=run_forward_jitted,  # Use the pre-compiled function from notebook
# )

# # Alternative: If you want to see the input feature count:
# print(f"Model expects {params['grid2mesh_gnn']['~_networks_builder']['encoder_nodes_grid_nodes_mlp']['~']['linear_0']['w'].shape[0]} input features")

# # Check your actual input features:
# grid_features = model_utils.dataset_to_stacked(eval_inputs)
# forcings_features = model_utils.dataset_to_stacked(eval_forcings.isel(time=0))
# total_features = grid_features.sizes['channels'] + forcings_features.sizes['channels']
# print(f"You're providing {total_features} input features")
# # """

TypeError: transform_with_state.<locals>.apply_fn() got multiple values for argument 'rng'

In [None]:
# 1. Create the model
model = graphcast.GraphCast(model_config, task_config)

# 2. Initialize the model (creates graph structures)
model._maybe_init(inputs)

ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
from graphcast import graphcast
from graphcast import xarray_jax
import numpy as np
import xarray

@hk.transform_with_state
def step_by_step_inference(model_config, task_config, inputs, targets_template, forcings):
    """Run GraphCast inference step by step through each component."""

    # 1. Create the model
    model = graphcast.GraphCast(model_config, task_config)
    # import pdb; pdb.set_trace()
    # 2. Initialize the model (creates graph structures)
    model._maybe_init(inputs)
    
    print("=== Step 1: Convert inputs to grid node features ===")
    # This stacks all variables and reshapes to [num_grid_nodes, batch, channels]
    grid_node_features = model._inputs_to_grid_node_features(inputs, forcings)
    print(f"Grid node features shape: {grid_node_features.shape}")
    print(f"  - num_grid_nodes: {grid_node_features.shape[0]}")
    print(f"  - batch_size: {grid_node_features.shape[1]}")
    print(f"  - num_channels: {grid_node_features.shape[2]}")
    
    print("\n=== Step 2: Grid2Mesh GNN ===")
    # Transfer data from grid to mesh nodes
    latent_mesh_nodes, latent_grid_nodes = model._run_grid2mesh_gnn(grid_node_features)
    print(f"Latent mesh nodes shape: {latent_mesh_nodes.shape}")
    print(f"Latent grid nodes shape: {latent_grid_nodes.shape}")
    print(f"  - Mesh has {model._num_mesh_nodes} nodes")
    print(f"  - Grid has {model._num_grid_nodes} nodes")
    print(f"  - Latent size: {model_config.latent_size}")
    
    print("\n=== Step 3: Mesh GNN ===")
    # Process on the multi-scale mesh
    updated_latent_mesh_nodes = model._run_mesh_gnn(latent_mesh_nodes)
    print(f"Updated mesh nodes shape: {updated_latent_mesh_nodes.shape}")
    print(f"  - {model_config.gnn_msg_steps} message passing steps")
    
    print("\n=== Step 4: Mesh2Grid GNN ===")
    # Transfer back from mesh to grid
    output_grid_nodes = model._run_mesh2grid_gnn(
        updated_latent_mesh_nodes, latent_grid_nodes)
    print(f"Output grid nodes shape: {output_grid_nodes.shape}")
    print(f"  - Output channels: {output_grid_nodes.shape[2]}")
    
    print("\n=== Step 5: Convert outputs back to xarray ===")
    # Reshape back to lat/lon grid format
    predictions = model._grid_node_outputs_to_prediction(
        output_grid_nodes, targets_template)
    print(f"Predictions shape: {predictions.sizes}")
    
    return predictions


# Create a wrapper for autoregressive rollout that shows each step
def manual_autoregressive_rollout(
    run_forward_jitted,  # Use the pre-compiled function from notebook
    eval_inputs, eval_targets, eval_forcings, num_steps=None):
    """Manually perform autoregressive rollout to see each step."""
    
    if num_steps is None:
        num_steps = eval_targets.dims['time']
    
    # Initialize with the input
    current_inputs = eval_inputs
    predictions_list = []
    
    print(f"Starting autoregressive rollout for {num_steps} steps\n")
    
    for step in range(num_steps):
        import pdb; pdb.set_trace()
        print(f"Step {step + 1}/{num_steps}")
        
        # Get forcings for this timestep
        step_forcings = eval_forcings.isel(time=slice(step, step+1))
        
        # Create targets template for this step
        step_targets_template = eval_targets.isel(time=slice(step, step+1)) * np.nan
        
        # Run one forward pass using the pre-compiled function
        step_predictions = run_forward_jitted(
            rng=jax.random.PRNGKey(0),  # Add the required rng parameter
            inputs=current_inputs, 
            targets_template=step_targets_template, 
            forcings=step_forcings
        )
        
        predictions_list.append(step_predictions)
        
        # For autoregressive rollout, use predictions as next input
        # (keeping the first time from original input for context)
        if step < num_steps - 1:
            # Get static variables from original inputs (they don't change)
            static_vars = ["geopotential_at_surface", "land_sea_mask"]
            
            # Create next input by combining:
            # 1. Original first frame for context
            # 2. Current predictions for target variables
            # 3. Static variables from original inputs
            
            next_inputs = eval_inputs.isel(time=slice(0, 1)).copy()
            
            # Update with predictions for target variables only
            for var in step_predictions.data_vars:
                if var in next_inputs:
                    next_inputs[var] = step_predictions[var]
            
            # Concatenate context frame with updated frame
            current_inputs = xarray.concat([
                eval_inputs.isel(time=slice(0, 1)),
                next_inputs
            ], dim='time')
        
        print(f"  Prediction shape: {step_predictions.sizes}")
    
    # Concatenate all predictions
    all_predictions = xarray.concat(predictions_list, dim='time')
    print(f"\nFinal predictions shape: {all_predictions.sizes}")
    
    return all_predictions

# First, let's see what happens in a single forward pass using loaded params:
single_prediction, _ = step_by_step_inference.apply(
    params,  # Use the loaded params from checkpoint
    state,   # Use the loaded state
    jax.random.PRNGKey(0),
    model_config,
    task_config,
    eval_inputs,
    targets_template=eval_targets * np.nan,  #.isel(time=slice(0,1)) * np.nan Just 1 step
    forcings=eval_forcings # .isel(time=slice(0,1))
)

# # Now run the full autoregressive rollout manually:
# predictions_manual = manual_autoregressive_rollout(
#     run_forward_jitted = run_forward_jitted,
#     eval_inputs=eval_inputs,
#     eval_targets=eval_targets,
#     eval_forcings=eval_forcings,
#     num_steps=3,
# )

# # Alternative: If you want to see the input feature count:
# print(f"Model expects {params['grid2mesh_gnn']['~_networks_builder']['encoder_nodes_grid_nodes_mlp']['~']['linear_0']['w'].shape[0]} input features")

# # Check your actual input features:
# grid_features = model_utils.dataset_to_stacked(eval_inputs)
# forcings_features = model_utils.dataset_to_stacked(eval_forcings.isel(time=0))
# total_features = grid_features.sizes['channels'] + forcings_features.sizes['channels']
# print(f"You're providing {total_features} input features")

Split level 0: 38 vertices, 66 faces
Split level 1: 158 vertices, 264 faces
=== Step 1: Convert inputs to grid node features ===
Grid node features shape: (65160, 1, 183)
  - num_grid_nodes: 65160
  - batch_size: 1
  - num_channels: 183

=== Step 2: Grid2Mesh GNN ===


In [None]:
# @title Plot predictions
plot_size = 5
plot_max_steps = min(single_prediction.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(single_prediction, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(single_prediction, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in single_prediction[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)

In [None]:
# @title Plot predictions
plot_size = 5
plot_max_steps = min(single_prediction.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(single_prediction, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(single_prediction, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in single_prediction[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)

  plot_max_steps = min(single_prediction.dims["time"], plot_pred_max_steps.value)


# Train the model

Finetuning GraphCast to get Extended GraphCast works as follows:
* 

In [None]:
# @title Loss computation (autoregressive loss over multiple steps)
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
print("Loss:", float(loss))

In [None]:
# @title Gradient computation (backprop through time)
loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

In [None]:
# @title Autoregressive rollout (keep the loop in JAX)
print("Inputs:  ", train_inputs.dims.mapping)
print("Targets: ", train_targets.dims.mapping)
print("Forcings:", train_forcings.dims.mapping)

predictions =   (
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)
predictions