Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,18 +516,7 @@ def backproject(self, rot_matrices, symmetry_group=None):
), "Number of rotation matrices must match the number of images"

# Get symmetry rotations from SymmetryGroup.
if symmetry_group is None:
symmetry_rots = np.eye(3, dtype=self.dtype)[None]
else:
if isinstance(symmetry_group, str):
symmetry_group = SymmetryGroup.from_string(
symmetry_group, dtype=self.dtype
)
if not isinstance(symmetry_group, SymmetryGroup):
raise TypeError(
f"`symmetry_group` must be a `SymmetryGroup` instance. Found {type(symmetry_group)}."
)
symmetry_rots = symmetry_group.matrices
symmetry_rots = SymmetryGroup.parse(symmetry_group, dtype=self.dtype).matrices

# Compute Fourier transform of images.
im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2)
Expand Down
11 changes: 3 additions & 8 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,8 @@ def symmetry_group(self, value):
raise RuntimeError(
f"This source is no longer mutable. Try new_source = source.update(symmetry_group='{value}')."
)
if isinstance(value, str):
value = SymmetryGroup.from_string(value, dtype=self.dtype)
if not isinstance(value, SymmetryGroup):
raise ValueError(
"`symmetry_group` must be an instance of the SymmetryGroup class"
)
self._symmetry_group = value

self._symmetry_group = SymmetryGroup.parse(value, dtype=self.dtype)
self.set_metadata(["_rlnSymmetryGroup"], str(self.symmetry_group))

def _populate_symmetry_group(self, symmetry_group):
Expand All @@ -275,7 +270,7 @@ def _populate_symmetry_group(self, symmetry_group):
f"Overriding metadata with supplied symmetry group {symmetry_group}"
)
else:
symmetry_group = SymmetryGroup.from_string(
symmetry_group = SymmetryGroup.parse(
symmetry=self.get_metadata(["_rlnSymmetryGroup"])[0],
dtype=self.dtype,
)
Expand Down
28 changes: 24 additions & 4 deletions src/aspire/volume/symmetry_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,37 @@ def __str__(self):
return f"{self.to_string}"

@staticmethod
def from_string(symmetry, dtype):
def parse(symmetry, dtype):
"""
Takes a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete
SymmetryGroup object.
Takes a SymmetryGroup instance or a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete
SymmetryGroup object with the specified dtype.

:param symmetry: A string indicating the symmetry of a molecule.
:param symmetry: A string (or SymmetryGroup instance) indicating the symmetry of a molecule.
:param dtype: dtype for rotation matrices.
:return: Concrete SymmetryGroup object.
"""

if symmetry is None:
return IdentitySymmetryGroup(dtype=dtype)

if isinstance(symmetry, SymmetryGroup):
if symmetry.dtype != dtype:
logger.warning(f"Recasting SymmetryGroup with dtype {dtype}.")
group_kwargs = dict(dtype=dtype)
if getattr(symmetry, "order", False) and symmetry.order > 1:
group_kwargs["order"] = symmetry.order
symmetry = symmetry.__class__(**group_kwargs)
return symmetry

if not isinstance(symmetry, str):
raise TypeError(
f"`symmetry` must be a string or `SymmetryGroup` instance. Found {type(symmetry)}"
)

symmetry = symmetry.upper()
if symmetry == "C1":
return IdentitySymmetryGroup(dtype=dtype)

symmetry_type = symmetry[0]
symmetric_order = symmetry[1:]

Expand Down
11 changes: 1 addition & 10 deletions src/aspire/volume/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,7 @@ def _set_symmetry_group(self, value):

:param value: A `SymmetryGroup` instance or string indicating symmetry, ie. "C5", "D7", "T", etc.
"""
# If value not provided set symmetry to the `IdentitySymmetryGroup`.
value = value or IdentitySymmetryGroup(dtype=self.dtype)
if isinstance(value, str):
value = SymmetryGroup.from_string(value, dtype=self.dtype)
if not isinstance(value, SymmetryGroup):
raise ValueError(
"`symmetry_group` must be an instance of the SymmetryGroup class"
" or a string indicating the symmetry, ie. 'C5', 'D7', 'T', etc."
)
self._symmetry_group = value
self._symmetry_group = SymmetryGroup.parse(value, dtype=self.dtype)

def _symmetry_group_warning(self, stacklevel):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_backproject_symmetry_group():

# Attempt backproject with bad symmetry group.
not_a_symmetry_group = []
with raises(TypeError, match=r"`symmetry_group` must be a `SymmetryGroup`"):
with raises(TypeError, match=r"`symmetry` must be a string or `SymmetryGroup`"):
_ = im.backproject(rots, symmetry_group=not_a_symmetry_group)

# Symmetry from string.
Expand Down
29 changes: 27 additions & 2 deletions tests/test_symmetry_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aspire.volume import (
CnSymmetryGroup,
DnSymmetryGroup,
IdentitySymmetryGroup,
OSymmetryGroup,
SymmetryGroup,
TSymmetryGroup,
Expand Down Expand Up @@ -79,9 +80,33 @@ def test_group_rotations(group_fixture):
assert isinstance(rotations, Rotation)


def test_from_string_error():
def test_parser_identity():
result = SymmetryGroup.parse("C1", dtype=np.float32)
assert isinstance(result, IdentitySymmetryGroup)


def test_parser_with_group(group_fixture):
"""Test SymmetryGroup instance are parsed correctly."""
result = SymmetryGroup.parse(group_fixture, group_fixture.dtype)
assert result == group_fixture
assert result.dtype == group_fixture.dtype


def test_parser_dtype_casting(group_fixture, caplog):
"""Test that dtype gets re-cast and warns."""
dtype = np.float32
if group_fixture.dtype == np.float32:
dtype = np.float64

caplog.clear()
msg = f"Recasting SymmetryGroup with dtype {dtype}."
_ = SymmetryGroup.parse(group_fixture, dtype)
assert msg in caplog.text


def test_parser_error():
junk_symmetry = "P12"
with pytest.raises(
ValueError, match=f"Symmetry type {junk_symmetry[0]} not supported.*"
):
_ = SymmetryGroup.from_string(junk_symmetry, dtype=np.float32)
_ = SymmetryGroup.parse(junk_symmetry, dtype=np.float32)
7 changes: 0 additions & 7 deletions tests/test_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,13 +727,6 @@ def test_symmetry_group_set_get(sym_group, sym_string):
assert isinstance(vol.symmetry_group, SymmetryGroup)
assert str(vol.symmetry_group) == sym_string

# Check for expected error when symmetry_group is not a SymmetryGroup object.
with raises(
ValueError,
match=r"`symmetry_group` must be an instance of the SymmetryGroup class.*",
):
_ = Volume(data, symmetry_group=123, dtype=dtype)


def test_symmetry_group_pass_through(symmetric_vols):
vol_c3, _ = symmetric_vols
Expand Down