In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import time
import numpy as np
import torch

from aurora import AuroraSmall
from aurora.batch import Batch
from aurora.model.encoder import Perceiver3DEncoder
from aurora.model.swin3d import Swin3DTransformerBackbone
from aurora.model.decoder import Perceiver3DDecoder
from aurora.model.encoder import Perceiver3DEncoder
from aurora.model.swin3d import Swin3DTransformerBackbone
from aurora.model.decoder import Perceiver3DDecoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [None]:
model = AuroraSmall(use_lora=False)

In [4]:

model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
model.eval()
model = model.to("cuda")
i = 1  # Select this time index in the downloaded data.

batch = Batch(
    surf_vars={
        # First select time points `i` and `i - 1`. Afterwards, `[None]` inserts a
        # batch dimension of size one.
        "2t": jnp.array(surf_vars_ds["t2m"].values[[i - 1, i]][None]),
        "10u": jnp.array(surf_vars_ds["u10"].values[[i - 1, i]][None]),
        "10v": jnp.array(surf_vars_ds["v10"].values[[i - 1, i]][None]),
        "msl": jnp.array(surf_vars_ds["msl"].values[[i - 1, i]][None]),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time.
        "z": jnp.array(static_vars_ds["z"].values[0]),
        "slt": jnp.array(static_vars_ds["slt"].values[0]),
        "lsm": jnp.array(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": jnp.array(atmos_vars_ds["t"].values[[i - 1, i]][None]),
        "u": jnp.array(atmos_vars_ds["u"].values[[i - 1, i]][None]),
        "v": jnp.array(atmos_vars_ds["v"].values[[i - 1, i]][None]),
        "q": jnp.array(atmos_vars_ds["q"].values[[i - 1, i]][None]),
        "z": jnp.array(atmos_vars_ds["z"].values[[i - 1, i]][None]),
    },
    metadata=Metadata(
        lat=jnp.array(surf_vars_ds.latitude.values, dtype=jnp.float32),
        lon=jnp.array(surf_vars_ds.longitude.values, dtype=jnp.float32),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element.
        time=(
            jnp.array(
                surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[1].timestamp(),
                dtype=jnp.int64,
            ),
        ),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
    ),
)

H, W = batch.spatial_shape
patch_res   = (
    model.encoder.latent_levels,
    H // model.patch_size,
    W // model.patch_size,
)
rollout_step = batch.metadata.rollout_step

AttributeError: "Aurora" object has no attribute "load_checkpoint". If "load_checkpoint" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [None]:
def benchmark_encoder(model, batch, n_warmup=1, n_runs=20):
    # warm-up
    for _ in range(n_warmup):
        _ = model.encoder(batch, lead_time=model.timestep)
    torch.cuda.synchronize()

    # timed
    times = []
    for _ in range(n_runs):
        t0 = time.perf_counter()
        _ = model.encoder(batch, lead_time=model.timestep)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000.0)
    arr = np.array(times)
    print(f"Encoder: mean {arr.mean():.2f} ms ± {arr.std():.2f} ms over {n_runs} runs")


In [None]:
def benchmark_backbone(model, x, patch_res, rollout_step, n_warmup=1, n_runs=20):
    # warm-up
    for _ in range(n_warmup):
        _ = model.backbone(
            x, lead_time=model.timestep, patch_res=patch_res, rollout_step=rollout_step
        )
    torch.cuda.synchronize()

    # timed
    times = []
    for _ in range(n_runs):
        t0 = time.perf_counter()
        _ = model.backbone(
            x, lead_time=model.timestep, patch_res=patch_res, rollout_step=rollout_step
        )
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000.0)
    arr = np.array(times)
    print(f"Backbone: mean {arr.mean():.2f} ms ± {arr.std():.2f} ms over {n_runs} runs")

In [None]:
def benchmark_decoder(model, x, batch, patch_res, n_warmup=1, n_runs=20):
    # warm-up
    for _ in range(n_warmup):
        _ = model.decoder(
            x, batch, lead_time=model.timestep, patch_res=patch_res
        )
    torch.cuda.synchronize()

    # timed
    times = []
    for _ in range(n_runs):
        t0 = time.perf_counter()
        _ = model.decoder(
            x, batch, lead_time=model.timestep, patch_res=patch_res
        )
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000.0)
    arr = np.array(times)
    print(f"Decoder: mean {arr.mean():.2f} ms ± {arr.std():.2f} ms over {n_runs} runs")

In [None]:
benchmark_encoder(model, batch, n_warmup=1, n_runs=20)
x_enc = model.encoder(batch, lead_time=model.timestep)
torch.cuda.synchronize()

# 2) Backbone → x_back
benchmark_backbone(model, x_enc, patch_res, rollout_step, n_warmup=1, n_runs=20)
x_back = model.backbone(
    x_enc,
    lead_time=model.timestep,
    patch_res=patch_res,
    rollout_step=rollout_step,
)
torch.cuda.synchronize()

# 3) Decoder
benchmark_decoder(model, x_back, batch, patch_res, n_warmup=1, n_runs=20)