diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 31146317bb..8910436de2 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -132,8 +132,16 @@ # We apply ``phase_flip()`` to correct for CTF effects. src = src.phase_flip() -src.images[0:10].show() +# %% +# Cache +# ----- +# We apply ``cache`` to store the results of the ``ImageSource`` +# pipeline up to this point. This is optional, but can provide +# benefit when used intently on machines with adequate memory. + +src = src.cache() +src.images[0:10].show() # %% # Class Averaging @@ -172,8 +180,14 @@ classifier=rir, ) -# We'll continue our pipeline with the first ``n_classes`` from ``avgs``. -avgs = avgs[:n_classes] +# We'll continue our pipeline using only the first ``n_classes`` from +# ``avgs``. The ``cache()`` call is used here to precompute results +# for the ``:n_classes`` slice. This avoids recomputing the same +# images twice when peeking in the next cell then requesting them in +# the following ``CLSyncVoting`` algorithm. Outside of demonstration +# purposes, where we are repeatedly peeking at various stage results, +# such caching can be dropped allowing for more lazy evaluation. +avgs = avgs[:n_classes].cache() # %% diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index c9d3f7dade..586e0a08a9 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -284,7 +284,7 @@ def _images(self, indices): # Check if this src cached images. if self._cached_im is not None: logger.debug(f"Loading {len(indices)} images from image cache") - im = Image(self._cached_im[indices, :, :]) + im = self._cached_im[indices, :, :] # Check for heap cached image sets from class_selector. elif heap_inds: diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 4f4a499a0a..afaf26fc23 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1498,14 +1498,18 @@ def _images(self, indices): :param indices: A 1-D NumPy array of indices. :return: An `Image` object. """ - mapped_indices = self.index_map[indices] - # Load previous source image data and apply any transforms - # belonging to this IndexedSource. Note the previous source - # requires remapped indices, while the current source uses the - # `indices` arg directly. - return self.generation_pipeline.forward( - self.src.images[mapped_indices], indices - ) + + if self._cached_im is not None: + im = self._cached_im[indices] + else: + mapped_indices = self.index_map[indices] + # Load previous source image data and apply any transforms + # belonging to this IndexedSource. Note the previous source + # requires remapped indices, while the current source uses the + # `indices` arg directly. + im = self.src.images[mapped_indices] + + return self.generation_pipeline.forward(im, indices) def __repr__(self): return f"{self.__class__.__name__} mapping {self.n} of {self.src.n} indices from {self.src.__class__.__name__}."