Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a function to get the full-size of xg and yg from input files #195

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions xmitgcm/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,49 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
dtype=np.dtype('d'), endian='>',
use_dask=False,
extra_metadata=None)



@pytest.mark.parametrize("usedask", [True, False])
def test_get_xg_yg_from_input(all_grid_datadirs, usedask):
from xmitgcm.utils import get_xg_yg_from_input, get_extra_metadata
from xmitgcm.utils import read_raw_data
dirname, expected = all_grid_datadirs
md = get_extra_metadata(domain=expected['domain'], nx=expected['nx'])
tx=30
ty=30
bl=[1,2,3]
ds = get_xg_yg_from_input(dirname + '/' + expected['gridfile'],
geometry=expected['geometry'],
dtype=np.dtype('d'), endian='>',
use_dask=usedask,
extra_metadata=md,
tilex=tx,tiley=ty,
blankList=bl)
# test types
assert type(ds) == xarray.Dataset
assert type(ds['XG']) == xarray.core.dataarray.DataArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use isinstance here


if usedask:
ds.load()

# check all variables are in
expected_variables = ['XG', 'YG']

for var in expected_variables:
assert type(ds[var]) == xarray.core.dataarray.DataArray
assert ds[var].values.shape[1] == tx+1
assert ds[var].values.shape[2] ==ty+1


# passing llc without metadata should fail
if expected['geometry'] == 'llc':
with pytest.raises(ValueError):
ds = get_xg_yg_from_input(dirname + '/' + expected['gridfile'],
geometry=expected['geometry'],
dtype=np.dtype('d'), endian='>',
use_dask=False,
extra_metadata=None)


@pytest.mark.parametrize("dtype", [np.dtype('d'), np.dtype('f')])
Expand Down
224 changes: 191 additions & 33 deletions xmitgcm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,46 @@ def _llc_data_shape(llc_id, nz=None):
return data_shape


def _file_metadata(endian='>',dtype=np.dtype('d'),extra_metadata=None):
file_metadata = {}
# grid variables are stored in this order
file_metadata['fldList'] = ['XC', 'YC', 'DXF', 'DYF', 'RAC',
'XG', 'YG', 'DXV', 'DYU', 'RAZ',
'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG']

file_metadata['vars'] = file_metadata['fldList']
dims_vars_list = []
for var in file_metadata['fldList']:
dims_vars_list.append(('ny', 'nx'))
file_metadata['dims_vars'] = dims_vars_list

# no vertical levels or time records
file_metadata['nz'] = 1
file_metadata['nt'] = 1

# for curvilinear non-facet grids (TO DO)
# if nx is not None:
# file_metadata['nx'] = nx
# if ny is not None:
# file_metadata['ny'] = ny
if extra_metadata is not None:
file_metadata.update(extra_metadata)

# numeric representation
file_metadata['endian'] = endian
file_metadata['dtype'] = dtype
return file_metadata

def _nxgrid_nygrid(file_metadata,kfacet):
if file_metadata['facet_orders'][kfacet] == 'C':
nxgrid = file_metadata['nx'] + 1
nygrid = file_metadata['ny_facets'][kfacet] + 1
elif file_metadata['facet_orders'][kfacet] == 'F':
nxgrid = file_metadata['ny_facets'][kfacet] + 1
nygrid = file_metadata['nx'] + 1
return nxgrid,nygrid


def read_all_variables(variable_list, file_metadata, use_mmap=False,
use_dask=False, chunks="3D"):
"""
Expand Down Expand Up @@ -1278,33 +1318,7 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
all grid variables
"""

file_metadata = {}
# grid variables are stored in this order
file_metadata['fldList'] = ['XC', 'YC', 'DXF', 'DYF', 'RAC',
'XG', 'YG', 'DXV', 'DYU', 'RAZ',
'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG']

file_metadata['vars'] = file_metadata['fldList']
dims_vars_list = []
for var in file_metadata['fldList']:
dims_vars_list.append(('ny', 'nx'))
file_metadata['dims_vars'] = dims_vars_list

# no vertical levels or time records
file_metadata['nz'] = 1
file_metadata['nt'] = 1

# for curvilinear non-facet grids (TO DO)
# if nx is not None:
# file_metadata['nx'] = nx
# if ny is not None:
# file_metadata['ny'] = ny
if extra_metadata is not None:
file_metadata.update(extra_metadata)

# numeric representation
file_metadata['endian'] = endian
file_metadata['dtype'] = dtype
file_metadata=_file_metadata(endian,dtype,extra_metadata)

if geometry == 'llc':
nfacets = 5
Expand All @@ -1328,12 +1342,8 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
fname = gridfile.replace('<NFACET>', str(kfacet+1).zfill(3))
grid_metadata['filename'] = fname

if file_metadata['facet_orders'][kfacet] == 'C':
nxgrid = file_metadata['nx'] + 1
nygrid = file_metadata['ny_facets'][kfacet] + 1
elif file_metadata['facet_orders'][kfacet] == 'F':
nxgrid = file_metadata['ny_facets'][kfacet] + 1
nygrid = file_metadata['nx'] + 1
nxgrid,nygrid=_nxgrid_nygrid(file_metadata,kfacet)


grid_metadata.update({'nx': nxgrid, 'ny': nygrid,
'has_faces': False})
Expand Down Expand Up @@ -1437,6 +1447,154 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',

return grid

def get_xg_yg_from_input(gridfile, nx=None, ny=None, geometry='llc',
dtype=np.dtype('d'), endian='>', use_dask=False,
extra_metadata=None, tilex=30, tiley=30, blankList=None):
"""
Read grid variables from grid input files, and tiles them according to
tilesizes input by the user, skipping blank tiles.
This function only reads xg and yg, and outputs all the values stored in the
input grid file (including the rightmost and uppermost xg and yg values).
It is useful for findingwhere a lat/lon point is on the llc grid.

PARAMETERS
----------
gridfile : str
gridfile must contain <NFACET> as wildcard (e.g. tile<NFACET>.mitgrid)
nx : int
size of the face in the x direction
ny : int
size of the face in the y direction
geometry : str
domain geometry can be llc, cs or carthesian not supported yet
dtype : np.dtype
numeric precision (single/double) of input data
endian : string
endianness of input data
use_dask : bool
use dask or not
extra_metadata : dict
dictionary of extra metadata, needed for llc configurations
tilex : int
size of tile in the x direction
tiley : int
size of tile in the y direction
blankList : arraylike
List of blank tiles (indexing starts at 1 so that you can copy
directly from data.exch2)
RETURNS
-------
grid : xarray.Dataset
all grid variables
"""

file_metadata=_file_metadata(endian,dtype,extra_metadata)

if geometry == 'llc':
nfacets = 5
try:
nfaces = len(file_metadata['face_facets'])
except:
raise ValueError('metadata must contain face_facets')
if geometry == 'cs': # pragma: no cover
raise NotImplementedError("'cs' geometry is not supported yet")

# create placeholders for data
gridfields = {}
for field in ['XG', 'YG']:#
gridfields.update({field: None})

if geometry == 'llc':
tileno=0
dummy=0
for kfacet in range(nfacets):
# we need to adapt the metadata to the grid file
grid_metadata = file_metadata.copy()

fname = gridfile.replace('<NFACET>', str(kfacet+1).zfill(3))
grid_metadata['filename'] = fname

nxgrid,nygrid=_nxgrid_nygrid(file_metadata,kfacet)

grid_metadata.update({'nx': nxgrid, 'ny': nygrid,
'has_faces': False})

raw = read_all_variables(grid_metadata['vars'], grid_metadata,
use_dask=use_dask)

rawfields = {}
for kfield in np.arange(len(file_metadata['fldList'])):

rawfields.update(
{file_metadata['fldList'][kfield]: raw[kfield]})

tiles_on_facet=(nxgrid-1)*(nygrid-1)//tilex//tiley
tile_in_x=(nxgrid-1)//tilex
tile_in_y=(nygrid-1)//tiley

for field in ['XG', 'YG']:
if field =='XG':
save_tile=tileno
else:
tileno=save_tile
if kfacet == 0:
dummy=0
# symetrize
tmp = rawfields[field][:, :, :, :].squeeze()
# transpose
if grid_metadata['facet_orders'][kfacet] == 'F':
tmp = tmp.transpose()

for tileon in range(0,tiles_on_facet):
tileno=tileno+1
if tileno not in blankList:
offsety=(tileon//tile_in_x)
offsetx=(tileon-offsety*tile_in_x)
#transpose facet if needed
tmpt=tmp
if file_metadata['facet_orders'][kfacet] == 'F':
tmpt=tmp.transpose()
# extract the data
dataface = tmpt[offsety*tiley:(offsety+1)*tiley+1,offsetx*tilex:(offsetx+1)*tilex+1]
# assign values
dataface = dsa.stack([dataface], axis=0)
if dummy == 0:
gridfields[field] = dataface
dummy=1
else:
gridfields[field] = dsa.concatenate(
[gridfields[field], dataface], axis=0)

elif geometry == 'cs': # pragma: no cover
raise NotImplementedError("'cs' geometry is not supported yet")
pass

# create the dataset
if geometry in ['llc', 'cs']:
ntile=gridfields['XG'].shape[0]
grid = xr.Dataset({'XG': (['tile', 'j_g', 'i_g'], gridfields['XG']),
'YG': (['tile', 'j_g', 'i_g'], gridfields['YG']),
},
coords={'i_g': (['i_g'],
np.arange(tilex+1)),
'j_g': (['j_g'],
np.arange(tiley+1)),
'tile': (['tile'], np.arange(ntile))
}
)
else: # pragma: no cover
grid = xr.Dataset({'XG': (['j_g', 'i_g'], gridfields['XG']),
'YG': (['j_g', 'i_g'], gridfields['YG']),
},
coords={'i_g': (['i_g'],
np.arange(tilex+1)),
'j_g': (['j_g'],
np.arange(tiley+1))
}
)

return grid


########## WRITING BINARIES #############################

Expand Down