forked from astropy/photutils
-
Notifications
You must be signed in to change notification settings - Fork 0
/
deblend.py
269 lines (225 loc) · 10.9 KB
/
deblend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Functions for deblending sources."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from copy import deepcopy
import warnings
import numpy as np
from astropy.utils.exceptions import AstropyUserWarning
from .core import _convolve_data, detect_sources
from ..segmentation import SegmentationImage
__all__ = ['deblend_sources']
def deblend_sources(data, segment_img, npixels, filter_kernel=None,
labels=None, nlevels=32, contrast=0.001,
mode='exponential', connectivity=8, relabel=True):
"""
Deblend overlapping sources labeled in a segmentation image.
Sources are deblended using a combination of multi-thresholding and
`watershed segmentation
<https://en.wikipedia.org/wiki/Watershed_(image_processing)>`_. In
order to deblend sources, they must be separated enough such that
there is a saddle between them.
.. note::
This function is experimental. Please report any issues on the
`Photutils GitHub issue tracker
<https://github.com/astropy/photutils/issues>`_
Parameters
----------
data : array_like
The 2D array of the image.
segment_img : `~photutils.segmentation.SegmentationImage` or array_like (int)
A 2D segmentation image, either as a
`~photutils.segmentation.SegmentationImage` object or an
`~numpy.ndarray`, with the same shape as ``data`` where sources
are labeled by different positive integer values. A value of
zero is reserved for the background.
npixels : int
The number of connected pixels, each greater than ``threshold``,
that an object must have to be detected. ``npixels`` must be a
positive integer.
filter_kernel : array-like (2D) or `~astropy.convolution.Kernel2D`, optional
The 2D array of the kernel used to filter the image before
thresholding. Filtering the image will smooth the noise and
maximize detectability of objects with a shape similar to the
kernel.
labels : int or array-like of int, optional
The label numbers to deblend. If `None` (default), then all
labels in the segmentation image will be deblended.
nlevels : int, optional
The number of multi-thresholding levels to use. Each source
will be re-thresholded at ``nlevels``, spaced exponentially or
linearly (see the ``mode`` keyword), between its minimum and
maximum values within the source segment.
contrast : float, optional
The fraction of the total (blended) source flux that a local
peak must have to be considered as a separate object.
``contrast`` must be between 0 and 1, inclusive. If ``contrast
= 0`` then every local peak will be made a separate object
(maximum deblending). If ``contrast = 1`` then no deblending
will occur. The default is 0.001, which will deblend sources with
a magnitude differences of about 7.5.
mode : {'exponential', 'linear'}, optional
The mode used in defining the spacing between the
multi-thresholding levels (see the ``nlevels`` keyword).
connectivity : {4, 8}, optional
The type of pixel connectivity used in determining how pixels
are grouped into a detected source. The options are 4 or 8
(default). 4-connected pixels touch along their edges.
8-connected pixels touch along their edges or corners. For
reference, SExtractor uses 8-connected pixels.
relabel : bool
If `True` (default), then the segmentation image will be
relabeled such that the labels are in sequential order starting
from 1.
Returns
-------
segment_image : `~photutils.segmentation.SegmentationImage`
A 2D segmentation image, with the same shape as ``data``, where
sources are marked by different positive integer values. A
value of zero is reserved for the background.
See Also
--------
:func:`photutils.detection.detect_sources`
"""
if not isinstance(segment_img, SegmentationImage):
segment_img = SegmentationImage(segment_img)
if segment_img.shape != data.shape:
raise ValueError('The data and segmentation image must have '
'the same shape')
if labels is None:
labels = segment_img.labels
labels = np.atleast_1d(labels)
data = _convolve_data(data, filter_kernel, mode='constant',
fill_value=0.0)
last_label = segment_img.max
segm_deblended = deepcopy(segment_img)
for label in labels:
segment_img.check_label(label)
source_slice = segment_img.slices[label - 1]
source_data = data[source_slice]
source_segm = SegmentationImage(np.copy(
segment_img.data[source_slice]))
source_segm.keep_labels(label) # include only one label
source_deblended = _deblend_source(
source_data, source_segm, npixels, nlevels=nlevels,
contrast=contrast, mode=mode, connectivity=connectivity)
if source_deblended.nlabels > 1:
# replace the original source with the deblended source
source_mask = (source_deblended.data > 0)
segm_deblended._data[source_slice][source_mask] = (
source_deblended.data[source_mask] + last_label)
last_label += source_deblended.nlabels
if relabel:
segm_deblended.relabel_sequential()
return segm_deblended
def _deblend_source(data, segment_img, npixels, nlevels=32, contrast=0.001,
mode='exponential', connectivity=8):
"""
Deblend a single labeled source.
Parameters
----------
data : array_like
The 2D array of the image. The should be a cutout for a single
source. ``data`` should already be smoothed by the same filter
used in :func:`~photutils.detect_sources`, if applicable.
segment_img : `~photutils.segmentation.SegmentationImage`
A cutout `~photutils.segmentation.SegmentationImage` object with
the same shape as ``data``. ``segment_img`` should contain only
*one* source label.
npixels : int
The number of connected pixels, each greater than ``threshold``,
that an object must have to be detected. ``npixels`` must be a
positive integer.
nlevels : int, optional
The number of multi-thresholding levels to use. Each source
will be re-thresholded at ``nlevels``, spaced exponentially or
linearly (see the ``mode`` keyword), between its minimum and
maximum values within the source segment.
contrast : float, optional
The fraction of the total (blended) source flux that a local
peak must have to be considered as a separate object.
``contrast`` must be between 0 and 1, inclusive. If ``contrast
= 0`` then every local peak will be made a separate object
(maximum deblending). If ``contrast = 1`` then no deblending
will occur. The default is 0.001, which will deblend sources with
a magnitude differences of about 7.5.
mode : {'exponential', 'linear'}, optional
The mode used in defining the spacing between the
multi-thresholding levels (see the ``nlevels`` keyword).
connectivity : {4, 8}, optional
The type of pixel connectivity used in determining how pixels
are grouped into a detected source. The options are 4 or 8
(default). 4-connected pixels touch along their edges.
8-connected pixels touch along their edges or corners. For
reference, SExtractor uses 8-connected pixels.
Returns
-------
segment_image : `~photutils.segmentation.SegmentationImage`
A 2D segmentation image, with the same shape as ``data``, where
sources are marked by different positive integer values. A
value of zero is reserved for the background.
"""
from scipy import ndimage
from skimage.morphology import watershed
if nlevels < 1:
raise ValueError('nlevels must be >= 1, got "{0}"'.format(nlevels))
if contrast < 0 or contrast > 1:
raise ValueError('contrast must be >= 0 or <= 1, got '
'"{0}"'.format(contrast))
segm_mask = (segment_img.data > 0)
source_values = data[segm_mask]
source_min = np.min(source_values)
source_max = np.max(source_values)
if source_min == source_max:
return segment_img # no deblending
if source_min < 0:
warnings.warn('Source "{0}" contains negative values, setting '
'deblending mode to "linear"'.format(
segment_img.labels[0]), AstropyUserWarning)
mode = 'linear'
source_sum = float(np.sum(source_values))
steps = np.arange(1., nlevels + 1)
if mode == 'exponential':
if source_min == 0:
source_min = source_max * 0.01
thresholds = source_min * ((source_max / source_min) **
(steps / (nlevels + 1)))
elif mode == 'linear':
thresholds = source_min + ((source_max - source_min) /
(nlevels + 1)) * steps
else:
raise ValueError('"{0}" is an invalid mode; mode must be '
'"exponential" or "linear"')
# create top-down tree of local peaks
segm_tree = []
for level in thresholds[::-1]:
segm_tmp = detect_sources(data, level, npixels=npixels,
connectivity=connectivity)
if segm_tmp.nlabels >= 2:
fluxes = []
for i in segm_tmp.labels:
fluxes.append(np.sum(data[segm_tmp == i]))
idx = np.where((np.array(fluxes) / source_sum) >= contrast)[0]
if len(idx >= 2):
segm_tree.append(segm_tmp)
nbranch = len(segm_tree)
if nbranch == 0:
return segment_img
else:
for j in np.arange(nbranch - 1, 0, -1):
intersect_mask = (segm_tree[j].data *
segm_tree[j - 1].data).astype(bool)
intersect_labels = np.unique(segm_tree[j].data[intersect_mask])
if segm_tree[j - 1].nlabels <= len(intersect_labels):
segm_tree[j - 1] = segm_tree[j]
else:
# If a higher tree level has more peaks than in the
# intersected label(s) with the level below, then remove
# the intersected label(s) in the lower level, add the
# higher level, and relabel.
segm_tree[j].remove_labels(intersect_labels)
new_segments = segm_tree[j].data + segm_tree[j - 1].data
new_segm, nsegm = ndimage.label(new_segments)
segm_tree[j - 1] = SegmentationImage(new_segm)
return SegmentationImage(watershed(-data, segm_tree[0].data,
mask=segment_img.data))