In [1]:
from pathlib import Path
import pandas as pd
import os
import cdsapi
import torch
import xarray as xr
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

from aurora import Batch, Metadata, Aurora, rollout

In [7]:
DAY_START = pd.Timestamp("2025-10-13")
DAY_END = pd.Timestamp("2025-10-14")

DOWNLOAD_PATH = Path("../data/era5")
DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)

# Download all relevant ERA5 inputs for Aurora
Can download 4GB+ of data, depending on date range

In [8]:
c = cdsapi.Client(sleep_max=10)

def download_static():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "geopotential",
                "land_sea_mask",
                "soil_type",
            ],
            "year": "2023", # doesn't matter, doesn't change
            "month": "01",
            "day": "01",
            "time": "00:00",
            "format": "netcdf",
        },
        str(DOWNLOAD_PATH / "static.nc"),
    )
    print("Static variables downloaded!")

def download_data(file_name: Path, data_source, vars, pressure_levels, time=["00:00", "06:00", "12:00", "18:00"]):
    params = {
            "product_type": "reanalysis",
            "variable": vars,
            "year": str(DAY_START.year),
            "month": str(DAY_START.month).zfill(2),
            "day": [str(date.day).zfill(2) for date in pd.date_range(DAY_START, DAY_END, freq="D")],
            "time": time,
            "format": "netcdf",
        }
    if pressure_levels:
        params["pressure_level"] = pressure_levels

    file_name.parent.mkdir(parents=True, exist_ok=True)
    c.retrieve(
        data_source,
        params,
        str(file_name),
    )

In [9]:
if not (DOWNLOAD_PATH / "static.nc").exists():
   download_static()

#Download the surface-level variables.
surface_path = DOWNLOAD_PATH / "surf_vars" / f"{DAY_START.strftime('%Y-%m-%d')}_{DAY_END.day}-surface-level.nc"
if not surface_path.exists():
    download_data(surface_path, "reanalysis-era5-single-levels", ["2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind", "mean_sea_level_pressure"], None)
    print("Surface-level variables downloaded!")

# Download the atmospheric variables.
atmos_path = DOWNLOAD_PATH / "atmos_vars" / f"{DAY_START.strftime('%Y-%m-%d')}_{DAY_END.day}-atmospheric.nc"
if not atmos_path.exists():
    download_data(atmos_path, "reanalysis-era5-pressure-levels", 
                  ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "geopotential"], 
                  ["50", "100", "150", "200", "250", "300", "400", "500", "600", "700", "850", "925", "1000"])
    print("Atmospheric variables downloaded!")

2025-10-20 10:14:52,117 INFO Request ID is c9f81d91-ed7e-4794-b61f-1224a4ed0118
2025-10-20 10:14:52,352 INFO status has been updated to accepted
Recovering from connection error [('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))], attempt 1 of 500
Retrying in 10 seconds
Recovering from connection error [('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))], attempt 2 of 500
Retrying in 10 seconds
Recovering from connection error [('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))], attempt 3 of 500
Retrying in 10 seconds
Recovering from connection error [('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))], attempt 4 of 500
Retrying in 10 seconds
Recovering from connection error [('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))], attempt 5 of 500
Retrying in 10 seconds
2025-10-20

afad128a23c5c9d318d3f1ea512ee6fb.nc:   0%|          | 0.00/52.5M [00:00<?, ?B/s]

Surface-level variables downloaded!


2025-10-20 10:21:38,578 INFO Request ID is cd9b239f-9819-4a3e-88af-48537f7f1243
2025-10-20 10:21:38,756 INFO status has been updated to accepted
2025-10-20 10:21:44,089 INFO status has been updated to running
Recovering from connection error [('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))], attempt 1 of 500
Retrying in 10 seconds
2025-10-20 10:25:52,916 INFO status has been updated to successful


3e640c84dc6739939af7a7408c9a7a4f.nc:   0%|          | 0.00/826M [00:00<?, ?B/s]

Atmospheric variables downloaded!


# Convert data to batch for Aurora

In [10]:
static_vars_ds = xr.open_dataset(DOWNLOAD_PATH / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(surface_path, engine="netcdf4")
atmos_vars_ds = xr.open_dataset(atmos_path, engine="netcdf4")

In [11]:
batch = Batch(
    surf_vars={
        # First select the first two time points: 00:00 and 06:00. Afterwards, `[None]`
        # inserts a batch dimension of size one.
        "2t": torch.from_numpy(surf_vars_ds["t2m"].values[:2][None]),
        "10u": torch.from_numpy(surf_vars_ds["u10"].values[:2][None]),
        "10v": torch.from_numpy(surf_vars_ds["v10"].values[:2][None]),
        "msl": torch.from_numpy(surf_vars_ds["msl"].values[:2][None]),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time.
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": torch.from_numpy(atmos_vars_ds["t"].values[:2][None]),
        "u": torch.from_numpy(atmos_vars_ds["u"].values[:2][None]),
        "v": torch.from_numpy(atmos_vars_ds["v"].values[:2][None]),
        "q": torch.from_numpy(atmos_vars_ds["q"].values[:2][None]),
        "z": torch.from_numpy(atmos_vars_ds["z"].values[:2][None]),
    },
    metadata=Metadata(
        lat=torch.from_numpy(surf_vars_ds.latitude.values),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element. Select element 1, corresponding to time
        # 06:00.
        time=(surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[1],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
    ),
)

# Setup Aurora Model & Inference

In [12]:
model = Aurora(use_lora=False)  # The pretrained version does not use LoRA.
#model_path = hf_hub_download(repo_id=model.default_checkpoint_repo, filename="aurora-0.25-pretrained.ckpt", cache_dir="/workspace/aurora/model/")
#model.load_checkpoint_local(model_path)

model.eval()
model = model.to("cuda")
print("Num parameters: ", sum([p.numel() for p in model.parameters()]))

Num parameters:  1256300176


In [13]:
with torch.inference_mode():
        steps = surf_vars_ds["t2m"].shape[0] - 2
        preds = [pred.to("cpu") for pred in rollout(model, batch, steps=steps)]

In [20]:
preds[1].atmos_vars['t'].shape

torch.Size([1, 1, 13, 720, 1440])

In [None]:
import sys
sys.path.append("../")
from backend.cds_loader import CDSLoader

loader = CDSLoader(cache_dir="../data/era5/")