In [1]:
from pathlib import Path

import fsspec
import xarray as xr

# Data will be downloaded here.
download_path = Path("downloads")
download_path.mkdir(parents=True, exist_ok=True)

# We will download from Google Cloud.
url = "gs://weatherbench2/datasets/hres_t0/2016-2022-6h-1440x721.zarr"
ds = xr.open_zarr(fsspec.get_mapper(url), chunks=None)

# Day to download. This will download all times for that day.
day = "2022-09-17"

# Download the surface-level variables. We write the downloaded data to another file to cache.
if not (download_path / f"{day}-surface-level.nc").exists():
    surface_vars = [
        "10m_u_component_of_wind",
        "10m_v_component_of_wind",
        "2m_temperature",
        "mean_sea_level_pressure",
    ]
    ds_surf = ds[surface_vars].sel(time=day).compute()
    ds_surf.to_netcdf(str(download_path / f"{day}-surface-level.nc"))
print("Surface-level variables downloaded!")

# Download the atmospheric variables. We write the downloaded data to another file to cache.
if not (download_path / f"{day}-atmospheric.nc").exists():
    atmos_vars = [
        "temperature",
        "u_component_of_wind",
        "v_component_of_wind",
        "specific_humidity",
        "geopotential",
    ]
    ds_atmos = ds[atmos_vars].sel(time=day).compute()
    ds_atmos.to_netcdf(str(download_path / f"{day}-atmospheric.nc"))
print("Atmos-level variables downloaded!")

Surface-level variables downloaded!
Atmos-level variables downloaded!


In [2]:
import cdsapi

c = cdsapi.Client()

# Download the static variables.
if not (download_path / "static.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "geopotential",
                "land_sea_mask",
                "soil_type",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": "00:00",
            "format": "netcdf",
        },
        str(download_path / "static.nc"),
    )
print("Static variables downloaded!")

2025-05-29 14:04:14,715 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.
2025-05-29 14:04:15,860 INFO Request ID is cd19f93b-19a0-40c7-b97a-304b544621cf
2025-05-29 14:04:16,253 INFO status has been updated to accepted
2025-05-29 14:04:30,902 INFO status has been updated to running
2025-05-29 14:04:38,762 INFO status has been updated to successful


324c56498527a3da5e71abb324dd717f.nc:   0%|          | 0.00/3.34M [00:00<?, ?B/s]

Static variables downloaded!


In [4]:
import pickle
import torch
import xarray as xr

# Difference of the static variables from hugging-face and ERA5.
with open("/data/jiyilun/typhoon/download/era5/aurora-0.25-static.pickle", "rb") as f:
    static_hf = pickle.load(f)
static_era5 = xr.open_dataset(download_path / "static.nc", engine="netcdf4")

for var_name in ["z", "slt", "lsm"]:
    static_hf_var = static_hf[var_name]
    static_era5_var = static_era5[var_name].values[0]
    print(f"hf {var_name}:")
    print(static_hf_var)
    print(f"era5 {var_name}:")
    print(static_era5_var)
    print(f"diff:")
    print(static_hf_var - static_era5_var)

hf z:
[[-9.7460938e-01 -9.7460938e-01 -9.7460938e-01 ... -9.7460938e-01
  -9.7460938e-01 -9.7460938e-01]
 [ 4.4121094e+00  4.4121094e+00  4.4121094e+00 ...  4.4121094e+00
   4.4121094e+00  4.4121094e+00]
 [-1.8730469e+00 -1.8730469e+00 -1.8730469e+00 ... -1.8730469e+00
  -1.8730469e+00 -1.8730469e+00]
 ...
 [ 2.6957861e+04  2.6961453e+04  2.6965045e+04 ...  2.6949779e+04
   2.6952475e+04  2.6956066e+04]
 [ 2.7163490e+04  2.7165285e+04  2.7167082e+04 ...  2.7159000e+04
   2.7159898e+04  2.7161693e+04]
 [ 2.7718416e+04  2.7718416e+04  2.7718416e+04 ...  2.7718416e+04
   2.7718416e+04  2.7718416e+04]]
era5 z:
[[-1.2397461e+00 -1.2397461e+00 -1.2397461e+00 ... -1.2397461e+00
  -1.2397461e+00 -1.2397461e+00]
 [ 4.6157227e+00  4.6040039e+00  4.5883789e+00 ...  4.7915039e+00
   4.7329102e+00  4.6743164e+00]
 [-1.7709961e+00 -1.7163086e+00 -1.7084961e+00 ... -1.5092773e+00
  -1.6225586e+00 -1.7124023e+00]
 ...
 [ 2.6958170e+04  2.6961174e+04  2.6964666e+04 ...  2.6949771e+04
   2.6952857e+04  

In [3]:
from data.dataset import TyphoonTrajectoryDataset
import torch
import numpy as np
import xarray as xr

from aurora import Batch, Metadata

static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(download_path / f"{day}-surface-level.nc", engine="netcdf4")
atmos_vars_ds = xr.open_dataset(download_path / f"{day}-atmospheric.nc", engine="netcdf4")


def _prepare(x: np.ndarray) -> torch.Tensor:
    """Prepare a variable.

    This does the following things:
    * Select time points two and three: hours 06:00 and 12:00.
    * Insert an empty batch dimension with `[None]`.
    * Flip along the latitude axis to ensure that the latitudes are decreasing.
    * Copy the data, because the data must be contiguous when converting to PyTorch.
    * Convert to PyTorch.
    """
    return torch.from_numpy(x[[1, 2]][None][..., ::-1, :].copy())


batch = Batch(
    surf_vars={
        "2t": _prepare(surf_vars_ds["2m_temperature"].values),
        "10u": _prepare(surf_vars_ds["10m_u_component_of_wind"].values),
        "10v": _prepare(surf_vars_ds["10m_v_component_of_wind"].values),
        "msl": _prepare(surf_vars_ds["mean_sea_level_pressure"].values),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time. They
        # don't need to be flipped along the latitude dimension, because they are from
        # ERA5.
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": _prepare(atmos_vars_ds["temperature"].values),
        "u": _prepare(atmos_vars_ds["u_component_of_wind"].values),
        "v": _prepare(atmos_vars_ds["v_component_of_wind"].values),
        "q": _prepare(atmos_vars_ds["specific_humidity"].values),
        "z": _prepare(atmos_vars_ds["geopotential"].values),
    },
    metadata=Metadata(
        # Flip the latitudes! We need to copy because converting to PyTorch, because the
        # data must be contiguous.
        lat=torch.from_numpy(surf_vars_ds.latitude.values[::-1].copy()),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element. Select the third time point.
        time=(surf_vars_ds.time.values.astype("datetime64[s]").tolist()[2],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.level.values),
    ),
)

test_data = TyphoonTrajectoryDataset("/data/jiyilun/typhoon/download", 2022, 2022, 8, 8, with_hres_t0=True)
x, y, batch_my_ds = test_data[173]

  return self.fget.__get__(instance, owner)()


In [4]:
# Difference of the official batch and the batch of my dataset.
print(f"surface var:")
for var_name in ["2t", "10u", "10v", "msl"]:
    print(f"{var_name}:")
    print(batch.surf_vars[var_name] - batch_my_ds.surf_vars[var_name])

print(f"static var:")
for var_name in ["z", "slt", "lsm"]:
    print(f"{var_name}:")
    print(batch.static_vars[var_name] - batch_my_ds.static_vars[var_name])

print(f"atmospheric var:")
for var_name in ["t", "u", "v", "q", "z"]:
    print(f"{var_name}:")
    print(batch.atmos_vars[var_name] - batch_my_ds.atmos_vars[var_name])

print(f"metadata:")
print(f"lat:")
print(batch.metadata.lat - batch_my_ds.metadata.lat)
print(f"lon:")
print(batch.metadata.lon - batch_my_ds.metadata.lon)
print(f"time:")
print(batch.metadata.time[0] - batch_my_ds.metadata.time[0])
print(f"atmos_levels:")
print([batch.metadata.atmos_levels[i] - batch_my_ds.metadata.atmos_levels[i] for i in range(len(batch.metadata.atmos_levels))])


surface var:
2t:
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])
10u:
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 