# Comparing Models in Earth-2 MIP

The following notebook demonstrates how to use Earth-2 MIP for running different AI weather models and comparing their outputs. Specifically, this will compare the Pangu weather model and Deep Learning Weather Prediction (DLWP) mode with an intial state pulled from the Climate Data Store (CDS). This will also how how to interact with Earth-2 MIP using Python APIs for greater control over inference workflows

In summary this notebook will cover the following topics:

- Configuring and setting up Pangu Model Registry and DLWP Model Registry
- Setting up a basic deterministic inferencer for both models
- Running inference in a Python script
- Post processing results

## Set Up

Starting off with imports, hopefully you have already installed Earth-2 MIP from this repository.

In [1]:
# For running inference we dont need much
import os, json, logging, datetime
import xarray

Prior to importing Earth-2 MIP, its critical we set up a few enviroment variables which will help Earth-2 MIP get configured correctly under the hood. There are a number of different global configuration options, the ones will will set here are:

- `WORLD_SIZE`: Tells Earth-2 MIP (which uses Modulus under the hood) the number of GPUs present
- `MODEL_REGISTRY`: This variable tells Earth-2 MIP where look for a model registery

For addition information on the concept of model registry in Earth-2 MIP, have a look at the following notebooks with some additional information:

- [01_ensemble_inference](./01_ensemble_inference.ipynb)

In [2]:
# Set number of GPUs to use to 1
os.environ['WORLD_SIZE'] = '1'
# Set model registry as a local folder
model_registry = os.path.join(os.path.dirname(os.path.realpath(os. getcwd())), "models")
os.makedirs(model_registry, exist_ok=True)
os.environ['MODEL_REGISTRY'] = model_registry

# With the enviroment variables set now we import Earth-2 MIP
from earth2mip import registry, inference_ensemble
from earth2mip.initial_conditions import cds

The cell above created a model registry folder for us, now we need to populate it with model packages. We will start with the Pangu weather model by fetching the ONNX checkpoints and creating the `metadata.json`. This metadata JSON file will help Earth-2 MIP interact with the model checkpoint. Specifically, using a Python entry point. This will be discussed in more detail in later notebooks, but fundementally this tells Earth-2 MIP what load function to call for this model (this load function is found in `earth2mip/networks/pangu.py`)

In [3]:
# First set up a pangu folder
import subprocess
if not os.path.isdir(os.path.join(model_registry, 'pangu')):
    pangu_registry = os.path.join(model_registry, "pangu")
    os.makedirs(pangu_registry, exist_ok=True)
    # Wget onnx files
    print("Downloading model checkpoint, this may take a bit")
    subprocess.run(['wget', '-nc', '-P', f'{pangu_registry}', 'https://get.ecmwf.int/repository/test-data/ai-models/pangu-weather/pangu_weather_24.onnx'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    subprocess.run(['wget', '-nc', '-P', f'{pangu_registry}', 'https://get.ecmwf.int/repository/test-data/ai-models/pangu-weather/pangu_weather_6.onnx'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)

    with open(os.path.join(pangu_registry, 'metadata.json'), 'w') as outfile:
        json.dump({"entrypoint": {"name": "earth2mip.networks.pangu:load"}}, outfile, indent=2)

Next DLWP model package will need to be downloaded. This model follows the standard proceedure most do in Earth-2 MIP, being served via Modulus and hosted on NGC model registry. The install process is simple and all required files are present in the downloaded zip.

In [4]:
# Now set up DLWP folder
if not os.path.isdir(os.path.join(model_registry, 'dlwp')):
    print("Downloading model checkpoint, this may take a bit")
    subprocess.run(['wget', '-nc', '-P', f'{model_registry}', 'https://api.ngc.nvidia.com/v2/models/nvidia/modulus/modulus_dlwp_cubesphere/versions/v0.2/files/dlwp_cubesphere.zip'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    subprocess.run(['unzip', '-u', f'{model_registry}/dlwp_cubesphere.zip', '-d', f'{model_registry}'])
    subprocess.run(['rm', f'{model_registry}/dlwp_cubesphere.zip'])

Downloading model checkpoint, this may take a bit
Archive:  /code/earth2-mip/examples/models/dlwp_cubesphere.zip
   creating: /code/earth2-mip/examples/models/dlwp/
  inflating: /code/earth2-mip/examples/models/dlwp/map_CS64_LL721x1440.nc  
  inflating: /code/earth2-mip/examples/models/dlwp/geopotential_rs_cs.nc  
  inflating: /code/earth2-mip/examples/models/dlwp/global_stds.npy  
  inflating: /code/earth2-mip/examples/models/dlwp/dlwp.mdlus  
 extracting: /code/earth2-mip/examples/models/dlwp/metadata.json  
  inflating: /code/earth2-mip/examples/models/dlwp/map_LL721x1440_CS64.nc  
  inflating: /code/earth2-mip/examples/models/dlwp/latlon_grid_field_rs_cs.nc  
  inflating: /code/earth2-mip/examples/models/dlwp/land_sea_mask_rs_cs.nc  
  inflating: /code/earth2-mip/examples/models/dlwp/initial_condition_7ch.nc  
  inflating: /code/earth2-mip/examples/models/dlwp/global_means.npy  
  inflating: /code/earth2-mip/examples/models/dlwp/simple_inference.py  


The final setup step is to set up your CDS API key so we can access ERA5 data to act as an initial state. Earth-2 MIP supports a number of different initial state data sources that are supported including HDF5, CDS, GFS, etc. The CDS initial state provides a convient way to access a limited amount of historical weather data. Its recommended for accessing an initial state, but larger data requirements should use locally stored weather datasets.

Enter your CDS API uid and key below (found under your profile page). If you don't a CDS API key, find out more here.
- [https://cds.climate.copernicus.eu/cdsapp#!/home](https://cds.climate.copernicus.eu/cdsapp#!/home)
- [https://cds.climate.copernicus.eu/api-how-to](https://cds.climate.copernicus.eu/api-how-to)

In [5]:
# Run this cell and input your credentials in the notebook
uid = input("Enter in CDS UID (e.g. 123456)")
key = input("Enter your CDS API key (e.g. 12345678-1234-1234-1234-123456123456)")

# Write to config file for CDS library
with open(os.path.join(os.path.expanduser("~"), '.cdsapirc'), 'w') as f:
    f.write('url: https://cds.climate.copernicus.eu/api/v2\n')
    f.write(f'key: {uid}:{key}\n')

Enter in CDS UID (e.g. 123456) 211582
Enter your CDS API key (e.g. 12345678-1234-1234-1234-123456123456) 9a810e70-9f48-49b4-9e1b-e7ae17fe15cc


## Running Inference

To run inference of these models we will use some of Earth-2 MIPs Python APIs to perform inference.
The first step is to load the model from the model registry, which is done using the `registry.get_model` command.
This will look in your `MODEL_REGISTRY` folder for the provided name and use this as a filesystem for loading necessary files.

The model is then loaded into memory using the load function for that particular network. Earth-2 MIP has multiple abstracts that can allow this to be automated that can be used instead if desired.

In [7]:
import earth2mip.networks.dlwp as dlwp
import earth2mip.networks.pangu as pangu

# Load DLWP model from registry
package = registry.get_model("dlwp")
dlwp_inference_model = dlwp.load(package)

# Load Pangu model(s) from registry
package = registry.get_model("pangu")
pangu_inference_model = pangu.load(package)

TypeError: Inference.__init__() got an unexpected keyword argument 'channels'

Next we set up the initial state data source for January 1st, 2018 at 00:00:00 UTC. As previously mentioned, we will pull data on the fly from CDS (make sure you set up your API key above). Since DLWP and Pangu require different channels (and time steps), we will create two seperate data-sources for them.

In [None]:
# Initial state data/time
time = datetime.datetime(2018, 1, 1)

# DLWP datasource
dlwp_data_source = cds.DataSource(dlwp_inference_model.in_channel_names)

# Pangu datasource, this is much simplier since pangu only uses one timestep as an input
pangu_data_source = cds.DataSource(pangu_inference_model.in_channel_names)

With the initial state downloaded for each and set up in an Xarray dataset, we can now run deterministic inference for both which can be achieved using the `inference_ensemble.run_basic_inference` method which will produce a Xarray [data array](https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html) to then work with. 

In [None]:
# Run DLWP inference
dlwp_ds = inference_ensemble.run_basic_inference(
    dlwp_inference_model,
    n=12, # Note we run 12 steps here because DLWP is at 12 hour dt
    data_source=dlwp_data_source,
    time=time,
)
print(dlwp_ds)

In [None]:
# Run Pangu inference
pangu_ds = inference_ensemble.run_basic_inference(
    pangu_inference_model,
    n=24, # Note we run 24 steps here because Pangu is at 6 hour dt
    data_source=pangu_data_source,
    time=time,
)
print(pangu_ds)

## Post Processing

With inference complete, now the fun part: post processing and analysis!
Here we will just plot the z500 time-series of each model.

In [None]:
import matplotlib.pyplot as plt

# Get data-arrays at 12 hour steps
dlwp_arr = dlwp_ds.sel(channel="z500").values
pangu_arr = pangu_ds.sel(channel="z500").values[::2]
# Plot
fig, axs = plt.subplots(2, 13, figsize=(13*4, 5))
for i in range(13):
    axs[0,i].imshow(dlwp_arr[i,0])
    axs[1,i].imshow(pangu_arr[i,0])
    axs[0,i].set_title(time + datetime.timedelta(hours=12*i))

axs[0,0].set_ylabel("DLWP")
axs[1,0].set_ylabel("Pangu")
plt.suptitle("z500 DLWP vs Pangu")
plt.show()

And that completes the second notebook detailing how to run deterministic inference of two models using Earth-2 MIP. In the next notebook, we will dive deeper into how a PyTorch model is integrated into Earth-2 MIP. 