In [3]:
from datetime import datetime

import torch

from aurora import Batch, Metadata

batch = Batch(
    surf_vars={k: torch.randn(1, 2, 17, 32) for k in ("2t", "10u", "10v", "msl")},
    static_vars={k: torch.randn(17, 32) for k in ("lsm", "z", "slt")},
    atmos_vars={k: torch.randn(1, 2, 4, 17, 32) for k in ("z", "u", "v", "t", "q")},
    metadata=Metadata(
        lat=torch.linspace(90, -90, 17),
        lon=torch.linspace(0, 360, 32 + 1)[:-1],
        time=(datetime(2020, 6, 1, 12, 0),),
        atmos_levels=(100, 250, 500, 850),
    ),
)

In [4]:
from aurora import Aurora

model = Aurora()

In [5]:
from aurora import AuroraSmallPretrained

model = AuroraSmallPretrained()

In [6]:
model.load_checkpoint()


In [7]:
model.eval()


AuroraSmallPretrained(
  (encoder): Perceiver3DEncoder(
    (surf_mlp): MLP(
      (net): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=256, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (surf_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (pos_embed): Linear(in_features=256, out_features=256, bias=True)
    (scale_embed): Linear(in_features=256, out_features=256, bias=True)
    (lead_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (absolute_time_embed): Linear(in_features=256, out_features=256, bias=True)
    (atmos_levels_embed): Linear(in_features=256, out_features=256, bias=True)
    (surf_token_embeds): LevelPatchEmbed(
      (weights): ParameterDict(
          (10u): Parameter containing: [torch.FloatTensor of size 256x1x2x4x4]
          (10v): Parameter containing: [torch.FloatTensor of size 25

In [8]:
model = model.to("cuda")

In [9]:
with torch.inference_mode():
    pred = model.forward(batch)

In [10]:
type(pred)

aurora.batch.Batch

In [11]:
pred.surf_vars.keys()

dict_keys(['2t', '10u', '10v', 'msl'])

In [12]:
pred

Batch(surf_vars={'2t': tensor([[[[20.9424, 19.1772, 15.9650, 16.4379, 20.7231, 19.6562, 16.3069,
           16.7527, 21.0247, 19.6238, 16.0983, 16.3456, 21.7314, 20.2208,
           16.9228, 17.7034, 20.9521, 19.4392, 16.7328, 17.6951, 21.2990,
           20.0245, 17.0677, 17.8823, 21.9993, 20.1649, 16.9025, 17.8705,
           22.3823, 20.9948, 17.7374, 18.7185],
          [17.7776, 18.0994, 22.9549, 18.8170, 17.3239, 18.0355, 23.0388,
           18.9840, 17.7986, 18.5142, 23.0349, 18.5105, 18.2674, 18.8969,
           23.9762, 19.7894, 18.3739, 18.9633, 24.1160, 20.1490, 18.5201,
           19.1021, 24.4223, 20.2393, 19.2858, 19.7445, 24.6481, 20.7181,
           19.8323, 20.4846, 25.4977, 21.0825],
          [16.1193, 18.3776, 22.4140, 23.0987, 15.8212, 18.1243, 22.1546,
           22.9271, 15.9934, 18.3437, 22.1802, 22.3493, 16.0675, 18.7025,
           23.0259, 23.5865, 16.4751, 18.8062, 22.9919, 23.4730, 17.0847,
           19.4280, 23.8191, 24.5496, 17.5213, 19.8416, 23.9122, 24

In [13]:
from aurora import AuroraWave


In [14]:
model=AuroraWave()

In [15]:
model.load_checkpoint()

In [18]:
#1= scalar variable, 2= number of timesteps in input, 17= number of latitudes, 32= number of longitudes

batch = Batch(
    surf_vars={k: torch.randn(1, 2, 17, 32) for k in ("2t", "10u", "10v", "swh", "mwd", "mwp", "pp1d", "shww", "mdww", "mpww", "shts", "mdts", "mpts", "swh1", "mwd1", "mwp1", "swh2", "mwd2", "mwp2", "10u_wave", "10v_wave", "wind")},
    static_vars={k: torch.randn(17, 32) for k in ("lsm", "z", "slt","z","wmb")},
    atmos_vars={k: torch.randn(1, 2, 4, 17, 32) for k in ("z", "u", "v", "t", "q")},
    metadata=Metadata(
        lat=torch.linspace(90, -90, 17),
        lon=torch.linspace(0, 360, 32 + 1)[:-1],
        time=(datetime(2020, 6, 1, 12, 0),),
        atmos_levels=(100, 250, 500, 850),
    ),
)

In [20]:
pred= model.forward(batch)


AssertionError: 