Skip to content

Commit

Permalink
Generalize testing version comparison (#3271)
Browse files Browse the repository at this point in the history
* Generalize version comparison

* Add xfail condition for old pint

* Limit xfail to appropriate scipy
  • Loading branch information
dcamron committed Nov 14, 2023
1 parent 40d7305 commit 8d6a48b
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 38 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ markers = "xfail_dask: marks tests as expected to fail with Dask arrays"
norecursedirs = "build docs .idea"
doctest_optionflags = "NORMALIZE_WHITESPACE"
mpl-results-path = "test_output"
xfail_strict = true

[tool.ruff]
line-length = 95
Expand Down
80 changes: 66 additions & 14 deletions src/metpy/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
"""
import contextlib
import functools
from importlib.metadata import requires, version
import operator as op
import re

import matplotlib
import numpy as np
import numpy.testing
from packaging.version import Version
Expand All @@ -23,37 +25,87 @@
from .deprecation import MetpyDeprecationWarning
from .units import units

MPL_VERSION = Version(matplotlib.__version__)

def version_check(version_spec):
"""Return comparison between the active module and a requested version number.
def mpl_version_before(ver):
"""Return whether the active matplotlib is before a certain version.
Will also validate specification against package metadata to alert if spec is irrelevant.
Parameters
----------
ver : str
The version string for a certain release
version_spec : str
Module version specification to validate against installed package. Must take the form
of `f'{module_name}{comparison_operator}{version_number}'` where `comparison_operator`
must be one of `['==', '=', '!=', '<', '<=', '>', '>=']`, eg `'metpy>1.0'`.
Returns
-------
bool : whether the current version was released before the passed in one
bool : Whether the installed package validates against the provided specification
"""
return MPL_VERSION < Version(ver)
comparison_operators = {
'==': op.eq, '=': op.eq, '!=': op.ne, '<': op.lt, '<=': op.le, '>': op.gt, '>=': op.ge,
}

# Match version_spec for groups of module name,
# comparison operator, and requested module version
module_name, comparison, version_number = _parse_version_spec(version_spec)

def mpl_version_equal(ver):
"""Return whether the active matplotlib is equal to a certain version.
# Check MetPy metadata for minimum required version of same package
metadata_spec = _get_metadata_spec(module_name)
_, _, minimum_version_number = _parse_version_spec(metadata_spec)

installed_version = Version(version(module_name))
specified_version = Version(version_number)
minimum_version = Version(minimum_version_number)

if specified_version < minimum_version:
raise ValueError(
f'Specified {version_spec} outdated according to MetPy minimum {metadata_spec}.')

try:
return comparison_operators[comparison](installed_version, specified_version)
except KeyError:
raise ValueError(
f'Comparison operator {comparison} not one of {list(comparison_operators)}.'
) from None


def _parse_version_spec(version_spec):
"""Parse module name, comparison, and version from pip-style package spec string.
Parameters
----------
version_spec : str
Package spec to parse
Returns
-------
tuple of str : Parsed specification groups of package name, comparison, and version
"""
pattern = re.compile(r'(\w+)\s*([<>!=]+)\s*([\d.]+)')
match = pattern.match(version_spec)

if not match:
raise ValueError(f'Invalid version specification {version_spec}.'
f'See version_check documentation for more information.')
else:
return match.groups()


def _get_metadata_spec(module_name):
"""Get package spec string for requested module from package metadata.
Parameters
----------
ver : str
The version string for a certain release
module_name : str
Name of MetPy required package to look up
Returns
-------
bool : whether the current version is equal to the passed in one
str : Package spec string for request module
"""
return MPL_VERSION == Version(ver)
return [entry for entry in requires('metpy') if module_name.lower() in entry.lower()][0]


def needs_module(module):
Expand Down
7 changes: 4 additions & 3 deletions tests/calc/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from metpy.calc import (bulk_shear, bunkers_storm_motion, critical_angle,
mean_pressure_weighted, precipitable_water, significant_tornado,
supercell_composite, weighted_continuous_average)
from metpy.testing import assert_almost_equal, assert_array_almost_equal, get_upper_air_data
from metpy.testing import (assert_almost_equal, assert_array_almost_equal, get_upper_air_data,
version_check)
from metpy.units import concatenate, units


Expand Down Expand Up @@ -130,15 +131,15 @@ def test_weighted_continuous_average():
assert_almost_equal(v, 6.900543760612305 * units('m/s'), 7)


@pytest.mark.xfail(reason='hgrecco/pint#1593')
@pytest.mark.xfail(condition=version_check('pint<0.21'), reason='hgrecco/pint#1593')
def test_weighted_continuous_average_temperature():
"""Test pressure-weighted mean temperature function with vertical interpolation."""
data = get_upper_air_data(datetime(2016, 5, 22, 0), 'DDC')
t, = weighted_continuous_average(data['pressure'],
data['temperature'],
height=data['height'],
depth=6000 * units('meter'))
assert_almost_equal(t, 279.3275828240889 * units('kelvin'), 7)
assert_almost_equal(t, 279.07450928270185 * units('kelvin'), 7)


def test_weighted_continuous_average_elevated():
Expand Down
10 changes: 5 additions & 5 deletions tests/calc/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import warnings

import numpy as np
import packaging.version
import pytest
import scipy
import xarray as xr

from metpy.calc import (brunt_vaisala_frequency, brunt_vaisala_frequency_squared,
Expand Down Expand Up @@ -39,7 +37,8 @@
virtual_temperature, virtual_temperature_from_dewpoint,
wet_bulb_temperature)
from metpy.calc.thermo import _find_append_zero_crossings
from metpy.testing import assert_almost_equal, assert_array_almost_equal, assert_nan
from metpy.testing import (assert_almost_equal, assert_array_almost_equal, assert_nan,
version_check)
from metpy.units import is_quantity, masked_array, units


Expand Down Expand Up @@ -201,8 +200,9 @@ def test_moist_lapse_starting_points(start, direction):
@pytest.mark.xfail(platform.machine() == 'aarch64',
reason='ValueError is not raised on aarch64')
@pytest.mark.xfail(platform.machine() == 'arm64', reason='ValueError is not raised on Mac M2')
@pytest.mark.xfail(sys.platform == 'win32', reason='solve_ivp() does not error on Windows')
@pytest.mark.xfail(packaging.version.parse(scipy.__version__) < packaging.version.parse('1.7'),
@pytest.mark.xfail((sys.platform == 'win32') and version_check('scipy<1.11.3'),
reason='solve_ivp() does not error on Windows + SciPy < 1.11.3')
@pytest.mark.xfail(version_check('scipy<1.7'),
reason='solve_ivp() does not error on Scipy < 1.7')
def test_moist_lapse_failure():
"""Test moist_lapse under conditions that cause the ODE solver to fail."""
Expand Down
12 changes: 7 additions & 5 deletions tests/plots/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from metpy.io.metar import parse_metar_file
from metpy.plots import (ArrowPlot, BarbPlot, ContourPlot, FilledContourPlot, ImagePlot,
MapPanel, PanelContainer, PlotGeometry, PlotObs, RasterPlot)
from metpy.testing import mpl_version_before, needs_cartopy
from metpy.testing import needs_cartopy, version_check
from metpy.units import units


Expand Down Expand Up @@ -334,8 +334,9 @@ def test_declarative_contour_cam():
return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True,
tolerance=3.71 if mpl_version_before('3.8') else 0.74)
@pytest.mark.mpl_image_compare(
remove_text=True,
tolerance=3.71 if version_check('matplotlib<3.8') else 0.74)
@needs_cartopy
def test_declarative_contour_options():
"""Test making a contour plot."""
Expand Down Expand Up @@ -428,8 +429,9 @@ def test_declarative_additional_layers_plot_options():
return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True,
tolerance=2.74 if mpl_version_before('3.8') else 1.91)
@pytest.mark.mpl_image_compare(
remove_text=True,
tolerance=2.74 if version_check('matplotlib<3.8') else 1.91)
@needs_cartopy
def test_declarative_contour_convert_units():
"""Test making a contour plot."""
Expand Down
12 changes: 7 additions & 5 deletions tests/plots/test_skewt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from metpy.plots import Hodograph, SkewT
from metpy.testing import mpl_version_before, mpl_version_equal
from metpy.testing import version_check
from metpy.units import units


Expand Down Expand Up @@ -155,8 +155,10 @@ def test_skewt_units():
skew.ax.axvline(-10, color='orange')

# On Matplotlib <= 3.6, ax[hv]line() doesn't trigger unit labels
assert skew.ax.get_xlabel() == ('degree_Celsius' if mpl_version_equal('3.7.0') else '')
assert skew.ax.get_ylabel() == ('hectopascal' if mpl_version_equal('3.7.0') else '')
assert skew.ax.get_xlabel() == (
'degree_Celsius' if version_check('matplotlib==3.7.0') else '')
assert skew.ax.get_ylabel() == (
'hectopascal' if version_check('matplotlib==3.7.0') else '')

# Clear them for the image test
skew.ax.set_xlabel('')
Expand Down Expand Up @@ -318,8 +320,8 @@ def test_hodograph_api():
return fig


@pytest.mark.mpl_image_compare(remove_text=True,
tolerance=0.6 if mpl_version_before('3.5') else 0.)
@pytest.mark.mpl_image_compare(
remove_text=True, tolerance=0.6 if version_check('matplotlib==3.5') else 0.)
def test_hodograph_units():
"""Test passing quantities to Hodograph."""
fig = plt.figure(figsize=(9, 9))
Expand Down
12 changes: 7 additions & 5 deletions tests/plots/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from metpy.plots import add_metpy_logo, add_timestamp, add_unidata_logo, convert_gempak_color
from metpy.testing import get_test_data, mpl_version_before
from metpy.testing import get_test_data, version_check


@pytest.mark.mpl_image_compare(tolerance=2.638, remove_text=True)
Expand Down Expand Up @@ -91,8 +91,9 @@ def test_add_logo_invalid_size():
add_metpy_logo(fig, size='jumbo')


@pytest.mark.mpl_image_compare(tolerance=1.072 if mpl_version_before('3.5') else 0,
remove_text=True)
@pytest.mark.mpl_image_compare(
tolerance=1.072 if version_check('matplotlib<3.5') else 0,
remove_text=True)
def test_gempak_color_image_compare():
"""Test creating a plot with all the GEMPAK colors."""
c = range(32)
Expand All @@ -111,8 +112,9 @@ def test_gempak_color_image_compare():
return fig


@pytest.mark.mpl_image_compare(tolerance=1.215 if mpl_version_before('3.5') else 0,
remove_text=True)
@pytest.mark.mpl_image_compare(
tolerance=1.215 if version_check('matplotlib<3.5') else 0,
remove_text=True)
def test_gempak_color_xw_image_compare():
"""Test creating a plot with all the GEMPAK colors using xw style."""
c = range(32)
Expand Down
26 changes: 25 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from metpy.deprecation import MetpyDeprecationWarning
from metpy.testing import (assert_array_almost_equal, check_and_drop_units,
check_and_silence_deprecation)
check_and_silence_deprecation, version_check)


# Test #1183: numpy.testing.assert_array* ignores any masked value, so work-around
Expand Down Expand Up @@ -42,3 +42,27 @@ def test_check_and_drop_units_with_dataarray():
assert isinstance(actual, np.ndarray)
assert isinstance(desired, np.ndarray)
np.testing.assert_array_almost_equal(actual, desired)


def test_module_version_check():
"""Test parsing and version comparison of installed package."""
numpy_version = np.__version__
assert version_check(f'numpy >={numpy_version}')


def test_module_version_check_outdated_spec():
"""Test checking test version specs against package metadata."""
with pytest.raises(ValueError, match='Specified numpy'):
version_check('numpy>0.0.0')


def test_module_version_check_nonsense():
"""Test failed pattern match of package specification."""
with pytest.raises(ValueError, match='Invalid version '):
version_check('thousands of birds picking packages')


def test_module_version_check_invalid_comparison():
"""Test invalid operator in version comparison."""
with pytest.raises(ValueError, match='Comparison operator << '):
version_check('numpy << 36')

0 comments on commit 8d6a48b

Please sign in to comment.