Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,8 @@ def downsample(self, L):
ds_factor = self.L / L
self.unique_filters = [f.scale(ds_factor) for f in self.unique_filters]
self.offsets /= ds_factor
if self.pixel_size is not None:
self.pixel_size *= ds_factor

self.L = L

Expand Down Expand Up @@ -1519,10 +1521,19 @@ def __init__(self, src, indices, memory=None):
pixel_size=src.pixel_size,
)

# Create filter indices, these are required to pass unharmed through filter eval code
# that is potentially called by other methods later.
self.filter_indices = np.zeros(self.n, dtype=int)
self.unique_filters = [IdentityFilter()]
if src.unique_filters:
# Remap the filter indices to be unique.
# Removes duplicates and filters that are unused in new source.
_filter_indices = src.filter_indices[self.index_map]
# _unq[_inv] reconstructs _filter_indices
_unq, _inv = np.unique(_filter_indices, return_inverse=True)
# Repack unique_filters
self.filter_indices = _inv
self.unique_filters = [copy.copy(src.unique_filters[i]) for i in _unq]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that see don't necessarily need to be copied, but it should be less risky in case they are mutated.

else:
# Pass through the None case
self.unique_filters = src.unique_filters
self.filter_indices = np.zeros(self.n, dtype=int)

# Any further operations should not mutate this instance.
self._mutable = False
Expand Down
157 changes: 156 additions & 1 deletion tests/test_indexed_source.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging
import os
import tempfile

import numpy as np
import pytest

from aspire.source import Simulation
from aspire.downloader import emdb_8012
from aspire.operators import CTFFilter
from aspire.source import RelionSource, Simulation
from aspire.utils import Rotation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,3 +54,153 @@ def test_repr(sim_fixture):

# Check index counts are mentioned in the repr
assert f"{sim2.n} of {sim.n}" in repr(sim2)


@pytest.mark.expensive
def test_filter_mapping():
"""
This test is designed to ensure that `unique_filters` and `filter_indices`
are being remapped correctly upon slicing.

Additionally it tests that a realistic preprocessing pipeline is equivalent
and can be saved then reloaded.
"""

# Generate N projection images,
# using N//2 rotations and
# N//2 ctf filters such that images[0::2] == images[1::2].
N = 100
SEED = 1234
DT = np.float64
DS = 129

v = emdb_8012().astype(DT)

# Generate N//2 rotations
rots = Rotation.generate_random_rotations(N // 2, dtype=DT, seed=SEED)
angles = Rotation(np.repeat(rots, 2, axis=0)).angles

# Generate N//2 rotations and repeat indices
defoci = np.linspace(1000, 25000, N // 2)
ctf_filters = [
CTFFilter(
v.pixel_size,
200,
defocus_u=defoci[d],
defocus_v=defoci[-d],
defocus_ang=np.pi / (N // 2) * d,
Cs=2.0,
alpha=0.1,
)
for d in range(N // 2)
]
ctf_indices = np.repeat(np.arange(N // 2), 2)

# Construct the source
src = Simulation(
vols=v,
n=N,
dtype=DT,
seed=SEED,
unique_filters=ctf_filters,
filter_indices=ctf_indices,
angles=angles,
offsets=0,
amplitudes=1,
).cache()

srcA = src[0::2]
srcB = src[1::2]

# Sanity check the images before proceeding
np.testing.assert_allclose(srcA.images[:], src.images[0::2])
np.testing.assert_allclose(srcB.images[:], src.images[1::2])
# Confirm the intention of the test
np.testing.assert_allclose(srcB.images[:], srcA.images[:])

# Preprocess the `src` stack
pp = (
src.phase_flip()
.downsample(DS)
.normalize_background()
.legacy_whiten()
.invert_contrast()
.cache()
)

# Preprocess the indexed sources
ppA = (
srcA.phase_flip()
.downsample(DS)
.normalize_background()
.legacy_whiten()
.invert_contrast()
.cache()
)
ppB = (
srcB.phase_flip()
.downsample(DS)
.normalize_background()
.legacy_whiten()
.invert_contrast()
.cache()
)

# Confirm we match the original images
np.testing.assert_allclose(ppA.images[:], pp.images[0::2], atol=1e-6)
np.testing.assert_allclose(ppB.images[:], pp.images[1::2], atol=1e-6)
# Confirm A and B are equivalent
np.testing.assert_allclose(ppB.images[:], ppA.images[:], atol=1e-6)

# Create a tmp dir for this test output
with tempfile.TemporaryDirectory() as tmpdir_name:
# Save the initial images
src.save(os.path.join(tmpdir_name, "src.star"))
srcA.save(os.path.join(tmpdir_name, "srcA.star"))
srcB.save(os.path.join(tmpdir_name, "srcB.star"))

# Save the preprocessed images.
pp.save(os.path.join(tmpdir_name, "pp.star"))
ppA.save(os.path.join(tmpdir_name, "ppA.star"))
ppB.save(os.path.join(tmpdir_name, "ppB.star"))

# Reload, assigning `pixel_size`.
_src = RelionSource(
os.path.join(tmpdir_name, "src.star"), pixel_size=src.pixel_size
)
_srcA = RelionSource(
os.path.join(tmpdir_name, "srcA.star"), pixel_size=srcA.pixel_size
)
_srcB = RelionSource(
os.path.join(tmpdir_name, "srcB.star"), pixel_size=srcB.pixel_size
)
_pp = RelionSource(
os.path.join(tmpdir_name, "pp.star"), pixel_size=pp.pixel_size
)
_ppA = RelionSource(
os.path.join(tmpdir_name, "ppA.star"), pixel_size=ppA.pixel_size
)
_ppB = RelionSource(
os.path.join(tmpdir_name, "ppB.star"), pixel_size=ppB.pixel_size
)

# Confirm reloaded sources match the source it was saved from.
# This implies the equalities in the next section translate
# to the original source as well
# Ideally, _if everything is working_, many of these are redundant.
# If something fails to work, they may help pinpoint the fault.
np.testing.assert_allclose(_src.images[:], src.images[:])
np.testing.assert_allclose(_srcA.images[:], srcA.images[:])
np.testing.assert_allclose(_srcB.images[:], srcB.images[:])
np.testing.assert_allclose(_pp.images[:], pp.images[:], atol=1e-6)
np.testing.assert_allclose(_ppA.images[:], ppA.images[:], atol=1e-6)
np.testing.assert_allclose(_ppB.images[:], ppB.images[:], atol=1e-6)

# Confirm reloading slices matches the reloading saved stack of images.
np.testing.assert_allclose(_srcA.images[:], _src.images[0::2])
np.testing.assert_allclose(_srcB.images[:], _src.images[1::2])
np.testing.assert_allclose(_ppA.images[:], _pp.images[0::2], atol=1e-5)
np.testing.assert_allclose(_ppB.images[:], _pp.images[1::2], atol=1e-5)
# Confirm A and B are still equivalent
np.testing.assert_allclose(_srcB.images[:], _srcA.images[:])
np.testing.assert_allclose(_ppB.images[:], _ppA.images[:])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1236 for additional tests that might fit here.

Loading