From 9ea9306fcbe265af89d2dd87876771d69465d486 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 21 Mar 2024 13:27:49 -0400 Subject: [PATCH 1/7] SymmetryGroup.from_string extra parsing logic. --- src/aspire/image/image.py | 15 +++------------ src/aspire/source/image.py | 9 ++------- src/aspire/volume/symmetry_groups.py | 17 +++++++++++++++++ tests/test_image.py | 2 +- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 81589b1472..2e898f6bbc 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -516,18 +516,9 @@ def backproject(self, rot_matrices, symmetry_group=None): ), "Number of rotation matrices must match the number of images" # Get symmetry rotations from SymmetryGroup. - if symmetry_group is None: - symmetry_rots = np.eye(3, dtype=self.dtype)[None] - else: - if isinstance(symmetry_group, str): - symmetry_group = SymmetryGroup.from_string( - symmetry_group, dtype=self.dtype - ) - if not isinstance(symmetry_group, SymmetryGroup): - raise TypeError( - f"`symmetry_group` must be a `SymmetryGroup` instance. Found {type(symmetry_group)}." - ) - symmetry_rots = symmetry_group.matrices + symmetry_rots = SymmetryGroup.from_string( + symmetry_group, dtype=self.dtype + ).matrices # Compute Fourier transform of images. im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index afaf26fc23..b144497a5c 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -255,13 +255,8 @@ def symmetry_group(self, value): raise RuntimeError( f"This source is no longer mutable. Try new_source = source.update(symmetry_group='{value}')." ) - if isinstance(value, str): - value = SymmetryGroup.from_string(value, dtype=self.dtype) - if not isinstance(value, SymmetryGroup): - raise ValueError( - "`symmetry_group` must be an instance of the SymmetryGroup class" - ) - self._symmetry_group = value + + self._symmetry_group = SymmetryGroup.from_string(value, dtype=self.dtype) self.set_metadata(["_rlnSymmetryGroup"], str(self.symmetry_group)) def _populate_symmetry_group(self, symmetry_group): diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index 171b77605e..68e9de1e8a 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -56,6 +56,23 @@ def from_string(symmetry, dtype): :return: Concrete SymmetryGroup object. """ + if symmetry is None: + return IdentitySymmetryGroup(dtype=dtype) + + if isinstance(symmetry, SymmetryGroup): + if symmetry.dtype != dtype: + logger.warning(f"Recasting SymmetryGroup with dtype {dtype}.") + group_kwargs = dict(dtype=dtype) + if getattr(symmetry, "order", False) and symmetry.order > 1: + group_kwargs["order"] = symmetry.order + symmetry = symmetry.__class__(**group_kwargs) + return symmetry + + if not isinstance(symmetry, str): + raise TypeError( + f"`symmetry` must be a string or `SymmetryGroup` instance. Found {type(symmetry)}" + ) + symmetry = symmetry.upper() symmetry_type = symmetry[0] symmetric_order = symmetry[1:] diff --git a/tests/test_image.py b/tests/test_image.py index ab616d04e6..d9a062bbb7 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -307,7 +307,7 @@ def test_backproject_symmetry_group(): # Attempt backproject with bad symmetry group. not_a_symmetry_group = [] - with raises(TypeError, match=r"`symmetry_group` must be a `SymmetryGroup`"): + with raises(TypeError, match=r"`symmetry` must be a string or `SymmetryGroup`"): _ = im.backproject(rots, symmetry_group=not_a_symmetry_group) # Symmetry from string. From 88a2881a6303e00fbe04325a1708e1fd1b7853ed Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 21 Mar 2024 14:25:13 -0400 Subject: [PATCH 2/7] Add tests. Fix C1 from string to be IdentitySymmetryGroup. --- src/aspire/volume/symmetry_groups.py | 3 +++ tests/test_symmetry_groups.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index 68e9de1e8a..eb37d76e89 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -74,6 +74,9 @@ def from_string(symmetry, dtype): ) symmetry = symmetry.upper() + if symmetry == "C1": + return IdentitySymmetryGroup(dtype=dtype) + symmetry_type = symmetry[0] symmetric_order = symmetry[1:] diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index 34f64543ab..8727f9a9a0 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -8,6 +8,7 @@ from aspire.volume import ( CnSymmetryGroup, DnSymmetryGroup, + IdentitySymmetryGroup, OSymmetryGroup, SymmetryGroup, TSymmetryGroup, @@ -79,6 +80,30 @@ def test_group_rotations(group_fixture): assert isinstance(rotations, Rotation) +def test_from_string_identity(): + result = SymmetryGroup.from_string("C1", dtype=np.float32) + assert isinstance(result, IdentitySymmetryGroup) + + +def test_from_string_with_group(group_fixture): + """Test SymmetryGroup instance are parsed correctly.""" + result = SymmetryGroup.from_string(group_fixture, group_fixture.dtype) + assert result == group_fixture + assert result.dtype == group_fixture.dtype + + +def test_from_string_dtype_casting(group_fixture, caplog): + """Test that dtype gets re-cast and warns.""" + dtype = np.float32 + if group_fixture.dtype == np.float32: + dtype = np.float64 + + caplog.clear() + msg = f"Recasting SymmetryGroup with dtype {dtype}." + result = SymmetryGroup.from_string(group_fixture, dtype) + assert msg in caplog.text + + def test_from_string_error(): junk_symmetry = "P12" with pytest.raises( From 5e31b9f2ffe4314386c68ffb34c9e381d5f7c8f7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 21 Mar 2024 14:30:14 -0400 Subject: [PATCH 3/7] tox --- tests/test_symmetry_groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index 8727f9a9a0..249fe7023f 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -100,7 +100,7 @@ def test_from_string_dtype_casting(group_fixture, caplog): caplog.clear() msg = f"Recasting SymmetryGroup with dtype {dtype}." - result = SymmetryGroup.from_string(group_fixture, dtype) + _ = SymmetryGroup.from_string(group_fixture, dtype) assert msg in caplog.text From 40caba5d5ca54a1f2d5c82db9c8ebb22b39228c8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 Mar 2024 08:50:55 -0400 Subject: [PATCH 4/7] remove parsing from volume.py --- src/aspire/volume/volume.py | 11 +---------- tests/test_volume.py | 7 ------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 22b8cc8624..9eefd0bd27 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -171,16 +171,7 @@ def _set_symmetry_group(self, value): :param value: A `SymmetryGroup` instance or string indicating symmetry, ie. "C5", "D7", "T", etc. """ - # If value not provided set symmetry to the `IdentitySymmetryGroup`. - value = value or IdentitySymmetryGroup(dtype=self.dtype) - if isinstance(value, str): - value = SymmetryGroup.from_string(value, dtype=self.dtype) - if not isinstance(value, SymmetryGroup): - raise ValueError( - "`symmetry_group` must be an instance of the SymmetryGroup class" - " or a string indicating the symmetry, ie. 'C5', 'D7', 'T', etc." - ) - self._symmetry_group = value + self._symmetry_group = SymmetryGroup.from_string(value, dtype=self.dtype) def _symmetry_group_warning(self, stacklevel): """ diff --git a/tests/test_volume.py b/tests/test_volume.py index a0aaa8c4b0..ea52d1d67f 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -727,13 +727,6 @@ def test_symmetry_group_set_get(sym_group, sym_string): assert isinstance(vol.symmetry_group, SymmetryGroup) assert str(vol.symmetry_group) == sym_string - # Check for expected error when symmetry_group is not a SymmetryGroup object. - with raises( - ValueError, - match=r"`symmetry_group` must be an instance of the SymmetryGroup class.*", - ): - _ = Volume(data, symmetry_group=123, dtype=dtype) - def test_symmetry_group_pass_through(symmetric_vols): vol_c3, _ = symmetric_vols From 0813161ceb4d3aa9f3eb9d06b3e286bb917e4379 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 Mar 2024 09:00:25 -0400 Subject: [PATCH 5/7] Change name from_string --> parser --- src/aspire/image/image.py | 2 +- src/aspire/source/image.py | 4 ++-- src/aspire/volume/symmetry_groups.py | 2 +- src/aspire/volume/volume.py | 2 +- tests/test_symmetry_groups.py | 16 ++++++++-------- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 2e898f6bbc..9a675b0aff 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -516,7 +516,7 @@ def backproject(self, rot_matrices, symmetry_group=None): ), "Number of rotation matrices must match the number of images" # Get symmetry rotations from SymmetryGroup. - symmetry_rots = SymmetryGroup.from_string( + symmetry_rots = SymmetryGroup.parser( symmetry_group, dtype=self.dtype ).matrices diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index b144497a5c..dfa82ba06d 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -256,7 +256,7 @@ def symmetry_group(self, value): f"This source is no longer mutable. Try new_source = source.update(symmetry_group='{value}')." ) - self._symmetry_group = SymmetryGroup.from_string(value, dtype=self.dtype) + self._symmetry_group = SymmetryGroup.parser(value, dtype=self.dtype) self.set_metadata(["_rlnSymmetryGroup"], str(self.symmetry_group)) def _populate_symmetry_group(self, symmetry_group): @@ -270,7 +270,7 @@ def _populate_symmetry_group(self, symmetry_group): f"Overriding metadata with supplied symmetry group {symmetry_group}" ) else: - symmetry_group = SymmetryGroup.from_string( + symmetry_group = SymmetryGroup.parser( symmetry=self.get_metadata(["_rlnSymmetryGroup"])[0], dtype=self.dtype, ) diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index eb37d76e89..000a65dd69 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -46,7 +46,7 @@ def __str__(self): return f"{self.to_string}" @staticmethod - def from_string(symmetry, dtype): + def parser(symmetry, dtype): """ Takes a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete SymmetryGroup object. diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 9eefd0bd27..0082ba4c6b 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -171,7 +171,7 @@ def _set_symmetry_group(self, value): :param value: A `SymmetryGroup` instance or string indicating symmetry, ie. "C5", "D7", "T", etc. """ - self._symmetry_group = SymmetryGroup.from_string(value, dtype=self.dtype) + self._symmetry_group = SymmetryGroup.parser(value, dtype=self.dtype) def _symmetry_group_warning(self, stacklevel): """ diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index 249fe7023f..23008d3b39 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -80,19 +80,19 @@ def test_group_rotations(group_fixture): assert isinstance(rotations, Rotation) -def test_from_string_identity(): - result = SymmetryGroup.from_string("C1", dtype=np.float32) +def test_parser_identity(): + result = SymmetryGroup.parser("C1", dtype=np.float32) assert isinstance(result, IdentitySymmetryGroup) -def test_from_string_with_group(group_fixture): +def test_parser_with_group(group_fixture): """Test SymmetryGroup instance are parsed correctly.""" - result = SymmetryGroup.from_string(group_fixture, group_fixture.dtype) + result = SymmetryGroup.parser(group_fixture, group_fixture.dtype) assert result == group_fixture assert result.dtype == group_fixture.dtype -def test_from_string_dtype_casting(group_fixture, caplog): +def test_parser_dtype_casting(group_fixture, caplog): """Test that dtype gets re-cast and warns.""" dtype = np.float32 if group_fixture.dtype == np.float32: @@ -100,13 +100,13 @@ def test_from_string_dtype_casting(group_fixture, caplog): caplog.clear() msg = f"Recasting SymmetryGroup with dtype {dtype}." - _ = SymmetryGroup.from_string(group_fixture, dtype) + _ = SymmetryGroup.parser(group_fixture, dtype) assert msg in caplog.text -def test_from_string_error(): +def test_parser_error(): junk_symmetry = "P12" with pytest.raises( ValueError, match=f"Symmetry type {junk_symmetry[0]} not supported.*" ): - _ = SymmetryGroup.from_string(junk_symmetry, dtype=np.float32) + _ = SymmetryGroup.parser(junk_symmetry, dtype=np.float32) From 3d12612892683833d6eaf7239117410c66de631a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 Mar 2024 09:12:00 -0400 Subject: [PATCH 6/7] update docstring. --- src/aspire/image/image.py | 4 +--- src/aspire/volume/symmetry_groups.py | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 9a675b0aff..c9ef3939c0 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -516,9 +516,7 @@ def backproject(self, rot_matrices, symmetry_group=None): ), "Number of rotation matrices must match the number of images" # Get symmetry rotations from SymmetryGroup. - symmetry_rots = SymmetryGroup.parser( - symmetry_group, dtype=self.dtype - ).matrices + symmetry_rots = SymmetryGroup.parser(symmetry_group, dtype=self.dtype).matrices # Compute Fourier transform of images. im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2) diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index 000a65dd69..e24047722e 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -48,10 +48,10 @@ def __str__(self): @staticmethod def parser(symmetry, dtype): """ - Takes a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete - SymmetryGroup object. + Takes a SymmetryGroup instance or a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete + SymmetryGroup object with the specified dtype. - :param symmetry: A string indicating the symmetry of a molecule. + :param symmetry: A string (or SymmetryGroup instance) indicating the symmetry of a molecule. :param dtype: dtype for rotation matrices. :return: Concrete SymmetryGroup object. """ From c0f4ba1e2abee88e280c95d66c7e91a4970a6c21 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 22 Mar 2024 13:07:36 -0400 Subject: [PATCH 7/7] parser --> parse. oops. --- src/aspire/image/image.py | 2 +- src/aspire/source/image.py | 4 ++-- src/aspire/volume/symmetry_groups.py | 2 +- src/aspire/volume/volume.py | 2 +- tests/test_symmetry_groups.py | 8 ++++---- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index c9ef3939c0..9d2911ad7a 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -516,7 +516,7 @@ def backproject(self, rot_matrices, symmetry_group=None): ), "Number of rotation matrices must match the number of images" # Get symmetry rotations from SymmetryGroup. - symmetry_rots = SymmetryGroup.parser(symmetry_group, dtype=self.dtype).matrices + symmetry_rots = SymmetryGroup.parse(symmetry_group, dtype=self.dtype).matrices # Compute Fourier transform of images. im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index dfa82ba06d..473585acb9 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -256,7 +256,7 @@ def symmetry_group(self, value): f"This source is no longer mutable. Try new_source = source.update(symmetry_group='{value}')." ) - self._symmetry_group = SymmetryGroup.parser(value, dtype=self.dtype) + self._symmetry_group = SymmetryGroup.parse(value, dtype=self.dtype) self.set_metadata(["_rlnSymmetryGroup"], str(self.symmetry_group)) def _populate_symmetry_group(self, symmetry_group): @@ -270,7 +270,7 @@ def _populate_symmetry_group(self, symmetry_group): f"Overriding metadata with supplied symmetry group {symmetry_group}" ) else: - symmetry_group = SymmetryGroup.parser( + symmetry_group = SymmetryGroup.parse( symmetry=self.get_metadata(["_rlnSymmetryGroup"])[0], dtype=self.dtype, ) diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index e24047722e..6916e27cad 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -46,7 +46,7 @@ def __str__(self): return f"{self.to_string}" @staticmethod - def parser(symmetry, dtype): + def parse(symmetry, dtype): """ Takes a SymmetryGroup instance or a string, ie. 'C1', 'C7', 'D3', 'T', 'O', and returns a concrete SymmetryGroup object with the specified dtype. diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 0082ba4c6b..b6c100db36 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -171,7 +171,7 @@ def _set_symmetry_group(self, value): :param value: A `SymmetryGroup` instance or string indicating symmetry, ie. "C5", "D7", "T", etc. """ - self._symmetry_group = SymmetryGroup.parser(value, dtype=self.dtype) + self._symmetry_group = SymmetryGroup.parse(value, dtype=self.dtype) def _symmetry_group_warning(self, stacklevel): """ diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index 23008d3b39..236f1c60de 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -81,13 +81,13 @@ def test_group_rotations(group_fixture): def test_parser_identity(): - result = SymmetryGroup.parser("C1", dtype=np.float32) + result = SymmetryGroup.parse("C1", dtype=np.float32) assert isinstance(result, IdentitySymmetryGroup) def test_parser_with_group(group_fixture): """Test SymmetryGroup instance are parsed correctly.""" - result = SymmetryGroup.parser(group_fixture, group_fixture.dtype) + result = SymmetryGroup.parse(group_fixture, group_fixture.dtype) assert result == group_fixture assert result.dtype == group_fixture.dtype @@ -100,7 +100,7 @@ def test_parser_dtype_casting(group_fixture, caplog): caplog.clear() msg = f"Recasting SymmetryGroup with dtype {dtype}." - _ = SymmetryGroup.parser(group_fixture, dtype) + _ = SymmetryGroup.parse(group_fixture, dtype) assert msg in caplog.text @@ -109,4 +109,4 @@ def test_parser_error(): with pytest.raises( ValueError, match=f"Symmetry type {junk_symmetry[0]} not supported.*" ): - _ = SymmetryGroup.parser(junk_symmetry, dtype=np.float32) + _ = SymmetryGroup.parse(junk_symmetry, dtype=np.float32)