In [None]:
from datetime import datetime

import torch

from aurora import Batch, Metadata

batch = Batch(
    surf_vars={k: torch.randn(1, 2, 17, 32) for k in ("2t", "10u", "10v", "msl")},

    
    
    static_vars={k: torch.randn(17, 32) for k in ("lsm", "z", "slt")},
    
    
    
    
    atmos_vars={k: torch.randn(1, 2, 4, 17, 32) for k in ("z", "u", "v", "t", "q")},
    
    
    
    
    metadata=Metadata(
        lat=torch.linspace(90, -90, 17),
        lon=torch.linspace(0, 360, 32 + 1)[:-1],
        time=(datetime(2020, 6, 1, 12, 0),),
        atmos_levels=(100, 250, 500, 850),
    ),
)

In [13]:
print(batch.surf_vars["2t"])  # torch.Size([1, 2, 17, 32])

tensor([[[[-1.1464,  0.6358,  0.4453,  ..., -1.0646, -1.1846, -1.1706],
          [-0.9513,  0.6625,  0.2254,  ..., -1.1913,  1.6600, -1.2473],
          [ 0.4726, -0.5628, -0.6538,  ..., -0.5497,  1.0078,  2.0545],
          ...,
          [-1.9612,  0.3900,  1.2261,  ..., -0.0078, -1.1267,  1.7106],
          [-0.9128, -0.2846, -0.6508,  ...,  1.7923,  1.4557,  0.0425],
          [ 0.5050, -1.0462, -0.4515,  ...,  0.2574,  0.5874,  2.3478]],

         [[-0.5091,  0.7613, -0.0786,  ..., -0.2584, -0.3589,  1.1545],
          [ 1.1219, -0.2821,  1.8236,  ...,  0.3181,  0.1216, -1.1880],
          [ 0.2346,  0.4229,  1.4610,  ...,  0.5245, -0.2199,  0.0477],
          ...,
          [-0.1309, -0.3338, -1.7106,  ..., -0.1362,  0.0769,  0.1425],
          [ 1.5336, -0.0494, -1.5605,  ...,  0.4046,  1.0033,  0.2758],
          [ 0.6112,  1.1969, -0.5227,  ...,  1.4016, -0.0073,  0.2193]]]])


In [2]:
from aurora import AuroraSmallPretrained

model = AuroraSmallPretrained()

In [3]:
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
model.eval()

AuroraSmallPretrained(
  (encoder): Perceiver3DEncoder(
    (surf_mlp): MLP(
      (net): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (surf_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (pos_embed): Linear(in_features=256, out_features=256, bias=True)
    (scale_embed): Linear(in_features=256, out_features=256, bias=True)
    (lead_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (absolute_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (atmos_levels_embed): Linear(in_features=256, out_features=256, bias=True)
    (surf_token_embeds): LevelPatchEmbed(
      (weights): ParameterDict(
          (10u): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
          (10v): Parameter containing: [torch.FloatTensor of size 25

In [4]:
model = model.to("cuda")

with torch.inference_mode():
    pred = model.forward(batch)

In [5]:
print(pred)

Batch(surf_vars={'2t': tensor([[[[19.9701, 19.0364, 15.7767, 16.6024, 19.8266, 19.0175, 15.9789,
           17.0966, 21.2377, 20.6009, 17.4163, 17.6907, 21.6143, 21.0737,
           17.9236, 18.2475, 20.4469, 19.1351, 15.8234, 16.8593, 21.4810,
           20.4736, 16.9594, 17.3869, 20.8130, 19.6667, 16.6641, 18.1117,
           22.0461, 21.0992, 17.8457, 18.5716],
          [16.7538, 17.6721, 22.6890, 18.5505, 17.1299, 18.2172, 23.7073,
           19.6230, 18.1968, 19.0616, 23.8449, 19.1516, 18.6934, 19.6096,
           24.6696, 20.3122, 17.6131, 18.3799, 23.2815, 19.0904, 18.1641,
           18.9083, 23.9024, 19.3856, 17.8072, 18.6640, 24.1617, 20.0804,
           18.9976, 20.0550, 25.3648, 21.0304],
          [15.1945, 17.3554, 21.6204, 22.5040, 15.7787, 18.4281, 22.8770,
           23.5318, 16.8994, 19.0523, 22.6462, 23.0343, 16.9967, 19.4036,
           23.3286, 23.9851, 15.7477, 17.9239, 22.0143, 22.5868, 16.7141,
           18.9467, 23.1377, 23.7746, 16.0686, 18.5231, 22.9580, 23