Skip to content

Commit

Permalink
remaining tutorial notebooks (#492)
Browse files Browse the repository at this point in the history
* BUG the psf_func argument was not being used as expected, it was being called once when the Survey object was created, rather than at each batch inside the DrawBlendsGenerator

* type annotation eror

* spacing

* draft generation tutorial notebook

* draft /outline notebooks

* SepMultiband -> SepMultiBand typo fixed

* how to use cosmos galaxies added

* remove whitespace

* docstring fixed

* outline

* will add later (custom sampling)

* deblending notebook skeleton

* nbstripout

* actually finished this tutorial

* BUG was using exactly the same array for iou by mistake

* REF delete trailing whitespace

* BUG matched arrays are actually 0 by default, so it's safer to check if they are non-zero for segmentation

* still working on metrics notebook

* advanced deblending notebook added

* fix a few typos

* typos

* missing example

* image changed

* wrap up first version of metrics tutotrial

---------

Co-authored-by: Andrii Torchylo <andrii.torchylo@gmail.com>
  • Loading branch information
ismael-mendoza and atorchylo committed May 2, 2024
1 parent 8939751 commit 98ed3e5
Show file tree
Hide file tree
Showing 11 changed files with 2,033 additions and 41 deletions.
14 changes: 7 additions & 7 deletions btk/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@


class Deblender(ABC):
"""Abstract base class containing the measure class for BTK.
"""Abstract base class containing the deblender class for BTK.
Each new measure class should be a subclass of Measure.
Each new deblender class should be a subclass of Deblender.
"""

def __init__(self, max_n_sources: int) -> None:
Expand Down Expand Up @@ -71,7 +71,7 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
"""

def batch_call(self, blend_batch: BlendBatch, njobs: int = 1, **kwargs) -> DeblendBatch:
"""Implements the call of a measure function on the entire batch.
"""Implements the call of the deblender on the entire batch.
Overwrite this function if you perform measurments on the batch.
The default fucntionality is to use multiprocessing to speed up
Expand Down Expand Up @@ -160,7 +160,7 @@ def deblend(self, ii: int, mr_batch: MultiResolutionBlendBatch) -> DeblendExampl
"""

def batch_call(self, mr_batch: MultiResolutionBlendBatch, njobs: int = 1) -> DeblendBatch:
"""Implements the call of a measure function on the entire batch.
"""Implements the call of the deblender on the entire batch.
Overwrite this function if you perform measurments on a batch.
The default fucntionality is to use multiprocessing to speed up
Expand Down Expand Up @@ -348,7 +348,7 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample:
)


class SepMultiband(Deblender):
class SepMultiBand(Deblender):
"""This class returns centers detected with SEP by combining predictions in different bands.
For each band in the input image we run `sep` for detection and append new detections
Expand All @@ -359,7 +359,7 @@ class SepMultiband(Deblender):
"""

def __init__(self, max_n_sources: int, matching_threshold: float = 1.0, thresh: float = 1.5):
"""Initialize the SepMultiband Deblender.
"""Initialize the SepMultiBand Deblender.
Args:
max_n_sources: See parent class.
Expand Down Expand Up @@ -672,6 +672,6 @@ def __next__(self) -> Tuple[BlendBatch, Dict[str, DeblendBatch]]:
available_deblenders = {
"PeakLocalMax": PeakLocalMax,
"SepSingleBand": SepSingleBand,
"SepMultiBand": SepMultiband,
"SepMultiBand": SepMultiBand,
"Scarlet": Scarlet,
}
3 changes: 2 additions & 1 deletion btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def _get_psf_from_survey(self, survey: Survey) -> List[galsim.GSObject]:
for band in survey.available_filters:
filt = survey.get_filter(band)
if callable(filt.psf):
generated_psf = filt.psf() # generate the PSF with the provided function
# generate the PSF with the provided function
generated_psf = filt.psf(survey, filt)
if isinstance(generated_psf, galsim.GSObject):
psf.append(generated_psf)
else:
Expand Down
10 changes: 5 additions & 5 deletions btk/metrics/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def _get_data(
) -> Dict[str, np.ndarray]:
assert seg1.shape == seg2.shape
assert seg1.ndim == 4 # batch, max_n_sources, x, y
ious = np.zeros(self.batch_size)
ious = np.full((self.batch_size, seg1.shape[1]), fill_value=np.nan)
for ii in range(self.batch_size):
n_sources = np.sum(~np.isnan(seg1[ii].sum(axis=(-1, -2))))
n_sources1 = np.sum(np.sum(seg1[ii], axis=(-1, -2)) > 0)
n_sources2 = np.sum(np.sum(seg2[ii], axis=(-1, -2)) > 0)
n_sources = min(n_sources1, n_sources2)
if n_sources > 0:
seg1_ii = seg1[ii, :n_sources]
seg2_ii = seg2[ii, :n_sources]
ious[ii] = iou(seg1_ii, seg2_ii)
else:
ious[ii] = np.nan
ious[ii, :n_sources] = iou(seg1_ii, seg2_ii)

return {"iou": ious}

Expand Down
2 changes: 1 addition & 1 deletion btk/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def iou(seg1: np.ndarray, seg2: np.ndarray) -> np.ndarray:
"""
assert not np.any(np.isnan(seg1)) and not np.any(np.isnan(seg2))
seg1 = seg1.astype(bool)
seg2 = seg1.astype(bool)
seg2 = seg2.astype(bool)
i = np.logical_and(seg1, seg2).sum(axis=(-1, -2))
u = np.logical_or(seg1, seg2).sum(axis=(-1, -2))
return i / u
Expand Down
2 changes: 1 addition & 1 deletion btk/sampling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ class PairSampling(SamplingFunction):
def __init__(
self,
stamp_size: float = 24.0,
max_shift: float = Optional[None],
max_shift: Optional[float] = None,
mag_name: str = "i_ab",
seed: int = DEFAULT_SEED,
bright_cut: float = 25.3,
Expand Down
5 changes: 2 additions & 3 deletions btk/survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ def get_surveys(
for band in survey.available_filters:
filtr = survey.get_filter(band)
if psf_func is None:
psf = _get_default_psf_with_galcheat_info(survey, filtr)
filtr.psf = _get_default_psf_with_galcheat_info(survey, filtr)
else:
psf = psf_func(survey, filtr)
filtr.psf = psf
filtr.psf = psf_func
surveys.append(survey)

if len(surveys) == 1:
Expand Down
43 changes: 22 additions & 21 deletions notebooks/00-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"from tqdm import tqdm "
"from tqdm import tqdm"
]
},
{
Expand Down Expand Up @@ -147,7 +147,7 @@
"seed = 0 # random seed for reproducibility purposes # seed = 8\n",
"sampling_function = btk.sampling_functions.DefaultSampling(\n",
" max_number=max_number, min_number=max_number, # always get `max_number` galaxies\n",
" stamp_size=stamp_size, max_shift=max_shift, \n",
" stamp_size=stamp_size, max_shift=max_shift,\n",
" min_mag = 24, max_mag = 25,\n",
" seed = seed)"
]
Expand Down Expand Up @@ -314,7 +314,7 @@
}
],
"source": [
"blend_batch.blend_images.shape \n",
"blend_batch.blend_images.shape\n",
"# shape = (batch_size, n_bands, stamp_size, stamp_size)"
]
},
Expand Down Expand Up @@ -359,7 +359,7 @@
}
],
"source": [
"blend_batch.catalog_list[0] \n",
"blend_batch.catalog_list[0]\n",
"# blend_list is a list of astropy tables, one for each blend in the batch."
]
},
Expand Down Expand Up @@ -437,7 +437,7 @@
"plt.imshow(blend_batch.blend_images[0, 2, :, :], cmap=\"gray\")\n",
"\n",
"# plot centers\n",
"plt.scatter(blend_batch.catalog_list[0][\"x_peak\"], \n",
"plt.scatter(blend_batch.catalog_list[0][\"x_peak\"],\n",
" blend_batch.catalog_list[0][\"y_peak\"], c=\"r\", marker=\"x\")"
]
},
Expand Down Expand Up @@ -478,7 +478,7 @@
"from astropy.visualization import make_lupton_rgb\n",
"\n",
"im = blend_batch.blend_images[0]\n",
"bands = [1, 2, 3] # g, r, i \n",
"bands = [1, 2, 3] # g, r, i\n",
"stretch = np.max(im) - np.min(im)\n",
"Q = 0.1\n",
"\n",
Expand Down Expand Up @@ -750,11 +750,10 @@
"plt.imshow(blend_batch.blend_images[0, 2, :, :], cmap=\"gray\")\n",
"\n",
"# plot centers of truth\n",
"plt.scatter(blend_batch.catalog_list[0][\"x_peak\"], blend_batch.catalog_list[0][\"y_peak\"], c=\"r\", marker=\"x\")\n",
"plt.scatter(blend_batch.catalog_list[0][\"x_peak\"],\n",
" blend_batch.catalog_list[0][\"y_peak\"], c=\"r\", marker=\"x\")\n",
"\n",
"# plot centers of prediction\n",
"\n",
"# need to use wcs to convert ra and dec to x and y\n",
"x, y = deblend_batch.catalog_list[0]['x_peak'], deblend_batch.catalog_list[0]['y_peak']\n",
"plt.scatter(x, y, c=\"b\", marker=\"+\", s=50)"
]
Expand All @@ -765,7 +764,7 @@
"source": [
"We can also inspect the deblended images and compare with the isolated ones. \n",
"\n",
"**Note:** Images do not necessarily line up in general as below, need to match (see next section)"
"**Note:** Images do not necessarily line up in general, need to match (see next section)"
]
},
{
Expand Down Expand Up @@ -795,7 +794,7 @@
}
],
"source": [
"# plot isolated images from truth \n",
"# plot isolated images from truth\n",
"fig, axes = plt.subplots(2, 3, figsize=(12, 8))\n",
"\n",
"clip = 40 # zoom into images\n",
Expand All @@ -809,19 +808,21 @@
"\n",
"for ii in range(3):\n",
" ax = axes.flatten()[ii]\n",
" ax.imshow(blend_batch.isolated_images[0, ii, 2, clip:-clip, clip:-clip], cmap=\"gray\", \n",
" ax.imshow(blend_batch.isolated_images[0, ii, 2, clip:-clip, clip:-clip], cmap=\"gray\",\n",
" vmin=vmin, vmax=vmax)\n",
"\n",
"# plot isolated images from prediction\n",
"for ii in range(3, 6):\n",
" ax = axes.flatten()[ii]\n",
" ax.imshow(deblend_batch.deblended_images[0, ii-3, 0, clip:-clip, clip:-clip], cmap=\"gray\", \n",
" ax.imshow(deblend_batch.deblended_images[0, ii-3, 0, clip:-clip, clip:-clip], cmap=\"gray\",\n",
" vmin=vmin, vmax=vmax)\n",
"\n",
"# add colorbar\n",
"fig.subplots_adjust(right=0.8)\n",
"cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n",
"fig.colorbar(axes[0, 0].imshow(blend_batch.isolated_images[0, 0, 2, clip:-clip, clip:-clip], cmap=\"gray\", vmin=vmin, vmax=vmax), cax=cbar_ax)"
"fig.colorbar(axes[0, 0].imshow(blend_batch.isolated_images[0, 0, 2, clip:-clip, clip:-clip],\n",
" cmap=\"gray\", vmin=vmin, vmax=vmax),\n",
" cax=cbar_ax)"
]
},
{
Expand Down Expand Up @@ -863,7 +864,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Before matching, we need to add pixel centroids to the catalogs of `deblend_batch`. As it currently only contains `ra` and `dec`. We can do it easily with the utility function `btk.utils.add_pixel_columns`:"
"Note that for the `SepSingleBand` deblender, the catalogs already contain an `x_peak` and `y_peak` columns. The `PixelHungarianMatcher` requires these columns, otherwise they would need to be added separately."
]
},
{
Expand Down Expand Up @@ -892,7 +893,7 @@
"metadata": {},
"outputs": [],
"source": [
"# matchers operate on `catalog_lists`, so we need to extract them from our batch classes. \n",
"# matchers operate on `catalog_lists`, so we need to extract them from our batch classes.\n",
"true_catalog_list = blend_batch.catalog_list\n",
"pred_catalog_list = deblend_batch.catalog_list\n",
"matching = matcher(true_catalog_list, pred_catalog_list) # matching object"
Expand Down Expand Up @@ -922,10 +923,10 @@
}
],
"source": [
"(matching.true_matches[0], # index of each true source that is matched with predicted \n",
"(matching.true_matches[0], # index of each true source that is matched with predicted\n",
" matching.pred_matches[0], #index of each pred source that is matched with truth\n",
" matching.n_true[0], # number of truth total\n",
" matching.n_pred[0], # number of predicted total \n",
" matching.n_pred[0], # number of predicted total\n",
")"
]
},
Expand Down Expand Up @@ -1175,7 +1176,7 @@
"source": [
"# for reconstruction and segmentation, need to match first and then apply metric\n",
"iso_images1 = blend_batch.isolated_images[:, :, 2] # only r-band\n",
"iso_images2 = deblend_batch.deblended_images[:, :, 0] \n",
"iso_images2 = deblend_batch.deblended_images[:, :, 0]\n",
"iso_images_matched1 = matching.match_true_arrays(iso_images1)\n",
"iso_images_matched2 = matching.match_pred_arrays(iso_images2)"
]
Expand Down Expand Up @@ -1312,7 +1313,7 @@
"ellips1 = get_ksb_ellipticity(iso_images_matched1, psf_r, pixel_scale=0.2)\n",
"\n",
"# NOTE: assumed deblended images are psf convolved with same psf\n",
"ellips2 = get_ksb_ellipticity(iso_images_matched2, psf_r, pixel_scale=0.2) \n",
"ellips2 = get_ksb_ellipticity(iso_images_matched2, psf_r, pixel_scale=0.2)\n",
"\n",
"\n",
"# mask nan's (non-matches), look at first component only\n",
Expand Down Expand Up @@ -1351,7 +1352,7 @@
}
],
"source": [
"# look at first component \n",
"# look at first component\n",
"plt.scatter(e11, e12)\n",
"plt.plot([-1, 1], [-1, 1])"
]
Expand Down

0 comments on commit 98ed3e5

Please sign in to comment.