diff --git a/src/aspire/source/__init__.py b/src/aspire/source/__init__.py index db3e955e78..55afda4833 100644 --- a/src/aspire/source/__init__.py +++ b/src/aspire/source/__init__.py @@ -1,7 +1,7 @@ import logging from aspire.source.coordinates import BoxesCoordinateSource, CentersCoordinateSource -from aspire.source.image import ArrayImageSource, ImageSource +from aspire.source.image import ArrayImageSource, ImageSource, IndexedSource from aspire.source.relion import RelionSource from aspire.source.simulation import Simulation diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 6a176c688e..cf5c5561a8 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -162,6 +162,22 @@ def __init__(self, L, n, dtype="double", metadata=None, memory=None): logger.info(f"Creating {self.__class__.__name__} with {len(self)} images.") + def __getitem__(self, indices): + """ + Check `indices` and return slice of current Source as a new + Source. + + Internally uses `IndexedSource`. + + :param indices: Requested indices as a Python slice object, + 1-D NumPy array, list, or a single integer. Slices default + to a start of 0, an end of self.num_imgs, and a step of 1. + See _ImageAccessor. + :return: Source composed of the images and metadata at `indices`. + """ + + return IndexedSource(self, indices) + @property def n_ctf_filters(self): """ @@ -311,17 +327,29 @@ def has_metadata(self, metadata_fields): metadata_fields = [metadata_fields] return all(f in self._metadata.columns for f in metadata_fields) - def get_metadata(self, metadata_fields, indices=None, default_value=None): - """ - Get metadata field information of this ImageSource for selected indices + def get_metadata(self, metadata_fields=None, indices=None, default_value=None): + """ + Get metadata field information of this ImageSource for a + selection of fields of indices. The default should return the + entire metadata table. + + :param metadata_fields: A string, or list of strings, + representing the metadata field(s) to be queried. + Defaults to None, which yields all populated columns. + :param indices: A list of 0-based indices indicating the + indices for which to get metadata. If indices is None, + then values corresponding to all indices in this Source + object are returned. + :param default_value: Default scalar value to use for any + fields not found in the metadata. If None, no default + value is used, and missing field(s) cause a RuntimeError. + :return: An ndarray of values (any valid np types) + representing metadata info. + """ + # When metadata_fields=None, default to returning all. + if metadata_fields is None: + metadata_fields = self._metadata.columns - :param metadata_fields: A string, of list of strings, representing the metadata field(s) to be queried. - :param indices: A list of 0-based indices indicating the indices for which to get metadata. - If indices is None, then values corresponding to all indices in this Source object are returned. - :param default_value: Default scalar value to use for any fields not found in the metadata. If None, - no default value is used, and missing field(s) cause a RuntimeError. - :return: An ndarray of values (any valid np types) representing metadata info. - """ if isinstance(metadata_fields, str): metadata_fields = [metadata_fields] if indices is None: @@ -803,6 +831,63 @@ def save_images( im.save(mrcs_filepath, overwrite=overwrite) +class IndexedSource(ImageSource): + """ + Map into another into ImageSource. + """ + + def __init__(self, src, indices, memory=None): + """ + Instantiates a new source along given `indices`. + + :param src: ImageSource to be used as the source. + :param index_map: index_map + :param memory: str or None + The path of the base directory to use as a data store or + None. If None is given, no caching is performed. + """ + + self.src = src + if not isinstance(src, ImageSource): + raise TypeError(f"Input src {src} must be an ImageSource.") + + # `_ImageAccessor` performs checking and slicing logic. + # `index_map` sequence forms a natural map from the "new" source -> "self". + # Example, if request=slice(500,1000), + # then new_src[0] ~> old_src[500]; index_map[0] = 500. + self.index_map = _ImageAccessor(lambda x: x, src.n)[indices] + + # Get all the metadata associated with these indices. + # Note, I would have prefered to use our API (get_metadata) + # here, but it returns a Numpy array, which would need to be + # converted back into Pandas for use below. So here we'll just + # use `loc` to return a dataframe. + metadata = self.src._metadata.loc[self.index_map] + + # Construct a fully formed ImageSource with this metadata + super().__init__( + L=src.L, + n=len(self.index_map), + dtype=src.dtype, + metadata=metadata, + memory=memory, + ) + + def _images(self, indices): + """ + Returns images from `self.src` corresponding to `indices` + remapped by `self.index_map`. + + :param indices: A 1-D NumPy array of indices. + :return: An `Image` object. + """ + mapped_indices = self.index_map[indices] + return self.src.images[mapped_indices] + + def __repr__(self): + return f"{self.__class__.__name__} mapping {self.n} of {self.src.n} indices from {self.src.__class__.__name__}." + + class ArrayImageSource(ImageSource): """ An `ImageSource` object that holds a reference to an underlying `Image` object (a thin wrapper on an ndarray) diff --git a/tests/test_indexed_source.py b/tests/test_indexed_source.py new file mode 100644 index 0000000000..30a23ee16b --- /dev/null +++ b/tests/test_indexed_source.py @@ -0,0 +1,44 @@ +import logging + +import numpy as np +import pytest + +from aspire.source import Simulation + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def sim_fixture(): + """ + Generate a very small simulation and slice it. + """ + sim = Simulation(L=8, n=10, C=1) + sim2 = sim[0::2] # Slice the evens + return sim, sim2 + + +def test_remapping(sim_fixture): + sim, sim2 = sim_fixture + + # Check images are served correctly, using internal index. + assert np.allclose(sim.images[sim2.index_map].asnumpy(), sim2.images[:].asnumpy()) + + # Check images are served correctly, using known index (evens). + index = list(range(0, sim.n, 2)) + assert np.allclose(sim.images[index].asnumpy(), sim2.images[:].asnumpy()) + + # Check meta is served correctly. + assert np.all(sim.get_metadata(indices=sim2.index_map) == sim2.get_metadata()) + + +def test_repr(sim_fixture): + sim, sim2 = sim_fixture + + logger.debug(f"repr(IndexedSource): {repr(sim2)}") + + # Check `sim` is mentioned in the repr + assert type(sim).__name__ in repr(sim2) + + # Check index counts are mentioned in the repr + assert f"{sim2.n} of {sim.n}" in repr(sim2) diff --git a/tests/test_simulation_metadata.py b/tests/test_simulation_metadata.py index eab76225d1..0d208297e2 100644 --- a/tests/test_simulation_metadata.py +++ b/tests/test_simulation_metadata.py @@ -82,3 +82,33 @@ def testMetadata5(self): equal_nan=True, ) ) + + def test_get_metadata_all(self): + """ + Test we can get the entire metadata table. + """ + + # Get the metadata via our API. + metadata_api = self.sim.get_metadata() + + # Access the metadata directly in the frame. + metadata_df = self.sim._metadata.to_numpy() + + # Assert we've returned the entire table. + self.assertTrue(np.all(metadata_api == metadata_df)) + + def test_get_metadata_index_slice(self): + """ + Test we can get all columns for a selection of rows. + """ + # Test rows + rows = [0, 1, 42] + + # Get the metadata from our API. + metadata_api = self.sim.get_metadata(indices=rows) + + # Access the metadata directly in the frame. + metadata_df = self.sim._metadata.loc[rows].to_numpy() + + # Assert we've returned the rows + self.assertTrue(np.all(metadata_api == metadata_df))