Skip to content

Commit

Permalink
Merge pull request #96 from keflavich/full_cube_reprojection
Browse files Browse the repository at this point in the history
Generalize reprojection in 3 dimensions
  • Loading branch information
astrofrog committed Apr 3, 2016
2 parents 0a06093 + 295192c commit 00342cc
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 36 deletions.
156 changes: 129 additions & 27 deletions reproject/interpolation/core.py
Expand Up @@ -2,7 +2,10 @@
from __future__ import absolute_import, division, print_function

import numpy as np
from astropy.wcs import WCSSUB_CELESTIAL
from astropy import wcs

from distutils.version import StrictVersion
NP_LT_17 = StrictVersion(np.__version__) < StrictVersion('1.7')

from ..wcs_utils import convert_world_coordinates
from ..array_utils import iterate_over_celestial_slices, pad_edge_1
Expand All @@ -19,19 +22,48 @@ def map_coordinates(image, coords, **kwargs):

from scipy.ndimage import map_coordinates as scipy_map_coordinates

ny, nx = image.shape

image = pad_edge_1(image)

values = scipy_map_coordinates(image, coords + 1, **kwargs)

reset = ((coords[0] < -0.5) | (coords[0] > ny - 0.5) |
(coords[1] < -0.5) | (coords[1] > nx - 0.5))
reset = np.zeros(coords.shape[1], dtype=bool)

for i in range(coords.shape[0]):
reset |= (coords[i] < -0.5)
reset |= (coords[i] > image.shape[i] - 0.5)

values[reset] = kwargs.get('cval', 0.)

return values


def _get_input_pixels_full(wcs_in, wcs_out, shape_out):
"""
Get the pixel coordinates of the pixels in an array of shape ``shape_out``
in the input WCS.
"""
if NP_LT_17:
raise NotImplementedError("The grid determination requires numpy >=1.7")

# Generate pixel coordinates of output image
p_out_ax = []
for size in shape_out:
p_out_ax.append(np.arange(size))

p_out = np.meshgrid(*p_out_ax, indexing='ij')

# Convert output pixel coordinates to pixel coordinates in original image
# (using pixel centers).
args = tuple(p_out[::-1]) + (0,)
w_out = wcs_out.wcs_pix2world(*args)

args = tuple(w_out) + (0,)
p_in = wcs_in.wcs_world2pix(*args)

# return x,y,z for consistency with _get_input_pixels_celestial
return p_in


def _get_input_pixels_celestial(wcs_in, wcs_out, shape_out):
"""
Get the pixel coordinates of the pixels in an array of shape ``shape_out``
Expand All @@ -42,21 +74,37 @@ def _get_input_pixels_celestial(wcs_in, wcs_out, shape_out):
# necessarily the case. Also assuming something about the order of the
# arguments.

if len(shape_out) > 3:
raise ValueError(">3 dimensional cube")

# Generate pixel coordinates of output image
xp_out_ax = np.arange(shape_out[1])
yp_out_ax = np.arange(shape_out[0])
xp_out, yp_out = np.meshgrid(xp_out_ax, yp_out_ax)
# reversed because numpy and wcs index in opposite directions
# z,y,x if ::1
# x,y,z if ::-1
pixels_out = np.indices(shape_out)[::-1].astype('float')

# Convert output pixel coordinates to pixel coordinates in original image
# (using pixel centers).
xw_out, yw_out = wcs_out.wcs_pix2world(xp_out, yp_out, 0)
# x,y,z
args = tuple(pixels_out) + (0,)
out_world = wcs_out.wcs_pix2world(*args)

xw_in, yw_in = convert_world_coordinates(xw_out, yw_out, wcs_out, wcs_in)
args = tuple(out_world[:2]) + (wcs_out.celestial, wcs_in.celestial)
xw_in, yw_in = convert_world_coordinates(*args)

xp_in, yp_in = wcs_in.wcs_world2pix(xw_in, yw_in, 0)
xp_in, yp_in = wcs_in.celestial.wcs_world2pix(xw_in, yw_in, 0)

return xp_in, yp_in
input_pixels = [xp_in, yp_in,]
if pixels_out[0].ndim == 3:
zw_out = out_world[2]
zp_in = wcs_in.sub([wcs.WCSSUB_SPECTRAL]).wcs_world2pix(zw_out.ravel(),
0)[0].reshape(zw_out.shape)
input_pixels += [zp_in]

# x,y,z
retval = np.array(input_pixels)
assert retval.shape == (len(shape_out),)+tuple(shape_out)
return retval

def _reproject_celestial(array, wcs_in, wcs_out, shape_out, order=1):
"""
Expand All @@ -68,13 +116,72 @@ def _reproject_celestial(array, wcs_in, wcs_out, shape_out, order=1):

# For now, assume axes are independent in this routine

# Check that WCSs are equivalent
if ((wcs_in.naxis != wcs_out.naxis or
(list(wcs_in.wcs.axis_types) != list(wcs_out.wcs.axis_types)) or
(list(wcs_in.wcs.cunit) != list(wcs_out.wcs.cunit)))):
raise ValueError("The input and output WCS are not equivalent")

# We create an output array with the required shape, then create an array
# that is in order of [rest, lat, lon] where rest is the flattened
# remainder of the array. We then operate on the view, but this will change
# the original array with the correct shape.

array_new = np.zeros(shape_out)

if len(shape_out)>=3 and (shape_out[0] != array.shape[0]):

if ((list(wcs_in.sub([wcs.WCSSUB_SPECTRAL]).wcs.ctype) !=
list(wcs_out.sub([wcs.WCSSUB_SPECTRAL]).wcs.ctype))):
raise ValueError("The input and output spectral coordinate types "
"are not equivalent.")


# do full 3D interpolation
xp_in, yp_in, zp_in = _get_input_pixels_celestial(wcs_in, wcs_out,
shape_out)
coordinates = np.array([zp_in.ravel(), yp_in.ravel(), xp_in.ravel()])
bad_data = ~np.isfinite(array)
array[bad_data] = 0
array_new = map_coordinates(array, coordinates, order=order,
cval=np.nan,
mode='constant').reshape(shape_out)

else:
xp_in = yp_in = None

# Loop over slices and interpolate
for slice_in, slice_out in iterate_over_celestial_slices(array,
array_new,
wcs_in):

if xp_in is None: # Get position of output pixel centers in input image
xp_in, yp_in = _get_input_pixels_celestial(wcs_in.celestial,
wcs_out.celestial,
slice_out.shape)
coordinates = np.array([yp_in.ravel(), xp_in.ravel()])

slice_out[:,:] = map_coordinates(slice_in,
coordinates,
order=order, cval=np.nan,
mode='constant'
).reshape(slice_out.shape)

return array_new, (~np.isnan(array_new)).astype(float)


def _reproject_full(array, wcs_in, wcs_out, shape_out, order=1):
"""
Reproject n-dimensional data to a new projection using interpolation.
"""

# Make sure image is floating point
array = np.asarray(array, dtype=float)

# Check that WCSs are equivalent
if wcs_in.naxis == wcs_out.naxis and np.any(wcs_in.wcs.axis_types != wcs_out.wcs.axis_types):
raise ValueError("The input and output WCS are not equivalent")

# Extract celestial part of WCS in lon/lat order
wcs_in_celestial = wcs_in.sub([WCSSUB_CELESTIAL])
wcs_out_celestial = wcs_out.sub([WCSSUB_CELESTIAL])

# We create an output array with the required shape, then create an array
# that is in order of [rest, lat, lon] where rest is the flattened
Expand All @@ -83,19 +190,14 @@ def _reproject_celestial(array, wcs_in, wcs_out, shape_out, order=1):

array_new = np.zeros(shape_out)

coordinates = None

# Loop over slices and interpolate
for slice_in, slice_out in iterate_over_celestial_slices(array, array_new, wcs_in):
xp_in, yp_in, zp_in = _get_input_pixels_full(wcs_in, wcs_out, shape_out)

if coordinates is None: # Get position of output pixel centers in input image
xp_in, yp_in = _get_input_pixels_celestial(wcs_in_celestial, wcs_out_celestial, slice_out.shape)
coordinates = np.array([yp_in.ravel(), xp_in.ravel()])
coordinates = np.array([p.ravel() for p in (zp_in, yp_in, xp_in)])

slice_out[:, :] = map_coordinates(slice_in,
coordinates,
order=order, cval=np.nan,
mode='constant'
).reshape(slice_out.shape)
array_new = map_coordinates(array,
coordinates,
order=order, cval=np.nan,
mode='constant'
).reshape(shape_out)

return array_new, (~np.isnan(array_new)).astype(float)
26 changes: 20 additions & 6 deletions reproject/interpolation/high_level.py
Expand Up @@ -5,7 +5,7 @@
from astropy.extern import six

from ..utils import parse_input_data, parse_output_projection
from .core import _reproject_celestial
from .core import _reproject_celestial, _reproject_full

__all__ = ['reproject_interp']

Expand All @@ -16,7 +16,8 @@
ORDER['bicubic'] = 3


def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=None, order='bilinear'):
def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=0,
order='bilinear', independent_celestial_slices=False):
"""
Reproject data to a new projection using interpolation (this is typically
the fastest way to reproject an image).
Expand Down Expand Up @@ -54,7 +55,21 @@ def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=None,
* 'bicubic'
or an integer. A value of ``0`` indicates nearest neighbor
interpolation.
interpolation.
independent_celestial_slices : bool, optional
This can be set to ``True`` for n-dimensional input in the following case
(all conditions have to be fulfilled):
* The number of pixels in each non-celestial dimension is the same
between the input and target header.
* The WCS coordinates along the non-celestial dimensions are the
same between the input and target WCS.
* The celestial WCS component is independent from other WCS
coordinates.
In this special case, we can make things a little faster by
reprojecting each celestial slice independently using the same
transformation.
Returns
-------
Expand All @@ -72,8 +87,7 @@ def reproject_interp(input_data, output_projection, shape_out=None, hdu_in=None,
if isinstance(order, six.string_types):
order = ORDER[order]

# For now only celestial reprojection is supported
if wcs_in.has_celestial:
if (wcs_in.celestial and wcs_in.naxis == 2) or independent_celestial_slices:
return _reproject_celestial(array_in, wcs_in, wcs_out, shape_out=shape_out, order=order)
else:
raise NotImplementedError("Currently only data with a WCS that includes a celestial component can be reprojected")
return _reproject_full(array_in, wcs_in, wcs_out, shape_out=shape_out, order=order)

0 comments on commit 00342cc

Please sign in to comment.