In [166]:
## import packages
import torch
import numpy as np
import xarray as xr
from model import ORCADLConfig, ORCADLModel

In [167]:
## Download the contents in stat and ckpt folder

## Prepare input data

# salinity, potential temp, sea surface temp, zonal current, meridional current, sea surface height, zonal wind stress, meridional wind stress
variables = ['salt', 'pottmp', 'sst', 'ucur', 'vcur', 'sshg', 'uflx', 'vflx'] 

# load mean and std
stat = {
    'mean': {v: np.load(f"./stat/mean/{v}.npy") for v in variables},
    'std': {v: np.load(f"./stat/std/{v}.npy") for v in variables}
}

# load data
month = 0  # the corresponding statistical values ​​for each month are different
ocean_vars = []
atmo_vars = []
for v in variables[:-2]:
    ds = xr.open_dataset(f"./example_data/{v}.nc")
    normed_data = (ds[v].values - stat['mean'][v][month]) / stat['std'][v][month]
    ocean_vars.append(normed_data if len(normed_data.shape) == 3 else normed_data[None])
for v in variables[-2:]:
    ds = xr.open_dataset(f"./example_data/{v}.nc")
    normed_data = (ds[v].values - stat['mean'][v][month]) / stat['std'][v][month]
    atmo_vars.append(normed_data[None])

device = 'cuda' if torch.cuda.is_available() else 'cpu'

ocean_vars = torch.from_numpy(np.nan_to_num(np.concatenate(ocean_vars, axis=0)))[None].float().to(device) # (1, 66, 128, 360)
atmo_vars = torch.from_numpy(np.nan_to_num(np.concatenate(atmo_vars, axis=0)))[None].float().to(device) # (1, 2, 128, 360)

print(ocean_vars.shape, atmo_vars.shape)

torch.Size([1, 66, 128, 360]) torch.Size([1, 2, 128, 360])


In [168]:
## Setup ORCA-DL
model = ORCADLModel(ORCADLConfig.from_json_file('./model_config.json'))
model.load_state_dict(torch.load('./ckpt/seed_1.bin', map_location='cpu'))
model.to(device)
model.eval()

In [169]:
## Run the model

with torch.no_grad():
    # single step
    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=1)
    print(output.preds.shape) # (1, 66, 128, 360)

    # Post-process the output
    preds = output.preds.detach().cpu().numpy()
    # salinity, potential temp, sea surface temp, zonal current, meridional current, sea surface height
    pred_all_variables = np.split(preds, model.split_chans, axis=1) # split by channels
    # The pred_all_variables contains the prediction of all ocaen variables and the order is the same as the input ocean variables.

    # inverse the normalization
    pred_sst = pred_all_variables[2] * stat['std']['sst'][month+1] + stat['mean']['sst'][month+1]  # stat should be the prediction month
    print(pred_sst.shape) # (1, 1, 128, 360)

    # multi steps
    steps = 6
    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=steps)
    print(output.preds.shape) # (1, steps, 66, 128, 360)

    # batch input
    batch_size = 4
    ocean_vars = ocean_vars.repeat(batch_size, 1, 1, 1)
    atmo_vars = atmo_vars.repeat(batch_size, 1, 1, 1)
    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=steps)
    print(output.preds.shape) # (batch_size, steps, 66, 128, 360)



torch.Size([1, 66, 128, 360])
(1, 1, 128, 360)
torch.Size([1, 6, 66, 128, 360])
torch.Size([4, 6, 66, 128, 360])
