From 191ef1bbb7b3d85c6e76412a2175744068224138 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 2 Dec 2016 12:44:48 +0000 Subject: [PATCH] Minor fixes to plot_coord function and tests --- wcsaxes/core.py | 11 +++++------ wcsaxes/tests/test_images.py | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/wcsaxes/core.py b/wcsaxes/core.py index 4d3c07a..a39de2e 100644 --- a/wcsaxes/core.py +++ b/wcsaxes/core.py @@ -150,12 +150,11 @@ def plot_coord(self, *args, **kwargs): matplotlib.Axes.plot : This method is called from this function with all arguments passed to it. """ - args = list(args) - coord_instances = (SkyCoord, BaseCoordinateFrame) - if isinstance(args[0], coord_instances): + + if isinstance(args[0], (SkyCoord, BaseCoordinateFrame)): # Extract the frame from the first argument. - frame0 = args.pop(0) + frame0 = args[0] if isinstance(frame0, SkyCoord): frame0 = frame0.frame @@ -174,9 +173,9 @@ def plot_coord(self, *args, **kwargs): " as it is automatically determined by the input coordinate frame.") transform = self.get_transform(frame0) - kwargs.update({'transform':transform}) + kwargs.update({'transform': transform}) - args = plot_data + args + args = tuple(plot_data) + args[1:] super(WCSAxes, self).plot(*args, **kwargs) diff --git a/wcsaxes/tests/test_images.py b/wcsaxes/tests/test_images.py index 122affa..fa73a73 100644 --- a/wcsaxes/tests/test_images.py +++ b/wcsaxes/tests/test_images.py @@ -196,7 +196,7 @@ def test_cube_slice_image_lonlat(self): return fig @remote_data - @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='plot_coord.png', tolerance=1.5) + @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=1.5) def test_plot_coord(self): fig = plt.figure(figsize=(6, 6)) ax = fig.add_axes([0.15, 0.15, 0.8, 0.8], @@ -205,13 +205,13 @@ def test_plot_coord(self): ax.set_xlim(-0.5, 720.5) ax.set_ylim(-0.5, 720.5) - c = SkyCoord(266*u.deg, -29*u.deg) + c = SkyCoord(266 * u.deg, -29 * u.deg) ax.plot_coord(c, 'o') return fig @remote_data - @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='plot_line.png', tolerance=1.5) + @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=1.5) def test_plot_line(self): fig = plt.figure(figsize=(6, 6)) ax = fig.add_axes([0.15, 0.15, 0.8, 0.8], @@ -220,7 +220,7 @@ def test_plot_line(self): ax.set_xlim(-0.5, 720.5) ax.set_ylim(-0.5, 720.5) - c = SkyCoord([266, 266.8]*u.deg, [-29, -28.9]*u.deg) + c = SkyCoord([266, 266.8] * u.deg, [-29, -28.9] * u.deg) ax.plot_coord(c) return fig