Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/aspire/source/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from aspire.source.coordinates import BoxesCoordinateSource, CentersCoordinateSource
from aspire.source.image import ArrayImageSource, ImageSource
from aspire.source.image import ArrayImageSource, ImageSource, IndexedSource
from aspire.source.relion import RelionSource
from aspire.source.simulation import Simulation

Expand Down
105 changes: 95 additions & 10 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def __init__(self, L, n, dtype="double", metadata=None, memory=None):

logger.info(f"Creating {self.__class__.__name__} with {len(self)} images.")

def __getitem__(self, indices):
"""
Check `indices` and return slice of current Source as a new
Source.

Internally uses `IndexedSource`.

:param indices: Requested indices as a Python slice object,
1-D NumPy array, list, or a single integer. Slices default
to a start of 0, an end of self.num_imgs, and a step of 1.
See _ImageAccessor.
:return: Source composed of the images and metadata at `indices`.
"""

return IndexedSource(self, indices)

@property
def n_ctf_filters(self):
"""
Expand Down Expand Up @@ -311,17 +327,29 @@ def has_metadata(self, metadata_fields):
metadata_fields = [metadata_fields]
return all(f in self._metadata.columns for f in metadata_fields)

def get_metadata(self, metadata_fields, indices=None, default_value=None):
"""
Get metadata field information of this ImageSource for selected indices
def get_metadata(self, metadata_fields=None, indices=None, default_value=None):
"""
Get metadata field information of this ImageSource for a
selection of fields of indices. The default should return the
entire metadata table.

:param metadata_fields: A string, or list of strings,
representing the metadata field(s) to be queried.
Defaults to None, which yields all populated columns.
:param indices: A list of 0-based indices indicating the
indices for which to get metadata. If indices is None,
then values corresponding to all indices in this Source
object are returned.
:param default_value: Default scalar value to use for any
fields not found in the metadata. If None, no default
value is used, and missing field(s) cause a RuntimeError.
:return: An ndarray of values (any valid np types)
representing metadata info.
"""
# When metadata_fields=None, default to returning all.
if metadata_fields is None:
metadata_fields = self._metadata.columns

:param metadata_fields: A string, of list of strings, representing the metadata field(s) to be queried.
:param indices: A list of 0-based indices indicating the indices for which to get metadata.
If indices is None, then values corresponding to all indices in this Source object are returned.
:param default_value: Default scalar value to use for any fields not found in the metadata. If None,
no default value is used, and missing field(s) cause a RuntimeError.
:return: An ndarray of values (any valid np types) representing metadata info.
"""
if isinstance(metadata_fields, str):
metadata_fields = [metadata_fields]
if indices is None:
Expand Down Expand Up @@ -803,6 +831,63 @@ def save_images(
im.save(mrcs_filepath, overwrite=overwrite)


class IndexedSource(ImageSource):
"""
Map into another into ImageSource.
"""

def __init__(self, src, indices, memory=None):
"""
Instantiates a new source along given `indices`.

:param src: ImageSource to be used as the source.
:param index_map: index_map
:param memory: str or None
The path of the base directory to use as a data store or
None. If None is given, no caching is performed.
"""

self.src = src
if not isinstance(src, ImageSource):
raise TypeError(f"Input src {src} must be an ImageSource.")

# `_ImageAccessor` performs checking and slicing logic.
# `index_map` sequence forms a natural map from the "new" source -> "self".
# Example, if request=slice(500,1000),
# then new_src[0] ~> old_src[500]; index_map[0] = 500.
self.index_map = _ImageAccessor(lambda x: x, src.n)[indices]

# Get all the metadata associated with these indices.
# Note, I would have prefered to use our API (get_metadata)
# here, but it returns a Numpy array, which would need to be
# converted back into Pandas for use below. So here we'll just
# use `loc` to return a dataframe.
metadata = self.src._metadata.loc[self.index_map]

# Construct a fully formed ImageSource with this metadata
super().__init__(
L=src.L,
n=len(self.index_map),
dtype=src.dtype,
metadata=metadata,
memory=memory,
)

def _images(self, indices):
"""
Returns images from `self.src` corresponding to `indices`
remapped by `self.index_map`.

:param indices: A 1-D NumPy array of indices.
:return: An `Image` object.
"""
mapped_indices = self.index_map[indices]
return self.src.images[mapped_indices]

def __repr__(self):
return f"{self.__class__.__name__} mapping {self.n} of {self.src.n} indices from {self.src.__class__.__name__}."


class ArrayImageSource(ImageSource):
"""
An `ImageSource` object that holds a reference to an underlying `Image` object (a thin wrapper on an ndarray)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_indexed_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

import numpy as np
import pytest

from aspire.source import Simulation

logger = logging.getLogger(__name__)


@pytest.fixture
def sim_fixture():
"""
Generate a very small simulation and slice it.
"""
sim = Simulation(L=8, n=10, C=1)
sim2 = sim[0::2] # Slice the evens
return sim, sim2


def test_remapping(sim_fixture):
sim, sim2 = sim_fixture

# Check images are served correctly, using internal index.
assert np.allclose(sim.images[sim2.index_map].asnumpy(), sim2.images[:].asnumpy())

# Check images are served correctly, using known index (evens).
index = list(range(0, sim.n, 2))
assert np.allclose(sim.images[index].asnumpy(), sim2.images[:].asnumpy())

# Check meta is served correctly.
assert np.all(sim.get_metadata(indices=sim2.index_map) == sim2.get_metadata())


def test_repr(sim_fixture):
sim, sim2 = sim_fixture

logger.debug(f"repr(IndexedSource): {repr(sim2)}")

# Check `sim` is mentioned in the repr
assert type(sim).__name__ in repr(sim2)

# Check index counts are mentioned in the repr
assert f"{sim2.n} of {sim.n}" in repr(sim2)
30 changes: 30 additions & 0 deletions tests/test_simulation_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,33 @@ def testMetadata5(self):
equal_nan=True,
)
)

def test_get_metadata_all(self):
"""
Test we can get the entire metadata table.
"""

# Get the metadata via our API.
metadata_api = self.sim.get_metadata()

# Access the metadata directly in the frame.
metadata_df = self.sim._metadata.to_numpy()

# Assert we've returned the entire table.
self.assertTrue(np.all(metadata_api == metadata_df))

def test_get_metadata_index_slice(self):
"""
Test we can get all columns for a selection of rows.
"""
# Test rows
rows = [0, 1, 42]

# Get the metadata from our API.
metadata_api = self.sim.get_metadata(indices=rows)

# Access the metadata directly in the frame.
metadata_df = self.sim._metadata.loc[rows].to_numpy()

# Assert we've returned the rows
self.assertTrue(np.all(metadata_api == metadata_df))