Skip to content

Commit

Permalink
Merge pull request #1329 from kgoebber/convert_obs_units
Browse files Browse the repository at this point in the history
Convert obs units
  • Loading branch information
dopplershift committed Aug 8, 2020
2 parents 3569494 + e03c05f commit 345ddd0
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 20 deletions.
73 changes: 56 additions & 17 deletions src/metpy/plots/declarative.py
Expand Up @@ -1305,12 +1305,14 @@ class PlotObs(HasTraits):
* time
* fields
* locations (optional)
* time_range (optional)
* time_window (optional)
* formats (optional)
* colors (optional)
* plot_units (optional)
* vector_field (optional)
* vector_field_color (optional)
* vector_field_length (optional)
* vector_plot_units (optional)
* reduce_points (optional)
"""

Expand Down Expand Up @@ -1393,6 +1395,20 @@ class PlotObs(HasTraits):
reduce_points = Float(default_value=0)
reduce_points.__doc__ = """Float to reduce number of points plotted. (optional)"""

plot_units = List(default_value=[None], allow_none=True)
plot_units.__doc__ = """A list of the desired units to plot the fields in.
Setting this attribute will convert the units of the field variable to the given units for
plotting using the MetPy Units module, provided that units are attached to the DataFrame.
"""

vector_plot_units = Unicode(default_value=None, allow_none=True)
vector_plot_units.__doc__ = """The desired units to plot the vector field in.
Setting this attribute will convert the units of the field variable to the given units for
plotting using the MetPy Units module, provided that units are attached to the DataFrame.
"""

def clear(self):
"""Clear the plot.
Expand Down Expand Up @@ -1519,34 +1535,57 @@ def _build(self):
transform=ccrs.PlateCarree(), fontsize=10)

for i, ob_type in enumerate(self.fields):
field_kwargs = {}
if len(self.locations) > 1:
location = self.locations[i]
else:
location = self.locations[0]
if len(self.colors) > 1:
color = self.colors[i]
field_kwargs['color'] = self.colors[i]
else:
color = self.colors[0]
field_kwargs['color'] = self.colors[0]
if len(self.formats) > 1:
formats = self.formats[i]
field_kwargs['formatter'] = self.formats[i]
else:
field_kwargs['formatter'] = self.formats[0]
if len(self.plot_units) > 1:
field_kwargs['plot_units'] = self.plot_units[i]
else:
formats = self.formats[0]
if formats is not None:
mapper = getattr(wx_symbols, str(formats), None)
field_kwargs['plot_units'] = self.plot_units[0]
if hasattr(self.data, 'units') and (field_kwargs['plot_units'] is not None):
parameter = data[ob_type][subset].values * units(self.data.units[ob_type])
else:
parameter = data[ob_type][subset]
if field_kwargs['formatter'] is not None:
mapper = getattr(wx_symbols, str(field_kwargs['formatter']), None)
if mapper is not None:
self.handle.plot_symbol(location, data[ob_type][subset],
mapper, color=color)
field_kwargs.pop('formatter')
self.handle.plot_symbol(location, parameter,
mapper, **field_kwargs)
else:
if formats == 'text':
self.handle.plot_text(location, data[ob_type][subset], color=color)
if self.formats[i] == 'text':
self.handle.plot_text(location, data[ob_type][subset],
color=field_kwargs['color'])
else:
self.handle.plot_parameter(location, data[ob_type][subset],
color=color, formatter=self.formats[i])
**field_kwargs)
else:
self.handle.plot_parameter(location, data[ob_type][subset], color=color)
field_kwargs.pop('formatter')
self.handle.plot_parameter(location, parameter, **field_kwargs)

if self.vector_field[0] is not None:
kwargs = {'color': self.vector_field_color}
vector_kwargs = {}
vector_kwargs['color'] = self.vector_field_color
vector_kwargs['plot_units'] = self.vector_plot_units
if hasattr(self.data, 'units') and (vector_kwargs['plot_units'] is not None):
u = (data[self.vector_field[0]][subset].values
* units(self.data.units[self.vector_field[0]]))
v = (data[self.vector_field[1]][subset].values
* units(self.data.units[self.vector_field[1]]))
else:
vector_kwargs.pop('plot_units')
u = data[self.vector_field[0]][subset]
v = data[self.vector_field[1]][subset]
if self.vector_field_length is not None:
kwargs['length'] = self.vector_field_length
self.handle.plot_barb(data[self.vector_field[0]][subset],
data[self.vector_field[1]][subset], **kwargs)
vector_kwargs['length'] = self.vector_field_length
self.handle.plot_barb(u, v, **vector_kwargs)
22 changes: 19 additions & 3 deletions src/metpy/plots/station_plot.py
Expand Up @@ -185,6 +185,8 @@ def plot_parameter(self, location, parameter, formatter='.0f', **kwargs):
How to format the data as a string for plotting. If a string, it should be
compatible with the :func:`format` builtin. If a callable, this should take a
value and return a string. Defaults to '0.f'.
plot_units: `pint.unit`
Units to plot in (performing conversion if necessary). Defaults to given units.
kwargs
Additional keyword arguments to use for matplotlib's plotting functions.
Expand All @@ -194,6 +196,9 @@ def plot_parameter(self, location, parameter, formatter='.0f', **kwargs):
plot_barb, plot_symbol, plot_text
"""
# If plot_units specified, convert the data to those units
plotting_units = kwargs.pop('plot_units', None)
parameter = self._scalar_plotting_units(parameter, plotting_units)
if hasattr(parameter, 'units'):
parameter = parameter.magnitude
text = self._to_string_list(parameter, formatter)
Expand Down Expand Up @@ -266,7 +271,7 @@ def plot_barb(self, u, v, **kwargs):

# If plot_units specified, convert the data to those units
plotting_units = kwargs.pop('plot_units', None)
u, v = self._plotting_units(u, v, plotting_units)
u, v = self._vector_plotting_units(u, v, plotting_units)

# Empirically determined
pivot = 0.51 * np.sqrt(self.fontsize)
Expand Down Expand Up @@ -309,7 +314,7 @@ def plot_arrow(self, u, v, **kwargs):

# If plot_units specified, convert the data to those units
plotting_units = kwargs.pop('plot_units', None)
u, v = self._plotting_units(u, v, plotting_units)
u, v = self._vector_plotting_units(u, v, plotting_units)

defaults = {'pivot': 'tail', 'scale': 20, 'scale_units': 'inches', 'width': 0.002}
defaults.update(kwargs)
Expand All @@ -321,7 +326,7 @@ def plot_arrow(self, u, v, **kwargs):
self.arrows = self.ax.quiver(self.x, self.y, u, v, **defaults)

@staticmethod
def _plotting_units(u, v, plotting_units):
def _vector_plotting_units(u, v, plotting_units):
"""Handle conversion to plotting units for barbs and arrows."""
if plotting_units:
if hasattr(u, 'units') and hasattr(v, 'units'):
Expand All @@ -336,6 +341,17 @@ def _plotting_units(u, v, plotting_units):
v = np.array(v)
return u, v

@staticmethod
def _scalar_plotting_units(scalar_value, plotting_units):
"""Handle conversion to plotting units for barbs and arrows."""
if plotting_units:
if hasattr(scalar_value, 'units'):
scalar_value = scalar_value.to(plotting_units)
else:
raise ValueError('To convert to plotting units, units must be attached to '
'scalar value being converted.')
return scalar_value

def _make_kwargs(self, kwargs):
"""Assemble kwargs as necessary.
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
113 changes: 113 additions & 0 deletions tests/plots/test_declarative.py
Expand Up @@ -18,6 +18,7 @@

from metpy.cbook import get_test_data
from metpy.io import GiniFile
from metpy.io.metar import parse_metar_file
from metpy.plots import (BarbPlot, ContourPlot, FilledContourPlot, ImagePlot, MapPanel,
PanelContainer, PlotObs)
# Fixtures to make sure we have the right backend
Expand Down Expand Up @@ -736,6 +737,48 @@ def test_declarative_upa_obs():
return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True, tolerance=0.08)
def test_declarative_upa_obs_convert_barb_units():
"""Test making a full upperair observation plot."""
data = pd.read_csv(get_test_data('UPA_obs.csv', as_file_obj=False))
data.units = ''
data.units = {'pressure': 'hPa', 'height': 'meters', 'temperature': 'degC',
'dewpoint': 'degC', 'direction': 'degrees', 'speed': 'knots',
'station': None, 'time': None, 'u_wind': 'knots', 'v_wind': 'knots',
'latitude': 'degrees', 'longitude': 'degrees'}

obs = PlotObs()
obs.data = data
obs.time = datetime(1993, 3, 14, 0)
obs.level = 500 * units.hPa
obs.fields = ['temperature', 'dewpoint', 'height']
obs.locations = ['NW', 'SW', 'NE']
obs.formats = [None, None, lambda v: format(v, '.0f')[:3]]
obs.vector_field = ('u_wind', 'v_wind')
obs.vector_field_length = 7
obs.vector_plot_units = 'm/s'
obs.reduce_points = 0

# Panel for plot with Map features
panel = MapPanel()
panel.layout = (1, 1, 1)
panel.area = (-124, -72, 20, 53)
panel.projection = 'lcc'
panel.layers = ['coastline', 'borders', 'states', 'land']
panel.plots = [obs]

# Bringing it all together
pc = PanelContainer()
pc.size = (15, 10)
pc.panels = [panel]

pc.draw()

obs.level = 300 * units.hPa

return pc.figure


def test_attribute_error_time():
"""Make sure we get a useful error when the time variable is not found."""
data = pd.read_csv(get_test_data('SFC_obs.csv', as_file_obj=False),
Expand Down Expand Up @@ -798,6 +841,76 @@ def test_attribute_error_station():
pc.draw()


@pytest.mark.mpl_image_compare(remove_text=True,
tolerance={'2.1': 0.407}.get(MPL_VERSION, 0.022))
def test_declarative_sfc_obs_change_units():
"""Test making a surface observation plot."""
data = parse_metar_file(get_test_data('metar_20190701_1200.txt', as_file_obj=False),
year=2019, month=7)

obs = PlotObs()
obs.data = data
obs.time = datetime(2019, 7, 1, 12)
obs.time_window = timedelta(minutes=15)
obs.level = None
obs.fields = ['air_temperature']
obs.color = ['black']
obs.plot_units = ['degF']

# Panel for plot with Map features
panel = MapPanel()
panel.layout = (1, 1, 1)
panel.projection = ccrs.PlateCarree()
panel.area = 'in'
panel.layers = ['states']
panel.plots = [obs]

# Bringing it all together
pc = PanelContainer()
pc.size = (10, 10)
pc.panels = [panel]

pc.draw()

return pc.figure


@pytest.mark.mpl_image_compare(remove_text=True,
tolerance={'2.1': 0.09}.get(MPL_VERSION, 0.022))
def test_declarative_multiple_sfc_obs_change_units():
"""Test making a surface observation plot."""
data = parse_metar_file(get_test_data('metar_20190701_1200.txt', as_file_obj=False),
year=2019, month=7)

obs = PlotObs()
obs.data = data
obs.time = datetime(2019, 7, 1, 12)
obs.time_window = timedelta(minutes=15)
obs.level = None
obs.fields = ['air_temperature', 'dew_point_temperature', 'air_pressure_at_sea_level']
obs.locations = ['NW', 'W', 'NE']
obs.colors = ['red', 'green', 'black']
obs.reduce_points = 0.75
obs.plot_units = ['degF', 'degF', None]

# Panel for plot with Map features
panel = MapPanel()
panel.layout = (1, 1, 1)
panel.projection = ccrs.PlateCarree()
panel.area = 'in'
panel.layers = ['states']
panel.plots = [obs]

# Bringing it all together
pc = PanelContainer()
pc.size = (12, 12)
pc.panels = [panel]

pc.draw()

return pc.figure


def test_save():
"""Test that our saving function works."""
pc = PanelContainer()
Expand Down
35 changes: 35 additions & 0 deletions tests/plots/test_station_plot.py
Expand Up @@ -434,3 +434,38 @@ def test_symbol_pandas_timeseries():
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%-d'))

return fig


@pytest.mark.mpl_image_compare(tolerance=2.444, savefig_kwargs={'dpi': 300}, remove_text=True)
def test_stationplot_unit_conversion():
"""Test the StationPlot API."""
fig = plt.figure(figsize=(9, 9))

# testing data
x = np.array([1, 5])
y = np.array([2, 4])

# Make the plot
sp = StationPlot(fig.add_subplot(1, 1, 1), x, y, fontsize=16)
sp.plot_barb([20, 0], [0, -50])
sp.plot_text('E', ['KOKC', 'ICT'], color='blue')
sp.plot_parameter('NW', [10.5, 15] * units.degC, plot_units='degF', color='red')
sp.plot_symbol('S', [5, 7], high_clouds, color='green')

sp.ax.set_xlim(0, 6)
sp.ax.set_ylim(0, 6)

return fig


def test_scalar_unit_conversion_exception():
"""Test that errors are raise if unit conversion is requested on un-united data."""
T = 50
x_pos = np.array([0])
y_pos = np.array([0])

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
stnplot = StationPlot(ax, x_pos, y_pos)
with pytest.raises(ValueError):
stnplot.plot_parameter('C', T, plot_units='degC')

0 comments on commit 345ddd0

Please sign in to comment.