Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions gallery/tutorials/pipeline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,16 @@
# We apply ``phase_flip()`` to correct for CTF effects.

src = src.phase_flip()
src.images[0:10].show()

# %%
# Cache
# -----
# We apply ``cache`` to store the results of the ``ImageSource``
# pipeline up to this point. This is optional, but can provide
# benefit when used intently on machines with adequate memory.

src = src.cache()
src.images[0:10].show()

# %%
# Class Averaging
Expand Down Expand Up @@ -172,8 +180,14 @@
classifier=rir,
)

# We'll continue our pipeline with the first ``n_classes`` from ``avgs``.
avgs = avgs[:n_classes]
# We'll continue our pipeline using only the first ``n_classes`` from
# ``avgs``. The ``cache()`` call is used here to precompute results
# for the ``:n_classes`` slice. This avoids recomputing the same
# images twice when peeking in the next cell then requesting them in
# the following ``CLSyncVoting`` algorithm. Outside of demonstration
# purposes, where we are repeatedly peeking at various stage results,
# such caching can be dropped allowing for more lazy evaluation.
avgs = avgs[:n_classes].cache()


# %%
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/denoising/class_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _images(self, indices):
# Check if this src cached images.
if self._cached_im is not None:
logger.debug(f"Loading {len(indices)} images from image cache")
im = Image(self._cached_im[indices, :, :])
im = self._cached_im[indices, :, :]

# Check for heap cached image sets from class_selector.
elif heap_inds:
Expand Down
20 changes: 12 additions & 8 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,14 +1498,18 @@ def _images(self, indices):
:param indices: A 1-D NumPy array of indices.
:return: An `Image` object.
"""
mapped_indices = self.index_map[indices]
# Load previous source image data and apply any transforms
# belonging to this IndexedSource. Note the previous source
# requires remapped indices, while the current source uses the
# `indices` arg directly.
return self.generation_pipeline.forward(
self.src.images[mapped_indices], indices
)

if self._cached_im is not None:
im = self._cached_im[indices]
else:
mapped_indices = self.index_map[indices]
# Load previous source image data and apply any transforms
# belonging to this IndexedSource. Note the previous source
# requires remapped indices, while the current source uses the
# `indices` arg directly.
im = self.src.images[mapped_indices]

return self.generation_pipeline.forward(im, indices)

def __repr__(self):
return f"{self.__class__.__name__} mapping {self.n} of {self.src.n} indices from {self.src.__class__.__name__}."
Expand Down