# Weather forecasting models

In [1]:
import torch
import json
import numpy as np
import xarray as xr
from functools import partial

from typing import Optional, Dict, Any

from era5.dataloader_era5 import *
from models.utils import pop_prefix_from_state_dict

from aurora import Aurora, rollout
from aurora import Batch, Metadata

from makani.models.networks.sfnonet import SphericalFourierNeuralOperatorNet as SFNO


  from .autonotebook import tqdm as notebook_tqdm
  def forward(ctx, X, weight, bias, inp_group_name, out_group_name):
  def backward(ctx, grad_out):


# Data setup

We created a dataloader for loading the ERA5 input data. Each weather forecasting model we include in our setup here is trained on ERA5 data, but needs to have a specific input format. The dataloader loads the data and provides it, depending on which model is used, in the correct format. The dataloader has to be initialized at the beginning, and then you can pick a date and a model, and it outputs the corresponding data at that time in the correct format. Use the ```get_data``` method for this. 

We've created a custom dataloader to streamline the process of preparing ERA5 input data for our weather forecasting models.

### How It Works
This dataloader is designed to handle the specific input format requirements of each model in our setup. Instead of manually reformatting the data for every model, you simply initialize the dataloader once at the start of your program.

After initialization, you can use the ```get_data``` method. This method takes two arguments: a date and the specific model you wish to use. The dataloader then automatically fetches the corresponding ERA5 data and delivers it in the correct format for that particular model. This ensures a consistent and efficient data pipeline for the included models. You might want to extend the functionality of the dataloader to other model's requirements. 

In [2]:
# basic stats for now
metadata_path = "/era5/2018/73varQ/data.json"
data_path = "/era5/2018/73varQ/restricted_3days_2018.h5"

# AURORA

This is an example code of how to run inference for the pretrained AURORA model.

- [Paper Link](https://www.nature.com/articles/s41586-025-09005-y)
- [Github Link](https://github.com/microsoft/aurora/tree/main)

In [3]:
dataloader = dataloader_era5(
    data_path=data_path,
    metadata_path=metadata_path,
    in_channels=None,
    out_channels=None,
    model='aurora',
    normalize=True
)

In [4]:
aurora_model = Aurora(use_lora=False)  # The pretrained version does not use LoRA.
aurora_model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")

aurora_model.eval()
aurora_model = aurora_model.to("cuda")

In [5]:
date = "2018-01-02T06:00:00"
batch = dataloader.get_data(date)

Date: 2018-01-02 06:00:00+00:00
Reading input file from /era5/2018/73varQ/restricted_3days_2018.h5...
Shape of data_handle: (13, 75, 721, 1440)
Shape of data: torch.Size([2, 75, 721, 1440])
Channel list: ['u10m', 'v10m', 'u100m', 'v100m', 't2m', 'sp', 'msl', 'tcwv', 'u50', 'u100', 'u150', 'u200', 'u250', 'u300', 'u400', 'u500', 'u600', 'u700', 'u850', 'u925', 'u1000', 'v50', 'v100', 'v150', 'v200', 'v250', 'v300', 'v400', 'v500', 'v600', 'v700', 'v850', 'v925', 'v1000', 'z50', 'z100', 'z150', 'z200', 'z250', 'z300', 'z400', 'z500', 'z600', 'z700', 'z850', 'z925', 'z1000', 't50', 't100', 't150', 't200', 't250', 't300', 't400', 't500', 't600', 't700', 't850', 't925', 't1000', 'q50', 'q100', 'q150', 'q200', 'q250', 'q300', 'q400', 'q500', 'q600', 'q700', 'q850', 'q925', 'q1000', 'sst', 'tp']
Shape of static variables: torch.Size([3, 721, 1440])
Shape of data: torch.Size([2, 75, 721, 1440])
Time: 2018-01-02 06:00:00


In [None]:
aurora_model.eval()
aurora_model = aurora_model.to("cpu")
batch = batch.to("cpu")

with torch.inference_mode():
    preds = [pred.to("cpu") for pred in rollout(aurora_model, batch, steps=1)]
    aurora_model = aurora_model.to("cpu")

print(f"Preds: {preds}")

# PANGU-WEATHER

This is an example code of how to run inference for the pretrained Pangu-Weather model.

- [Paper Link](https://www.nature.com/articles/s41586-023-06185-3)
- [Github Link](https://github.com/198808xc/Pangu-Weather)

In [None]:
dataloader = dataloader_era5(
    data_path=data_path,
    metadata_path=metadata_path,
    in_channels=None,
    out_channels=None,
    model='pangu',
    normalize=True
)

In [None]:
pangu_model_sess = models.pangu.define_pangu_onnx(model_type=6)
print(f"Pangu model {pangu_model_sess}")

'pangu_model_sess = models.pangu.define_pangu_onnx(model_type=6)\nprint(f"Pangu model {pangu_model_sess}")'

In [None]:
date = "2018-01-02T06:00:00"
upper_data, surface_data = dataloader.get_data(date)
print(f"Shape of upper data: {upper_data.shape}")
print(f"Shape of surface data: {surface_data.shape}")

Date: 2018-01-02 06:00:00+00:00
Reading input file from /era5/2018/73varQ/restricted_3days_2018.h5...
Shape of data_handle: (13, 75, 721, 1440)


Shape of data: (75, 721, 1440)
Channel list: ['u10m', 'v10m', 'u100m', 'v100m', 't2m', 'sp', 'msl', 'tcwv', 'u50', 'u100', 'u150', 'u200', 'u250', 'u300', 'u400', 'u500', 'u600', 'u700', 'u850', 'u925', 'u1000', 'v50', 'v100', 'v150', 'v200', 'v250', 'v300', 'v400', 'v500', 'v600', 'v700', 'v850', 'v925', 'v1000', 'z50', 'z100', 'z150', 'z200', 'z250', 'z300', 'z400', 'z500', 'z600', 'z700', 'z850', 'z925', 'z1000', 't50', 't100', 't150', 't200', 't250', 't300', 't400', 't500', 't600', 't700', 't850', 't925', 't1000', 'q50', 'q100', 'q150', 'q200', 'q250', 'q300', 'q400', 'q500', 'q600', 'q700', 'q850', 'q925', 'q1000', 'sst', 'tp']
Shape of upper data: (5, 13, 721, 1440)
Shape of surface data: (4, 721, 1440)


In [None]:
# Run the inference session
output_upper, output_surface = pangu_model_sess.run(None, {'input':upper_data, 'input_surface':surface_data})
print(f"Shape of output_upper: {output_upper.shape}")
print(f"Shape of output_surface: {output_surface.shape}")

'# Run the inference session\noutput_upper, output_surface = pangu_model_sess.run(None, {\'input\':upper_data, \'input_surface\':surface_data})\nprint(f"Shape of output_upper: {output_upper.shape}")\nprint(f"Shape of output_surface: {output_surface.shape}")'

# SFNO

This is an example code of how to run inference for the pretrained SFNO model.

- [Paper Link](https://arxiv.org/abs/2306.03838)
- [Github Link](https://github.com/NVIDIA/makani/tree/v0.1.1)

The model weights were taken from the public NVIDIA Modulus release [here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/sfno_73ch_small). Note that this is not the final model, but a smaller version of the SFNO model. Please download the data, the dataloader now assumes it is located under ```/era5/sfno```.




In [None]:
# basic stats for now
metadata_path = "/era5/2018/73varQ/data.json"
data_path = "/era5/2018/73varQ/restricted_1stday_jan_2018.h5"

In [None]:
dataloader = dataloader_era5(
    data_path=data_path,
    metadata_path=metadata_path,
    in_channels=None,
    out_channels=None,
    model='sfno',
    normalize=True
)



Loaded stats for the SFNO model.
Shape of mean: (1, 75, 1, 1)
Shape of std: (1, 75, 1, 1)


In [None]:
date = "2018-01-01T06:00:00"
output_data = dataloader.get_data(date)
print(f"Shape of data: {output_data.shape}")

Date: 2018-01-01 06:00:00+00:00
Reading input file from /era5/2018/73varQ/restricted_1stday_jan_2018.h5...
Shape of data_handle: (5, 75, 721, 1440)


Shape of data: (75, 721, 1440)
Channel list: ['u10m', 'v10m', 'u100m', 'v100m', 't2m', 'sp', 'msl', 'tcwv', 'u50', 'u100', 'u150', 'u200', 'u250', 'u300', 'u400', 'u500', 'u600', 'u700', 'u850', 'u925', 'u1000', 'v50', 'v100', 'v150', 'v200', 'v250', 'v300', 'v400', 'v500', 'v600', 'v700', 'v850', 'v925', 'v1000', 'z50', 'z100', 'z150', 'z200', 'z250', 'z300', 'z400', 'z500', 'z600', 'z700', 'z850', 'z925', 'z1000', 't50', 't100', 't150', 't200', 't250', 't300', 't400', 't500', 't600', 't700', 't850', 't925', 't1000', 'q50', 'q100', 'q150', 'q200', 'q250', 'q300', 'q400', 'q500', 'q600', 'q700', 'q850', 'q925', 'q1000', 'sst', 'tp']
Shape of static variables: (1, 4, 721, 1440)
Shape of data: (1, 73, 721, 1440)
Shape of final data: (1, 77, 721, 1440)
Shape of data: torch.Size([1, 77, 721, 1440])


In [None]:
# load config
config_path = "/era5/sfno/sfno_73ch_small_config.json"
with open(config_path, "r") as f:
    config = json.load(f)

nlat = config['img_shape_x']
nlon = config['img_shape_y']
print(f"num lat: {nlat}")
print(f"num lon: {nlon}")
print(f'Num in channels: {config["N_in_channels"]}')
print(f'Num out channels: {config["N_out_channels"]}')
print(f'Add zenith: {config["add_zenith"]}')
print(f'Add landmask: {config["add_landmask"]}')

num lat: 721
num lon: 1440
Num in channels: 77
Num out channels: 73
Add zenith: True
Add landmask: True


In [None]:
model = partial(SFNO, 
    img_size=(nlat, nlon),  
    grid=config["data_grid_type"],
    num_layers=config['num_layers'], 
    scale_factor=config['scale_factor'],
    inp_chans=config["N_in_channels"],
    out_chans=config["N_out_channels"],
    embed_dim=config['embed_dim'], 
    big_skip=True, 
    pos_embed=config["pos_embed"], 
    use_mlp=config["use_mlp"], 
    normalization_layer=config["normalization_layer"]
)

model = model()

In [None]:
# load checkpoint
ckpt_path = "/era5/sfno/checkpoints/sfno_73ch_small_training_checkpoints_best_ckpt_mp0.tar"
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state_dict = checkpoint["model_state"]
pop_prefix_from_state_dict(state_dict, "module.model.")
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [None]:
output_data = torch.tensor(output_data, dtype=torch.float32)
pred = model(output_data)
print(f"Shape of output: {pred.shape}")

Shape of output: torch.Size([1, 73, 721, 1440])


# Further readings

- Benchmark - [WeatherBench](https://sites.research.google/gr/weatherbench/)
- IBTrACS - [Cyclone tracking data](https://www.ncei.noaa.gov/products/international-best-track-archive)