In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import dataclasses
import time
from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
import xarray as xr
from flax.core.frozen_dict import FrozenDict
from jax import config, tree_util

from aurora import AuroraSmall, Batch, Metadata
from aurora.model.decoder import Perceiver3DDecoder
from aurora.model.encoder import Perceiver3DEncoder
from aurora.model.swin3d import Swin3DTransformerBackbone

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config.update("jax_enable_x64", True)
config.update("jax_log_compiles", True)

model = AuroraSmall(use_lora=False)
rng = jax.random.PRNGKey(0)

download_path = Path("datasetEnviousScratch")
static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(download_path / "2023-01-01-surface-level.nc", engine="netcdf4")
atmos_vars_ds = xr.open_dataset(download_path / "2023-01-01-atmospheric.nc", engine="netcdf4")

i = 1  # Select this time index in the downloaded data.

batch = Batch(
    surf_vars={
        # First select time points `i` and `i - 1`. Afterwards, `[None]` inserts a
        # batch dimension of size one.
        "2t": jnp.array(surf_vars_ds["t2m"].values[[i - 1, i]][None]),
        "10u": jnp.array(surf_vars_ds["u10"].values[[i - 1, i]][None]),
        "10v": jnp.array(surf_vars_ds["v10"].values[[i - 1, i]][None]),
        "msl": jnp.array(surf_vars_ds["msl"].values[[i - 1, i]][None]),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time.
        "z": jnp.array(static_vars_ds["z"].values[0]),
        "slt": jnp.array(static_vars_ds["slt"].values[0]),
        "lsm": jnp.array(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": jnp.array(atmos_vars_ds["t"].values[[i - 1, i]][None]),
        "u": jnp.array(atmos_vars_ds["u"].values[[i - 1, i]][None]),
        "v": jnp.array(atmos_vars_ds["v"].values[[i - 1, i]][None]),
        "q": jnp.array(atmos_vars_ds["q"].values[[i - 1, i]][None]),
        "z": jnp.array(atmos_vars_ds["z"].values[[i - 1, i]][None]),
    },
    metadata=Metadata(
        lat=jnp.array(surf_vars_ds.latitude.values, dtype=jnp.float32),
        lon=jnp.array(surf_vars_ds.longitude.values, dtype=jnp.float32),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element.
        time=(
            jnp.array(
                surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[1].timestamp(),
                dtype=jnp.int64,
            ),
        ),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
    ),
)
patch_size = 4
surf_state = FrozenDict({})

batch = batch.normalise(surf_stats=surf_state)
batch = batch.crop(patch_size=patch_size)
B, T = next(iter(batch.surf_vars.values())).shape[:2]
batch = dataclasses.replace(
    batch,
    static_vars={
        k: jnp.tile(jnp.expand_dims(v, (0, 1)), (B, T, 1, 1)) for k, v in batch.static_vars.items()
    },
)


params_encoder = ocp.StandardCheckpointer().restore("/home1/a/akaush/aurora/checkpoints")
params_backbone = ocp.StandardCheckpointer().restore(
    "/home1/a/akaush/aurora/checkpointsTillBackbone"
)
params_decoder = ocp.StandardCheckpointer().restore("/home1/a/akaush/aurora/checkpointsTillDecoder")
params = {
    "encoder": params_encoder["encoder"],
    "backbone": params_backbone["backbone"],
    "decoder": params_decoder["decoder"],
}
params = jax.device_put(params, device=jax.devices("gpu")[0])
encoder_rng, rng = jax.random.split(rng, 2)
timestep = 21600











In [4]:
def benchmark_encoder(timestep, encoder_module, encoder_vars, batch, rng, n_warmup=1, n_runs=20):
    jitted_enc = jax.jit(
        partial(
            encoder_module.apply,
            encoder_vars,
            lead_time=timestep,
            training=False,
        )
    )

    for _ in range(n_warmup):
        rng, enc_rng = jax.random.split(rng)
        out = jitted_enc(batch, rng=enc_rng)
        out.block_until_ready()

    print("completed all compilation")

    timings = []
    for _ in range(n_runs):
        rng, enc_rng = jax.random.split(rng)
        t0 = time.perf_counter()
        out = jitted_enc(batch, rng=enc_rng)
        out.block_until_ready()
        t1 = time.perf_counter()
        timings.append((t1 - t0) * 1000.0)  # ms

    arr = np.array(timings)
    print(f"\nEncoder over {n_runs} runs: mean {arr.mean():.2f} ms  ± {arr.std():.2f} ms\n")

In [5]:
def benchmark_backbone(
    backbone_module, backbone_vars, x, patch_res, rollout_step, timestep, rng, n_warmup=1, n_runs=20
):
    jitted_back = jax.jit(
        partial(
            backbone_module.apply,
            backbone_vars,
            lead_time=timestep,
            patch_res=patch_res,
            rollout_step=rollout_step,
            training=False,
        )
    )

    timings = []
    for _ in range(n_warmup):
        t0 = time.perf_counter()
        rng, brng = jax.random.split(rng)
        out = jitted_back(x, rng=brng)
        out.block_until_ready()
        t1 = time.perf_counter()
        timings.append((t1 - t0) * 1000.0)
    arr = np.array(timings)
    print(f"backbone ➔ mean warmup {arr.mean():.2f} ms  ± {arr.std():.2f} ms over {n_runs} runs")

    print("completed all compilation")

    # timed runs
    times = []
    for _ in range(n_runs):
        rng, brng = jax.random.split(rng)
        t0 = time.perf_counter()
        out = jitted_back(x, rng=brng)
        out.block_until_ready()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000.0)

    arr = np.array(times)
    print(f"Backbone ➔ mean {arr.mean():.2f} ms  ± {arr.std():.2f} ms over {n_runs} runs")

In [6]:
def benchmark_decoder(
    decoder_module,
    decoder_vars,
    x,
    batch,
    patch_res,
    timestep,
    rng,
    n_warmup=1,
    n_runs=20,
):
    """
    JIT‐compile only the Perceiver3DDecoder, then measure its execution time.
    """
    # 1) bind static kwargs and JIT it
    jitted_dec = jax.jit(
        partial(
            decoder_module.apply,
            decoder_vars,
            lead_time=timestep,
            patch_res=patch_res,
            training=False,
        )
    )

    def block_all(pytree):
        return tree_util.tree_map(lambda arr: arr.block_until_ready(), pytree)

    # 2) Warm‐up (compile + first run) — not timed
    timings = []
    for _ in range(n_warmup):
        rng, drng = jax.random.split(rng)
        out = jitted_dec(x, batch, rng=drng)
        block_all(out)

    print("completed all compilations")

    # 3) Timed runs
    timings = []
    for _ in range(n_runs):
        rng, drng = jax.random.split(rng)
        t0 = time.perf_counter()
        out = jitted_dec(x, batch, rng=drng)
        block_all(out)
        t1 = time.perf_counter()
        timings.append((t1 - t0) * 1000.0)

    arr = np.array(timings)
    print(f"Decoder ➔ mean {arr.mean():.2f} ms  ± {arr.std():.2f} ms over {n_runs} runs")

In [7]:
encoder = Perceiver3DEncoder(
    surf_vars_temp=model.surf_vars,
    static_vars=model.static_vars,
    atmos_vars=model.atmos_vars,
    patch_size=model.patch_size,
    embed_dim=model.embed_dim,
    num_heads=model.num_heads,
    drop_rate=model.drop_rate,
    mlp_ratio=model.mlp_ratio,
    head_dim=model.embed_dim // model.num_heads,
    depth=model.enc_depth,
    latent_levels=model.latent_levels,
    max_history_size=model.max_history_size,
    perceiver_ln_eps=model.perceiver_ln_eps,
    stabilise_level_agg=model.stabilise_level_agg,
)
encoder_vars = {"params": params["encoder"]}
benchmark_encoder(timestep, encoder, encoder_vars, batch, rng)

















completed all compilation

Encoder over 20 runs: mean 79.52 ms  ± 1.86 ms



In [8]:
jitted_enc = jax.jit(
    partial(
        encoder.apply,
        encoder_vars,
        lead_time=model.timestep,
        training=False,
    )
)
rng, enc_rng = jax.random.split(rng)
x = jitted_enc(batch, rng=enc_rng)
x.block_until_ready()

# prepare patch_res & rollout_step
H, W = batch.spatial_shape
patch_res = (model.latent_levels, H // model.patch_size, W // model.patch_size)
rollout_step = batch.metadata.rollout_step





In [9]:
backbone = Swin3DTransformerBackbone(
    window_size_temp=model.window_size,
    encoder_depths=model.encoder_depths,
    encoder_num_heads=model.encoder_num_heads,
    decoder_depths=model.decoder_depths,
    decoder_num_heads=model.decoder_num_heads,
    embed_dim=model.embed_dim,
    mlp_ratio=model.mlp_ratio,
    drop_path_rate=model.drop_path,
    drop_rate=model.drop_rate,
    use_lora=model.use_lora,
    lora_steps=model.lora_steps,
    lora_mode=model.lora_mode,
)
backbone_vars = {"params": params["backbone"]}

In [10]:
benchmark_backbone(
    backbone,
    backbone_vars,
    x,
    patch_res,
    rollout_step,
    timestep=model.timestep,
    rng=rng,
    n_warmup=1,
    n_runs=20,
)



































backbone ➔ mean warmup 25531.36 ms  ± 0.00 ms over 20 runs
completed all compilation
Backbone ➔ mean 1039.23 ms  ± 12.81 ms over 20 runs


In [11]:
jitted_back = jax.jit(
    partial(
        backbone.apply,
        backbone_vars,
        lead_time=model.timestep,
        patch_res=patch_res,
        rollout_step=rollout_step,
        training=False,
    )
)

rng, brng = jax.random.split(rng)
x = jitted_back(x, rng=brng)
x.block_until_ready()

decoder = Perceiver3DDecoder(
    surf_vars=model.surf_vars,
    atmos_vars=model.atmos_vars,
    patch_size=model.patch_size,
    embed_dim=model.embed_dim * 2,
    head_dim=(model.embed_dim * 2) // model.num_heads,
    num_heads=model.num_heads,
    depth=model.dec_depth,
    mlp_ratio=model.dec_mlp_ratio,
    perceiver_ln_eps=model.perceiver_ln_eps,
)
decoder_vars = {"params": params["decoder"]}

















In [12]:
benchmark_decoder(
    decoder,
    decoder_vars,
    x,
    batch,
    patch_res,
    timestep=model.timestep,
    rng=rng,
    n_warmup=1,
    n_runs=20,
)







completed all compilations
Decoder ➔ mean 151.92 ms  ± 1.60 ms over 20 runs
