diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 81589b1472..9d2911ad7a 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -516,18 +516,7 @@ def backproject(self, rot_matrices, symmetry_group=None): ), "Number of rotation matrices must match the number of images" # Get symmetry rotations from SymmetryGroup. - if symmetry_group is None: - symmetry_rots = np.eye(3, dtype=self.dtype)[None] - else: - if isinstance(symmetry_group, str): - symmetry_group = SymmetryGroup.from_string( - symmetry_group, dtype=self.dtype - ) - if not isinstance(symmetry_group, SymmetryGroup): - raise TypeError( - f"`symmetry_group` must be a `SymmetryGroup` instance. Found {type(symmetry_group)}." - ) - symmetry_rots = symmetry_group.matrices + symmetry_rots = SymmetryGroup.parse(symmetry_group, dtype=self.dtype).matrices # Compute Fourier transform of images. im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index afaf26fc23..473585acb9 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -255,13 +255,8 @@ def symmetry_group(self, value): raise RuntimeError( f"This source is no longer mutable. Try new_source = source.update(symmetry_group='{value}')." ) - if isinstance(value, str): - value = SymmetryGroup.from_string(value, dtype=self.dtype) - if not isinstance(value, SymmetryGroup): - raise ValueError( - "`symmetry_group` must be an instance of the SymmetryGroup class" - ) - self._symmetry_group = value + + self._symmetry_group = SymmetryGroup.parse(value, dtype=self.dtype) self.set_metadata(["_rlnSymmetryGroup"], str(self.symmetry_group)) def _populate_symmetry_group(self, symmetry_group): @@ -275,7 +270,7 @@ def _populate_symmetry_group(self, symmetry_group): f"Overriding metadata with supplied symmetry group {symmetry_group}" ) else: - symmetry_group = SymmetryGroup.from_string( + symmetry_group = SymmetryGroup.parse( symmetry=self.get_metadata(["_rlnSymmetryGroup"])[0], dtype=self.dtype, ) diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index 171b77605e..6916e27cad 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -46,17 +46,37 @@ def __str__(self): return f"{self.to_string}" @staticmethod - def from_string(symmetry, dtype): + def parse(symmetry, dtype): """ - Takes a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete - SymmetryGroup object. + Takes a SymmetryGroup instance or a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete + SymmetryGroup object with the specified dtype. - :param symmetry: A string indicating the symmetry of a molecule. + :param symmetry: A string (or SymmetryGroup instance) indicating the symmetry of a molecule. :param dtype: dtype for rotation matrices. :return: Concrete SymmetryGroup object. """ + if symmetry is None: + return IdentitySymmetryGroup(dtype=dtype) + + if isinstance(symmetry, SymmetryGroup): + if symmetry.dtype != dtype: + logger.warning(f"Recasting SymmetryGroup with dtype {dtype}.") + group_kwargs = dict(dtype=dtype) + if getattr(symmetry, "order", False) and symmetry.order > 1: + group_kwargs["order"] = symmetry.order + symmetry = symmetry.__class__(**group_kwargs) + return symmetry + + if not isinstance(symmetry, str): + raise TypeError( + f"`symmetry` must be a string or `SymmetryGroup` instance. Found {type(symmetry)}" + ) + symmetry = symmetry.upper() + if symmetry == "C1": + return IdentitySymmetryGroup(dtype=dtype) + symmetry_type = symmetry[0] symmetric_order = symmetry[1:] diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 22b8cc8624..b6c100db36 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -171,16 +171,7 @@ def _set_symmetry_group(self, value): :param value: A `SymmetryGroup` instance or string indicating symmetry, ie. "C5", "D7", "T", etc. """ - # If value not provided set symmetry to the `IdentitySymmetryGroup`. - value = value or IdentitySymmetryGroup(dtype=self.dtype) - if isinstance(value, str): - value = SymmetryGroup.from_string(value, dtype=self.dtype) - if not isinstance(value, SymmetryGroup): - raise ValueError( - "`symmetry_group` must be an instance of the SymmetryGroup class" - " or a string indicating the symmetry, ie. 'C5', 'D7', 'T', etc." - ) - self._symmetry_group = value + self._symmetry_group = SymmetryGroup.parse(value, dtype=self.dtype) def _symmetry_group_warning(self, stacklevel): """ diff --git a/tests/test_image.py b/tests/test_image.py index ab616d04e6..d9a062bbb7 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -307,7 +307,7 @@ def test_backproject_symmetry_group(): # Attempt backproject with bad symmetry group. not_a_symmetry_group = [] - with raises(TypeError, match=r"`symmetry_group` must be a `SymmetryGroup`"): + with raises(TypeError, match=r"`symmetry` must be a string or `SymmetryGroup`"): _ = im.backproject(rots, symmetry_group=not_a_symmetry_group) # Symmetry from string. diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index 34f64543ab..236f1c60de 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -8,6 +8,7 @@ from aspire.volume import ( CnSymmetryGroup, DnSymmetryGroup, + IdentitySymmetryGroup, OSymmetryGroup, SymmetryGroup, TSymmetryGroup, @@ -79,9 +80,33 @@ def test_group_rotations(group_fixture): assert isinstance(rotations, Rotation) -def test_from_string_error(): +def test_parser_identity(): + result = SymmetryGroup.parse("C1", dtype=np.float32) + assert isinstance(result, IdentitySymmetryGroup) + + +def test_parser_with_group(group_fixture): + """Test SymmetryGroup instance are parsed correctly.""" + result = SymmetryGroup.parse(group_fixture, group_fixture.dtype) + assert result == group_fixture + assert result.dtype == group_fixture.dtype + + +def test_parser_dtype_casting(group_fixture, caplog): + """Test that dtype gets re-cast and warns.""" + dtype = np.float32 + if group_fixture.dtype == np.float32: + dtype = np.float64 + + caplog.clear() + msg = f"Recasting SymmetryGroup with dtype {dtype}." + _ = SymmetryGroup.parse(group_fixture, dtype) + assert msg in caplog.text + + +def test_parser_error(): junk_symmetry = "P12" with pytest.raises( ValueError, match=f"Symmetry type {junk_symmetry[0]} not supported.*" ): - _ = SymmetryGroup.from_string(junk_symmetry, dtype=np.float32) + _ = SymmetryGroup.parse(junk_symmetry, dtype=np.float32) diff --git a/tests/test_volume.py b/tests/test_volume.py index a0aaa8c4b0..ea52d1d67f 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -727,13 +727,6 @@ def test_symmetry_group_set_get(sym_group, sym_string): assert isinstance(vol.symmetry_group, SymmetryGroup) assert str(vol.symmetry_group) == sym_string - # Check for expected error when symmetry_group is not a SymmetryGroup object. - with raises( - ValueError, - match=r"`symmetry_group` must be an instance of the SymmetryGroup class.*", - ): - _ = Volume(data, symmetry_group=123, dtype=dtype) - def test_symmetry_group_pass_through(symmetric_vols): vol_c3, _ = symmetric_vols