Skip to content

Commit

Permalink
Last few tweaks to the quickstart & advanced-analysis tutorial (#477)
Browse files Browse the repository at this point in the history
* update poetry

* add colorblind cycle

* add authors and remove autoreload stuff

* working on tidying up plots and adding more useful comments

* fix aperture functions, fix docstrings

* update packages

* more ignore

* use sky level instead

* update git ignore with cache

* fix error in tp counts

* docstring

* author list

* wrapping up plots

* more gitignore

* will remove later

* latest poetry from main branch

* add docstring to peak_local_max

* update docstring

* more error catching

* exclude import errors in notebooks

* fix various errors in matching code

There was an issue where matches containg -1 indices could be added to the final catalog or array in the Matching object (not always the non-zero matches in the front!). In the end I realized that keeping the -1 is not really useful, so I removed it altogether and array of indices matched don't contain any -1 anymore, just the true matches (accounting for distance)

* we fix this given the new convention

* new images after fixing matching

* recall figure with new matching procedure

* make blendedness robust to empty arrays from max_n_sources

* all figures redone and pushed

* remove cache as figures are finalized

* add authors

* typo

* correct notebook and figure with correct blendedness definition

* delete notebooks that will not make it for the release
  • Loading branch information
ismael-mendoza committed Apr 19, 2024
1 parent 180bf61 commit 115a0ae
Show file tree
Hide file tree
Showing 14 changed files with 494 additions and 797 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
.coverage
/dist/
/docs/build
/outputs
/data/cache
20 changes: 17 additions & 3 deletions btk/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class PeakLocalMax(Deblender):
def __init__(
self,
max_n_sources: int,
sky_level: float,
threshold_scale: int = 5,
min_distance: int = 2,
use_mean: bool = False,
Expand All @@ -201,14 +202,16 @@ def __init__(
Args:
max_n_sources: See parent class.
threshold_scale: Minimum intensity of peaks.
sky_level: Background intensity in images to be detected (assumed constant).
threshold_scale: Minimum number of sigmas above noise level for detections.
min_distance: Minimum distance in pixels between two peaks.
use_mean: Flag to use the band average for deblending.
use_band: Integer index of the band to use for deblending
"""
super().__init__(max_n_sources)
self.min_distance = min_distance
self.threshold_scale = threshold_scale
self.sky_level = sky_level

if use_band is None and not use_mean:
raise ValueError("Either set 'use_mean=True' OR indicate a 'use_band' index")
Expand All @@ -223,10 +226,15 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
image = np.mean(blend_image, axis=0) if self.use_mean else blend_image[self.use_band]

# compute threshold value
threshold = self.threshold_scale * np.std(image)
threshold = self.threshold_scale * np.sqrt(self.sky_level)

# calculate coordinates
coordinates = peak_local_max(image, min_distance=self.min_distance, threshold_abs=threshold)
coordinates = peak_local_max(
image,
min_distance=self.min_distance,
threshold_abs=threshold,
num_peaks=self.max_n_sources,
)
x, y = coordinates[:, 1], coordinates[:, 0]

# convert coordinates to ra, dec
Expand Down Expand Up @@ -388,6 +396,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
band_image, self.thresh, err=bkg.globalrms, segmentation_map=False
)

if len(catalog) > self.max_n_sources:
raise ValueError(
"SEP predicted more sources than `max_n_sources`. Consider increasing `thresh`"
" or `max_n_sources`."
)

# convert predictions to arcseconds
ra_detections, dec_detections = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
ra_detections *= 3600
Expand Down
89 changes: 63 additions & 26 deletions btk/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ def __init__(
"""Initialize MatchInfo.
Args:
true_matches: a list of 1D array, each entry corresponds to a numpy array
containing the index of detected object in the truth catalog that
got matched with the i-th truth object in the blend.
pred_matches: a list of 1D array, where the j-th entry of i-th array
corresponds to the index of truth object in the i-th blend
that got matched with the j-th detected object in that blend.
If no match, value is -1.
true_matches: a list of 1D arrays, each array containing the matching indices of
truth catalog rows for a given element of a batch. These indices are in 1-1
correspondence to the `pred_matches` indices.
pred_matches: a list of 1D arrays, each array containing the matching indices of
predicted catalog rows for a given element of a batch. These indices are in 1-1
correspondence to the `true_matches` indices.
n_true: a 1D array of length N, where each entry is the number of truth objects.
n_pred: a 1D array of length N, where each entry is the number of detected objects.
"""
Expand Down Expand Up @@ -66,12 +65,11 @@ def _match_arrays(self, *arrs: np.ndarray, true_or_pred: str) -> tuple:
new_arrs = []
for arr in arrs:
assert len(arr) == self.batch_size
new_arr = np.zeros_like(arr)
new_arr = np.zeros((self.batch_size, self.max_n_sources, *arr.shape[2:]))
for ii in range(self.batch_size):
n_sources = len(matches[ii])
assert n_sources <= self.max_n_sources
new_arr[ii, :n_sources] = arr[ii][matches[ii]]
new_arrs.append(new_arr[:, : self.max_n_sources])
for jj, m in enumerate(matches[ii]):
new_arr[ii, jj] = arr[ii, m]
new_arrs.append(new_arr)
return tuple(new_arrs) if len(new_arrs) > 1 else new_arrs[0]

def match_true_catalogs(self, catalog_list: Table) -> List[Table]:
Expand Down Expand Up @@ -115,6 +113,31 @@ def filter_by_true(self, mask: List[np.ndarray]) -> "Matching":
new_true_matches, new_pred_matches, np.array(new_n_true), np.array(new_n_pred)
)

def filter_by_pred(self, mask: List[np.ndarray]) -> "Matching":
"""Returns a new Matching object with detected objects that pass the mask."""
new_true_matches = []
new_pred_matches = []
new_n_true = []
new_n_pred = []
for ii in range(self.batch_size):
true_match = self.true_matches[ii]
pred_match = self.pred_matches[ii]

# get indices in pred_matches of pred objects that do not pass mask
isin_pred = np.isin(pred_match, np.where(~mask[ii])[0])
index_of_interest = np.argwhere(isin_pred).ravel()

# remove those indices from true_matches and pred_matches
new_true_matches.append(np.delete(true_match, index_of_interest, axis=0))
new_pred_matches.append(np.delete(pred_match, index_of_interest, axis=0))

new_n_true.append(self.n_true[ii])
new_n_pred.append(np.sum(mask[ii]))

return Matching(
new_true_matches, new_pred_matches, np.array(new_n_true), np.array(new_n_pred)
)

@property
def tp(self) -> np.ndarray:
"""Returns true positive array."""
Expand Down Expand Up @@ -154,6 +177,10 @@ def __call__(self, true_catalog_list: List[Table], pred_catalog_list: List[Table
true_match, pred_match = np.array([]).astype(int), np.array([]).astype(int)
else:
true_match, pred_match = self.match_catalogs(true_catalog, pred_catalog)

if -1 in true_match or -1 in pred_match:
raise ValueError("Matcher should return only matching indices, not dummy -1 index.")

match_true.append(true_match)
match_pred.append(pred_match)
n_true.append(len(true_catalog))
Expand Down Expand Up @@ -182,6 +209,11 @@ class IdentityMatcher(Matcher):

def match_catalogs(self, truth_catalog, predicted_catalog) -> np.ndarray:
"""Returns trivial identity matching."""
if not len(truth_catalog) == len(predicted_catalog):
raise ValueError(
"IdenityMatcher can be used because the given pair of truth "
"and predicated catalogs do not have the same size."
)
true_indx = np.array(range(len(truth_catalog)))
pred_indx = np.array(range(len(predicted_catalog)))
return true_indx, pred_indx
Expand All @@ -208,25 +240,28 @@ def match_catalogs(self, truth_catalog: Table, predicted_catalog: Table) -> np.n
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
Args:
truth_catalog: truth catalog containing relevant detecion information
predicted_catalog: predicted catalog to compare with the ground truth
truth_catalog: truth catalog containing true source centroid information.
predicted_catalog: predicted catalog containing detection information.
Returns:
matched_indx: a 1D array where j-th entry is the index of the target row
that matched with the j-th detected row. If no match, value is -1.
A tuple of two arrays, each has length equal to the total number of matches. The i-th
index in the first array is the row of the truth catalog that was matched with the row
of the predicted catalog corresponding to the i-th index of the second array. The index
in each array.
"""
dist = self.compute_distance_matrix(truth_catalog, predicted_catalog)
# solve optimization problem using Hungarian matching algorithm
# truth_catalog[true_indx[i]] is matched with predicted_catalog[matched_indx[i]]
# len(true_indx) = len(detect_indx) = min(len(true_table), len(detected_table))
# len(true_indx) = len(pred_indx) = min(len(true_table), len(pred_table))
true_indx, pred_indx = linear_sum_assignment(dist)

# if the distance is greater than max_sep then mark detection as -1
true_mask = dist.T[pred_indx, true_indx] > self.max_sep
true_indx[true_mask] = -1
pred_mask = dist[true_indx, pred_indx] > self.max_sep
pred_indx[pred_mask] = -1
# if the distance is greater than `max_sep` then remove the matching.
true_mask = dist.T[pred_indx, true_indx] < self.max_sep
true_match = true_indx[true_mask]
pred_mask = dist[true_indx, pred_indx] < self.max_sep
pred_match = pred_indx[pred_mask]

return true_indx, pred_indx
return true_match, pred_match


class SkyClosestNeighbourMatcher(Matcher):
Expand Down Expand Up @@ -284,7 +319,8 @@ def match_catalogs(self, truth_catalog: Table, predicted_catalog: Table) -> np.n
pred_indx[match_id] = target_idx

# if the matched distance exceeds max_sep, we discard that detection
pred_indx[d2d.to(units.arcsec) > self.max_sep * units.arcsec] = -1
pred_mask = d2d.to(units.arcsec) < self.max_sep * units.arcsec
pred_match = pred_indx[pred_mask]

# now for ture indices
idx, d2d, _ = true_coordinates.match_to_catalog_sky(pred_coordinates)
Expand All @@ -295,9 +331,10 @@ def match_catalogs(self, truth_catalog: Table, predicted_catalog: Table) -> np.n
match_id = np.argmin(masked_d2d)
true_indx[match_id] = target_idx

true_indx[d2d.to(units.arcsec) > self.max_sep * units.arcsec] = -1
true_mask = d2d.to(units.arcsec) < self.max_sep * units.arcsec
true_match = true_indx[true_mask]

return true_indx, pred_indx
return true_match, pred_match


def pixel_l2_distance_matrix(
Expand Down
48 changes: 34 additions & 14 deletions btk/measure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module for measuring galaxy properties."""
"""Module for measuring galaxy properties from images."""

from typing import Tuple

Expand Down Expand Up @@ -56,36 +56,42 @@ def get_ksb_ellipticity(
return ellipticities


def get_blendedness(iso_image: np.ndarray):
def get_blendedness(iso_image: np.ndarray) -> np.ndarray:
"""Calculate blendedness given isolated images of each galaxy in a blend.
Args:
iso_image: Array of shape = (..., N, H, W) corresponding to images of the isolated
galaxy you are calculating blendedness for.
iso_image: Array of shape = (..., N, H, W) corresponding to images of isolated
galaxiesi you are calculating blendedness for.
Returns:
Array of size (..., N) corresponding to blendedness values for each individual galaxy.
"""
assert iso_image.ndim >= 3
num = np.sum(iso_image * iso_image, axis=(-1, -2))
blend = np.sum(iso_image, axis=-3)[..., None, :, :]
denom = np.sum(blend * iso_image, axis=(-1, -2))
return 1 - num / denom
return 1 - np.divide(num, denom, out=np.zeros_like(num), where=(num != 0))


def get_snr(iso_image: np.ndarray, sky_level: float) -> float:
def get_snr(iso_image: np.ndarray, sky_level: float) -> np.ndarray:
"""Calculate SNR of a set of isolated galaxies with same sky level.
Args:
iso_image: Array of shape = (..., H, W) corresponding to image of the isolated
galaxy you are calculating SNR for.
sky_level: Background level of all images. Images are assume to be
background-substracted.
Returns:
Array of size (...) corresponding to SNR values for each individual galaxy.
"""
images = iso_image + sky_level
return np.sqrt(np.sum(iso_image * iso_image / images, axis=(-1, -2)))


def _get_single_aperture_flux(
image: np.ndarray, x: np.ndarray, y: np.ndarray, radius: float, sky_level: float
) -> np.ndarray:
) -> Tuple[np.ndarray, np.ndarray]:
"""Utility function to measure flux using fixed circular aperture with sep.
Args:
Expand All @@ -95,15 +101,19 @@ def _get_single_aperture_flux(
sky_level (float): Background level of all images.
Images are assume to be background substracted.
radius (float): Radius of the aperture in pixels.
Returns:
Tuple of flux and fluxerr.
"""
assert image.ndim == 2
flux, _, _ = sep.sum_circle(image, x, y, radius, err=sky_level)
return flux[0]
assert x.ndim == 1 and y.ndim == 1
flux, fluxerr, _ = sep.sum_circle(image, x, y, radius, var=sky_level)
return flux, fluxerr


def get_aperture_fluxes(
images: np.ndarray, xs: np.ndarray, ys: np.ndarray, radius: float, sky_level: float
) -> np.ndarray:
) -> Tuple[np.ndarray, np.ndarray]:
"""Utility function to measure flux using fixed circular aperture with sep.
Args:
Expand All @@ -113,13 +123,23 @@ def get_aperture_fluxes(
sky_level (float): Background level of all images.
Images are assume to be background substracted.
radius (float): Radius of the aperture in pixels.
Returns:
fluxes (np.array): Array of shape (B, N) corresponding to the measured aperture fluxes
in each given position for each of the B batches.
fluxerr (np.array): Array of same shape with corresponding flux errors.
"""
assert images.ndim == 3
batch_size = images.shape[0]
fluxes = np.zeros((batch_size, len(xs)))
assert xs.ndim == 2 and ys.ndim == 2
batch_size, max_n_sources = xs.shape
fluxes = np.zeros((batch_size, max_n_sources))
fluxerrs = np.zeros((batch_size, max_n_sources))
for ii in range(batch_size):
fluxes[ii] = _get_single_aperture_flux(images[ii], xs[ii], ys[ii], radius, sky_level)
return fluxes
n_sources = np.sum((xs[ii] > 0) & (ys[ii] > 0)).astype(int)
flux, err = _get_single_aperture_flux(images[ii], xs[ii], ys[ii], radius, sky_level)
fluxes[ii, :n_sources] = flux[:n_sources]
fluxerrs[ii, :n_sources] = err[:n_sources]
return fluxes, fluxerrs


def get_residual_images(iso_images: np.ndarray, blend_images: np.ndarray) -> np.ndarray:
Expand Down
12 changes: 12 additions & 0 deletions btk/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

import numpy as np

CB_color_cycle = [
"#377eb8",
"#ff7f00",
"#4daf4a",
"#f781bf",
"#a65628",
"#984ea3",
"#999999",
"#e41a1c",
"#dede00",
]


def get_rgb(image: np.ndarray, min_val: Optional[float] = None, max_val: Optional[float] = None):
"""Function to normalize 3 band input image to RGB 0-255 image.
Expand Down
7 changes: 2 additions & 5 deletions notebooks/00-quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
"*Authors:* Ismael Mendoza, Andrii Torchylo, Thomas Sainrat"
]
},
{
Expand Down

0 comments on commit 115a0ae

Please sign in to comment.