In [1]:
import xarray as xr
import numpy as np
from scipy.interpolate import interp1d

In [None]:

import sys
import os
import argparse

import xarray as xr
import numpy as np
from scipy.interpolate import interp1d

axis_candidates = {
    'time': ['t', 'time', 'time_counter'],
    'x': ['x', 'lon', 'x_grid_T', 'x_grid_U', 'x_grid_V', 'x_grid_W'],
    'y': ['y', 'lat', 'y_grid_T', 'y_grid_U', 'y_grid_V', 'y_grid_W'],
    'z': ['z', 'lev', 'nav_lev', 'depth', 'deptht', 'depthu', 'depthv', 'depthw']
}

def detect_axis(ds, axis_type, where='dims'):
    
    candidates = axis_candidates.get(axis_type, [])
    if where in ['dims', 'coords']:
        search_space = getattr(ds, where, {})  # ds.dims or ds.coords
    else:        
        search_space = ds.data_vars
        
    for candidate in candidates:
        if candidate in search_space:
            return candidate
    print(f"No {axis_type} {where} found among {candidates}")

    return None

def interpolate(data, old_depths, new_depths, axis):

    new_data = interp1d(old_depths, data, axis=axis, bounds_error=False, fill_value="extrapolate")
    
    return new_data(new_depths)


def main(input_nc, srcdomain_nc, dstdomain_nc):
    
    ds = xr.open_dataset(input_nc)
    srcdomain = xr.open_dataset(srcdomain_nc)
    dstdomain = xr.open_dataset(dstdomain_nc)
    
    # assign axis
    axes = {}
    for ax in ['time', 'x', 'y', 'z']:
        axes[ax] = detect_axis(ds, ax, where='dims')
        print(f"Detected axis {ax}: {axes[ax]}")

    # assign depths
    depth = detect_axis(srcdomain, 'z', where='vars')
    old_depths = srcdomain[depth].values
    new_depths = dstdomain[depth].values
    
    output_vars = {}
    encoding = {}
    for varname in ds.data_vars:
        print(varname)
        var = ds[varname]
        
        if axes['z'] not in var.dims:
            output_vars[varname] = var  # Keep the variable as is if no vertical dimension
            continue  # Skip variables without vertical dimension

        if axes['time'] in var.dims:
            data = []
            print(var.sizes[axes['time']])
            for t in range(var.sizes[axes['time']]):                
                slice_t = var.isel({axes['time']: t}).values
                interp_slice = interpolate(slice_t, old_depths, new_depths, axis=0)
                data.append(interp_slice)
            new_array = np.stack(data, axis=0)
            new_dims = (axes['time'], axes['z'], axes['y'], axes['x'])
            new_coords = var.coords
        else:
            data = var.values
            new_array = interpolate(data, old_depths, new_depths, axis=0)
            new_dims = var.dims
            new_coords = {}
            
        output_vars[varname] = xr.DataArray(new_array, dims=new_dims, coords=new_coords, name=varname, attrs=var.attrs)
        
        encoding[varname] = {
            '_FillValue': 9.96921e+36,
            'missing_value': 9.96921e+36,
            'zlib': True,
            'complevel': 4,
            'dtype': 'float32'
        }

    new_ds = xr.Dataset(output_vars)    
    new_ds.attrs = ds.attrs

    return new_ds
    

In [2]:
ds = xr.open_dataset('/perm/itas/data/nemo/woce/woce_temp_monthly_init_4p2.nc')

In [3]:
domain1 = xr.open_dataset('/ec/res4/hpcperm/itas/data/ece-4-database/nemo/domain/eORCA1/domain_cfg.nc')

In [4]:
domain2 = xr.open_dataset('/ec/res4/hpcperm/itas/data/ece-4-database/nemo/domain/ORCA2/domain_cfg.nc')

In [19]:
ds.data_vars

Data variables:
    contemp  (time, z, y, x) float64 858MB ...
    nav_lat  (y, x) float64 953kB ...
    nav_lev  (z) float64 600B ...
    nav_lon  (y, x) float64 953kB ...

In [None]:
var = ds['contemp']

In [25]:
if 'time' not in var.dims:
    print("Variable does not have a time dimension.")
else:
    print("Variable has a time dimension.")

Variable has a time dimension.


In [34]:
input_nc='/perm/itas/data/nemo/woce/woce_temp_monthly_init_4p2.nc'
srcdomain_nc='/ec/res4/hpcperm/itas/data/ece-4-database/nemo/domain/eORCA1/domain_cfg.nc'
dstdomain_nc='/ec/res4/hpcperm/itas/data/ece-4-database/nemo/domain/ORCA2/domain_cfg.nc'

data = main(input_nc, srcdomain_nc, dstdomain_nc)

Detected axis time: time
Detected axis x: x
Detected axis y: y
Detected axis z: z
contemp
12
nav_lat
nav_lev
nav_lon


In [38]:
data

In [30]:
old_depths = domain1['nav_lev'].values
new_depths = domain2['nav_lev'].values

In [74]:
ds = ds.assign_coords(nav_lev=("z", old_depths))

In [75]:
if 'nav_lev' not in ds.coords:
    print("Dataset does not contain 'nav_lev' coordinate.")
else:
    print("Dataset contains 'nav_lev' coordinate.")

Dataset contains 'nav_lev' coordinate.


In [34]:
def vertical_interp(field, old_z, new_z):
    # field: numpy array (..., z, ...)
    interp_func = interp1d(old_z, field, axis=0, bounds_error=False, fill_value="extrapolate")
    return interp_func(new_z)

In [39]:
# Select a single time index (e.g., t=0) to match the shape of old_depths along the interpolation axis
field = interp1d(old_depths, ds['thetao'].isel(time_counter=0).values, axis=0, bounds_error=False, fill_value="extrapolate")(new_depths)

In [52]:
interp_vars = {}
for var in ['thetao', 'so']:
    interp_data = []
    for t in range(ds.sizes['time_counter']):
        slice_t = ds[var].isel(time_counter=t).values
        interp_slice = vertical_interp(slice_t, old_depths, new_depths)
        interp_data.append(interp_slice)
    interp_array = np.stack(interp_data, axis=0)
    dims = ("time_counter", "z", "y", "x")
    coords = {"time_counter": ds["time_counter"], "z": new_depths, "y": ds["y"], "x": ds["x"]}
    interp_vars[var] = xr.DataArray(interp_array, dims=dims, coords=coords, name=var)

new_ds = xr.merge([interp_vars['thetao'], interp_vars['so']])