From cf14ea5c78e48cd3af6c8bb3961a701a64abc70e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 1 Mar 2024 09:49:32 -0500 Subject: [PATCH 1/4] Use _cached_im when available in IndexedSource --- src/aspire/source/image.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 4f4a499a0a..afaf26fc23 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1498,14 +1498,18 @@ def _images(self, indices): :param indices: A 1-D NumPy array of indices. :return: An `Image` object. """ - mapped_indices = self.index_map[indices] - # Load previous source image data and apply any transforms - # belonging to this IndexedSource. Note the previous source - # requires remapped indices, while the current source uses the - # `indices` arg directly. - return self.generation_pipeline.forward( - self.src.images[mapped_indices], indices - ) + + if self._cached_im is not None: + im = self._cached_im[indices] + else: + mapped_indices = self.index_map[indices] + # Load previous source image data and apply any transforms + # belonging to this IndexedSource. Note the previous source + # requires remapped indices, while the current source uses the + # `indices` arg directly. + im = self.src.images[mapped_indices] + + return self.generation_pipeline.forward(im, indices) def __repr__(self): return f"{self.__class__.__name__} mapping {self.n} of {self.src.n} indices from {self.src.__class__.__name__}." From 9832cd06e3b730c5391dd270e9fcaebfb98d8f71 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 1 Mar 2024 10:31:23 -0500 Subject: [PATCH 2/4] Fix regression bug in class_avg caching, no need to call Image --- src/aspire/denoising/class_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index c9d3f7dade..586e0a08a9 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -284,7 +284,7 @@ def _images(self, indices): # Check if this src cached images. if self._cached_im is not None: logger.debug(f"Loading {len(indices)} images from image cache") - im = Image(self._cached_im[indices, :, :]) + im = self._cached_im[indices, :, :] # Check for heap cached image sets from class_selector. elif heap_inds: From 1ea2a47c7e68263d95cd35c833d27001785eb8a0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 1 Mar 2024 11:15:00 -0500 Subject: [PATCH 3/4] pipeline updates --- gallery/tutorials/pipeline_demo.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 31146317bb..f75f28b08a 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -132,8 +132,15 @@ # We apply ``phase_flip()`` to correct for CTF effects. src = src.phase_flip() -src.images[0:10].show() +# %% +# CTF Correction +# -------------- +# We apply ``cache`` to store the results of the ``ImageSource`` +# pipeline up to this point. + +src = src.cache() +src.images[0:10].show() # %% # Class Averaging @@ -172,8 +179,14 @@ classifier=rir, ) -# We'll continue our pipeline with the first ``n_classes`` from ``avgs``. -avgs = avgs[:n_classes] +# We'll continue our pipeline using only the first ``n_classes`` from +# ``avgs``. The ``cache()`` call is used here to precompute results +# for the ``:n_classes`` slice. This avoids recomputing the same +# images twice when peeking in the next cell then requesting them in +# the following ``CLSyncVoting`` algorithm. Outside of demonstration +# purposes, where we are repeatedly peeking at various stage results, +# such caching can be dropped allowing for more lazy evaluation. +avgs = avgs[:n_classes].cache() # %% From 193b8a9bf01d94267a27746f2397a3ba7ae0458f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 4 Mar 2024 09:31:50 -0500 Subject: [PATCH 4/4] Fixup pipeline cell --- gallery/tutorials/pipeline_demo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index f75f28b08a..8910436de2 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -134,10 +134,11 @@ src = src.phase_flip() # %% -# CTF Correction -# -------------- +# Cache +# ----- # We apply ``cache`` to store the results of the ``ImageSource`` -# pipeline up to this point. +# pipeline up to this point. This is optional, but can provide +# benefit when used intently on machines with adequate memory. src = src.cache() src.images[0:10].show()