In [9]:
from datetime import datetime

import torch

from aurora import Batch, Metadata

batch = Batch(
    surf_vars={k: torch.randn(1, 2, 17, 32) for k in ("2t", "10u", "10v", "msl")},
    static_vars={k: torch.randn(17, 32) for k in ("lsm", "z", "slt")},
    atmos_vars={k: torch.randn(1, 2, 4, 17, 32) for k in ("z", "u", "v", "t", "q")},
    metadata=Metadata(
        lat=torch.linspace(90, -90, 17),
        lon=torch.linspace(0, 360, 32 + 1)[:-1],
        time=(datetime(2020, 6, 1, 12, 0),),
        atmos_levels=(100, 250, 500, 850),
    ),
)

In [1]:
from aurora import AuroraSmallPretrained

model = AuroraSmallPretrained()

In [None]:
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
model.eval()

In [10]:
model = model.to("cuda")

with torch.inference_mode():
    pred = model.forward(batch)

In [None]:
print(pred)

In [13]:
# Detect whether CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")

Running on cuda


In [1]:
import torch, platform
print("torch", torch.__version__, "cuda", torch.version.cuda, "python", platform.python_version())


torch 2.7.1+cu128 cuda 12.8 python 3.13.3


In [2]:
import torch, os
assert torch.cuda.is_available(), "No CUDA GPU detected"
print("GPU:", torch.cuda.get_device_name(0))


GPU: NVIDIA GeForce RTX 4070 Laptop GPU
