Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimate DZU, DZT #44

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions ci/environment-upstream-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- pyyaml>=5.3.1
- scipy
- toolz
- zarr
- pip:
- git+https://github.com/pydata/xarray.git#egg=xarray
- git+https://github.com/dask/dask.git#egg=dask
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dependencies:
- xarray>=0.16.1
- xgcm
- watermark
- zarr
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Grid
~~~~

.. autosummary::
calc_dzu_dzt
get_grid


Expand Down Expand Up @@ -42,6 +43,7 @@ Utilities
~~~~~~~~~

.. autosummary::
four_point_min
lateral_fill


Expand All @@ -54,6 +56,8 @@ xgcm utilities

.. currentmodule:: pop_tools

.. autofunction:: calc_dzu_dzt

.. autofunction:: get_grid

.. autofunction:: eos
Expand Down
157 changes: 156 additions & 1 deletion pop_tools/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pooch
import xarray as xr
import yaml
from numba import jit, prange
from numba import double, float_, guvectorize, int_, jit, prange

try:
from tqdm import tqdm
Expand Down Expand Up @@ -463,3 +463,158 @@ def _compute_corners(ULAT, ULONG):
corner_lon[0, :, 3] = corner_lon[1, :, 3] - (corner_lon[2, :, 3] - corner_lon[1, :, 3])

return corner_lat, corner_lon


@guvectorize(
[
(int_[:, :], int_[:, :]),
(float_[:, :], float_[:, :]),
(double[:, :], double[:, :]),
],
'(n,m)->(n,m)',
nopython=True,
cache=True,
)
def numba_4pt_min(var, out):
"""
gufunc to calculate minimum over
(i, j+1) ————— (i+1, j+1)
| |
| |
(i,j) ————— (i+1, j)
at every depth level.

Expects and returns a 2d numpy array
"""
dim1, dim0 = var.shape
out[:] = 0

for j in prange(dim1 - 1):
for i in prange(dim0 - 1):
out[j, i] = np.min(
np.array([var[j, i], var[j + 1, i], var[j, i + 1], var[j + 1, i + 1]])
)


def four_point_min(array, dims=('nlat', 'nlon')):
"""
Utility function that calculates minimum at 4 surrounding points in 2D slices
along dimensions ``dims.

Output at (i,j) is minimium over the following 4 points
(i, j+1) ————— (i+1, j+1)
| |
| |
(i,j) ————— (i+1, j)

Parameters
----------
array: DataArray
A 2D or 3D DataArray

dims: tuple or list
two element tuple or list of dimension names

Returns
-------
DataArray
"""

import dask

if len(dims) != 2:
raise ValueError(f'Expected 2 dimensions. Received {dims} instead.')

array = array.transpose(..., *dims)
data = array.data

# map_overlap does not support negative axes :/
depth = {array.ndim - 2: (0, 1), array.ndim - 1: (0, 1)}

if dask.is_dask_collection(data):
result = data.map_overlap(numba_4pt_min, depth=depth, boundary='none', meta=data._meta)
else:
result = numba_4pt_min(data)

return array.copy(data=result)


def calc_dzu_dzt(grid):
"""
Calculates DZT and DZU from a dataset containing dz, KMT and DZBC

.. warning::

This function does not do the right thing at the tripole grid seam.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this out of scope? At this point, the tripole grid is the only one that we use partial bottom cells with.

I have treated the tripole seam in other instance where I am using roll like this:

# the tripole grid is periodic at the top: the left half of the top row maps to the 
# right half. If `ltripole == True`, I replace the bottom row of the 
# `KMT` array with the top row flipped left-to-right. 
kmt_rollable = ds.KMT.copy()
if ltripole:
    kmt_rollable[0, :] = kmt_rollable[-1, ::-1]

A similar treatment could be applied here, though would need to be implemented as compatible with your numba_4pt_min.


Parameters
----------
grid: Dataset
An xarray Dataset containing grid variables. This *must* contain partial bottom
cell information: KMT and DZBC. Datasets with dimensions renamed for xgcm are not
allowed.

Returns
-------
DZT, DZU: DataArray

Notes
-----
From Frank's zulip convo
https://zulip.cloud.ucar.edu/#narrow/stream/9-CGD-OCE/topic/pop-tools/near/2864
andersy005 marked this conversation as resolved.
Show resolved Hide resolved

DZT[:,:,k] = dz[k] if k< KMT-1 # converting from Fortran to python indexing
DZT[i,j,KMT[i,j]-1] = DZBC[i,j]
DZU = min of 4 surrounding DZT

"""

if not isinstance(grid, xr.Dataset):
raise ValueError(
f'Expected xarray Dataset with grid variables. Received {type(grid).__name__} instead.'
)
expected_vars = ['dz', 'KMT', 'DZBC']
missing_vars = set(expected_vars) - set(grid.variables)
if missing_vars:
raise ValueError(f'Variables {missing_vars} are missing in the provided dataset.')

dz = grid.dz
KMT = grid.KMT
DZBC = grid.DZBC

dzunit = dz.attrs.get('units', None)
zunit = {'units': dzunit} if dzunit is not None else {}

# build a 1D DataArray of z-index value
fortran_zindex = dz.copy(data=np.arange(1, grid.sizes['z_t'] + 1))

# set values at KMT to DZBC, else, use existing nominal dz
DZT = xr.where(fortran_zindex == KMT, DZBC, dz)
DZT.name = 'DZT'
DZT.attrs = {
'standard_name': 'cell_thickness',
'long_name': 'Thickness of T cells',
**zunit,
'grid_loc': '3111',
}

if 'nlon_t' in DZT.dims:
raise ValueError('datasets renamed for xgcm are not allowed.')

# now make DZU
DZU = four_point_min(DZT)
KMU = four_point_min(KMT)

# In Fortran-like code, DZU is computed using a WORK variable that has DZT values.
# Then only values above KMU are modified, so we replicate that here
# so that we can run tests and users can check against existing code
DZU = xr.where(fortran_zindex >= KMU, DZT, DZU)
DZU.name = 'DZU'
DZU.attrs = {
'standard_name': 'cell_thickness',
'long_name': 'Thickness of U cells',
**zunit,
'grid_loc': '3221',
}

return DZT, DZU
1 change: 1 addition & 0 deletions pop_tools/xgcm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _label_coord_grid_locs(ds):
'DZU': '3221',
'DZT': '3111',
'HT': '2110',
'DZBC': '2110',
dcherian marked this conversation as resolved.
Show resolved Hide resolved
'HU': '2220',
'HTE': '2210',
'HTN': '2120',
Expand Down
42 changes: 42 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import pytest
import xarray as xr
from xarray.testing import assert_equal

import pop_tools
from pop_tools import DATASETS
from pop_tools.datasets import UnzipZarr

from .util import ds_compare, is_ncar_host

Expand Down Expand Up @@ -43,3 +46,42 @@ def test_get_grid_to_netcdf():
gridfile = f'{grid}_{format}.nc'
ds.to_netcdf(gridfile, format=format)
os.system(f'rm -f {gridfile}')


def test_four_point_min_kmu():
zstore = DATASETS.fetch('comp-grid.tx9.1v3.20170718.zarr.zip', processor=UnzipZarr())
ds = xr.open_zarr(zstore)

# topmost row is wrong because we need to account for tripole seam
# rightmost nlon is wrong because it doesn't matter
expected = ds.KMU.isel(nlat=slice(-1), nlon=slice(-1))
actual = pop_tools.grid.four_point_min(ds.KMT).isel(nlat=slice(-1), nlon=slice(-1))
assert_equal(expected, actual)

# make sure dask & numpy results check out
actual = pop_tools.grid.four_point_min(ds.KMT.compute()).isel(nlat=slice(-1), nlon=slice(-1))
assert_equal(expected, actual)


def test_dzu_dzt():

zstore = DATASETS.fetch('comp-grid.tx9.1v3.20170718.zarr.zip', processor=UnzipZarr())
# chunk size is 300 along nlat; make sure we cross at least
# one chunk boundary to test map_overlap
ds = xr.open_zarr(zstore).sel(nlat=slice(100, 350))

dzu, dzt = pop_tools.grid.calc_dzu_dzt(ds)
# northernmost row will be wrong since we are working on a subset
assert_equal(dzu.isel(nlat=slice(-1)), ds['DZU'].isel(nlat=slice(-1)))
assert_equal(dzt, ds['DZT'])

_, xds = pop_tools.to_xgcm_grid_dataset(ds)
with pytest.raises(ValueError):
pop_tools.grid.calc_dzu_dzt(xds)

expected_vars = ['dz', 'KMT', 'DZBC']
for var in expected_vars:
dsc = ds.copy()
del dsc[var]
with pytest.raises(ValueError):
pop_tools.grid.calc_dzu_dzt(dsc)