Skip to content

Commit

Permalink
Fix 480 (#484)
Browse files Browse the repository at this point in the history
* a tad more precise

* catch errors early if deblended catalog is too large, to avoid unhelpful errors for users
  • Loading branch information
ismael-mendoza committed Apr 2, 2024
1 parent 85b4a23 commit 3eebcbd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
6 changes: 3 additions & 3 deletions btk/blend_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,12 @@ def __post_init__(self) -> None:
def _validate_catalog(self, catalog: Table):
if not ("ra" in catalog.colnames and "dec" in catalog.colnames):
raise ValueError(
"The output catalog of at least one of your measurement functions does"
"The output catalog of at least one of your deblenders does"
"not contain the mandatory 'ra' and 'dec' columns"
)
if not len(catalog) <= self.max_n_sources:
raise ValueError(
"The predicted catalog of at least one of your deblended images "
"The detection catalog of at least one of your deblended images "
"contains more sources than the maximum number of sources specified."
)
return catalog
Expand Down Expand Up @@ -320,7 +320,7 @@ def _validate_catalog(self, catalog_list: List[Table]):
)
if not len(catalog) <= self.max_n_sources:
raise ValueError(
"The predicted catalog of at least one of your deblended images "
"The detections catalog of at least one of your deblended images "
"contains more sources than the maximum number of sources specified."
)
return catalog_list
Expand Down
48 changes: 33 additions & 15 deletions btk/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class PeakLocalMax(Deblender):
"""This class detects centroids with `skimage.feature.peak_local_max`.
The function performs detection and deblending of the sources based on the provided
band index. If use_mean feature is used, then the measurement function is using
band index. If use_mean feature is used, then the Deblender will use
the average of all the bands.
"""

Expand All @@ -197,14 +197,14 @@ def __init__(
use_mean: bool = False,
use_band: Optional[int] = None,
) -> None:
"""Initializes measurement class. Exactly one of 'use_mean' or 'use_band' must be specified.
"""Initializes Deblender class. Exactly one of 'use_mean' or 'use_band' must be specified.
Args:
max_n_sources: See parent class.
threshold_scale: Minimum intensity of peaks.
min_distance: Minimum distance in pixels between two peaks.
use_mean: Flag to use the band average for the measurement.
use_band: Integer index of the band to use for the measurement.
use_mean: Flag to use the band average for deblending.
use_band: Integer index of the band to use for deblending
"""
super().__init__(max_n_sources)
self.min_distance = min_distance
Expand All @@ -218,7 +218,7 @@ def __init__(
self.use_band = use_band

def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
"""Performs measurement on the ii-th example from the batch."""
"""Performs deblending on the ii-th example from the batch."""
blend_image = blend_batch.blend_images[ii]
image = np.mean(blend_image, axis=0) if self.use_mean else blend_image[self.use_band]

Expand All @@ -240,6 +240,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
catalog["ra"], catalog["dec"] = ra, dec
catalog["x_peak"], catalog["y_peak"] = x, y

if len(catalog) > self.max_n_sources:
raise ValueError(
"`PeakLocalMax` detected more sources than `max_n_sources`. Consider increasing"
"`threshold_scale` or `max_n_sources`."
)

return DeblendExample(self.max_n_sources, catalog)


Expand All @@ -261,7 +267,7 @@ def __init__(
use_mean: bool = False,
use_band: Optional[int] = None,
) -> None:
"""Initializes measurement class. Exactly one of 'use_mean' or 'use_band' must be specified.
"""Initializes Deblender class. Exactly one of 'use_mean' or 'use_band' must be specified.
Args:
max_n_sources: See parent class.
Expand All @@ -270,8 +276,8 @@ def __init__(
will be `thresh * err[j, i]` where `err` is set to the global rms of
the background measured by SEP.
min_area: Minimum number of pixels required for an object. Default is 5.
use_mean: Flag to use the band average for the measurement
use_band: Integer index of the band to use for the measurement
use_mean: Flag to use the band average for deblending.
use_band: Integer index of the band to use for deblending.
"""
super().__init__(max_n_sources)
if use_band is None and not use_mean:
Expand All @@ -284,7 +290,7 @@ def __init__(
self.min_area = min_area

def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
"""Performs measurement on the i-th example from the batch."""
"""Performs deblending on the i-th example from the batch."""
# get a 1-channel input for sep
blend_image = blend_batch.blend_images[ii]
image = np.mean(blend_image, axis=0) if self.use_mean else blend_image[self.use_band]
Expand All @@ -299,6 +305,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
minarea=self.min_area,
)

if len(catalog) > self.max_n_sources:
raise ValueError(
"SEP predicted more sources than `max_n_sources`. Consider increasing `thresh`"
" or `max_n_sources`."
)

segmentation_exp = np.zeros((self.max_n_sources, *image.shape), dtype=bool)
deblended_images = np.zeros((self.max_n_sources, *image.shape), dtype=image.dtype)
n_objects = len(catalog)
Expand Down Expand Up @@ -339,7 +351,7 @@ class SepMultiband(Deblender):
"""

def __init__(self, max_n_sources: int, matching_threshold: float = 1.0, thresh: float = 1.5):
"""Initialize the SepMultiband measurement function.
"""Initialize the SepMultiband Deblender.
Args:
max_n_sources: See parent class.
Expand All @@ -351,7 +363,7 @@ def __init__(self, max_n_sources: int, matching_threshold: float = 1.0, thresh:
self.thresh = thresh

def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
"""Performs measurement on the ii-th example from the batch."""
"""Performs deblending on the ii-th example from the batch."""
# run source extractor on the first band
wcs = blend_batch.wcs
image = blend_batch.blend_images[ii]
Expand All @@ -361,6 +373,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
ra_coordinates *= 3600
dec_coordinates *= 3600

if len(catalog) > self.max_n_sources:
raise ValueError(
"SEP predicted more sources than `max_n_sources`. Consider increasing `thresh`"
" or `max_n_sources`."
)

# iterate over remaining bands and match predictions using KdTree
for band in range(1, image.shape[0]):
# run source extractor
Expand Down Expand Up @@ -449,7 +467,7 @@ def __init__(
def deblend(
self, ii: int, blend_batch: BlendBatch, reference_catalogs: Table = None
) -> DeblendExample:
"""Performs measurement on the ii-th example from the batch.
"""Performs deblending on the ii-th example from the batch.
Args:
ii: The index of the example in the batch.
Expand Down Expand Up @@ -556,14 +574,14 @@ def __init__(
njobs: int = 1,
verbose: bool = False,
):
"""Initialize measurement generator.
"""Initialize deblender generator.
Args:
deblenders: Deblender or a list of Deblender that will be used on the
outputs of the draw_blend_generator.
draw_blend_generator: Instance of subclasses of `DrawBlendsGenerator`.
njobs: The number of parallel processes to run [Default: 1].
verbose: Whether to print information about measurement.
verbose: Whether to print information about deblending.
"""
self.deblenders = self._validate_deblenders(deblenders)
self.deblender_names = self._get_unique_deblender_names()
Expand Down Expand Up @@ -615,7 +633,7 @@ def _get_unique_deblender_names(self) -> List[str]:
return deblender_names

def __next__(self) -> Tuple[BlendBatch, Dict[str, DeblendBatch]]:
"""Return measurement results on a single batch from the draw_blend_generator.
"""Return deblending results on a single batch from the draw_blend_generator.
Returns:
blend_batch: draw_blend_generator output from its `__next__` method.
Expand Down

0 comments on commit 3eebcbd

Please sign in to comment.