Skip to content

Commit

Permalink
Merge pull request #1087 from jthielen/pint-0.9
Browse files Browse the repository at this point in the history
Fix pint 0.9 errors from units.wraps and iterable
  • Loading branch information
dopplershift committed Jul 6, 2019
2 parents 754f7f2 + 1a5dcc6 commit b1e68bc
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Expand Up @@ -8,7 +8,7 @@ dependencies:
- numpy
- scipy
- matplotlib
- pint!=0.9
- pint
- netcdf4
- xarray
- pandas
Expand Down
9 changes: 8 additions & 1 deletion metpy/cbook.py
Expand Up @@ -6,7 +6,6 @@
import os

import numpy as np
from numpy import iterable
import pooch

from . import __version__
Expand Down Expand Up @@ -102,4 +101,12 @@ def broadcast_indices(x, minv, ndim, axis):
return tuple(ret)


def iterable(value):
"""Determine if value can be iterated over."""
# Special case for pint Quantities
if hasattr(value, 'magnitude'):
value = value.magnitude
return np.iterable(value)


__all__ = ('Registry', 'broadcast_indices', 'get_test_data', 'is_string_like', 'iterable')
21 changes: 18 additions & 3 deletions metpy/interpolate/one_dimension.py
Expand Up @@ -10,7 +10,6 @@

from ..cbook import broadcast_indices
from ..package_tools import Exporter
from ..units import units
from ..xarray import preprocess_xarray

exporter = Exporter(globals())
Expand Down Expand Up @@ -53,7 +52,6 @@ def interpolate_nans_1d(x, y, kind='linear'):

@exporter.export
@preprocess_xarray
@units.wraps(None, ('=A', '=A'))
def interpolate_1d(x, xp, *args, **kwargs):
r"""Interpolates data with any shape over a specified axis.
Expand Down Expand Up @@ -100,6 +98,9 @@ def interpolate_1d(x, xp, *args, **kwargs):
fill_value = kwargs.pop('fill_value', np.nan)
axis = kwargs.pop('axis', 0)

# Handle units
x, xp = _strip_matching_units(x, xp)

# Make x an array
x = np.asanyarray(x).reshape(-1)

Expand Down Expand Up @@ -175,7 +176,6 @@ def interpolate_1d(x, xp, *args, **kwargs):

@exporter.export
@preprocess_xarray
@units.wraps(None, ('=A', '=A'))
def log_interpolate_1d(x, xp, *args, **kwargs):
r"""Interpolates data with logarithmic x-scale over a specified axis.
Expand Down Expand Up @@ -222,7 +222,22 @@ def log_interpolate_1d(x, xp, *args, **kwargs):
fill_value = kwargs.pop('fill_value', np.nan)
axis = kwargs.pop('axis', 0)

# Handle units
x, xp = _strip_matching_units(x, xp)

# Log x and xp
log_x = np.log(x)
log_xp = np.log(xp)
return interpolate_1d(log_x, log_xp, *args, axis=axis, fill_value=fill_value)


def _strip_matching_units(*args):
"""Ensure arguments have same units and return with units stripped.
Replaces `@units.wraps(None, ('=A', '=A'))`, which breaks with `*args` handling for
pint>=0.9.
"""
if all(hasattr(arr, 'units') for arr in args):
return [arr.to(args[0].units).magnitude for arr in args]
else:
return args
5 changes: 5 additions & 0 deletions metpy/tests/test_units.py
Expand Up @@ -8,6 +8,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pint
import pytest

from metpy.testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -35,6 +36,8 @@ def test_concatenate_masked():
assert_array_equal(result.mask, np.array([False, True, False, False]))


@pytest.mark.skipif(pint.__version__ == '0.9', reason=('Currently broken upstream (see '
'pint#751'))
@pytest.mark.mpl_image_compare(tolerance=0, remove_text=True)
def test_axhline():
r"""Ensure that passing a quantity to axhline does not error."""
Expand All @@ -45,6 +48,8 @@ def test_axhline():
return fig


@pytest.mark.skipif(pint.__version__ == '0.9', reason=('Currently broken upstream (see '
'pint#751'))
@pytest.mark.mpl_image_compare(tolerance=0, remove_text=True)
def test_axvline():
r"""Ensure that passing a quantity to axvline does not error."""
Expand Down
7 changes: 6 additions & 1 deletion metpy/units.py
Expand Up @@ -18,6 +18,7 @@

import functools
import logging
import warnings

import numpy as np
import pint
Expand Down Expand Up @@ -46,6 +47,10 @@
except AttributeError:
log.warning('Failed to add gpm alias to meters.')

# Silence UnitStrippedWarning
if hasattr(pint, 'UnitStrippedWarning'):
warnings.simplefilter('ignore', category=pint.UnitStrippedWarning)


def pandas_dataframe_to_unit_arrays(df, column_units=None):
"""Attach units to data in pandas dataframes and return united arrays.
Expand Down Expand Up @@ -317,7 +322,7 @@ def dec(func):
try:
# Try to enable pint's built-in support
units.setup_matplotlib()
except (AttributeError, RuntimeError): # Pint's not available, try to enable our own
except (AttributeError, RuntimeError, ImportError): # Pint's not available, try our own
import matplotlib.units as munits

# Inheriting from object fixes the fact that matplotlib 1.4 doesn't
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -51,7 +51,7 @@

python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*',
install_requires=['matplotlib>=2.0.0', 'numpy>=1.12.0', 'scipy>=0.17.0',
'pint!=0.9', 'xarray>=0.10.7', 'enum34;python_version<"3.4"',
'pint', 'xarray>=0.10.7', 'enum34;python_version<"3.4"',
'contextlib2;python_version<"3.6"',
'pooch>=0.1', 'traitlets>=4.3.0'],
extras_require={
Expand Down

0 comments on commit b1e68bc

Please sign in to comment.