Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup cbook #1520

Merged
merged 2 commits into from Oct 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/metpy/calc/kinematics.py
Expand Up @@ -7,7 +7,6 @@
from . import coriolis_parameter
from .tools import first_derivative, get_layer_heights, gradient
from .. import constants as mpconsts
from ..cbook import iterable
from ..package_tools import Exporter
from ..units import check_units, concatenate, units
from ..xarray import add_grid_arguments_from_xarray, preprocess_and_wrap
Expand All @@ -16,7 +15,7 @@


def _stack(arrs):
return concatenate([a[np.newaxis] if iterable(a) else a for a in arrs], axis=0)
return concatenate([a[np.newaxis] if np.iterable(a) else a for a in arrs], axis=0)


@exporter.export
Expand Down
4 changes: 2 additions & 2 deletions src/metpy/calc/tools.py
Expand Up @@ -11,7 +11,7 @@
from scipy.spatial import cKDTree
import xarray as xr

from ..cbook import broadcast_indices, result_type
from ..cbook import broadcast_indices
from ..interpolate import interpolate_1d, log_interpolate_1d
from ..package_tools import Exporter
from ..units import check_units, concatenate, units
Expand Down Expand Up @@ -385,7 +385,7 @@ def _get_bound_pressure_height(pressure, bound, height=None, interpolate=True):
# 1.13 always returns float64. This can cause upstream users problems,
# resulting in something like np.append() to upcast.
bound_pressure = (np.interp(np.atleast_1d(bound.m), height.m,
pressure.m).astype(result_type(bound))
pressure.m).astype(np.result_type(bound))
* pressure.units)
else:
idx = (np.abs(height - bound)).argmin()
Expand Down
23 changes: 1 addition & 22 deletions src/metpy/cbook.py
Expand Up @@ -90,25 +90,4 @@ def broadcast_indices(x, minv, ndim, axis):
return tuple(ret)


def iterable(value):
"""Determine if value can be iterated over."""
# Special case for pint Quantities
if hasattr(value, 'magnitude'):
value = value.magnitude
return np.iterable(value)


def result_type(value):
"""Determine the type for numpy type casting in a pint-version-safe way."""
try:
return np.result_type(value)
except TypeError:
if hasattr(value, 'dtype'):
return value.dtype
elif hasattr(value, 'magnitude'):
return np.result_type(value.magnitude)
else:
raise TypeError(f'Cannot determine dtype for type {type(value)}')


__all__ = ('Registry', 'broadcast_indices', 'get_test_data', 'iterable', 'result_type')
__all__ = ('Registry', 'broadcast_indices', 'get_test_data')
3 changes: 1 addition & 2 deletions src/metpy/xarray.py
Expand Up @@ -661,14 +661,13 @@ def parse_cf(self, varname=None, coordinates=None):
Parsed DataArray (if varname is a string) or Dataset

"""
from .cbook import iterable
from .plots.mapping import CFProjection

if varname is None:
# If no varname is given, parse all variables in the dataset
varname = list(self._dataset.data_vars)

if iterable(varname) and not isinstance(varname, str):
if np.iterable(varname) and not isinstance(varname, str):
# If non-string iterable is given, apply recursively across the varnames
subset = xr.merge([self.parse_cf(single_varname, coordinates=coordinates)
for single_varname in varname])
Expand Down
30 changes: 1 addition & 29 deletions tests/test_cbook.py
Expand Up @@ -3,12 +3,7 @@
# SPDX-License-Identifier: BSD-3-Clause
"""Test functionality of MetPy's utility code."""

import numpy as np
import pytest
import xarray as xr

from metpy.cbook import Registry, result_type
from metpy.units import units
from metpy.cbook import Registry


def test_registry():
Expand All @@ -19,26 +14,3 @@ def test_registry():
reg.register('mine')(a)

assert reg['mine'] is a


@pytest.mark.parametrize(
'test_input, expected_type_match, custom_dtype',
[(1.0, 1.0, None),
(1, 1, None),
(np.array(1.0), 1.0, None),
(np.array(1), 1, None),
(np.array([1, 2, 3], dtype=np.int32), 1, 'int32'),
(units.Quantity(1, units.m), 1, None),
(units.Quantity(1.0, units.m), 1.0, None),
(units.Quantity([1, 2.0], units.m), 1.0, None),
([1, 2, 3] * units.m, 1, None),
(xr.DataArray(data=[1, 2.0]), 1.0, None)])
def test_result_type(test_input, expected_type_match, custom_dtype):
"""Test result_type on the kinds of things common in MetPy."""
assert result_type(test_input) == np.array(expected_type_match, dtype=custom_dtype).dtype


def test_result_type_failure():
"""Test result_type failure on non-numeric types."""
with pytest.raises(TypeError):
result_type([False])