Skip to content
This repository has been archived by the owner on Jun 16, 2018. It is now read-only.

Commit

Permalink
Minor fixes to plot_coord function and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Dec 2, 2016
1 parent da0b8d0 commit 191ef1b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
11 changes: 5 additions & 6 deletions wcsaxes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,11 @@ def plot_coord(self, *args, **kwargs):
matplotlib.Axes.plot : This method is called from this function with all arguments passed to it.
"""
args = list(args)
coord_instances = (SkyCoord, BaseCoordinateFrame)
if isinstance(args[0], coord_instances):

if isinstance(args[0], (SkyCoord, BaseCoordinateFrame)):

# Extract the frame from the first argument.
frame0 = args.pop(0)
frame0 = args[0]
if isinstance(frame0, SkyCoord):
frame0 = frame0.frame

Expand All @@ -174,9 +173,9 @@ def plot_coord(self, *args, **kwargs):
" as it is automatically determined by the input coordinate frame.")

transform = self.get_transform(frame0)
kwargs.update({'transform':transform})
kwargs.update({'transform': transform})

args = plot_data + args
args = tuple(plot_data) + args[1:]

super(WCSAxes, self).plot(*args, **kwargs)

Expand Down
8 changes: 4 additions & 4 deletions wcsaxes/tests/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_cube_slice_image_lonlat(self):
return fig

@remote_data
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='plot_coord.png', tolerance=1.5)
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=1.5)
def test_plot_coord(self):
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.15, 0.15, 0.8, 0.8],
Expand All @@ -205,13 +205,13 @@ def test_plot_coord(self):
ax.set_xlim(-0.5, 720.5)
ax.set_ylim(-0.5, 720.5)

c = SkyCoord(266*u.deg, -29*u.deg)
c = SkyCoord(266 * u.deg, -29 * u.deg)
ax.plot_coord(c, 'o')

return fig

@remote_data
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='plot_line.png', tolerance=1.5)
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=1.5)
def test_plot_line(self):
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.15, 0.15, 0.8, 0.8],
Expand All @@ -220,7 +220,7 @@ def test_plot_line(self):
ax.set_xlim(-0.5, 720.5)
ax.set_ylim(-0.5, 720.5)

c = SkyCoord([266, 266.8]*u.deg, [-29, -28.9]*u.deg)
c = SkyCoord([266, 266.8] * u.deg, [-29, -28.9] * u.deg)
ax.plot_coord(c)

return fig
Expand Down

0 comments on commit 191ef1b

Please sign in to comment.