Skip to content
This repository has been archived by the owner on Jun 16, 2018. It is now read-only.

Commit

Permalink
Merge pull request #209 from Cadair/plot_sc
Browse files Browse the repository at this point in the history
First pass at a SkyCoord plotting method.
  • Loading branch information
astrofrog committed Dec 2, 2016
2 parents 9e27169 + e9f938e commit da0b8d0
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
54 changes: 54 additions & 0 deletions wcsaxes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from matplotlib.axes import Axes, subplot_class_factory
from matplotlib.transforms import Affine2D, Bbox, Transform

from astropy.coordinates import SkyCoord, BaseCoordinateFrame
from astropy.wcs import WCS
from astropy.wcs.utils import wcs_to_celestial_frame
from astropy.extern import six
Expand Down Expand Up @@ -126,6 +127,59 @@ def imshow(self, X, *args, **kwargs):

return super(WCSAxes, self).imshow(X, *args, **kwargs)

def plot_coord(self, *args, **kwargs):
"""
Plot `~astropy.coordinates.SkyCoord` or
`~astropy.coordinates.BaseCoordinateFrame` objects onto the axes.
The first argument to `~wcsaxes.WCSAxes.plot_coord` should be a
coordinate, which will then be converted to the first two parameters to
`matplotlib.Axes.plot`. All other arguments are the same as
`matplotlib.Axes.plot`. If not specified a ``transform`` keyword
argument will be created based on the coordinate.
Parameters
----------
coordinate : `~astropy.coordinates.SkyCoord` or `~astropy.coordinates.BaseCoordinateFrame`
The coordinate object to plot on the axes. This is converted to the
first two arguments to `matplotlib.Axes.plot`.
See Also
--------
matplotlib.Axes.plot : This method is called from this function with all arguments passed to it.
"""
args = list(args)
coord_instances = (SkyCoord, BaseCoordinateFrame)
if isinstance(args[0], coord_instances):

# Extract the frame from the first argument.
frame0 = args.pop(0)
if isinstance(frame0, SkyCoord):
frame0 = frame0.frame

plot_data = []
for coord in self.coords:
if coord.coord_type == 'longitude':
plot_data.append(frame0.data.lon.to(coord.coord_unit).value)
elif coord.coord_type == 'latitude':
plot_data.append(frame0.data.lat.to(coord.coord_unit).value)
else:
raise NotImplementedError("Coordinates cannot be plotted with this "
"method because the WCS does not represent longitude/latitude.")

if 'transform' in kwargs.keys():
raise TypeError("The 'transform' keyword argument is not allowed,"
" as it is automatically determined by the input coordinate frame.")

transform = self.get_transform(frame0)
kwargs.update({'transform':transform})

args = plot_data + args

super(WCSAxes, self).plot(*args, **kwargs)

def reset_wcs(self, wcs=None, slices=None, transform=None, coord_meta=None):
"""
Reset the current Axes, to use a new WCS object.
Expand Down
31 changes: 31 additions & 0 deletions wcsaxes/tests/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astropy.io import fits
from astropy.tests.helper import pytest, remote_data
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord

from ..rc_utils import rc_context

Expand Down Expand Up @@ -194,6 +195,36 @@ def test_cube_slice_image_lonlat(self):

return fig

@remote_data
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='plot_coord.png', tolerance=1.5)
def test_plot_coord(self):
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.15, 0.15, 0.8, 0.8],
projection=WCS(self.twoMASS_k_header),
aspect='equal')
ax.set_xlim(-0.5, 720.5)
ax.set_ylim(-0.5, 720.5)

c = SkyCoord(266*u.deg, -29*u.deg)
ax.plot_coord(c, 'o')

return fig

@remote_data
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='plot_line.png', tolerance=1.5)
def test_plot_line(self):
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.15, 0.15, 0.8, 0.8],
projection=WCS(self.twoMASS_k_header),
aspect='equal')
ax.set_xlim(-0.5, 720.5)
ax.set_ylim(-0.5, 720.5)

c = SkyCoord([266, 266.8]*u.deg, [-29, -28.9]*u.deg)
ax.plot_coord(c)

return fig

@remote_data
@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, filename='changed_axis_units.png', tolerance=1.5)
def test_changed_axis_units(self):
Expand Down
19 changes: 19 additions & 0 deletions wcsaxes/tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import print_function

import os

import numpy as np
import matplotlib.pyplot as plt

import astropy.units as u
from astropy.wcs import WCS
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy.tests.helper import catch_warnings, pytest

from ..core import WCSAxes
Expand Down Expand Up @@ -81,3 +85,18 @@ def test_invalid_frame_overlay():
with pytest.raises(ValueError) as exc:
get_coord_meta('banana')
assert exc.value.args[0] == 'Unknown frame: banana'

def test_plot_coord_transform():

twoMASS_k_header = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), 'data')), '2MASS_k_header')
twoMASS_k_header = fits.Header.fromtextfile(twoMASS_k_header)
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.15, 0.15, 0.8, 0.8],
projection=WCS(twoMASS_k_header),
aspect='equal')
ax.set_xlim(-0.5, 720.5)
ax.set_ylim(-0.5, 720.5)

c = SkyCoord(359.76045223*u.deg, 0.26876217*u.deg)
with pytest.raises(TypeError):
ax.plot_coord(c, 'o', transform=ax.get_transform('galactic'))

0 comments on commit da0b8d0

Please sign in to comment.