In [1]:
## import packages
import torch
import numpy as np
import xarray as xr
from model import ORCADLConfig, ORCADLModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Download the contents in stat and ckpt folder

## Prepare input data

# salinity, potential temp, sea surface temp, zonal current, meridional current, sea surface height, zonal wind stress, meridional wind stress
variables = ['salt', 'pottmp', 'sst', 'ucur', 'vcur', 'sshg', 'uflx', 'vflx'] 

# load mean and std
stat = {
    'mean': {v: np.load(f"./stat/mean/{v}.npy") for v in variables},
    'std': {v: np.load(f"./stat/std/{v}.npy") for v in variables}
}
## Note: the unit for temperatue is °C, salinity is g/kg, current is m/s, ssh is m, wind stress is N/m²

# load data
month = 0  # the corresponding statistical values ​​for each month are different
ocean_vars = []
atmo_vars = []
for v in variables[:-2]:
    ds = xr.open_dataset(f"./example_data/{v}.nc")
    # Important: the units of input data should be consistent with the units of the statistical values
    normed_data = (ds[v].values - stat['mean'][v][month]) / stat['std'][v][month]
    ocean_vars.append(normed_data if len(normed_data.shape) == 3 else normed_data[None])
for v in variables[-2:]:
    ds = xr.open_dataset(f"./example_data/{v}.nc")
    normed_data = (ds[v].values - stat['mean'][v][month]) / stat['std'][v][month]
    atmo_vars.append(normed_data[None])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
ocean_vars = torch.from_numpy(np.nan_to_num(np.concatenate(ocean_vars, axis=0)))[None].float().to(device) # (1, 66, 128, 360)
atmo_vars = torch.from_numpy(np.nan_to_num(np.concatenate(atmo_vars, axis=0)))[None].float().to(device) # (1, 2, 128, 360)

print(ocean_vars.shape, atmo_vars.shape)

Using device: cuda
torch.Size([1, 66, 128, 360]) torch.Size([1, 2, 128, 360])


In [3]:
## Setup ORCA-DL
model = ORCADLModel(ORCADLConfig.from_json_file('./model_config.json'))
model.load_state_dict(torch.load('./ckpt/seed_1.bin', map_location='cpu'))
model.to(device)
model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


ORCADLModel(
  (enc_ocean): OceanEncoders(
    (encoder_list): ModuleList(
      (0-1): 2 x EncoderModule(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(16, 96, kernel_size=(2, 3), stride=(2, 3))
          (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (stages): ModuleList(
          (0): SwinEncoderStage(
            (blocks): ModuleList(
              (0): SwinLayer(
                (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                (attn): WindowAttention(
                  (pos_embed): RotaryPosEmbed2D()
                  (qkv): Linear(in_features=96, out_features=288, bias=True)
                  (attn_drop): Dropout(p=0.0, inplace=False)
                  (proj): Linear(in_features=96, out_features=96, bias=True)
                  (proj_drop): Dropout(p=0.0, inplace=False)
                  (softmax): Softmax(dim=-1)
                )
                (drop_pa

In [3]:
from netCDF4 import Dataset

import xarray as xr

# 1. 定义.nc文件路径（修改为你的实际文件路径）
nc_file_path = "/mnt/data/zhu.yishun/ORCA-DL-main/data/train_data/2015-2029/zos_Omon_E3SM-1-1_historical_r1i1p1f1_gr_196501-196912.nc"
  # 替换为你的.nc文件路径，如so_Omon_E3SM-1-0_historical_r1i1p1f1_gr_200501-200912.nc

# 2. 读取.nc文件（xarray采用懒加载，大文件也不会占用过多内存）
try:
    ds = xr.open_dataset(nc_file_path)
except FileNotFoundError:
    print("错误：未找到指定的.nc文件，请检查文件路径是否正确！")
    exit()
except Exception as e:
    print(f"错误：读取文件失败，异常信息：{e}")
    exit()

# 3. 输出.nc文件的完整结构
print("=" * 80)
print("1. .nc文件整体结构概览（维度、变量、全局属性）：")
print("=" * 80)
# info() 方法输出核心结构信息，包括维度、变量、数据类型、属性等
ds.info()

print("\n" + "=" * 80)
print("2. 详细维度信息：")
print("=" * 80)
# 输出所有维度（名称、大小）
for dim_name, dim_size in ds.dims.items():
    print(f"维度名称：{dim_name}，维度大小：{dim_size}")

print("\n" + "=" * 80)
print("3. 详细变量信息：")
print("=" * 80)
# 输出所有变量（名称、维度、数据类型、属性）
for var_name, var_data in ds.variables.items():
    print(f"变量名称：{var_name}")
    print(f"  - 变量维度：{list(var_data.dims)}")
    print(f"  - 变量形状：{var_data.shape}")
    print(f"  - 数据类型：{var_data.dtype}")
    print(f"  - 变量属性：{dict(var_data.attrs)}")
    print("-" * 40)

print("\n" + "=" * 80)
print("4. 全局属性信息（文件描述、来源等）：")
print("=" * 80)
for attr_name, attr_value in ds.attrs.items():
    print(f"{attr_name}：{attr_value}")

# 4. 关闭数据集（释放资源，可选，xarray会自动管理）
ds.close()

1. .nc文件整体结构概览（维度、变量、全局属性）：
xarray.Dataset {
dimensions:
	time = 60 ;
	bnds = 2 ;
	lat = 180 ;
	lon = 360 ;

variables:
	object time(time) ;
		time:bounds = time_bnds ;
		time:axis = T ;
		time:long_name = time ;
		time:standard_name = time ;
	object time_bnds(time, bnds) ;
	float64 lat(lat) ;
		lat:bounds = lat_bnds ;
		lat:units = degrees_north ;
		lat:axis = Y ;
		lat:long_name = Latitude ;
		lat:standard_name = latitude ;
	float64 lat_bnds(lat, bnds) ;
	float64 lon(lon) ;
		lon:bounds = lon_bnds ;
		lon:units = degrees_east ;
		lon:axis = X ;
		lon:long_name = Longitude ;
		lon:standard_name = longitude ;
	float64 lon_bnds(lon, bnds) ;
	float32 zos(time, lat, lon) ;
		zos:standard_name = sea_surface_height_above_geoid ;
		zos:long_name = Sea Surface Height Above Geoid ;
		zos:comment = This is the dynamic sea level, so should have zero global area mean. It should not include inverse barometer depressions from sea ice. ;
		zos:units = m ;
		zos:cell_methods = area: mean where sea ti

In [5]:
## Run the model

with torch.no_grad():
    # single step
    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=1)
    print(output.preds.shape) # (1, 66, 128, 360)

    # Post-process the output
    preds = output.preds.detach().cpu().numpy()
    # salinity, potential temp, sea surface temp, zonal current, meridional current, sea surface height
    pred_all_variables = np.split(preds, model.split_chans, axis=1) # split by channels
    # The pred_all_variables contains the prediction of all ocaen variables and the order is the same as the input ocean variables.

    # inverse the normalization
    pred_sst = pred_all_variables[2] * stat['std']['sst'][month+1] + stat['mean']['sst'][month+1]  # stat should be the prediction month
    print(pred_sst.shape) # (1, 1, 128, 360)

    # multi steps
    steps = 6
    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=steps)
    print(output.preds.shape) # (1, steps, 66, 128, 360)

    # batch input
    batch_size = 4
    ocean_vars = ocean_vars.repeat(batch_size, 1, 1, 1)
    atmo_vars = atmo_vars.repeat(batch_size, 1, 1, 1)
    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=steps)
    print(output.preds.shape) # (batch_size, steps, 66, 128, 360)



torch.Size([1, 66, 128, 360])
(1, 1, 128, 360)
torch.Size([1, 6, 66, 128, 360])
torch.Size([4, 6, 66, 128, 360])
