Skip to content

Commit

Permalink
Merge pull request #971 from jthielen/unit-aware-selection
Browse files Browse the repository at this point in the history
Add unit- and axis-aware selection to the metpy accessors
  • Loading branch information
jrleeman committed Dec 18, 2018
2 parents 3e8341b + d0748fe commit 113de63
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 10 deletions.
49 changes: 49 additions & 0 deletions metpy/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,52 @@ def test_find_axis_name_bad_identifier(test_var):
def test_cf_parse_with_grid_mapping(test_var):
"""Test cf_parse dont delete grid_mapping attribute."""
assert test_var.grid_mapping == 'Lambert_Conformal'


def test_data_array_loc_get_with_units(test_var):
"""Test the .loc indexer on the metpy accessor."""
truth = test_var.loc[:, 850.]
assert truth.identical(test_var.metpy.loc[:, 8.5e4 * units.Pa])


def test_data_array_loc_set_with_units(test_var):
"""Test the .loc indexer on the metpy accessor for setting."""
temperature = test_var.copy()
temperature.metpy.loc[:, 8.5e4 * units.Pa] = np.nan
assert np.isnan(temperature.loc[:, 850.]).all()
assert not np.isnan(temperature.loc[:, 700.]).any()


def test_data_array_sel_dict_with_units(test_var):
"""Test .sel on the metpy accessor with dictionary."""
truth = test_var.squeeze().loc[500.]
assert truth.identical(test_var.metpy.sel({'time': '1987-04-04T18:00:00',
'isobaric': 5e4 * units.Pa}))


def test_data_array_sel_kwargs_with_units(test_var):
"""Test .sel on the metpy accessor with kwargs and axis type."""
truth = test_var.loc[:, 500.][..., 122]
assert truth.identical(test_var.metpy.sel(vertical=5e4 * units.Pa, x=-16.569 * units.km,
tolerance=1., method='nearest'))


def test_dataset_loc_with_units(test_ds):
"""Test .loc on the metpy accessor for Datasets using slices."""
truth = test_ds[{'isobaric': slice(6, 17)}]
assert truth.identical(test_ds.metpy.loc[{'isobaric': slice(8.5e4 * units.Pa,
5e4 * units.Pa)}])


def test_dataset_sel_kwargs_with_units(test_ds):
"""Test .sel on the metpy accessor for Datasets with kwargs."""
truth = test_ds[{'time': 0, 'y': 50, 'x': 122}]
assert truth.identical(test_ds.metpy.sel(time='1987-04-04T18:00:00', y=-1.464e6 * units.m,
x=-17. * units.km, tolerance=1.,
method='nearest'))


def test_dataset_loc_without_dict(test_ds):
"""Test that .metpy.loc for Datasets raises error when used with a non-dict."""
with pytest.raises(TypeError):
test_ds.metpy.loc[:, 700 * units.hPa]
86 changes: 86 additions & 0 deletions metpy/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import xarray as xr
from xarray.core.accessors import DatetimeAccessor
from xarray.core.indexing import expanded_indexer
from xarray.core.utils import either_dict_or_kwargs, is_dict_like

from .units import DimensionalityError, units

Expand Down Expand Up @@ -143,6 +145,38 @@ def find_axis_name(self, axis):
raise ValueError('Given axis is not valid. Must be an axis number, a dimension '
'coordinate name, or a standard axis type.')

class _LocIndexer(object):
"""Provide the unit-wrapped .loc indexer for data arrays."""

def __init__(self, data_array):
self.data_array = data_array

def expand(self, key):
"""Parse key using xarray utils to ensure we have dimension names."""
if not is_dict_like(key):
labels = expanded_indexer(key, self.data_array.ndim)
key = dict(zip(self.data_array.dims, labels))
return key

def __getitem__(self, key):
key = _reassign_quantity_indexer(self.data_array, self.expand(key))
return self.data_array.loc[key]

def __setitem__(self, key, value):
key = _reassign_quantity_indexer(self.data_array, self.expand(key))
self.data_array.loc[key] = value

@property
def loc(self):
"""Make the LocIndexer available as a property."""
return self._LocIndexer(self._data_array)

def sel(self, indexers=None, method=None, tolerance=None, drop=False, **indexers_kwargs):
"""Wrap DataArray.sel to handle units."""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
indexers = _reassign_quantity_indexer(self._data_array, indexers)
return self._data_array.sel(indexers, method=method, tolerance=tolerance, drop=drop)


@xr.register_dataset_accessor('metpy')
class CFConventionHandler(object):
Expand Down Expand Up @@ -369,6 +403,27 @@ def _resolve_axis_conflict(self, axis, coord_lists):
+ ' coordinate. Specify the unique axes using the coordinates argument.')
coord_lists[axis] = []

class _LocIndexer(object):
"""Provide the unit-wrapped .loc indexer for datasets."""

def __init__(self, dataset):
self.dataset = dataset

def __getitem__(self, key):
parsed_key = _reassign_quantity_indexer(self.dataset, key)
return self.dataset.loc[parsed_key]

@property
def loc(self):
"""Make the LocIndexer available as a property."""
return self._LocIndexer(self._dataset)

def sel(self, indexers=None, method=None, tolerance=None, drop=False, **indexers_kwargs):
"""Wrap Dataset.sel to handle units."""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
indexers = _reassign_quantity_indexer(self._dataset, indexers)
return self._dataset.sel(indexers, method=method, tolerance=tolerance, drop=drop)


def preprocess_xarray(func):
"""Decorate a function to convert all DataArray arguments to pint.Quantities.
Expand Down Expand Up @@ -410,3 +465,34 @@ def strftime(self, date_format):
return strs.values.reshape(values.shape)

DatetimeAccessor.strftime = strftime


def _reassign_quantity_indexer(data, indexers):
"""Reassign a units.Quantity indexer to units of relevant coordinate."""
def _to_magnitude(val, unit):
try:
return val.to(unit).m
except AttributeError:
return val

for coord_name in indexers:
# Handle axis types for DataArrays
if (isinstance(data, xr.DataArray) and coord_name not in data.dims
and coord_name in readable_to_cf_axes):
axis = coord_name
coord_name = next(data.metpy.coordinates(axis)).name
indexers[coord_name] = indexers[axis]
del indexers[axis]

# Handle slices of quantities
if isinstance(indexers[coord_name], slice):
start = _to_magnitude(indexers[coord_name].start, data[coord_name].metpy.units)
stop = _to_magnitude(indexers[coord_name].stop, data[coord_name].metpy.units)
step = _to_magnitude(indexers[coord_name].step, data[coord_name].metpy.units)
indexers[coord_name] = slice(start, stop, step)

# Handle quantities
indexers[coord_name] = _to_magnitude(indexers[coord_name],
data[coord_name].metpy.units)

return indexers
36 changes: 26 additions & 10 deletions tutorials/xarray_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# Any import of metpy will activate the accessors
import metpy.calc as mpcalc
from metpy.testing import get_test_data
from metpy.units import units

#########################################################################
# Getting Data
Expand Down Expand Up @@ -75,10 +76,9 @@
# MetPy's DataArray accessor has a ``unit_array`` property to obtain a ``pint.Quantity`` array
# of just the data from the DataArray (metadata is removed) and a ``convert_units`` method to
# convert the the data from one unit to another (keeping it as a DataArray). For now, we'll
# just use ``convert_units`` to convert our pressure coordinates to ``hPa``.
# just use ``convert_units`` to convert our temperature to ``degC``.

data['isobaric1'].metpy.convert_units('hPa')
data['isobaric3'].metpy.convert_units('hPa')
data['temperature'].metpy.convert_units('degC')

#########################################################################
# Coordinates
Expand Down Expand Up @@ -115,6 +115,21 @@
# To verify, we can inspect all their names
print([coord.name for coord in (x, y, vertical, time)])

#########################################################################
# Indexing and Selecting Data
# ---------------------------
#
# MetPy provides wrappers for the usual xarray indexing and selection routines that can handle
# quantities with units. For DataArrays, MetPy also allows using the coordinate axis types
# mentioned above as aliases for the coordinates. And so, if we wanted 850 hPa heights,
# we would take:

print(data['height'].metpy.sel(vertical=850 * units.hPa))

#########################################################################
# For full details on xarray indexing/selection, see
# `xarray's documentation <http://xarray.pydata.org/en/stable/indexing.html>`_.

#########################################################################
# Projections
# -----------
Expand Down Expand Up @@ -149,7 +164,7 @@
lat, lon = xr.broadcast(y, x)
f = mpcalc.coriolis_parameter(lat)
dx, dy = mpcalc.lat_lon_grid_deltas(lon, lat, initstring=data_crs.proj4_init)
heights = data['height'].loc[time[0]].loc[{vertical.name: 500.}]
heights = data['height'].metpy.loc[{'time': time[0], 'vertical': 500. * units.hPa}]
u_geo, v_geo = mpcalc.geostrophic_wind(heights, f, dx, dy)
print(u_geo)
print(v_geo)
Expand Down Expand Up @@ -177,7 +192,7 @@
# takes a ``DataArray`` input, but returns unit arrays for use in other calculations. We could
# rewrite the above geostrophic wind example using this helper function as follows:

heights = data['height'].loc[time[0]].loc[{vertical.name: 500.}]
heights = data['height'].metpy.loc[{'time': time[0], 'vertical': 500. * units.hPa}]
lat, lon = xr.broadcast(y, x)
f = mpcalc.coriolis_parameter(lat)
dx, dy = mpcalc.grid_deltas_from_dataarray(heights)
Expand All @@ -198,14 +213,15 @@
# <http://xarray.pydata.org/en/stable/plotting.html>`_.)

# A very simple example example of a plot of 500 hPa heights
data['height'].loc[time[0]].loc[{vertical.name: 500.}].plot()
data['height'].metpy.loc[{'time': time[0], 'vertical': 500. * units.hPa}].plot()
plt.show()

#########################################################################

# Let's add a projection and coastlines to it
ax = plt.axes(projection=ccrs.LambertConformal())
data['height'].loc[time[0]].loc[{vertical.name: 500.}].plot(ax=ax, transform=data_crs)
data['height'].metpy.loc[{'time': time[0],
'vertical': 500. * units.hPa}].plot(ax=ax, transform=data_crs)
ax.coastlines()
plt.show()

Expand All @@ -214,7 +230,7 @@
# Or, let's make a full 500 hPa map with heights, temperature, winds, and humidity

# Select the data for this time and level
data_level = data.loc[{vertical.name: 500., time.name: time[0]}]
data_level = data.metpy.loc[{time.name: time[0], vertical.name: 500. * units.hPa}]

# Create the matplotlib figure and axis
fig, ax = plt.subplots(1, 1, figsize=(12, 8), subplot_kw={'projection': data_crs})
Expand All @@ -235,7 +251,7 @@
h_contour.clabel(fontsize=8, colors='k', inline=1, inline_spacing=8,
fmt='%i', rightside_up=True, use_clabeltext=True)
t_contour = ax.contour(x, y, data_level['temperature'], colors='xkcd:deep blue',
levels=range(248, 276, 2), alpha=0.8, linestyles='--')
levels=range(-26, 4, 2), alpha=0.8, linestyles='--')
t_contour.clabel(fontsize=8, colors='xkcd:deep blue', inline=1, inline_spacing=8,
fmt='%i', rightside_up=True, use_clabeltext=True)

Expand All @@ -247,7 +263,7 @@
edgecolor='#c7c783', zorder=0)

# Set a title and show the plot
ax.set_title('500 hPa Heights (m), Temperature (K), Humidity (%) at '
ax.set_title('500 hPa Heights (m), Temperature (\u00B0C), Humidity (%) at '
+ time[0].dt.strftime('%Y-%m-%d %H:%MZ'))
plt.show()

Expand Down

0 comments on commit 113de63

Please sign in to comment.