# Master's thesis - Lukas Meuris - training - loss functions

This notebook contains the code to train the graphCast model. 
We use this notebook to traint the graphcast model with different loss-functions to see the influence of it on the model performance.

 

In [1]:
#Pip install graphcast and dependencies
#!pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

# Installation and initialisation.


In [2]:
# Imports
import sys
sys.path.append("../")

import dataclasses
import functools

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import data_utils
from graphcast import graphcast
from graphcast import rollout
from graphcast import normalization
from graphcast import xarray_jax
from graphcast import xarray_tree
import haiku as hk
import jax
import numpy as np
import jax.numpy as jnp
import xarray

import optax

import os
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

# Access data from GCS

In [3]:
# Authenticate with Google Cloud Storage
# needed to get normalization data.

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

# Load the Data and initialize the model

## Load the model parameters

We use random parameters for the model initialization. 
We'll get random predictions, but we can change the model architecture.

model parameters:
- mesh size: specifies the internal graph representation of the earth. smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model. [4 - 6]
- GNN message passing steps: specifies the number of message passing steps through the GNN layer of the model. [1 - 32]
- Latent size: defines the feature size of the MLP's. [16, 32, 64 ,128, 256, 512]
- levels: the amount of pressure levels. [13 or 37]

In [4]:
# choose model parameters
random_mesh_size = 5 
random_gnn_msg_steps = 8 
random_latent_size = 128 
random_levels = 13 


The following section initialises the initial empty parameters and state, and defines the model configuring with the parameters selected above.
the task config defines the input and target variables to model.

In [5]:
# load the model parameters
params = None
state = {}
model_config = graphcast.ModelConfig(
    resolution=0,
    mesh_size=random_mesh_size,
    latent_size=random_latent_size,
    gnn_msg_steps=random_gnn_msg_steps,
    hidden_layers=1,
    radius_query_fraction_edge_length=0.6)
task_config = graphcast.TaskConfig(
    input_variables=graphcast.TASK.input_variables,
    target_variables=graphcast.TASK.target_variables,
    forcing_variables=graphcast.TASK.forcing_variables,
    pressure_levels=graphcast.PRESSURE_LEVELS[random_levels],
    input_duration=graphcast.TASK.input_duration,
)
model_config

ModelConfig(resolution=0, mesh_size=5, latent_size=128, gnn_msg_steps=8, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)

# Load the ERA5 data

ERA5 is used a ground truth to train the model on.
To download the ERA5 data and transform it so that it works with the graphcast model, see 'download_data.ipynb'


In [6]:
# Define the relative path to the file
relative_path = "ERA5_data/obs_data.zarr"

# Get the absolute path by joining the current directory with the relative path
absolute_path = os.path.join(os.path.dirname(os.getcwd()), relative_path)
print(absolute_path)

# Open the Zarr file using xarray
obs_data = xarray.open_zarr(absolute_path)

/home/jupyter-lukas/Masters-Thesis/ERA5_data/obs_data.zarr


## extract the training data
We select the data from 1980 to 2019 to train our model on.

In [7]:
#time slice: 1980-01-01T00:00:00.000000000 to 2019-12-31T00:00:00.000000000 - TRAINING
train_data = obs_data.sel(time=slice('1980-01-01T00:00:00.000000000','2019-12-31T00:00:00.000000000'))

## choose the maximum number of training steps.

In [8]:
train_steps_max = 12  

## extract initial training inputs, targets and forcings.
these values don't matter that much, this is just used to initialise the model parameters with.

In [9]:
train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    train_data, target_lead_times=slice("6h", f"{train_steps_max*6}h"),
    **dataclasses.asdict(task_config))

## Load normalization data

In [10]:
with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()


## Build jitted functions, and possibly initialize random weights

In [11]:
def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, inputs, targets, forcings, model_config, task_config):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is required by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs.compute(),
      targets_template=train_targets.compute(),
      forcings=train_forcings.compute())

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = jax.jit(with_configs(grads_fn))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

# Model training loop

## 1. general training
During general training we will loop over the entire dataset once, with only a single autoregressive step.
the goal is to first learn the model how to do a single step well, before learning it to make multiple steps.

We use the adamW optimiser with a lr with a cosine decay for 1e-3 to 0.

The training loop first select a train_batch, this batch takes three time data steps (2 inputs, 1 target) and loads it into memory.
then the training input and targets are extracted from this batch.
a prediction is done on this batch and a loss and grad is computed based on the error between the prediction and the actual values.
the model params are updates based on this information.

In [12]:
# define number of training steps
train_steps = 1
# data size
N = train_data.sizes['time'] - train_steps - 4

loss_array = []

#setup optimiser
lr = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=N)
optimiser = optax.adamw(lr, b1=0.9, b2=0.95, weight_decay=0.1)
opt_state = optimiser.init(params)

In [13]:
# training loop

for i in range(N):
    train_batch = train_data.isel(time=slice(i, i + train_steps + 2))
    train_batch = train_batch.compute()

    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    train_batch, target_lead_times=slice("6h", f"{train_steps*6}h"),
    **dataclasses.asdict(task_config))

    # calculate loss and gradients
    loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, train_inputs, train_targets, train_forcings)

    # update
    updates, opt_state = optimiser.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    loss_array.append(loss)
    if i%1000 == 0:
        print("I:", i , " - Loss:", loss)

print("general training finished.")

I: 0  - Loss: 96.61719
I: 1000  - Loss: 387.375
I: 2000  - Loss: 195.78125
I: 3000  - Loss: -908.1719
I: 4000  - Loss: 184.66602
I: 5000  - Loss: 865.84766
I: 6000  - Loss: -160.17383
I: 7000  - Loss: 13.078125
I: 8000  - Loss: -278.88672
I: 9000  - Loss: 934.6094
I: 10000  - Loss: 17023.977
I: 11000  - Loss: 3841.2344
I: 12000  - Loss: -1034.6172
I: 13000  - Loss: 6780.125
I: 14000  - Loss: 10247.266
I: 15000  - Loss: 15559.75
I: 16000  - Loss: 12104.418
I: 17000  - Loss: 19703.07
I: 18000  - Loss: 1652.25
I: 19000  - Loss: 13187.695
I: 20000  - Loss: 49720.766
I: 21000  - Loss: 35824.375
I: 22000  - Loss: 48472.156
I: 23000  - Loss: 35820.47
I: 24000  - Loss: 151281.0
I: 25000  - Loss: 18960.0
I: 26000  - Loss: 45623.125
I: 27000  - Loss: 26488.125
I: 28000  - Loss: 16280.109
I: 29000  - Loss: 36233.01
I: 30000  - Loss: -241765.0
I: 31000  - Loss: 27740.633
I: 32000  - Loss: 45260.562
I: 33000  - Loss: 45971.67
I: 34000  - Loss: -20872.188
I: 35000  - Loss: 41362.5
I: 36000  - Loss: 

## 2.Fine tuning
during fine tuning we select the last 11000 time steps of our whole dataset and use that to finetune the model to work with more autoregressive steps.
we increase the amount of autoregressive steps every 1000 loops from 2 to 12.

In [14]:
# data size
N = train_data.sizes['time'] - train_steps_max - 4
# only take the last 11000 time steps
Ksteps = 11000
Ktime = N - Ksteps
train_steps = 1

#setup optimiser
lr = 1e-7
optimiser = optax.adamw(lr, b1=0.9, b2=0.95, weight_decay=0.1)
opt_state = optimiser.init(params)

In [15]:
# training loop

for i in range(Ksteps):
    
    train_batch = train_data.isel(time=slice(Ktime + i,Ktime +  i + train_steps + 2))
    train_batch = train_batch.compute()

    train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    train_batch, target_lead_times=slice("6h", f"{train_steps*6}h"),
    **dataclasses.asdict(task_config))

    # calculate loss and gradients
    loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, train_inputs, train_targets, train_forcings)

    # update
    updates, opt_state = optimiser.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    loss_array.append(loss)
    if i%1000 == 0:
        train_steps += 1
        print("I:", i ," - steps: ",train_steps," - Loss:", loss)
        

print("finetuning training finished.")

I: 0  - steps:  2  - Loss: -104531.984
I: 1000  - steps:  3  - Loss: 168256.0
I: 2000  - steps:  4  - Loss: -36663.668
I: 3000  - steps:  5  - Loss: 50667.72
I: 4000  - steps:  6  - Loss: -30674.014
I: 5000  - steps:  7  - Loss: 135767.8
I: 6000  - steps:  8  - Loss: -11547.095
I: 7000  - steps:  9  - Loss: -6746.3164
I: 8000  - steps:  10  - Loss: 80205.055
I: 9000  - steps:  11  - Loss: -2357.9688
I: 10000  - steps:  12  - Loss: -13431.103
finetuning training finished.


# Save the model params to file
after training is complete we will save the model parameters to file.
Afterwards the parameters can be loaded in again to do predictions with the trained model.

In [16]:

def flatten_dict(d, parent_key='', sep='//'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def save_model_params(d, file_path):
    flat_dict = flatten_dict(d)
    # Convert JAX arrays to NumPy for saving
    np_dict = {k: np.array(v) if isinstance(v, jnp.ndarray) else v for k, v in flat_dict.items()}
    np.savez(file_path, **np_dict)

params_path = os.path.join('../models', 'params_64x32_rae.npz')
save_model_params(params, params_path)

Now, our trained model is saved to a file ,which can be used to load and run again.