diff --git a/src/metpy/plots/declarative.py b/src/metpy/plots/declarative.py index 4c071c6ea2..62f0d66489 100644 --- a/src/metpy/plots/declarative.py +++ b/src/metpy/plots/declarative.py @@ -1305,12 +1305,14 @@ class PlotObs(HasTraits): * time * fields * locations (optional) - * time_range (optional) + * time_window (optional) * formats (optional) * colors (optional) + * plot_units (optional) * vector_field (optional) * vector_field_color (optional) * vector_field_length (optional) + * vector_plot_units (optional) * reduce_points (optional) """ @@ -1393,6 +1395,20 @@ class PlotObs(HasTraits): reduce_points = Float(default_value=0) reduce_points.__doc__ = """Float to reduce number of points plotted. (optional)""" + plot_units = List(default_value=[None], allow_none=True) + plot_units.__doc__ = """A list of the desired units to plot the fields in. + + Setting this attribute will convert the units of the field variable to the given units for + plotting using the MetPy Units module, provided that units are attached to the DataFrame. + """ + + vector_plot_units = Unicode(default_value=None, allow_none=True) + vector_plot_units.__doc__ = """The desired units to plot the vector field in. + + Setting this attribute will convert the units of the field variable to the given units for + plotting using the MetPy Units module, provided that units are attached to the DataFrame. + """ + def clear(self): """Clear the plot. @@ -1519,34 +1535,57 @@ def _build(self): transform=ccrs.PlateCarree(), fontsize=10) for i, ob_type in enumerate(self.fields): + field_kwargs = {} if len(self.locations) > 1: location = self.locations[i] else: location = self.locations[0] if len(self.colors) > 1: - color = self.colors[i] + field_kwargs['color'] = self.colors[i] else: - color = self.colors[0] + field_kwargs['color'] = self.colors[0] if len(self.formats) > 1: - formats = self.formats[i] + field_kwargs['formatter'] = self.formats[i] + else: + field_kwargs['formatter'] = self.formats[0] + if len(self.plot_units) > 1: + field_kwargs['plot_units'] = self.plot_units[i] else: - formats = self.formats[0] - if formats is not None: - mapper = getattr(wx_symbols, str(formats), None) + field_kwargs['plot_units'] = self.plot_units[0] + if hasattr(self.data, 'units') and (field_kwargs['plot_units'] is not None): + parameter = data[ob_type][subset].values * units(self.data.units[ob_type]) + else: + parameter = data[ob_type][subset] + if field_kwargs['formatter'] is not None: + mapper = getattr(wx_symbols, str(field_kwargs['formatter']), None) if mapper is not None: - self.handle.plot_symbol(location, data[ob_type][subset], - mapper, color=color) + field_kwargs.pop('formatter') + self.handle.plot_symbol(location, parameter, + mapper, **field_kwargs) else: - if formats == 'text': - self.handle.plot_text(location, data[ob_type][subset], color=color) + if self.formats[i] == 'text': + self.handle.plot_text(location, data[ob_type][subset], + color=field_kwargs['color']) else: self.handle.plot_parameter(location, data[ob_type][subset], - color=color, formatter=self.formats[i]) + **field_kwargs) else: - self.handle.plot_parameter(location, data[ob_type][subset], color=color) + field_kwargs.pop('formatter') + self.handle.plot_parameter(location, parameter, **field_kwargs) + if self.vector_field[0] is not None: - kwargs = {'color': self.vector_field_color} + vector_kwargs = {} + vector_kwargs['color'] = self.vector_field_color + vector_kwargs['plot_units'] = self.vector_plot_units + if hasattr(self.data, 'units') and (vector_kwargs['plot_units'] is not None): + u = (data[self.vector_field[0]][subset].values + * units(self.data.units[self.vector_field[0]])) + v = (data[self.vector_field[1]][subset].values + * units(self.data.units[self.vector_field[1]])) + else: + vector_kwargs.pop('plot_units') + u = data[self.vector_field[0]][subset] + v = data[self.vector_field[1]][subset] if self.vector_field_length is not None: - kwargs['length'] = self.vector_field_length - self.handle.plot_barb(data[self.vector_field[0]][subset], - data[self.vector_field[1]][subset], **kwargs) + vector_kwargs['length'] = self.vector_field_length + self.handle.plot_barb(u, v, **vector_kwargs) diff --git a/src/metpy/plots/station_plot.py b/src/metpy/plots/station_plot.py index 54bd7cf7bb..17f5807e26 100644 --- a/src/metpy/plots/station_plot.py +++ b/src/metpy/plots/station_plot.py @@ -185,6 +185,8 @@ def plot_parameter(self, location, parameter, formatter='.0f', **kwargs): How to format the data as a string for plotting. If a string, it should be compatible with the :func:`format` builtin. If a callable, this should take a value and return a string. Defaults to '0.f'. + plot_units: `pint.unit` + Units to plot in (performing conversion if necessary). Defaults to given units. kwargs Additional keyword arguments to use for matplotlib's plotting functions. @@ -194,6 +196,9 @@ def plot_parameter(self, location, parameter, formatter='.0f', **kwargs): plot_barb, plot_symbol, plot_text """ + # If plot_units specified, convert the data to those units + plotting_units = kwargs.pop('plot_units', None) + parameter = self._scalar_plotting_units(parameter, plotting_units) if hasattr(parameter, 'units'): parameter = parameter.magnitude text = self._to_string_list(parameter, formatter) @@ -266,7 +271,7 @@ def plot_barb(self, u, v, **kwargs): # If plot_units specified, convert the data to those units plotting_units = kwargs.pop('plot_units', None) - u, v = self._plotting_units(u, v, plotting_units) + u, v = self._vector_plotting_units(u, v, plotting_units) # Empirically determined pivot = 0.51 * np.sqrt(self.fontsize) @@ -309,7 +314,7 @@ def plot_arrow(self, u, v, **kwargs): # If plot_units specified, convert the data to those units plotting_units = kwargs.pop('plot_units', None) - u, v = self._plotting_units(u, v, plotting_units) + u, v = self._vector_plotting_units(u, v, plotting_units) defaults = {'pivot': 'tail', 'scale': 20, 'scale_units': 'inches', 'width': 0.002} defaults.update(kwargs) @@ -321,7 +326,7 @@ def plot_arrow(self, u, v, **kwargs): self.arrows = self.ax.quiver(self.x, self.y, u, v, **defaults) @staticmethod - def _plotting_units(u, v, plotting_units): + def _vector_plotting_units(u, v, plotting_units): """Handle conversion to plotting units for barbs and arrows.""" if plotting_units: if hasattr(u, 'units') and hasattr(v, 'units'): @@ -336,6 +341,17 @@ def _plotting_units(u, v, plotting_units): v = np.array(v) return u, v + @staticmethod + def _scalar_plotting_units(scalar_value, plotting_units): + """Handle conversion to plotting units for barbs and arrows.""" + if plotting_units: + if hasattr(scalar_value, 'units'): + scalar_value = scalar_value.to(plotting_units) + else: + raise ValueError('To convert to plotting units, units must be attached to ' + 'scalar value being converted.') + return scalar_value + def _make_kwargs(self, kwargs): """Assemble kwargs as necessary. diff --git a/tests/plots/baseline/test_declarative_multiple_sfc_obs_change_units.png b/tests/plots/baseline/test_declarative_multiple_sfc_obs_change_units.png new file mode 100644 index 0000000000..4b3a88d9f4 Binary files /dev/null and b/tests/plots/baseline/test_declarative_multiple_sfc_obs_change_units.png differ diff --git a/tests/plots/baseline/test_declarative_sfc_obs_change_units.png b/tests/plots/baseline/test_declarative_sfc_obs_change_units.png new file mode 100644 index 0000000000..2f29559347 Binary files /dev/null and b/tests/plots/baseline/test_declarative_sfc_obs_change_units.png differ diff --git a/tests/plots/baseline/test_declarative_upa_obs_convert_barb_units.png b/tests/plots/baseline/test_declarative_upa_obs_convert_barb_units.png new file mode 100644 index 0000000000..2d8893e78c Binary files /dev/null and b/tests/plots/baseline/test_declarative_upa_obs_convert_barb_units.png differ diff --git a/tests/plots/baseline/test_stationplot_unit_conversion.png b/tests/plots/baseline/test_stationplot_unit_conversion.png new file mode 100644 index 0000000000..3a973bb809 Binary files /dev/null and b/tests/plots/baseline/test_stationplot_unit_conversion.png differ diff --git a/tests/plots/test_declarative.py b/tests/plots/test_declarative.py index ef5343df53..8142eeea35 100644 --- a/tests/plots/test_declarative.py +++ b/tests/plots/test_declarative.py @@ -18,6 +18,7 @@ from metpy.cbook import get_test_data from metpy.io import GiniFile +from metpy.io.metar import parse_metar_file from metpy.plots import (BarbPlot, ContourPlot, FilledContourPlot, ImagePlot, MapPanel, PanelContainer, PlotObs) # Fixtures to make sure we have the right backend @@ -736,6 +737,48 @@ def test_declarative_upa_obs(): return pc.figure +@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.08) +def test_declarative_upa_obs_convert_barb_units(): + """Test making a full upperair observation plot.""" + data = pd.read_csv(get_test_data('UPA_obs.csv', as_file_obj=False)) + data.units = '' + data.units = {'pressure': 'hPa', 'height': 'meters', 'temperature': 'degC', + 'dewpoint': 'degC', 'direction': 'degrees', 'speed': 'knots', + 'station': None, 'time': None, 'u_wind': 'knots', 'v_wind': 'knots', + 'latitude': 'degrees', 'longitude': 'degrees'} + + obs = PlotObs() + obs.data = data + obs.time = datetime(1993, 3, 14, 0) + obs.level = 500 * units.hPa + obs.fields = ['temperature', 'dewpoint', 'height'] + obs.locations = ['NW', 'SW', 'NE'] + obs.formats = [None, None, lambda v: format(v, '.0f')[:3]] + obs.vector_field = ('u_wind', 'v_wind') + obs.vector_field_length = 7 + obs.vector_plot_units = 'm/s' + obs.reduce_points = 0 + + # Panel for plot with Map features + panel = MapPanel() + panel.layout = (1, 1, 1) + panel.area = (-124, -72, 20, 53) + panel.projection = 'lcc' + panel.layers = ['coastline', 'borders', 'states', 'land'] + panel.plots = [obs] + + # Bringing it all together + pc = PanelContainer() + pc.size = (15, 10) + pc.panels = [panel] + + pc.draw() + + obs.level = 300 * units.hPa + + return pc.figure + + def test_attribute_error_time(): """Make sure we get a useful error when the time variable is not found.""" data = pd.read_csv(get_test_data('SFC_obs.csv', as_file_obj=False), @@ -798,6 +841,76 @@ def test_attribute_error_station(): pc.draw() +@pytest.mark.mpl_image_compare(remove_text=True, + tolerance={'2.1': 0.407}.get(MPL_VERSION, 0.022)) +def test_declarative_sfc_obs_change_units(): + """Test making a surface observation plot.""" + data = parse_metar_file(get_test_data('metar_20190701_1200.txt', as_file_obj=False), + year=2019, month=7) + + obs = PlotObs() + obs.data = data + obs.time = datetime(2019, 7, 1, 12) + obs.time_window = timedelta(minutes=15) + obs.level = None + obs.fields = ['air_temperature'] + obs.color = ['black'] + obs.plot_units = ['degF'] + + # Panel for plot with Map features + panel = MapPanel() + panel.layout = (1, 1, 1) + panel.projection = ccrs.PlateCarree() + panel.area = 'in' + panel.layers = ['states'] + panel.plots = [obs] + + # Bringing it all together + pc = PanelContainer() + pc.size = (10, 10) + pc.panels = [panel] + + pc.draw() + + return pc.figure + + +@pytest.mark.mpl_image_compare(remove_text=True, + tolerance={'2.1': 0.09}.get(MPL_VERSION, 0.022)) +def test_declarative_multiple_sfc_obs_change_units(): + """Test making a surface observation plot.""" + data = parse_metar_file(get_test_data('metar_20190701_1200.txt', as_file_obj=False), + year=2019, month=7) + + obs = PlotObs() + obs.data = data + obs.time = datetime(2019, 7, 1, 12) + obs.time_window = timedelta(minutes=15) + obs.level = None + obs.fields = ['air_temperature', 'dew_point_temperature', 'air_pressure_at_sea_level'] + obs.locations = ['NW', 'W', 'NE'] + obs.colors = ['red', 'green', 'black'] + obs.reduce_points = 0.75 + obs.plot_units = ['degF', 'degF', None] + + # Panel for plot with Map features + panel = MapPanel() + panel.layout = (1, 1, 1) + panel.projection = ccrs.PlateCarree() + panel.area = 'in' + panel.layers = ['states'] + panel.plots = [obs] + + # Bringing it all together + pc = PanelContainer() + pc.size = (12, 12) + pc.panels = [panel] + + pc.draw() + + return pc.figure + + def test_save(): """Test that our saving function works.""" pc = PanelContainer() diff --git a/tests/plots/test_station_plot.py b/tests/plots/test_station_plot.py index 0a49c98b97..58668f3f8c 100644 --- a/tests/plots/test_station_plot.py +++ b/tests/plots/test_station_plot.py @@ -434,3 +434,38 @@ def test_symbol_pandas_timeseries(): ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%-d')) return fig + + +@pytest.mark.mpl_image_compare(tolerance=2.444, savefig_kwargs={'dpi': 300}, remove_text=True) +def test_stationplot_unit_conversion(): + """Test the StationPlot API.""" + fig = plt.figure(figsize=(9, 9)) + + # testing data + x = np.array([1, 5]) + y = np.array([2, 4]) + + # Make the plot + sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=16) + sp.plot_barb([20, 0], [0, -50]) + sp.plot_text('E', ['KOKC', 'ICT'], color='blue') + sp.plot_parameter('NW', [10.5, 15] * units.degC, plot_units='degF', color='red') + sp.plot_symbol('S', [5, 7], high_clouds, color='green') + + sp.ax.set_xlim(0, 6) + sp.ax.set_ylim(0, 6) + + return fig + + +def test_scalar_unit_conversion_exception(): + """Test that errors are raise if unit conversion is requested on un-united data.""" + T = 50 + x_pos = np.array([0]) + y_pos = np.array([0]) + + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + stnplot = StationPlot(ax, x_pos, y_pos) + with pytest.raises(ValueError): + stnplot.plot_parameter('C', T, plot_units='degC')