# Master's thesis - Lukas Meuris - graphCast training - loss functions

This notebook contains the code to train the graphCast model. 
We use this notebook to traint he graphcast model with different loss-functions to see the influence of it on the model performance.

 

In [2]:
# @title Pip install graphcast and dependencies

!pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip


Collecting https://github.com/deepmind/graphcast/archive/master.zip
  Using cached https://github.com/deepmind/graphcast/archive/master.zip
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting cartopy (from graphcast==0.1)
  Downloading Cartopy-0.23.0-cp312-cp312-win_amd64.whl.metadata (8.2 kB)
Collecting chex (from graphcast==0.1)
  Using cached chex-0.1.86-py3-none-any.whl.metadata (17 kB)
Collecting colabtools (from graphcast==0.1)
  Using cached colabtools-0.0.1-py3-none-any.whl.metadata (511 bytes)
Collecting dask (from graphcast==0.1)
  Using cached dask-2024.4.2-py3-none-any.whl.metadata (3.8 kB)
Collecting dm-haiku (from graphcast==0.1)
  Using cached dm_haiku-0.0.12-py3-none-any.whl.metadata (19 kB)
Collecting dm-tree (from graphcast==0.1)
  Downloading dm_tree-0.1.8-cp312-cp312-win_amd64.whl.metadata (2.0 kB)
Collecting jax (from graphcast==0.1)
  Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB)
Collecting j

# Installation and initialisation.


In [3]:
# @title Imports
import sys
sys.path.append("../")

import dataclasses
import functools

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import data_utils
from graphcast import graphcast
from graphcast import rollout
from graphcast import normalization
from graphcast import xarray_jax
from graphcast import xarray_tree
import haiku as hk
import jax
import numpy as np
import xarray


def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


# Access data from GCS

In [4]:
# @title Authenticate with Google Cloud Storage

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

# Load the Data and initialize the model

## Load the model parameters

We use random parameters for the model initialization. 
We'll get random predictions, but we can change the model architecture.


*Checkpoints vary across a few axes:*
- *The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model.*
- *The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder.*
- *All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as "ERA5" take precipitation as input and expect ERA5 data as input, while model marked "ERA5-HRES" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below).*


In [13]:
# @title choose model parameters
random_mesh_size = 5 # mesh size: 4 - 6
random_gnn_msg_steps = 16 # message passing steps: 1 - 32
random_latent_size = 512 # latent size: 16,32,64,128,256,512
random_levels = 13 # levels: 13 or 37


In [14]:
# @title load the model
params = None  # Filled in below
state = {}
model_config = graphcast.ModelConfig(
    resolution=0,
    mesh_size=random_mesh_size,
    latent_size=random_latent_size,
    gnn_msg_steps=random_gnn_msg_steps,
    hidden_layers=1,
    radius_query_fraction_edge_length=0.6)
task_config = graphcast.TaskConfig(
    input_variables=graphcast.TASK.input_variables,
    target_variables=graphcast.TASK.target_variables,
    forcing_variables=graphcast.TASK.forcing_variables,
    pressure_levels=graphcast.PRESSURE_LEVELS[random_levels],
    input_duration=graphcast.TASK.input_duration,
)
model_config

ModelConfig(resolution=0, mesh_size=5, latent_size=512, gnn_msg_steps=16, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)

## Load the ERA5 data


In [15]:
# @title Get and filter the list of available example datasets, then select one.
#todo: replace this dataset with the ERA5 dataset from weatherbench2 (.zarr file)

with open('ERA5_data/1959-2022-6h-64x32_equiangular_conservative.zarr', 'rb') as f:
    obs_data = xarray.load_dataset(f).compute()

obs_data

FileNotFoundError: [Errno 2] No such file or directory: 'ERA5_data/1959-2022-6h-64x32_equiangular_conservative.zarr'

In [16]:
# @title Choose training and eval data to extract
train_steps = 12  # {1 - example_batch.sizes["time"]-2} | 12 = 3days
eval_steps = 40 # {1 - example_batch.sizes["time"]-2} | 40 = 10days


In [17]:
# @extract the training data (time and variables)
#time slice: 1980-01-01T00:00:00.000000000 to 2019-12-31T06:00:00.000000000 - TRAINING
train_data = obs_data.sel(time=slice('1980-01-01T00:00:00.000000000','2019-12-31T06:00:00.000000000'))
# remove unneeded vars
train_data

NameError: name 'obs_data' is not defined

In [18]:
# @extract the evaluation data (time and variables)

#time slice: 2020-01-01T00:00:00.000000000 to 2022-01-01T00:00:00.000000000 - EVALUATION
eval_data = obs_data.sel(time=slice('2020-01-01T00:00:00.000000000','2021-01-01T00:00:00.000000000'))
# remove unneeded vars
eval_data

NameError: name 'obs_data' is not defined

In [19]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    obs_data, target_lead_times=slice("6h", f"{train_steps*6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    obs_data, target_lead_times=slice("6h", f"{eval_steps*6}h"),
    **dataclasses.asdict(task_config))

print("All Examples:  ", obs_data.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)

NameError: name 'obs_data' is not defined

In [20]:
# @title Load normalization data

with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()


In [21]:
# @title Build jitted functions, and possibly initialize random weights

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is required by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets_template=train_targets,
      forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

NameError: name 'train_inputs' is not defined

In [22]:
# @title config check

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")


NameError: name 'eval_inputs' is not defined

# Train the model

The following operations require a large amount of memory and, depending on the accelerator being used, will only fit the very small "random" model on low resolution data. It uses the number of training steps selected above.

The first time executing the cell takes more time, as it include the time to jit the function.

In [44]:
# @title Loss computation (autoregressive loss over multiple steps)
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
print("Loss:", float(loss))

  scan_length = targets.dims['time']


Loss: 224.1053466796875


# @title Gradient computation (backprop through time)

loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")


# Run the model

In [46]:
# @title Autoregressive rollout (loop in python)

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Targets:  {'batch': 1, 'time': 12, 'lat': 181, 'lon': 360, 'level': 13}
Forcings: {'batch': 1, 'time': 12, 'lat': 181, 'lon': 360}


# Save the evaluation run to file


In [ ]:
# Specify the path where you want to save the Zarr file
zarr_path = "Evaluation_runs/mse_64x32_2020.zarr"

# Save the dataset to the Zarr file
predictions.to_zarr(zarr_path)

# Save the model to file

In [47]:
# @title saving the model
# Save the model
np.savez("models/model_64x32_mse.npz", **hk.data_structures.to_mutable_dict(params))

Now, our trained model is saved to a file ,which can be used to load and run again.