# Weather forecasting models

In [None]:
import torch
import json
import numpy as np
import xarray as xr
from functools import partial

from typing import Optional, Dict, Any

from era5.dataloader_era5 import *

from aurora import Aurora, rollout
from aurora import Batch, Metadata

from makani.models.networks.sfnonet import SphericalFourierNeuralOperatorNet as SFNO


# Data setup

We created a dataloader for loading the ERA5 input data. Each weather forecasting model we include in our setup here is trained on ERA5 data, but needs to have a specific input format. The dataloader loads the data and provides it, depending on which model is used, in the correct format. The dataloader has to be initialized at the beginning, and then you can pick a date and a model, and it outputs the corresponding data at that time in the correct format. Use the ```get_data``` method for this. 

We've created a custom dataloader to streamline the process of preparing ERA5 input data for our weather forecasting models.

### How It Works
This dataloader is designed to handle the specific input format requirements of each model in our setup. Instead of manually reformatting the data for every model, you simply initialize the dataloader once at the start of your program.

After initialization, you can use the ```get_data``` method. This method takes two arguments: a date and the specific model you wish to use. The dataloader then automatically fetches the corresponding ERA5 data and delivers it in the correct format for that particular model. This ensures a consistent and efficient data pipeline for the included models. You might want to extend the functionality of the dataloader to other model's requirements. 

In [None]:
# basic stats for now
stats_mean_path = "/era5/stats_era5/global_means.npy"
stats_std_path = "/era5/stats_era5/global_stds.npy"
metadata_path = "/era5/data.json"
data_path = "/era5/2018/restricted_3days_2018.h5"

In [None]:
dataloader = dataloader_era5(
    data_path=data_path,
    stats_mean_path=stats_mean_path,
    stats_std_path=stats_std_path,
    metadata_path=metadata_path,
    in_channels=None,
    out_channels=None,
    normalize=True
)

# AURORA

This is an example code of how to run inference for the pretrained AURORA model.

- [Paper Link](https://www.nature.com/articles/s41586-025-09005-y)
- [Github Link](https://github.com/microsoft/aurora/tree/main)

In [None]:
aurora_model = Aurora(use_lora=False)  # The pretrained version does not use LoRA.
aurora_model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")

aurora_model.eval()
aurora_model = aurora_model.to("cuda")

In [None]:
"""date = "2018-01-02T06:00:00"
batch = dataloader.get_data(date, model="aurora")
#print(f"Shape of data: {data}")"""

In [None]:
"""aurora_model.eval()
aurora_model = aurora_model.to("cpu")
batch = batch.to("cpu")

with torch.inference_mode():
    preds = [pred.to("cpu") for pred in rollout(aurora_model, batch, steps=1)]
    aurora_model = aurora_model.to("cpu")

print(f"Preds: {preds}")"""

# PANGU

This is an example code of how to run inference for the pretrained Pangu-Weather model.

- [Paper Link](https://www.nature.com/articles/s41586-023-06185-3)
- [Github Link](https://github.com/198808xc/Pangu-Weather)

In [None]:
pangu_model_sess = pangu.define_pangu_onnx(model_type=6)
print(f"Pangu model {pangu_model_sess}")

In [None]:
date = "2018-01-02T06:00:00"
upper_data, surface_data = dataloader.get_data(date, model="pangu")
print(f"Shape of upper data: {upper_data.shape}")
print(f"Shape of surface data: {surface_data.shape}")

In [None]:
# Run the inference session
output_upper, output_surface = pangu_model_sess.run(None, {'input':upper_data, 'input_surface':surface_data})
print(f"Shape of output_upper: {output_upper.shape}")
print(f"Shape of output_surface: {output_surface.shape}")

# SFNO

This is an example code of how to run inference for the pretrained SFNO model.

- [Paper Link](https://arxiv.org/abs/2306.03838)
- [Github Link](https://github.com/NVIDIA/makani/tree/v0.1.1)

The model weights were taken from the public NVIDIA Modulus release [here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/models/sfno_73ch_small). Note that this is not the final model, but a smaller version of the SFNO model. 




In [None]:
date = "2018-01-02T06:00:00"
output_data = dataloader.get_data(date, model="sfno")
print(f"Shape of upper data: {output_data.shape}")

In [None]:
# load config
config_path = "/bids_weather_forecasting_hackathon/checkpoints/sfno_checkpoints/sfno_73ch_small_config.json"
with open(config_path, "r") as f:
    config = json.load(f)

nlat = config['img_shape_x']
nlon = config['img_shape_y']
print(f"num lat: {nlat}")
print(f"num lon: {nlon}")

In [None]:
model = partial(SFNO, 
    img_size=(nlat, nlon),  
    grid=config["data_grid_type"],
    num_layers=config['num_layers'], 
    scale_factor=config['scale_factor'],
    inp_chans=config["N_in_channels"],
    out_chans=config["N_out_channels"],
    embed_dim=config['embed_dim'], 
    big_skip=True, 
    pos_embed=config["pos_embed"], 
    use_mlp=config["use_mlp"], 
    normalization_layer=config["normalization_layer"]
)

model = model()

In [None]:
def pop_prefix_from_state_dict(
    state_dict: Dict[str, Any],
    prefix: str,
) -> None:
    r"""Append the prefix to states in state_dict in place.

    ..note::
        Given a `state_dict` from a local model, a DP/DDP model can load it by applying
        `prepend_prefix_to_state_dict(state_dict, "module.")` before calling
        :meth:`torch.nn.Module.load_state_dict`.

    Args:
        state_dict (OrderedDict): a state-dict to be loaded to the model.
        prefix (str): prefix.
    """
    keys = list(state_dict.keys())
    for key in keys:
        # find prefix part in key and remove it
        if key.startswith(prefix):
            newkey = key[len(prefix):]
            state_dict[newkey] = state_dict.pop(key)

    # also strip the prefix in metadata if any.
    if hasattr(state_dict, "_metadata"):
        keys = list(state_dict._metadata.keys())
        for key in keys:
            # for the metadata dict, the key can be:
            # '': for the DDP module, which we want to remove.
            # 'module': for the actual model.
            # 'module.xx.xx': for the rest.
            if len(key) >= 0:
                newkey = prefix + key
                state_dict._metadata[newkey] = state_dict._metadata.pop(key)

In [None]:
# load checkpoint
ckpt_path = "/bids_weather_forecasting_hackathon/checkpoints/sfno_checkpoints/checkpoints/sfno_73ch_small_training_checkpoints_best_ckpt_mp0.tar"
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state_dict = checkpoint["model_state"]
pop_prefix_from_state_dict(state_dict, "module.model.")
model.load_state_dict(state_dict, strict=True)

In [None]:
pred = model(output_data)
print(f"Shape of output: {pred.shape}")