<a href="https://colab.research.google.com/github/andrewjustin/anemoi-workflow/blob/main/anemoi-workflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Anemoi Training Workflow Demo

This notebook will guide you through training an AI4NWP model with the Anemoi framework. https://github.com/ecmwf/anemoi-core/tree/main

Datasets used for training will be created using the ufs2arco package. https://github.com/NOAA-PSL/ufs2arco/tree/main

# 1) Environment Setup

The environment tested with this notebook utilized Python 3.11.13 on Ubuntu 24.04.
- **There is no guarantee that this notebook will run error-free using a Python installation on Windows**.

We will utilize *pip* for installing required packages. Make sure you have the latest version of *pip* before proceeding:

In [2]:
!pip install --upgrade pip



There are several packages that we need to install through pip.
- *ufs2arco*: module that will be used to generate the datasets. https://github.com/NOAA-PSL/ufs2arco/tree/main
- *anemoi-datasets*: Anemoi package that optimizes and handles datasets. https://anemoi.readthedocs.io/projects/datasets/en/latest/
    - Note that you *can* generate datasets with *anemoi-datasets* instead of *ufs2arco*, however this is not recommended.
- *anemoi-graphs*: Anemoi package that allows you to design graphs for AI4NWP models. https://anemoi.readthedocs.io/projects/graphs/en/latest/
- *anemoi-models*: provides the rest of the Anemoi packages with core model components. https://anemoi.readthedocs.io/projects/models/en/latest/
- *anemoi-training*: provides the training functionality for Anemoi. https://anemoi.readthedocs.io/projects/training/en/latest/
- *anemoi-inference*: framework for performing model inference with AI4NWP models trained using Anemoi. https://anemoi.readthedocs.io/projects/inference/en/latest/
- *flash-attn*: Attention mechanism used in Anemoi's transformer models.
  - Flash attention ONLY works on **NVIDIA Ampere GPUs *or newer***. An exhaustive list of Ampere GPUs can be found here: https://en.wikipedia.org/wiki/Ampere_(microarchitecture)#Products_using_Ampere
- *mpi4py*: Python bindings for the MPI interface. This is only required if you plan to retrieve data in parallel (**strongly recommended**, especially for very large datasets)
- *trimesh*: allows models to utilize triangular meshes.

In [3]:
!pip install ufs2arco==0.6.0 anemoi-datasets==0.5.25 anemoi-graphs==0.6.2 anemoi-models==0.8.1 anemoi-training==0.5.1 anemoi-inference==0.6.3 flash-attn mpi4py trimesh 'numpy<2.3' 'earthkit-data<0.14.0' --force-reinstall

Collecting ufs2arco==0.6.0
  Downloading ufs2arco-0.6.0-py3-none-any.whl.metadata (1.8 kB)
Collecting anemoi-datasets==0.5.25
  Downloading anemoi_datasets-0.5.25-py3-none-any.whl.metadata (16 kB)
Collecting anemoi-graphs==0.6.2
  Downloading anemoi_graphs-0.6.2-py3-none-any.whl.metadata (15 kB)
Collecting anemoi-models==0.8.1
  Downloading anemoi_models-0.8.1-py3-none-any.whl.metadata (16 kB)
Collecting anemoi-training==0.5.1
  Downloading anemoi_training-0.5.1-py3-none-any.whl.metadata (15 kB)
Collecting anemoi-inference==0.6.3
  Downloading anemoi_inference-0.6.3-py3-none-any.whl.metadata (16 kB)
Collecting flash-attn
  Downloading flash_attn-2.8.1.tar.gz (8.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m143.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mpi4py
  Downloading mpi4py-4.1.0-cp311-cp311-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (16 kB)
Collecting trimesh
  Downloa

# 2) Build Dataset Recipes

Datasets for Anemoi are created using the ufs2arco package: https://github.com/NOAA-PSL/ufs2arco/tree/main

YAML files containing a 'recipe' for the dataset can be called to generate the datasets.

You can create a single, large dataset for training, validation, and testing, or you can separate these into their own datasets.

We will walk through the process of creating a recipe YAML file.

## 2.1) Define data mover

As the name suggests, the data mover will be used to move data from a remote location to some local directory. There are two datamovers that can be used: 'datamover' and 'mpidatamover' (requires *mpi4py* package). The only difference between these is that 'mpidatamover' has the ability to utilize multiple processor threads, allowing data to be retrieved in parallel.

The data mover is defined in the YAML recipe with the **mover.name** parameter:

In [None]:
mover:
    name: mpidatamover

## 2.2) Define directories

There are three directory parameters that must be specified in the YAML file:
- **directories.zarr**: directory for the dataset in zarr format
- **directories.cache**: directory for dataset cache
- **directories.logs**: directory for logs showing dataset progress. These logs can be useful for monitoring dataset progress and debugging

Note that recursive directory structures will automatically be created if they do not already exist.

An example implementation in a YAML recipe is shown below:

In [None]:
directories:
  zarr: p1/dataset/training.zarr
  cache: p1/dataset/cache
  logs: p1/dataset/logs

## 2.3) Define dataset configuration

The dataset configuration will be within **source**.

### 2.3.1) Data source, time window, ensemble member selection

Choose the data source and time window you would like to use. The parameters required for this are as follows:
- **source.name**: name of the dataset
- **source.uri**: URI of the dataset
- **source.time.start**: beginning of the desired time window with format YYYY-MM-DD[T]HH
- **source.time.end**: end of the time window with format YYYY-MM-DD[T]HH
- **source.time.freq**: timestep frequency

If your dataset has forecast hours (e.g., GFS), you can specify desired forecast hours:
- **source.fh.start**: beginning forecast hour
- **source.fh.end**: end forecast hour
- **source.fh.step**: forecast hour step/interval

If your dataset has ensemble members (e.g., GEFS), you can retrieve specific ensemble members:
- **source.member.start**: beginning member number
- **source.member.end**: end member number
- **source.member.step**: member number step/interval

An example implementation in a YAML recipe is shown below. Note that we will add more **source** parameters in subsequent steps.

In [None]:
source:
  name: gcs_replay_atmosphere
  uri: gs://noaa-ufs-gefsv13replay/ufs-hr1/0.25-degree-subsampled/03h-freq/zarr/fv3.zarr
  time:
    start: 1994-01-01T00
    end: 1994-01-31T21
    freq: 3h

### 2.3.2) Variables

All variables are defined with the **source.variables** parameter in the recipe YAML.

In [None]:
source:
  name: gcs_replay_atmosphere
  uri: gs://noaa-ufs-gefsv13replay/ufs-hr1/0.25-degree-subsampled/03h-freq/zarr/fv3.zarr
  time:
    start: 1994-01-01T00
    end: 1994-01-31T21
    freq: 3h

  variables:
    - tmp2m
    - spfh2m

### 2.3.3) Pressure Levels

Pressure levels can be explicitly defined, or you can select a set of pressure levels through slicing in the recipe YAML.
- **source.levels**: list of all desired pressure levels
- **source.slice.sel.levels**: retrieve a 'slice' of all pressure levels between two values (e.g., [200, 1000] grabs all pressure levels between 1000 and 250 hPa).

In [None]:
source:
  name: gcs_replay_atmosphere
  uri: gs://noaa-ufs-gefsv13replay/ufs-hr1/0.25-degree-subsampled/03h-freq/zarr/fv3.zarr
  time:
    start: 1994-01-01T00
    end: 1994-01-31T21
    freq: 3h

  variables:
    - tmp2m
    - spfh2m

  slices:
    sel:
      level: [200, 1000]  # hPa

### 2.3.4) Coordinates and Selecting Subdomains

By default, all lat/lon points in the desired dataset will be obtained and no arguments are required to acquire the entire grid. Lat/lon coordinates in a subdomain can be explicitly defined, or you can select sets of coordinates through slicing in the recipe YAML.
- **source.longitude**: list of all desired longitude points
- **source.latitude**: list of all desired latitude points
- **source.slice.sel.longitude**: retrieve a slice of all longitudes within a range (e.g., [200, 300] grabs all longitudes between 200 and 300 degrees east, using the 360 degree system)
- **source.slice.sel.latitude**: retrieve a slice of all latitude values (e.g., [51, 25] grabs all latitudes between 25 and 51 degrees north)

In [None]:
source:
  name: gcs_replay_atmosphere
  uri: gs://noaa-ufs-gefsv13replay/ufs-hr1/0.25-degree-subsampled/03h-freq/zarr/fv3.zarr
  time:
    start: 1994-01-01T00
    end: 1994-01-31T21
    freq: 3h

  variables:
    - tmp2m
    - spfh2m

  slices:
    sel:
      level: [200, 1000]  # hPa
      latitude: [53, 21]
      longitude: [225, 300]

### 2.3.5) Configure Outputs

All of the outputs in the recipe YAML are done in the **target** section.

- **target.name**: target name (unsure of the exact purpose)
- **target.sort_channels_by_levels**: setting this to True will sort the channels by pressure level
- **target.rename**: allows you to rename variables or coordinates (example usage below)
- **target.chunks**: configure chunks by coordinates (example usage below)
- **target.forcings**: list of forcing variables

In [None]:
target:
  name: forecast
  sort_channels_by_levels: True
  rename:
    level: pressure  # rename 'level' to 'pressure'

  chunks:
    time: 1  # one timestep per chunk
    variable: -1  # undefined (all variables in one chunk)
    ensemble: 1  # one ensemble member per chunk

  forcings:
    - cos_latitude
    - sin_latitude
    - cos_longitude
    - sin_longitude
    - cos_julian_day
    - sin_julian_day
    - cos_local_time
    - sin_local_time
    - cos_solar_zenith_angle

### 2.3.6) Transforms

One neat feature of Anemoi is its support for data transformations prior to saving, all of which is done in the **transforms** section. You can perform mathematical operations on a variable(s) in order to get the desired units.

- **transforms.divide**: divide a specified variable by some value (example in the cell below)
- **transforms.multiply**: multiple a specified variable by some value

In [None]:
transforms:
  divide:
    geopotential_at_surface: 9.80665  # converts geopotential (m2/s2) to geopotential height (m)

SyntaxError: invalid syntax (836770422.py, line 1)

## 2.4) Generate datasets with ufs2arco

ufs2arco is used to build the datasets. You can keep training, validation, and testing datasets as the same file and select time windows from the zarr file, or you can make separate zarr files for each dataset. Separating dataset files might come with costs and benefits, but this is largely up to user preference.

Given a recipe *training.yaml*, you can generate the dataset with the command below:

In [None]:
!ufs2arco training.yaml

# 3) Generate and Modify Config Files

## 3.1) Generate Config Files

Anemoi has a command that generates some config files which can be utilized during model training.

Note that these generated files have **a lot** of parameters that should be modified in order to streamline your model training workflow.

Run the command below and **carefully** read the instructions/documentation in this sections 3.2 - 3.5.

In [None]:
!anemoi-training config generate

## 3.2) Define batch sizes and configure datasets

Batch sizes must be defined for each dataset. The default *dataloader* file *dataloader/native_grid.yaml* has pre-defined batch sizes, however these can be overriden in *config.yaml*.
- **dataloader.batch_size.training**: training dataset batch size
- **dataloader.batch_size.validation**: validation dataset batch size
- **dataloader.batch_size.test**: testing dataset batch size

For each dataset, the dataset path and start and end dates need to be specified.
- **dataloader.training.dataset**: full path to the training dataset
- **dataloader.training.start**: start date for training dataset (YYYY-MM-DD)
- **dataloader.training.end**: end date for training dataset (YYYY-MM-DD)
- **dataloader.validation.dataset**: full path to the validation dataset
- **dataloader.validation.start**: start date for validation dataset (YYYY-MM-DD)
- **dataloader.validation.end**: end date for validation dataset (YYYY-MM-DD)
- **dataloader.test.dataset**: full path to the test dataset
- **dataloader.test.start**: start date for test dataset (YYYY-MM-DD)
- **dataloader.test.end**: end date for test dataset (YYYY-MM-DD)

Example implementation in *config.yaml*:

In [None]:
dataloader:
  batch_size:
    training: 2
    validation: 2
    test: 2
  training:
    dataset: ${hardware.paths.data}/training.zarr
    start: 1994-01-01
    end: 1994-01-31
  validation:
    dataset: ${hardware.paths.data}/validation.zarr
    start: 1994-02-01
    end: 1994-02-28
  test:
    dataset: ${hardware.paths.data}/testing.zarr
    start: 1994-03-01
    end: 1994-03-31

## 3.3) Configure GPUs and Paths

One of the most important steps for running the Anemoi framework is configuring paths. At the top of *config.yaml*, the 'hardware' parameter should be set to 'example'. This calls the default settings in *hardware/example.yaml*, however the **data** path is not specified in the *example* yaml. In addition, you may want to specify different directories for storing outputs and model graphs.

- **hardware.paths.output**: directory for the outputs (checkpoints, plots, etc.). Directory structure will be created if it does not already exist.
- **hardware.paths.data**: directory for the datasets generated with ufs2arco.
- **hardware.paths.graph**: directory for the model graph.

The name of the zarr file containing the training dataset must also be specified.
- **hardware.files.dataset**: name of the training dataset zarr file (do not include absolute path with directory structure)

You can also specify the number of GPUs to use for each model with the **hardware.num_gpus_per_model** parameter.

An example implementation in *config.yaml* is shown below.

In [None]:
hardware:

  num_gpus_per_model: 1

  paths:
    output: p1/training-output/
    data: p1/dataset
    graph: p1/graph

  files:
    dataset: training.zarr

## 3.4) Configure Model Training

There are a few parameters that should be specified in the main *config.yaml* file so model training configurations can be easily modified.

At the top of *config.yaml*, you will probably see a 'training' parameter that is set to 'default'. This calls training configuration settings in the *training/default.yaml* file. All of these settings can be overriden in *config.yaml*.

Here are some useful training parameters to include in *config.yaml*:
- **training.max_epochs**: specifies the maximum number of epochs for model training. Training will stop if this limit is reached.
- **training.max_steps**: specifies the maximum number of total steps for model training (*not steps per epoch*). Training will stop if this limit is reached.
- **training.lr.rate**: starting learning rate
- **training.lr.min**: minimum learning rate

An example implementation in *config.yaml* with the aforementioned parameters is shown below.

In [None]:
training:
  max_epochs: 500
  max_steps: 10000
  lr:
    rate: 1e-4
    min: 3e-7

## 3.5) Configure Diagnostics

During training, it is useful to plot sample model predictions and log other information pertaining to the model output/performance in order to get a good idea if your model is 'working' as intended.

In the *config.yaml* file, the default file for diagnostics is *diagnostics/evaluation.yaml*. There are a couple empty fields that we will need to define in the following steps.

### 3.5.1) Performance Logging

For now, we will disable Weights and Biases for performance logging (though you may want to configure a WandB workflow in the future). This can be done by setting the **diagnostics.log.wandb.entity** parameter to 'null'.

We will also disable the MLflow tracking server by setting **diagnostics.log.mlflow.tracking_uri** to 'null'.

An example implementation in *config.yaml* is shown below. Note that we will continue to modify **diagnostics** in later steps.

In [None]:
diagnostics:
  log:
    wandb:
      entity: null
    mlflow:
      tracking_uri: null

### 3.5.2) Plotting

With the default settings in *diagnostics/evaluation.yaml*, the following plots will be produced at user-defined frequencies for specified variables:
* Spatial plots of model predictions and errors
* Histograms showing binned model predictions and errors for **every** variable in a single plot

The frequency of plotting can be modified directly in *config.yaml* with the following parameters:
* **diagnostics.plot.frequency.epoch**: plot frequency in epochs
* **diagnostics.plot.frequency.batch**: plot frequency in batches

Adding these to **diagnostics** in *config.yaml*:

In [None]:
diagnostics:
  log:
    wandb:
      entity: null
    mlflow:
      tracking_uri: null
  plot:
    frequency:
      epoch: 5
      batch: 20

The next thing to do is define what variables we want to plot.

First, let's modify a few lines in *diagnostics/evaluation.yaml*.
- Under **callbacks**, assure that every instance of **parameters** (should be three instances in total) calls back to the user-specified variables in **diagnostics.plot.parameters** (see cell below). This will make sure that plots include every variable that you would like to monitor.
- You can leave the instance of **parameters** near the top of the file unchanged as we will be overriding it in *config.yaml*.

In [None]:
parameters: ${diagnostics.plot.parameters}

Now that the plotting file is configured, we can add define the variables we want to plot in *config.yaml*.
* Note that precipitation and related moisture variables need to be defined in **diagnostics.plot.precip_and_related_fields** as well as **diagnostics.plot.parameters**.

Adding our desired variables for plotting to **diagnostics.plot** in *config.yaml*:

In [None]:
diagnostics:
  log:
    wandb:
      entity: null
    mlflow:
      tracking_uri: null
  plot:
    frequency:
      epoch: 1
      batch: 5
    parameters:
      - tmp_825  # 825 hPa temperature
      - tmp2m  # 2-meter temperature
    precip_and_related_fields: []

### 3.5.3) Model Settings

Several model configuration settings can be changed in *config.yaml*. By default, the *config.yaml* file will use the 'gnn' model. This is the Graph Neural Network architecture (https://arxiv.org/abs/1812.08434). These models are designed for learning relationships between nodes and edges. The two other model configurations available by default are the transformer (https://arxiv.org/abs/1706.03762) and graph transformer (https://arxiv.org/abs/2407.09777). Transformers excel at learning relationships between sequential data, such as sequential forecast timesteps in atmospheric data. The graph transformer combines the ideas of the GNN and transformer to handle sequential data connected through a graph.

The *model* files generated by Anemoi all have ReLU (rectified linear unit) boundings applied to the variable 'tp', or total precipitation. The output of the ReLU function $y$ will be zero for a given input $x$ when $x\leq0$, otherwise $y=x$. This means that the output of ReLU will never be negative, which makes since for precipitation.

If you do not have precipitation in your dataset, you need can disable all boundings in *config.yaml* by passing an empty list to the *model.bounding* parameter:

In [None]:
model:
  bounding: []

SyntaxError: invalid syntax (2674895688.py, line 1)

There are other model settings that can be configured in *config.yaml*.

The model's graph can also be changed with the **graph** parameter under **defaults** the top of *config.yaml*.

# 4) Set Environment Variables

Anemoi requires a "base seed" and a SLURM job ID.
- The base seed is used to initialize model weights. Changing the seed will result in different initial model parameters.
- The SLURM job ID is required, even if you are not on SLURM (just leave it as "0").

*Hydra* can be configured to output more complete tracebacks for debugging purposes.

In [None]:
model:
  bounding: []

SyntaxError: invalid syntax (2674895688.py, line 1)

In [None]:
import os

### Required ###
os.environ["ANEMOI_BASE_SEED"] = "42"
os.environ["SLURM_JOB_ID"] = "0"

### Optional ###
os.environ['HYDRA_FULL_ERROR'] = "1"  # for debugging

## 5) Train the Model

In [None]:
!anemoi-training train --config-name=config.yaml

## 6) Model Inference

Model inference with Anemoi is performed with the *anemoi-inference* module: https://anemoi.readthedocs.io/projects/inference/en/latest/index.html#index-page

### 6.1) Retrieve Model Runs and Load Checkpoint
Each model run is saved in a folder with a random hash identifier.

In [None]:
import os
model_runs = os.listdir('p1/training-output/checkpoint')
print('Available model runs:')
for run in model_runs:
    print(run + '\n')

Select a model run from the list above and load the checkpoint.

In [None]:
model_run = 'd46e7b66-9ba1-474f-9142-5dd28be63f50'  # model run hash identifier

## Do not change this ##
checkpoint = f'p1/training-output/checkpoint/{model_run}/inference-last.ckpt'

### 6.2) Configure and Run Model Inference
Select a target forecast time (valid time) from the testing dataset and set a forecast lead time.

You can also create and call a config YAML file that contains the inference settings, however all settings can be easily passed through the command line.

In [None]:
forecast_time = '1994-03-31T21'  # valid time [YYYY]-[MM]-[DD]T[HH]
lead_time = 12  # hours

## Do not change these ##
inference_dataset = 'p1/dataset/testing.zarr'
output_file = 'forecast.nc'  # output file containing the model forecast

!anemoi-inference run checkpoint={checkpoint} date={forecast_time} lead_time={lead_time} input.dataset={inference_dataset} output.netcdf={output_file}