Skip to content

Commit

Permalink
Merge 6693db5 into 85ab093
Browse files Browse the repository at this point in the history
  • Loading branch information
hamogu committed Feb 27, 2015
2 parents 85ab093 + 6693db5 commit 4b05c7f
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Expand Up @@ -52,6 +52,9 @@ New Features

- ``astropy.utils``

- extract array now offers different options to deal with array boundaries
[#3333]

- ``astropy.visualization``

- ``astropy.vo``
Expand Down
103 changes: 99 additions & 4 deletions astropy/nddata/tests/test_utils.py
Expand Up @@ -2,10 +2,10 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import numpy as np
from numpy.testing import assert_allclose

from ...tests.helper import pytest
from ..utils import extract_array, add_array, subpixel_indices
from ..utils import extract_array, add_array, subpixel_indices, \
overlap_slices, NoOverlapError, PartialOverlapError

test_positions = [(10.52, 3.12), (5.62, 12.97), (31.33, 31.77),
(0.46, 0.94), (20.45, 12.12), (42.24, 24.42)]
Expand All @@ -19,8 +19,103 @@

subsampling = 5

test_pos_bad = [(-1, -4), (-1, 0), (6, 2), (6, 6)]

def test_extract_array():

def test_slices_different_dim():
'''Overlap from arrays with different number of dim is undefined.'''
with pytest.raises(ValueError) as e:
temp = overlap_slices((4, 5, 6), (1, 2), (0, 0))
assert "the same number of dimensions" in str(e.value)


def test_slices_pos_different_dim():
'''Position must have same dim as arrays.'''
with pytest.raises(ValueError) as e:
temp = overlap_slices((4, 5), (1, 2), (0, 0, 3))
assert "the same number of dimensions" in str(e.value)


@pytest.mark.parametrize('pos', test_pos_bad)
def test_slices_no_overlap(pos):
'''If there is no overlap between arrays, an error should be raised.'''
with pytest.raises(NoOverlapError):
temp = overlap_slices((5, 5), (2, 2), pos)


def test_slices_partial_overlap():
'''Compute a slice for partially overlapping arrays.'''
temp = overlap_slices((5,), (3,), (0,))
assert temp == ((slice(0, 2, None),), (slice(1, 3, None),))

temp = overlap_slices((5,), (3,), (0,), mode='partial')
assert temp == ((slice(0, 2, None),), (slice(1, 3, None),))

for pos in [0, 4]:
with pytest.raises(PartialOverlapError) as e:
temp = overlap_slices((5,), (3,), (pos,), mode='strict')
assert 'Arrays overlap only partially.' in str(e.value)


def test_slices_overlap_wrong_mode():
'''Call overlap_slices with non-existing mode.'''
with pytest.raises(ValueError) as e:
temp = overlap_slices((5,), (3,), (0,), mode='full')
assert "Mode can only be" in str(e.value)


def test_extract_array_wrong_mode():
'''Call extract_array with non-existing mode.'''
with pytest.raises(ValueError) as e:
temp = extract_array(np.arange(4), (2, ), (0, ), mode='full')
assert "Valid modes are 'partial', 'trim', and 'strict'." == str(e.value)


def test_extract_array_1d_even():
'''Extract 1 d arrays.
All dimensions are treated the same, so we can test in 1 dim.
'''
assert np.all(extract_array(np.arange(4), (2, ), (0, ), fill_value=-99) == np.array([-99, 0]))
for i in [1, 2, 3]:
assert np.all(extract_array(np.arange(4), (2, ), (i, )) == np.array([i -1 , i]))
assert np.all(extract_array(np.arange(4.), (2, ), (4, ), fill_value=np.inf) == np.array([3, np.inf]))


def test_extract_array_1d_odd():
'''Extract 1 d arrays.
All dimensions are treated the same, so we can test in 1 dim.
The first few lines test the most error-prone part: Extraction of an
array on the boundaries.
Additional tests (e.g. dtype of return array) are done for the last
case only.
'''
assert np.all(extract_array(np.arange(4), (3,), (-1, ), fill_value=-99) == np.array([-99, -99, 0]))
assert np.all(extract_array(np.arange(4), (3,), (0, ), fill_value=-99) == np.array([-99, 0, 1]))
for i in [1,2]:
assert np.all(extract_array(np.arange(4), (3,), (i, )) == np.array([i-1, i, i+1]))
assert np.all(extract_array(np.arange(4), (3,), (3, ), fill_value=-99) == np.array([2, 3, -99]))
arrayin = np.arange(4.)
extracted = extract_array(arrayin, (3,), (4, ))
assert extracted[0] == 3
assert np.isnan(extracted[1]) # since I cannot use `==` to test for nan
assert extracted.dtype == arrayin.dtype


def test_extract_array_1d_trim():
'''Extract 1 d arrays.
All dimensions are treated the same, so we can test in 1 dim.
'''
assert np.all(extract_array(np.arange(4), (2, ), (0, ), mode='trim') == np.array([0]))
for i in [1, 2, 3]:
assert np.all(extract_array(np.arange(4), (2, ), (i, ), mode='trim') == np.array([i -1 , i]))
assert np.all(extract_array(np.arange(4.), (2, ), (4, ), mode='trim') == np.array([3]))


@pytest.mark.parametrize('mode', ['partial', 'trim', 'strict'])
def test_extract_array_easy(mode):
"""
Test extract_array utility function.
Expand All @@ -29,7 +124,7 @@ def test_extract_array():
large_test_array = np.zeros((11, 11))
small_test_array = np.ones((5, 5))
large_test_array[3:8, 3:8] = small_test_array
extracted_array = extract_array(large_test_array, (5, 5), (5, 5))
extracted_array = extract_array(large_test_array, (5, 5), (5, 5), mode=mode)
assert np.all(extracted_array == small_test_array)


Expand Down
100 changes: 81 additions & 19 deletions astropy/nddata/utils.py
Expand Up @@ -8,10 +8,18 @@
import numpy as np

__all__ = ['extract_array', 'add_array', 'subpixel_indices',
'overlap_slices']
'overlap_slices', 'NoOverlapError', 'PartialOverlapError']


def overlap_slices(large_array_shape, small_array_shape, position):
class NoOverlapError(ValueError):
'''Raised when determining the overlap of non-overlapping arrays.'''
pass

class PartialOverlapError(ValueError):
'''Raised when arrays only partially overlap.'''
pass

def overlap_slices(large_array_shape, small_array_shape, position, mode='partial'):
"""
Get slices for the overlapping part of a small and a large array.
Expand All @@ -27,9 +35,18 @@ def overlap_slices(large_array_shape, small_array_shape, position):
Shape of the large array.
small_array_shape : tuple
Shape of the small array.
position : tuple
position : tuple of integers
Position of the small array's center, with respect to the large array.
Coordinates should be in the same order as the array shape.
When determining the center for a coordinate with an even number of
elements, the position is rounded up. So the coordinates of the
center of an array with shape=(2,3) will be (1,1).
mode : ['partial', 'strict']
In "partial" mode, a partial overlap of the small and the large
array is sufficient. In the "strict" mode, the small array has to be
fully contained in the large array, otherwise an
`~astropy.nddata.utils.PartialOverlapError` is raised. In both modes,
non-overlapping arrays will raise a `~astropy.nddata.utils.NoOverlapError`.
Returns
-------
Expand All @@ -42,11 +59,34 @@ def overlap_slices(large_array_shape, small_array_shape, position):
``small_array[slices_small]`` extracts the region that is inside the
large array.
"""
if mode not in ['partial', 'strict']:
raise ValueError('Mode can only be "partial" or "strict".')
if len(small_array_shape) != len(large_array_shape):
raise ValueError("Both arrays must have the same number of dimensions.")

if len(small_array_shape) != len(position):
raise ValueError("Position must have the same number of dimensions as array.")

# Get edge coordinates
edges_min = [int(pos + 0.5 - small_shape / 2.) for (pos, small_shape) in
zip(position, small_array_shape)]
edges_max = [int(pos + 0.5 + small_shape / 2.) for (pos, small_shape) in
zip(position, small_array_shape)]
edges_min = [int(np.floor(pos + 0.5 - small_shape / 2.))
for (pos, small_shape) in zip(position, small_array_shape)]
edges_max = [int(np.floor(pos + 0.5 + small_shape / 2.))
for (pos, small_shape) in zip(position, small_array_shape)]

for e_max in edges_max:
if e_max <= 0:
raise NoOverlapError('Arrays do not overlap.')
for e_min, large_shape in zip(edges_min, large_array_shape):
if e_min >= large_shape:
raise NoOverlapError('Arrays do not overlap.')

if mode == 'strict':
for e_min in edges_min:
if e_min < 0:
raise PartialOverlapError('Arrays overlap only partially.')
for e_max, large_shape in zip(edges_max, large_array_shape):
if e_max >= large_shape:
raise PartialOverlapError('Arrays overlap only partially.')

# Set up slices
slices_large = tuple(slice(max(0, edge_min), min(large_shape, edge_max))
Expand All @@ -60,7 +100,7 @@ def overlap_slices(large_array_shape, small_array_shape, position):
return slices_large, slices_small


def extract_array(array_large, shape, position):
def extract_array(array_large, shape, position, mode='partial', fill_value=np.nan):
"""
Extract smaller array of given shape and position out of a larger array.
Expand All @@ -73,11 +113,26 @@ def extract_array(array_large, shape, position):
position : tuple
Position of the small array's center, with respect to the large array.
Coordinates should be in the same order as the array shape.
mode : ['partial', 'trim', 'strict']
In "partial" and "trim" mode, a partial overlap of the small and the large
array is sufficient. In the "strict" mode, the small array has to be
fully contained in the large array, otherwise an
`~astropy.nddata.utils.PartialOverlapError` is raised. In both modes,
non-overlapping arrays will raise a `~astropy.nddata.utils.NoOverlapError`.
In "partial" mode, positions in the extracted array, that do not overlap
with the original array, will be filled with ``fill_value``. In "trim" mode
only the overlapping elements are returned, thus the resulting array may
be smaller than requested.
fill_value : object of type array_large.dtype
In "partial" mode ``fill_value`` set the values in the extracted array that
do not overlap with ``large_array``.
Returns
-------
array_small : `~numpy.ndarray`
The extracted array
The extracted array. Values that do not overlap with ``array_large``
are masked.
Examples
--------
Expand All @@ -87,19 +142,26 @@ def extract_array(array_large, shape, position):
>>> import numpy as np
>>> from astropy.nddata.utils import extract_array
>>> large_array = np.arange(110).reshape((11, 10))
>>> large_array[4:9, 4:9] = np.ones((5, 5))
>>> extract_array(large_array, (3, 5), (7, 7))
array([[ 1, 1, 1, 1, 69],
[ 1, 1, 1, 1, 79],
[ 1, 1, 1, 1, 89]])
array([[65, 66, 67, 68, 69],
[75, 76, 77, 78, 79],
[85, 86, 87, 88, 89]])
"""
# Check if larger array is really larger
if all(large_shape > small_shape for (large_shape, small_shape)
in zip(array_large.shape, shape)):
large_slices, _ = overlap_slices(array_large.shape, shape, position)
return array_large[large_slices]
if mode in ['partial', 'trim']:
slicemode = 'partial'
elif mode == 'strict':
slicemode = mode
else:
raise ValueError("Can't extract array. Shape too large.")
raise ValueError("Valid modes are 'partial', 'trim', and 'strict'.")
large_slices, small_slices = overlap_slices(array_large.shape,
shape, position, mode=slicemode)
extracted_array = array_large[large_slices]
# Extracting on the edges is presumably a rare case, so treat special here.
if (extracted_array.shape != shape) and (mode=='partial'):
extracted_array = np.zeros(shape, dtype=array_large.dtype)
extracted_array[:] = fill_value
extracted_array[small_slices] = array_large[large_slices]
return extracted_array


def add_array(array_large, array_small, position):
Expand Down

0 comments on commit 4b05c7f

Please sign in to comment.