In [6]:
import xarray as xr
import xgcm
import numpy as np

In [165]:
domcfg = xr.open_dataset('../xbasin/tests/data/xnemogcm.domcfg_fr.nc')
metrics = {
    ('X',): ['e1t', 'e1u', 'e1v', 'e1f'],
    ('Y',): ['e2t', 'e2u', 'e2v', 'e2f'],
    ('Z',): ['e3t_0', 'e3u_0', 'e3v_0', 'e3f_0', 'e3w_0']
}
grid = xgcm.Grid(domcfg, metrics=metrics)
print(domcfg)

<xarray.Dataset>
Dimensions:    (x_c: 20, x_f: 20, y_c: 40, y_f: 40, z_c: 36, z_f: 36)
Coordinates:
  * x_c        (x_c) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * y_c        (y_c) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
  * x_f        (x_f) float64 0.5 1.5 2.5 3.5 4.5 ... 15.5 16.5 17.5 18.5 19.5
  * y_f        (y_f) float64 0.5 1.5 2.5 3.5 4.5 ... 35.5 36.5 37.5 38.5 39.5
  * z_c        (z_c) float32 5.0 15.0 25.0 35.0 ... 3379.915 3786.7092 4219.6006
  * z_f        (z_f) float32 4.5 14.5 24.5 34.5 ... 3379.415 3786.2092 4219.1006
Data variables:
    glamt      (y_c, x_c) float64 ...
    glamf      (y_f, x_f) float64 ...
    gphit      (y_c, x_c) float64 ...
    gphif      (y_f, x_f) float64 ...
    e1t        (y_c, x_c) float64 ...
    e1u        (y_c, x_f) float64 ...
    e1v        (y_f, x_c) float64 ...
    e1f        (y_f, x_f) float64 ...
    e2t        (y_c, x_c) float64 ...
    e2u        (y_c, x_f) float64 ...
    e2v        (y_f, x_c) fl

In [113]:
# We have data at a certain point
da = domcfg.e3t_0.copy(deep=True)

# We compute the depths by getting the vertical scale factor
e3 = grid.get_metric(da, 'Z')
depths = grid.cumsum(e3, axis='Z', boundary='fill', fill_value=0)

# Just to check that it is correct
(depths == domcfg.gdepw_0).all()

In [177]:
# We have data at a certain point
da = domcfg.e3w_0.copy(deep=True)

# We compute the depths by getting the vertical scale factor
e3 = grid.get_metric(da, 'Z')
print(e3.coords)
depths = grid.cumsum(e3, axis='Z', boundary='fill', fill_value=0) - e3.isel({'z_f':0}).drop_vars('z_f')/2

# Just to check that it is correct
print(np.abs(depths - domcfg.gdept_0).max())
print((depths - domcfg.gdept_0 != 0).sum())

Coordinates:
  * x_c      (x_c) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * y_c      (y_c) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
  * z_f      (z_f) float32 4.5 14.5 24.5 34.5 ... 3379.415 3786.2092 4219.1006
<xarray.DataArray ()>
array(4.54747351e-13)
<xarray.DataArray ()>
array(6)


In [125]:
z = grid.axes['Z']

In [129]:
z._get_axis_coord(da)

('left', 'z_f')

In [166]:
def compute_depth_of_shifted_array(grid, da, axis):
    # start to get the position of the data array
    axe = grid.axes[axis]
    (old_pos, old_dim) = axe._get_axis_coord(da)
    new_pos = axe._default_shifts[old_pos]
    assert (old_pos not in ['inner', 'outer']) and (new_pos not in ['inner', 'outer'])
    new_dim = axe.coords[new_pos]
    e3 = grid.get_metric(da, axes=axis)
    depths = grid.cumsum(e3, axis='Z', boundary='fill', fill_value=0)
    # If the shifted position is a center point, we need to remove half of the upper scale factor to get the depth
    if new_pos == 'center':
        depths -= e3.isel({old_dim:0}).drop_vars(old_dim)/2
    return depths

deptht = compute_depth_of_shifted_array(grid, domcfg.e3w_0, 'Z')
depthw = compute_depth_of_shifted_array(grid, domcfg.e3t_0, 'Z')
depthuw = compute_depth_of_shifted_array(grid, domcfg.e3u_0, 'Z')
depthvw = compute_depth_of_shifted_array(grid, domcfg.e3v_0, 'Z')
depthfw = compute_depth_of_shifted_array(grid, domcfg.e3f_0, 'Z')

assert (depthw == domcfg.gdepw_0).all()
assert (deptht - domcfg.gdept_0 <= 1e-12).all() # error of 1e-13 on some points

In [175]:
(depthuw - grid.interp(depthw, axis='X'))[:,1:-1,1:-1].max()

In [None]:
(depthuw - grid.interp(depthw, axis='X'))[:,1:-1,1:-1].max()