# Master's thesis - Lukas Meuris - graphCast training - loss functions

This notebook contains the code to train the graphCast model. 
We use this notebook to traint he graphcast model with different loss-functions to see the influence of it on the model performance.

 

In [1]:
# @title Pip install graphcast and dependencies

#!pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip


# Installation and initialisation.


In [2]:
# @title Imports
import sys
sys.path.append("../")

import dataclasses
import functools

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import data_utils
from graphcast import graphcast
from graphcast import rollout
from graphcast import normalization
from graphcast import xarray_jax
from graphcast import xarray_tree
import haiku as hk
import jax
import numpy as np
import xarray

import optax

import os
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

# Access data from GCS

In [3]:
# @title Authenticate with Google Cloud Storage

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

# Load the Data and initialize the model

## Load the model parameters

We use random parameters for the model initialization. 
We'll get random predictions, but we can change the model architecture.


*Checkpoints vary across a few axes:*
- *The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model.*
- *The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder.*
- *All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as "ERA5" take precipitation as input and expect ERA5 data as input, while model marked "ERA5-HRES" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below).*


In [4]:
# choose model parameters
random_mesh_size = 5 # mesh size: 4 - 6
random_gnn_msg_steps = 8 # message passing steps: 1 - 32
random_latent_size = 128 # latent size: 16,32,64,128,256,512
random_levels = 13 # levels: 13 or 37


In [5]:
# load the model parameters
params = None
state = {}
model_config = graphcast.ModelConfig(
    resolution=0,
    mesh_size=random_mesh_size,
    latent_size=random_latent_size,
    gnn_msg_steps=random_gnn_msg_steps,
    hidden_layers=1,
    radius_query_fraction_edge_length=0.6)
task_config = graphcast.TaskConfig(
    input_variables=graphcast.TASK.input_variables,
    target_variables=graphcast.TASK.target_variables,
    forcing_variables=graphcast.TASK.forcing_variables,
    pressure_levels=graphcast.PRESSURE_LEVELS[random_levels],
    input_duration=graphcast.TASK.input_duration,
)
model_config

ModelConfig(resolution=0, mesh_size=5, latent_size=128, gnn_msg_steps=8, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)

# Load the ERA5 data


In [6]:
# Define the relative path to the file
relative_path = "ERA5_data/obs_data.zarr"

# Get the absolute path by joining the current directory with the relative path
absolute_path = os.path.join(os.path.dirname(os.getcwd()), relative_path)
print(absolute_path)

# Open the Zarr file using xarray
obs_data = xarray.open_zarr(absolute_path)

/home/jupyter-lukas/Masters-Thesis/ERA5_data/obs_data.zarr


In [7]:
# view obs_data:
obs_data

Unnamed: 0,Array,Chunk
Bytes,719.09 kiB,179.77 kiB
Shape,"(1, 92044)","(1, 23011)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,datetime64[ns] numpy.ndarray,datetime64[ns] numpy.ndarray
"Array Chunk Bytes 719.09 kiB 179.77 kiB Shape (1, 92044) (1, 23011) Dask graph 4 chunks in 2 graph layers Data type datetime64[ns] numpy.ndarray",92044  1,

Unnamed: 0,Array,Chunk
Bytes,719.09 kiB,179.77 kiB
Shape,"(1, 92044)","(1, 23011)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,datetime64[ns] numpy.ndarray,datetime64[ns] numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 719.09 MiB 800.00 kiB Shape (1, 92044, 64, 32) (1, 100, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  32  64  92044,

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 719.09 MiB 800.00 kiB Shape (1, 92044, 64, 32) (1, 100, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  32  64  92044,

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 719.09 MiB 800.00 kiB Shape (1, 92044, 64, 32) (1, 100, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  32  64  92044,

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.13 GiB 10.16 MiB Shape (1, 92044, 13, 64, 32) (1, 100, 13, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",92044  1  32  64  13,

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.00 kiB,8.00 kiB
Shape,"(1, 64, 32)","(1, 64, 32)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8.00 kiB 8.00 kiB Shape (1, 64, 32) (1, 64, 32) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",32  64  1,

Unnamed: 0,Array,Chunk
Bytes,8.00 kiB,8.00 kiB
Shape,"(1, 64, 32)","(1, 64, 32)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.00 kiB,8.00 kiB
Shape,"(1, 64, 32)","(1, 64, 32)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8.00 kiB 8.00 kiB Shape (1, 64, 32) (1, 64, 32) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",32  64  1,

Unnamed: 0,Array,Chunk
Bytes,8.00 kiB,8.00 kiB
Shape,"(1, 64, 32)","(1, 64, 32)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 719.09 MiB 800.00 kiB Shape (1, 92044, 64, 32) (1, 100, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  32  64  92044,

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.13 GiB 10.16 MiB Shape (1, 92044, 13, 64, 32) (1, 100, 13, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",92044  1  32  64  13,

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.13 GiB 10.16 MiB Shape (1, 92044, 13, 64, 32) (1, 100, 13, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",92044  1  32  64  13,

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 719.09 MiB 800.00 kiB Shape (1, 92044, 64, 32) (1, 100, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  32  64  92044,

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 719.09 MiB 800.00 kiB Shape (1, 92044, 64, 32) (1, 100, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  32  64  92044,

Unnamed: 0,Array,Chunk
Bytes,719.09 MiB,800.00 kiB
Shape,"(1, 92044, 64, 32)","(1, 100, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.13 GiB 10.16 MiB Shape (1, 92044, 13, 64, 32) (1, 100, 13, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",92044  1  32  64  13,

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.13 GiB 10.16 MiB Shape (1, 92044, 13, 64, 32) (1, 100, 13, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",92044  1  32  64  13,

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.13 GiB 10.16 MiB Shape (1, 92044, 13, 64, 32) (1, 100, 13, 64, 32) Dask graph 921 chunks in 2 graph layers Data type float32 numpy.ndarray",92044  1  32  64  13,

Unnamed: 0,Array,Chunk
Bytes,9.13 GiB,10.16 MiB
Shape,"(1, 92044, 13, 64, 32)","(1, 100, 13, 64, 32)"
Dask graph,921 chunks in 2 graph layers,921 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## extract the training data

In [8]:
#time slice: 1980-01-01T00:00:00.000000000 to 2019-12-31T00:00:00.000000000 - TRAINING
train_data = obs_data.sel(time=slice('1980-01-01T00:00:00.000000000','2019-12-31T00:00:00.000000000'))

## extract the evaluation data

In [9]:
#time slice: 2020-01-01T00:00:00.000000000 to 2022-01-01T00:00:00.000000000 - EVALUATION
eval_data = obs_data.sel(time=slice('2020-01-01T00:00:00.000000000','2021-01-01T00:00:00.000000000'))

## choose number of training and evaluation steps.

In [10]:
# @title Choose training and eval data to extract
train_steps_max = 12  # {1 - obs_data.sizes["time"]-2} | 12 = 3days
eval_steps = 40 # {1 - obs_data.sizes["time"]-2} | 40 = 10days

## extract training and eval inputs, targets and forcings.

In [11]:

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    train_data, target_lead_times=slice("6h", f"{train_steps_max*6}h"),
    **dataclasses.asdict(task_config))

## Load normalization data

In [12]:

with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()


## Build jitted functions, and possibly initialize random weights

In [13]:

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, inputs, targets, forcings, model_config, task_config):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is required by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs.compute(),
      targets_template=train_targets.compute(),
      forcings=train_forcings.compute())

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = jax.jit(with_configs(grads_fn))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

# Model training loop

## 1. general training

In [14]:
# define number of training steps
train_steps = 1
# data size
N = train_data.sizes['time'] - train_steps - 4

loss_array = []

#setup optimiser
lr = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=N)
optimiser = optax.adamw(lr, b1=0.9, b2=0.95, weight_decay=0.1)
opt_state = optimiser.init(params)

In [15]:
# training loop

for i in range(N):
    train_batch = train_data.isel(time=slice(i, i + train_steps + 2))
    train_batch = train_batch.compute()

    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    train_batch, target_lead_times=slice("6h", f"{train_steps*6}h"),
    **dataclasses.asdict(task_config))

    # calculate loss and gradients
    loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, train_inputs, train_targets, train_forcings)

    # update
    updates, opt_state = optimiser.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    loss_array.append(loss)
    if i%1000 == 0:
        print("I:", i , " - Loss:", loss)

print("general training finished.")

I: 0  - Loss: 8.272949
I: 1000  - Loss: 1.354065
I: 2000  - Loss: 1.2044067
I: 3000  - Loss: 1.3353882
I: 4000  - Loss: 1.4089661
I: 5000  - Loss: 1.1546631
I: 6000  - Loss: 1.1311035
I: 7000  - Loss: 1.2196045
I: 8000  - Loss: 1.1618042
I: 9000  - Loss: 1.1314697
I: 10000  - Loss: 1.1262817
I: 11000  - Loss: 1.0549316
I: 12000  - Loss: 1.175354
I: 13000  - Loss: 1.0533447
I: 14000  - Loss: 1.0359497
I: 15000  - Loss: 1.0219727
I: 16000  - Loss: 0.9800415
I: 17000  - Loss: 0.9372864
I: 18000  - Loss: 0.8530884
I: 19000  - Loss: 0.9413147
I: 20000  - Loss: 0.77182007
I: 21000  - Loss: 0.8383484
I: 22000  - Loss: 0.83303833
I: 23000  - Loss: 0.77038574
I: 24000  - Loss: 0.81570435
I: 25000  - Loss: 0.91464233
I: 26000  - Loss: 0.74453735
I: 27000  - Loss: 0.7640381
I: 28000  - Loss: 0.68673706
I: 29000  - Loss: 0.72280884
I: 30000  - Loss: 0.72576904
I: 31000  - Loss: 0.6869812
I: 32000  - Loss: 0.7284851
I: 33000  - Loss: 0.68400574
I: 34000  - Loss: 0.746521
I: 35000  - Loss: 0.7279968

## 2.Fine tuning

In [16]:
# data size
N = train_data.sizes['time'] - train_steps_max - 4
# only take the last 11000 time steps
Ksteps = 11000
Ktime = N - Ksteps
train_steps = 1
loss_array = []

#setup optimiser
lr = 1e-7
optimiser = optax.adamw(lr, b1=0.9, b2=0.95, weight_decay=0.1)
opt_state = optimiser.init(params)

In [17]:
# training loop

for i in range(Ksteps):
    
    train_batch = train_data.isel(time=slice(Ktime + i,Ktime +  i + train_steps + 2))
    train_batch = train_batch.compute()

    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    train_batch, target_lead_times=slice("6h", f"{train_steps*6}h"),
    **dataclasses.asdict(task_config))

    # calculate loss and gradients
    loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, train_inputs, train_targets, train_forcings)

    # update
    updates, opt_state = optimiser.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    loss_array.append(loss)
    if i%1000 == 0:
        train_steps += 1
        print("I:", i ," - steps: ",train_steps," - Loss:", loss)
        

print("finetuning training finished.")

I: 0  - steps:  2  - Loss: 0.66973877
I: 1000  - steps:  3  - Loss: 1.078125
I: 2000  - steps:  4  - Loss: 1.3567709
I: 3000  - steps:  5  - Loss: 2.09375
I: 4000  - steps:  6  - Loss: 2.6304688
I: 5000  - steps:  7  - Loss: 3.280599
I: 6000  - steps:  8  - Loss: 2.9866073
I: 7000  - steps:  9  - Loss: 3.9746094
I: 8000  - steps:  10  - Loss: 4.2421875
I: 9000  - steps:  11  - Loss: 4.466016
I: 10000  - steps:  12  - Loss: 4.7979403


ValueError: 'grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0/w' with retrieved shape (186, 128) does not match shape=[98, 128] dtype=dtype(bfloat16)

# Save the model params to file

In [19]:
import jax
import numpy as np
import jax.numpy as jnp
import os 

def flatten_dict(d, parent_key='', sep='//'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def save_model_params(d, file_path):
    flat_dict = flatten_dict(d)
    # Convert JAX arrays to NumPy for saving
    np_dict = {k: np.array(v) if isinstance(v, jnp.ndarray) else v for k, v in flat_dict.items()}
    np.savez(file_path, **np_dict)

params_path = os.path.join('../models', 'params_64x32_mse_2.npz')
save_model_params(params, params_path)

Now, our trained model is saved to a file ,which can be used to load and run again.