In [7]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [8]:
import jax
import jax.numpy as jnp
from aurora import Aurora
from aurora.batch import Batch
import orbax.checkpoint as ocp
import xarray as xr
from aurora.model.encoder import Perceiver3DEncoder
from flax.core.frozen_dict import freeze, FrozenDict

# 1) Build your model and a dummy batch
model = AuroraSmall(use_lora=False)
rng = jax.random.PRNGKey(0)

download_path = Path("datasetEnviousScratch")
static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds = xr.open_dataset(download_path / "2023-01-01-surface-level.nc", engine="netcdf4")
atmos_vars_ds = xr.open_dataset(download_path / "2023-01-01-atmospheric.nc", engine="netcdf4")

i = 1  # Select this time index in the downloaded data.

batch = Batch(
    surf_vars={
        # First select time points `i` and `i - 1`. Afterwards, `[None]` inserts a
        # batch dimension of size one.
        "2t": jnp.array(surf_vars_ds["t2m"].values[[i - 1, i]][None]),
        "10u": jnp.array(surf_vars_ds["u10"].values[[i - 1, i]][None]),
        "10v": jnp.array(surf_vars_ds["v10"].values[[i - 1, i]][None]),
        "msl": jnp.array(surf_vars_ds["msl"].values[[i - 1, i]][None]),
    },
    static_vars={
        # The static variables are constant, so we just get them for the first time.
        "z": jnp.array(static_vars_ds["z"].values[0]),
        "slt": jnp.array(static_vars_ds["slt"].values[0]),
        "lsm": jnp.array(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        "t": jnp.array(atmos_vars_ds["t"].values[[i - 1, i]][None]),
        "u": jnp.array(atmos_vars_ds["u"].values[[i - 1, i]][None]),
        "v": jnp.array(atmos_vars_ds["v"].values[[i - 1, i]][None]),
        "q": jnp.array(atmos_vars_ds["q"].values[[i - 1, i]][None]),
        "z": jnp.array(atmos_vars_ds["z"].values[[i - 1, i]][None]),
    },
    metadata=Metadata(
        lat=jnp.array(surf_vars_ds.latitude.values, dtype=jnp.float32),
        lon=jnp.array(surf_vars_ds.longitude.values, dtype=jnp.float32),
        # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
        # `datetime.datetime`s. Note that this needs to be a tuple of length one:
        # one value for every batch element.
        time=(
            jnp.array(
                surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[1].timestamp(),
                dtype=jnp.int64,
            ),
        ),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
    ),
)
patch_size = 4
surf_state = FrozenDict({})

batch = batch.normalise(surf_stats=surf_state)
batch = batch.crop(patch_size=patch_size)



params_encoder = ocp.StandardCheckpointer().restore("/home1/a/akaush/aurora/checkpoints")
params_backbone = ocp.StandardCheckpointer().restore(
    "/home1/a/akaush/aurora/checkpointsTillBackbone"
)
params_decoder = ocp.StandardCheckpointer().restore("/home1/a/akaush/aurora/checkpointsTillDecoder")
params = {
    "encoder": params_encoder["encoder"],
    "backbone": params_backbone["backbone"],
    "decoder": params_decoder["decoder"],
}
params = jax.device_put(params, device=jax.devices("gpu")[0])

encoder_params = params['encoder']
encoder_params = jax.device_put(encoder_params, device=jax.devices("gpu")[0])


apply_fn = jax.jit(model.apply)
pred = apply_fn({"params": params}, batch, training=False, rng=rng)

encoder_rng, rng = jax.random.split(rng, 2)
encoder_apply_jit = jax.jit(model.encoder.apply)
timestep = 21600

ModuleNotFoundError: No module named 'jax'

In [None]:
x = encoder_apply_jit(
    {"params": encoder_params},
    batch,
    lead_time=timestep,
    training=False,
    rng=encoder_rng,
)