Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aspire/denoising/class_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
L=self.averager.src.L,
n=self.averager.src.n,
dtype=self.averager.src.dtype,
symmetry_group=self.src.symmetry_group,
)

# Any further operations should not mutate this instance.
Expand Down
12 changes: 11 additions & 1 deletion tests/test_class_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,14 @@ def class_sim_fixture(dtype, img_size):
# Note using a single volume via C=1 is critical to matching
# alignment without the complexity of remapping via states etc.
src = Simulation(
L=img_size, n=n, vols=v, offsets=0, amplitudes=1, C=1, angles=true_rots.angles
L=img_size,
n=n,
vols=v,
offsets=0,
amplitudes=1,
C=1,
angles=true_rots.angles,
symmetry_group="C4", # For testing symmetry_group pass-through.
)
# Prefetch all the images
src = src.cache()
Expand Down Expand Up @@ -193,6 +200,9 @@ class averages.
k = len(src2.class_indices)
np.testing.assert_equal(src2.class_indices, test_src.class_indices[::3][:k])

# Check symmetry_group pass-through.
assert test_src.symmetry_group == class_sim_fixture.symmetry_group


# Test the _HeapItem helper class
def test_heap_helper():
Expand Down
5 changes: 4 additions & 1 deletion tests/test_indexed_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def sim_fixture():
"""
Generate a very small simulation and slice it.
"""
sim = Simulation(L=8, n=10, C=1)
sim = Simulation(L=8, n=10, C=1, symmetry_group="D3")
sim2 = sim[0::2] # Slice the evens
return sim, sim2

Expand All @@ -31,6 +31,9 @@ def test_remapping(sim_fixture):
# Check meta is served correctly.
assert np.all(sim.get_metadata(indices=sim2.index_map) == sim2.get_metadata())

# Check symmetry_group pass-through.
assert sim.symmetry_group == sim2.symmetry_group


def test_repr(sim_fixture):
sim, sim2 = sim_fixture
Expand Down