From 06617f27d005128c9f1c21369cc85bf9af2468c6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 13 Mar 2023 15:02:03 -0400 Subject: [PATCH 1/3] Initial addition of RemappedSource --- src/aspire/source/__init__.py | 2 +- src/aspire/source/image.py | 105 +++++++++++++++++++++++++++--- tests/test_remapped_source.py | 40 ++++++++++++ tests/test_simulation_metadata.py | 30 +++++++++ 4 files changed, 166 insertions(+), 11 deletions(-) create mode 100644 tests/test_remapped_source.py diff --git a/src/aspire/source/__init__.py b/src/aspire/source/__init__.py index db3e955e78..494bba6ccf 100644 --- a/src/aspire/source/__init__.py +++ b/src/aspire/source/__init__.py @@ -1,7 +1,7 @@ import logging from aspire.source.coordinates import BoxesCoordinateSource, CentersCoordinateSource -from aspire.source.image import ArrayImageSource, ImageSource +from aspire.source.image import ArrayImageSource, ImageSource, RemappedSource from aspire.source.relion import RelionSource from aspire.source.simulation import Simulation diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 6a176c688e..01881503ef 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -162,6 +162,22 @@ def __init__(self, L, n, dtype="double", metadata=None, memory=None): logger.info(f"Creating {self.__class__.__name__} with {len(self)} images.") + def __getitem__(self, indices): + """ + Check `indices` and return slice of current Source as a new + Source. + + Internally uses `RemappedSource`. + + :param indices: Requested indices as a Python slice object, + 1-D NumPy array, list, or a single integer. Slices default + to a start of 0, an end of self.num_imgs, and a step of 1. + See _ImageAccessor. + :return: Source composed of the images and metadata at `indices`. + """ + + return RemappedSource(self, indices) + @property def n_ctf_filters(self): """ @@ -311,17 +327,29 @@ def has_metadata(self, metadata_fields): metadata_fields = [metadata_fields] return all(f in self._metadata.columns for f in metadata_fields) - def get_metadata(self, metadata_fields, indices=None, default_value=None): - """ - Get metadata field information of this ImageSource for selected indices + def get_metadata(self, metadata_fields=None, indices=None, default_value=None): + """ + Get metadata field information of this ImageSource for a + selection of fields of indices. The default should return the + entire metadata table. + + :param metadata_fields: A string, or list of strings, + representing the metadata field(s) to be queried. + Defaults to None, which yields all populated columns. + :param indices: A list of 0-based indices indicating the + indices for which to get metadata. If indices is None, + then values corresponding to all indices in this Source + object are returned. + :param default_value: Default scalar value to use for any + fields not found in the metadata. If None, no default + value is used, and missing field(s) cause a RuntimeError. + :return: An ndarray of values (any valid np types) + representing metadata info. + """ + # When metadata_fields=None, default to returning all. + if metadata_fields is None: + metadata_fields = self._metadata.columns - :param metadata_fields: A string, of list of strings, representing the metadata field(s) to be queried. - :param indices: A list of 0-based indices indicating the indices for which to get metadata. - If indices is None, then values corresponding to all indices in this Source object are returned. - :param default_value: Default scalar value to use for any fields not found in the metadata. If None, - no default value is used, and missing field(s) cause a RuntimeError. - :return: An ndarray of values (any valid np types) representing metadata info. - """ if isinstance(metadata_fields, str): metadata_fields = [metadata_fields] if indices is None: @@ -803,6 +831,63 @@ def save_images( im.save(mrcs_filepath, overwrite=overwrite) +class RemappedSource(ImageSource): + """ + Map into another into ImageSource. + """ + + def __init__(self, src, indices, memory=None): + """ + Instantiates a new source along given `indices`. + + :param src: ImageSource to be used as the source. + :param index_map: index_map + :param memory: str or None + The path of the base directory to use as a data store or + None. If None is given, no caching is performed. + """ + + self.src = src + if not isinstance(src, ImageSource): + raise TypeError(f"Input src {src} must be an ImageSource.") + + # `_ImageAccessor` performs checking and slicing logic. + # `index_map` sequence forms a natural map from the "new" source -> "self". + # Example, if request=slice(500,1000), + # then new_src[0] ~> old_src[500]; index_map[0] = 500. + self.index_map = _ImageAccessor(lambda x: x, src.n)[indices] + + # Get all the metadata associated with these indices. + # Note, I would have prefered to use our API (get_metadata) + # here, but it returns a Numpy array, which would need to be + # converted back into Pandas for use below. So here we'll just + # use `loc` to return a dataframe. + metadata = self.src._metadata.loc[self.index_map] + + # Construct a fully formed ImageSource with this metadata + super().__init__( + L=src.L, + n=len(self.index_map), + dtype=src.dtype, + metadata=metadata, + memory=memory, + ) + + def _images(self, indices): + """ + Returns images from `self.src` corresponding to `indices` + remapped by `self.index_map`. + + :param indices: A 1-D NumPy array of indices. + :return: An `Image` object. + """ + mapped_indices = self.index_map[indices] + return self.src.images[mapped_indices] + + def __repr__(self): + return f"{self.__class__.__name__} mapping {self.n} of {self.src.n} indices from {self.src.__class__.__name__}." + + class ArrayImageSource(ImageSource): """ An `ImageSource` object that holds a reference to an underlying `Image` object (a thin wrapper on an ndarray) diff --git a/tests/test_remapped_source.py b/tests/test_remapped_source.py new file mode 100644 index 0000000000..ac7ee7910f --- /dev/null +++ b/tests/test_remapped_source.py @@ -0,0 +1,40 @@ +import logging + +import numpy as np +import pytest + +from aspire.source import Simulation + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def sim_fixture(): + """ + Generate a very small simulation and slice it. + """ + sim = Simulation(L=8, n=10, C=1) + sim2 = sim[0::2] # Slice the evens + return sim, sim2 + + +def test_remapping(sim_fixture): + sim, sim2 = sim_fixture + + # Check images are served correctly. + assert np.allclose(sim.images[sim2.index_map].asnumpy(), sim2.images[:].asnumpy()) + + # Check meta is served correctly. + assert np.all(sim.get_metadata(indices=sim2.index_map) == sim2.get_metadata()) + + +def test_repr(sim_fixture): + sim, sim2 = sim_fixture + + logger.debug(f"repr(RemappedSource): {repr(sim2)}") + + # Check `sim` is mentioned in the repr + assert type(sim).__name__ in repr(sim2) + + # Check index counts are mentioned in the repr + assert f"{sim2.n} of {sim.n}" in repr(sim2) diff --git a/tests/test_simulation_metadata.py b/tests/test_simulation_metadata.py index eab76225d1..0d208297e2 100644 --- a/tests/test_simulation_metadata.py +++ b/tests/test_simulation_metadata.py @@ -82,3 +82,33 @@ def testMetadata5(self): equal_nan=True, ) ) + + def test_get_metadata_all(self): + """ + Test we can get the entire metadata table. + """ + + # Get the metadata via our API. + metadata_api = self.sim.get_metadata() + + # Access the metadata directly in the frame. + metadata_df = self.sim._metadata.to_numpy() + + # Assert we've returned the entire table. + self.assertTrue(np.all(metadata_api == metadata_df)) + + def test_get_metadata_index_slice(self): + """ + Test we can get all columns for a selection of rows. + """ + # Test rows + rows = [0, 1, 42] + + # Get the metadata from our API. + metadata_api = self.sim.get_metadata(indices=rows) + + # Access the metadata directly in the frame. + metadata_df = self.sim._metadata.loc[rows].to_numpy() + + # Assert we've returned the rows + self.assertTrue(np.all(metadata_api == metadata_df)) From 603f8d04329c936b0ecc18fba7d5cd8254a4a4d1 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 16 Mar 2023 11:59:33 -0400 Subject: [PATCH 2/3] Rename RemappedSource to IndexedSource --- src/aspire/source/__init__.py | 2 +- src/aspire/source/image.py | 6 +++--- tests/test_remapped_source.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aspire/source/__init__.py b/src/aspire/source/__init__.py index 494bba6ccf..55afda4833 100644 --- a/src/aspire/source/__init__.py +++ b/src/aspire/source/__init__.py @@ -1,7 +1,7 @@ import logging from aspire.source.coordinates import BoxesCoordinateSource, CentersCoordinateSource -from aspire.source.image import ArrayImageSource, ImageSource, RemappedSource +from aspire.source.image import ArrayImageSource, ImageSource, IndexedSource from aspire.source.relion import RelionSource from aspire.source.simulation import Simulation diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 01881503ef..cf5c5561a8 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -167,7 +167,7 @@ def __getitem__(self, indices): Check `indices` and return slice of current Source as a new Source. - Internally uses `RemappedSource`. + Internally uses `IndexedSource`. :param indices: Requested indices as a Python slice object, 1-D NumPy array, list, or a single integer. Slices default @@ -176,7 +176,7 @@ def __getitem__(self, indices): :return: Source composed of the images and metadata at `indices`. """ - return RemappedSource(self, indices) + return IndexedSource(self, indices) @property def n_ctf_filters(self): @@ -831,7 +831,7 @@ def save_images( im.save(mrcs_filepath, overwrite=overwrite) -class RemappedSource(ImageSource): +class IndexedSource(ImageSource): """ Map into another into ImageSource. """ diff --git a/tests/test_remapped_source.py b/tests/test_remapped_source.py index ac7ee7910f..75b029ee3b 100644 --- a/tests/test_remapped_source.py +++ b/tests/test_remapped_source.py @@ -31,7 +31,7 @@ def test_remapping(sim_fixture): def test_repr(sim_fixture): sim, sim2 = sim_fixture - logger.debug(f"repr(RemappedSource): {repr(sim2)}") + logger.debug(f"repr(IndexedSource): {repr(sim2)}") # Check `sim` is mentioned in the repr assert type(sim).__name__ in repr(sim2) From a8c54881cddf6c1019c1d297569698d426fa160a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 16 Mar 2023 12:03:24 -0400 Subject: [PATCH 3/3] Update and renamed IndexedSource unit test file --- tests/{test_remapped_source.py => test_indexed_source.py} | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) rename tests/{test_remapped_source.py => test_indexed_source.py} (78%) diff --git a/tests/test_remapped_source.py b/tests/test_indexed_source.py similarity index 78% rename from tests/test_remapped_source.py rename to tests/test_indexed_source.py index 75b029ee3b..30a23ee16b 100644 --- a/tests/test_remapped_source.py +++ b/tests/test_indexed_source.py @@ -21,9 +21,13 @@ def sim_fixture(): def test_remapping(sim_fixture): sim, sim2 = sim_fixture - # Check images are served correctly. + # Check images are served correctly, using internal index. assert np.allclose(sim.images[sim2.index_map].asnumpy(), sim2.images[:].asnumpy()) + # Check images are served correctly, using known index (evens). + index = list(range(0, sim.n, 2)) + assert np.allclose(sim.images[index].asnumpy(), sim2.images[:].asnumpy()) + # Check meta is served correctly. assert np.all(sim.get_metadata(indices=sim2.index_map) == sim2.get_metadata())