In [None]:
from datetime import datetime

import torch

from aurora import Batch, Metadata

batch = Batch(
    surf_vars={k: torch.randn(1, 2, 17, 32) for k in ("2t", "10u", "10v", "msl")},

    
    
    static_vars={k: torch.randn(17, 32) for k in ("lsm", "z", "slt")},
    
    
    
    
    atmos_vars={k: torch.randn(1, 2, 4, 17, 32) for k in ("z", "u", "v", "t", "q")},
    
    
    
    
    metadata=Metadata(
        lat=torch.linspace(90, -90, 17),
        lon=torch.linspace(0, 360, 32 + 1)[:-1],
        time=(datetime(2020, 6, 1, 12, 0),),
        atmos_levels=(100, 250, 500, 850),
    ),
)

In [13]:
print(batch.surf_vars["2t"])  # torch.Size([1, 2, 17, 32])

tensor([[[[-1.1464,  0.6358,  0.4453,  ..., -1.0646, -1.1846, -1.1706],
          [-0.9513,  0.6625,  0.2254,  ..., -1.1913,  1.6600, -1.2473],
          [ 0.4726, -0.5628, -0.6538,  ..., -0.5497,  1.0078,  2.0545],
          ...,
          [-1.9612,  0.3900,  1.2261,  ..., -0.0078, -1.1267,  1.7106],
          [-0.9128, -0.2846, -0.6508,  ...,  1.7923,  1.4557,  0.0425],
          [ 0.5050, -1.0462, -0.4515,  ...,  0.2574,  0.5874,  2.3478]],

         [[-0.5091,  0.7613, -0.0786,  ..., -0.2584, -0.3589,  1.1545],
          [ 1.1219, -0.2821,  1.8236,  ...,  0.3181,  0.1216, -1.1880],
          [ 0.2346,  0.4229,  1.4610,  ...,  0.5245, -0.2199,  0.0477],
          ...,
          [-0.1309, -0.3338, -1.7106,  ..., -0.1362,  0.0769,  0.1425],
          [ 1.5336, -0.0494, -1.5605,  ...,  0.4046,  1.0033,  0.2758],
          [ 0.6112,  1.1969, -0.5227,  ...,  1.4016, -0.0073,  0.2193]]]])


In [2]:
from aurora import AuroraSmallPretrained

model = AuroraSmallPretrained()

In [3]:
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")
model.eval()

AuroraSmallPretrained(
  (encoder): Perceiver3DEncoder(
    (surf_mlp): MLP(
      (net): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (surf_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (pos_embed): Linear(in_features=256, out_features=256, bias=True)
    (scale_embed): Linear(in_features=256, out_features=256, bias=True)
    (lead_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (absolute_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (atmos_levels_embed): Linear(in_features=256, out_features=256, bias=True)
    (surf_token_embeds): LevelPatchEmbed(
      (weights): ParameterDict(
          (10u): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
          (10v): Parameter containing: [torch.FloatTensor of size 25

In [15]:
model = model.to("cuda")

with torch.inference_mode():
    pred = model.forward(batch)

In [19]:
print(pred)

Batch(surf_vars={'2t': tensor([[[[20.0437, 18.6310, 15.5559, 16.6797, 20.4803, 19.8742, 16.5494,
           17.5925, 20.1460, 18.7401, 16.0305, 17.0197, 20.3183, 19.2618,
           15.5291, 15.7694, 19.9460, 18.7969, 15.5286, 15.8089, 20.4991,
           19.5414, 16.2314, 17.1690, 19.7086, 18.4739, 15.4209, 15.3688,
           20.8304, 19.9917, 16.7790, 17.1962],
          [17.1007, 18.0583, 23.4591, 19.1319, 17.5022, 18.5994, 24.1140,
           19.4763, 17.0396, 18.0845, 23.6106, 19.6529, 17.0598, 17.7908,
           22.6020, 17.9711, 17.0170, 17.9990, 23.0249, 18.6161, 17.4302,
           18.5978, 24.1008, 19.4451, 16.5497, 17.2712, 22.4158, 18.1747,
           17.9031, 18.9515, 24.0762, 19.4297],
          [15.5054, 18.1775, 22.6843, 23.0163, 15.8080, 18.2989, 22.7220,
           23.4379, 14.9272, 17.9347, 22.4323, 22.8665, 15.4889, 17.8208,
           21.8174, 22.3219, 15.6401, 18.2738, 22.4199, 22.7979, 16.0986,
           18.8228, 23.3233, 23.7350, 15.3592, 17.8300, 21.9549, 22

In [24]:
def __repr__(self):
        s = [f"Batch: {len(self.surf_vars)} surf, {len(self.static_vars)} static, {len(self.atmos_vars)} atmos"]
        s.append(f"  Spatial shape: {self.spatial_shape}")
        s.append(f"  Time: {self.metadata.time}")
        s.append(f"  Levels: {self.metadata.atmos_levels}")
        for name, tensor in self.surf_vars.items():
            s.append(f"    surf[{name}]: {tuple(tensor.shape)}, mean={tensor.mean():.2f}")
        for name, tensor in self.atmos_vars.items():
            s.append(f"    atmos[{name}]: {tuple(tensor.shape)}, mean={tensor.mean():.2f}")
        return "\n".join(s)

In [26]:
print(__repr__(pred))

Batch: 4 surf, 3 static, 5 atmos
  Spatial shape: torch.Size([16, 32])
  Time: (datetime.datetime(2020, 6, 1, 18, 0),)
  Levels: (100, 250, 500, 850)
    surf[2t]: (1, 1, 16, 32), mean=21.11
    surf[10u]: (1, 1, 16, 32), mean=11.78
    surf[10v]: (1, 1, 16, 32), mean=8.90
    surf[msl]: (1, 1, 16, 32), mean=7846.57
    atmos[z]: (1, 1, 4, 16, 32), mean=80111.38
    atmos[u]: (1, 1, 4, 16, 32), mean=9.22
    atmos[v]: (1, 1, 4, 16, 32), mean=-0.23
    atmos[t]: (1, 1, 4, 16, 32), mean=240.83
    atmos[q]: (1, 1, 4, 16, 32), mean=0.00
