From a35b7e882419f877eee8ff3485215dd82421158b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 8 Feb 2024 13:31:13 -0500 Subject: [PATCH 01/24] Add symmetry_group attribute to volume transformation methods. --- src/aspire/volume/volume.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index cc29710045..e358a59348 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -135,7 +135,9 @@ def astype(self, dtype, copy=True): Defaults to True. :return: Volume instance """ - return self.__class__(self.asnumpy().astype(dtype, copy=copy)) + return self.__class__( + self.asnumpy().astype(dtype, copy=copy), symmetry_group=self.symmetry_group + ) def _check_key_dims(self, key): if isinstance(key, tuple) and (len(key) > self._data.ndim): @@ -145,7 +147,7 @@ def _check_key_dims(self, key): def __getitem__(self, key): self._check_key_dims(key) - return self.__class__(self._data[key]) + return self.__class__(self._data[key], symmetry_group=self.symmetry_group) def __setitem__(self, key, value): self._check_key_dims(key) @@ -198,7 +200,10 @@ def stack_reshape(self, *args): f"Number of volumes {self.n_vols} cannot be reshaped to {shape}." ) - return self.__class__(self._data.reshape(*shape, *self._data.shape[-3:])) + return self.__class__( + self._data.reshape(*shape, *self._data.shape[-3:]), + symmetry_group=self.symmetry_group, + ) def __repr__(self): msg = ( @@ -359,7 +364,9 @@ def transpose(self): original_stack_shape = self.stack_shape v = self.stack_reshape(-1) vt = np.transpose(v._data, (0, -1, -2, -3)) - return self.__class__(vt).stack_reshape(original_stack_shape) + return self.__class__(vt, symmetry_group=self.symmetry_group).stack_reshape( + original_stack_shape + ) @property def T(self): @@ -400,7 +407,9 @@ def flip(self, axis=-3): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?" ) - return self.__class__(np.flip(self._data, axis)) + return self.__class__( + np.flip(self._data, axis), symmetry_group=self.symmetry_group + ) def downsample(self, ds_res, mask=None): """ @@ -427,7 +436,9 @@ def downsample(self, ds_res, mask=None): ds_res**3 / self.resolution**3 ) # returns a new Volume object - return self.__class__(np.real(out)).stack_reshape(original_stack_shape) + return self.__class__( + np.real(out), symmetry_group=self.symmetry_group + ).stack_reshape(original_stack_shape) def shift(self): raise NotImplementedError @@ -497,7 +508,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1))) ) - return self.__class__(vol) + return self.__class__(vol, symmetry_group=self.symmetry_group) def denoise(self): raise NotImplementedError From 9130df25b174d6d4cb1b81630f87c7ef398a6c6d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 8 Feb 2024 13:57:27 -0500 Subject: [PATCH 02/24] Add pass-through test --- tests/test_volume.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_volume.py b/tests/test_volume.py index 9d039c4c9f..72c33ceb5c 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -726,6 +726,28 @@ def test_symmetry_group_set_get(sym_group, sym_string): _ = Volume(data, symmetry_group=123, dtype=dtype) +def test_symmetry_group_pass_through(): + sym_group = "C5" + vol = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")), + symmetry_group=sym_group, + ) + + # Chack symmetry_group pass-through for various transformations. + assert str(vol.astype(np.float64).symmetry_group) == sym_group # astype + assert str(vol[0].symmetry_group) == sym_group # getitem + assert str(vol.stack_reshape((1, 1)).symmetry_group) == sym_group # stack_reshape + assert str(vol.T.symmetry_group) == sym_group # transpose + assert str(vol.flip().symmetry_group) == sym_group # flip + assert ( + str(vol.downsample(vol.resolution // 2).symmetry_group) == sym_group + ) # downsample + assert ( + str(vol.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)).symmetry_group) + == sym_group + ) # rotate + + def test_volume_load_with_symmetry(): # Check we can load a Volume with symmetry_group. vol = Volume( From 1584178f646c22d68ea5e67550b352ff3f680747 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 9 Feb 2024 16:03:14 -0500 Subject: [PATCH 03/24] symmetry_group warning --- src/aspire/volume/volume.py | 36 ++++++++++++++++++++-------- tests/test_volume.py | 47 +++++++++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index e358a59348..f70000180f 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -1,4 +1,5 @@ import logging +import warnings import mrcfile import numpy as np @@ -178,6 +179,19 @@ def _set_symmetry_group(self, value): ) self._symmetry_group = value + def _symmetry_group_warning(self): + """ + Warn when a transformation has the potential to change the symmetry + of a volume or the alignment of symmetry axes. + """ + if str(self.symmetry_group) != "C1": + msg = ( + f"`symmetry_group` attribute is being set to `C1`. This transformation may" + f" effect the symmetry (or symmetric alignment) of the volume. To reset the" + f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." + ) + warnings.warn(msg, UserWarning) + def stack_reshape(self, *args): """ Reshape the stack axis. @@ -357,16 +371,17 @@ def from_vec(cls, vec): def transpose(self): """ - Returns a new Volume instance with volume data axes tranposed. + Returns a new Volume instance with volume data axes transposed. :return: Volume instance. """ + self._symmetry_group_warning() + original_stack_shape = self.stack_shape - v = self.stack_reshape(-1) - vt = np.transpose(v._data, (0, -1, -2, -3)) - return self.__class__(vt, symmetry_group=self.symmetry_group).stack_reshape( - original_stack_shape - ) + v = self._data.reshape(-1, *self._data.shape[-3:]) + vt = np.transpose(v, (0, -1, -2, -3)) + vt = vt.reshape(*original_stack_shape, *self._data.shape[-3:]) + return self.__class__(vt) @property def T(self): @@ -396,6 +411,8 @@ def flip(self, axis=-3): :return: Volume instance. """ + self._symmetry_group_warning() + # Convert integer to tuple, so we can always loop. if isinstance(axis, int): axis = (axis,) @@ -407,9 +424,7 @@ def flip(self, axis=-3): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?" ) - return self.__class__( - np.flip(self._data, axis), symmetry_group=self.symmetry_group - ) + return self.__class__(np.flip(self._data, axis)) def downsample(self, ds_res, mask=None): """ @@ -455,6 +470,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): :return: `Volume` instance. """ + self._symmetry_group_warning() if self.stack_ndim > 1: raise NotImplementedError( "`rotation` is currently limited to 1D Volume stacks." @@ -508,7 +524,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1))) ) - return self.__class__(vol, symmetry_group=self.symmetry_group) + return self.__class__(vol) def denoise(self): raise NotImplementedError diff --git a/tests/test_volume.py b/tests/test_volume.py index 72c33ceb5c..a489c48feb 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,6 +1,7 @@ import logging import os import tempfile +import warnings from itertools import product import numpy as np @@ -733,19 +734,51 @@ def test_symmetry_group_pass_through(): symmetry_group=sym_group, ) - # Chack symmetry_group pass-through for various transformations. + # Check symmetry_group pass-through for various transformations. assert str(vol.astype(np.float64).symmetry_group) == sym_group # astype assert str(vol[0].symmetry_group) == sym_group # getitem assert str(vol.stack_reshape((1, 1)).symmetry_group) == sym_group # stack_reshape - assert str(vol.T.symmetry_group) == sym_group # transpose - assert str(vol.flip().symmetry_group) == sym_group # flip assert ( str(vol.downsample(vol.resolution // 2).symmetry_group) == sym_group ) # downsample - assert ( - str(vol.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)).symmetry_group) - == sym_group - ) # rotate + + +def test_symmetry_group_reset_warning(): + sym_group = "C5" + vol = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")), + symmetry_group=sym_group, + ) + + # Check we get warning on first transformation. + with pytest.warns( + UserWarning, match=r".*`symmetry_group` attribute is being set to `C1`.*" + ) as record: + vol_t = vol.T + vol_f = vol.flip() + vol_r = vol.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)) + assert len(record) == 3 + + # Check symmetry_group has been set to C1. + assert str(vol_t.symmetry_group) == "C1" + assert str(vol_f.symmetry_group) == "C1" + assert str(vol_r.symmetry_group) == "C1" + + # Check we get no warnings on second transformation. + with pytest.warns() as record: + _ = vol_t.T + _ = vol_f.flip() + _ = vol_r.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)) + + # Throw single test warning. + warnings.warn("test", Warning) + + # Should only have test warning + assert len(record) == 1 + assert str(record[0].message) == "test" + + # Check original volume has retained C5 symmetry. + assert str(vol.symmetry_group) == "C5" def test_volume_load_with_symmetry(): From 354b12d50bc718858d217252ff24214bd8f65fbb Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 12 Feb 2024 14:25:43 -0500 Subject: [PATCH 04/24] add method name to warning --- src/aspire/volume/volume.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index f70000180f..f4f6599b89 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -179,18 +179,20 @@ def _set_symmetry_group(self, value): ) self._symmetry_group = value - def _symmetry_group_warning(self): + def _symmetry_group_warning(self, method): """ Warn when a transformation has the potential to change the symmetry of a volume or the alignment of symmetry axes. + + :param method: Name of method (string). """ if str(self.symmetry_group) != "C1": msg = ( - f"`symmetry_group` attribute is being set to `C1`. This transformation may" - f" effect the symmetry (or symmetric alignment) of the volume. To reset the" + f"`symmetry_group` attribute is being set to `C1`. {self.__class__.__name__}.{method}()" + f" may effect the symmetry (or symmetric alignment) of the volume. To reset the" f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." ) - warnings.warn(msg, UserWarning) + warnings.warn(msg, UserWarning, stacklevel=3) def stack_reshape(self, *args): """ @@ -230,6 +232,7 @@ def __len__(self): return self.n_vols def __add__(self, other): + self._symmetry_group_warning("add") if isinstance(other, Volume): res = self.__class__(self._data + other.asnumpy()) else: @@ -375,7 +378,7 @@ def transpose(self): :return: Volume instance. """ - self._symmetry_group_warning() + self._symmetry_group_warning("transpose") original_stack_shape = self.stack_shape v = self._data.reshape(-1, *self._data.shape[-3:]) @@ -411,7 +414,7 @@ def flip(self, axis=-3): :return: Volume instance. """ - self._symmetry_group_warning() + self._symmetry_group_warning("flip") # Convert integer to tuple, so we can always loop. if isinstance(axis, int): @@ -470,7 +473,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): :return: `Volume` instance. """ - self._symmetry_group_warning() + self._symmetry_group_warning("rotate") if self.stack_ndim > 1: raise NotImplementedError( "`rotation` is currently limited to 1D Volume stacks." From 82b233507353c3b7bcc345ef585a80928b948bf8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 12 Feb 2024 15:12:52 -0500 Subject: [PATCH 05/24] handle algebraic operations --- src/aspire/volume/volume.py | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index f4f6599b89..9ef7d47fb5 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -232,10 +232,16 @@ def __len__(self): return self.n_vols def __add__(self, other): - self._symmetry_group_warning("add") if isinstance(other, Volume): - res = self.__class__(self._data + other.asnumpy()) + if str(self.symmetry_group) == str(other.symmetry_group): + res = self.__class__( + self._data + other.asnumpy(), symmetry_group=self.symmetry_group + ) + else: + self._symmetry_group_warning("add") + res = self.__class__(self._data + other.asnumpy()) else: + self._symmetry_group_warning("add") res = self.__class__(self._data + other) return res @@ -245,8 +251,15 @@ def __radd__(self, otherL): def __sub__(self, other): if isinstance(other, Volume): - res = self.__class__(self._data - other.asnumpy()) + if str(self.symmetry_group) == str(other.symmetry_group): + res = self.__class__( + self._data - other.asnumpy(), symmetry_group=self.symmetry_group + ) + else: + self._symmetry_group_warning("subtract") + res = self.__class__(self._data - other.asnumpy()) else: + self._symmetry_group_warning("subtract") res = self.__class__(self._data - other) return res @@ -256,8 +269,15 @@ def __rsub__(self, otherL): def __mul__(self, other): if isinstance(other, Volume): - res = self.__class__(self._data * other.asnumpy()) + if str(self.symmetry_group) == str(other.symmetry_group): + res = self.__class__( + self._data * other.asnumpy(), symmetry_group=self.symmetry_group + ) + else: + self._symmetry_group_warning("multiply") + res = self.__class__(self._data * other.asnumpy()) else: + self._symmetry_group_warning("multiply") res = self.__class__(self._data * other) return res @@ -270,8 +290,15 @@ def __truediv__(self, other): Scalar division, follows numpy semantics. """ if isinstance(other, Volume): - res = self.__class__(self._data / other.asnumpy()) + if str(self.symmetry_group) == str(other.symmetry_group): + res = self.__class__( + self._data / other.asnumpy(), symmetry_group=self.symmetry_group + ) + else: + self._symmetry_group_warning("divide") + res = self.__class__(self._data / other.asnumpy()) else: + self._symmetry_group_warning("divide") res = self.__class__(self._data / other) return res From 5ac4e9ba3d1e60021a88f362a3da4376ad9b1e83 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 12 Feb 2024 16:06:41 -0500 Subject: [PATCH 06/24] drop method name from warning --- src/aspire/volume/volume.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 9ef7d47fb5..1f2d4d6295 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -179,7 +179,7 @@ def _set_symmetry_group(self, value): ) self._symmetry_group = value - def _symmetry_group_warning(self, method): + def _symmetry_group_warning(self): """ Warn when a transformation has the potential to change the symmetry of a volume or the alignment of symmetry axes. @@ -188,8 +188,8 @@ def _symmetry_group_warning(self, method): """ if str(self.symmetry_group) != "C1": msg = ( - f"`symmetry_group` attribute is being set to `C1`. {self.__class__.__name__}.{method}()" - f" may effect the symmetry (or symmetric alignment) of the volume. To reset the" + f"`symmetry_group` attribute is being set to `C1`. This transformation may" + f" effect the symmetry (or symmetric alignment) of the volume. To reset the" f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." ) warnings.warn(msg, UserWarning, stacklevel=3) @@ -238,10 +238,10 @@ def __add__(self, other): self._data + other.asnumpy(), symmetry_group=self.symmetry_group ) else: - self._symmetry_group_warning("add") + self._symmetry_group_warning() res = self.__class__(self._data + other.asnumpy()) else: - self._symmetry_group_warning("add") + self._symmetry_group_warning() res = self.__class__(self._data + other) return res @@ -256,10 +256,10 @@ def __sub__(self, other): self._data - other.asnumpy(), symmetry_group=self.symmetry_group ) else: - self._symmetry_group_warning("subtract") + self._symmetry_group_warning() res = self.__class__(self._data - other.asnumpy()) else: - self._symmetry_group_warning("subtract") + self._symmetry_group_warning() res = self.__class__(self._data - other) return res @@ -274,10 +274,10 @@ def __mul__(self, other): self._data * other.asnumpy(), symmetry_group=self.symmetry_group ) else: - self._symmetry_group_warning("multiply") + self._symmetry_group_warning() res = self.__class__(self._data * other.asnumpy()) else: - self._symmetry_group_warning("multiply") + self._symmetry_group_warning() res = self.__class__(self._data * other) return res @@ -295,10 +295,10 @@ def __truediv__(self, other): self._data / other.asnumpy(), symmetry_group=self.symmetry_group ) else: - self._symmetry_group_warning("divide") + self._symmetry_group_warning() res = self.__class__(self._data / other.asnumpy()) else: - self._symmetry_group_warning("divide") + self._symmetry_group_warning() res = self.__class__(self._data / other) return res From 5d5b399fba41233affe204db5d4d2b062f6d0562 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Feb 2024 14:59:33 -0500 Subject: [PATCH 07/24] Check for incompatible symmetry. --- src/aspire/volume/volume.py | 71 +++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 1f2d4d6295..1e42ccc57a 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -183,8 +183,6 @@ def _symmetry_group_warning(self): """ Warn when a transformation has the potential to change the symmetry of a volume or the alignment of symmetry axes. - - :param method: Name of method (string). """ if str(self.symmetry_group) != "C1": msg = ( @@ -194,6 +192,27 @@ def _symmetry_group_warning(self): ) warnings.warn(msg, UserWarning, stacklevel=3) + def _result_symmetry(self, other): + """ + Check if `other` will alter symmetry of `self` and return appropriate symmetry_group. + """ + result_symmetry = self.symmetry_group + + # Warn and set to Identity if incompatible symmetries or non-saclar. + if isinstance(other, Volume) and str(self.symmetry_group) != str( + other.symmetry_group + ): + self._symmetry_group_warning() + result_symmetry = IdentitySymmetryGroup(dtype=self.dtype) + + elif not isinstance(other, Volume) and hasattr( + other, "__len__" + ): # ie. is not a scalar + self._symmetry_group_warning() + result_symmetry = IdentitySymmetryGroup(dtype=self.dtype) + + return result_symmetry + def stack_reshape(self, *args): """ Reshape the stack axis. @@ -232,17 +251,11 @@ def __len__(self): return self.n_vols def __add__(self, other): + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - if str(self.symmetry_group) == str(other.symmetry_group): - res = self.__class__( - self._data + other.asnumpy(), symmetry_group=self.symmetry_group - ) - else: - self._symmetry_group_warning() - res = self.__class__(self._data + other.asnumpy()) + res = self.__class__(self._data + other.asnumpy(), symmetry_group=symmetry) else: - self._symmetry_group_warning() - res = self.__class__(self._data + other) + res = self.__class__(self._data + other, symmetry_group=symmetry) return res @@ -250,17 +263,11 @@ def __radd__(self, otherL): return self + otherL def __sub__(self, other): + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - if str(self.symmetry_group) == str(other.symmetry_group): - res = self.__class__( - self._data - other.asnumpy(), symmetry_group=self.symmetry_group - ) - else: - self._symmetry_group_warning() - res = self.__class__(self._data - other.asnumpy()) + res = self.__class__(self._data - other.asnumpy(), symmetry_group=symmetry) else: - self._symmetry_group_warning() - res = self.__class__(self._data - other) + res = self.__class__(self._data - other, symmetry_group=symmetry) return res @@ -268,17 +275,11 @@ def __rsub__(self, otherL): return self.__class__(otherL - self._data) def __mul__(self, other): + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - if str(self.symmetry_group) == str(other.symmetry_group): - res = self.__class__( - self._data * other.asnumpy(), symmetry_group=self.symmetry_group - ) - else: - self._symmetry_group_warning() - res = self.__class__(self._data * other.asnumpy()) + res = self.__class__(self._data * other.asnumpy(), symmetry_group=symmetry) else: - self._symmetry_group_warning() - res = self.__class__(self._data * other) + res = self.__class__(self._data * other, symmetry_group=symmetry) return res @@ -289,17 +290,11 @@ def __truediv__(self, other): """ Scalar division, follows numpy semantics. """ + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - if str(self.symmetry_group) == str(other.symmetry_group): - res = self.__class__( - self._data / other.asnumpy(), symmetry_group=self.symmetry_group - ) - else: - self._symmetry_group_warning() - res = self.__class__(self._data / other.asnumpy()) + res = self.__class__(self._data / other.asnumpy(), symmetry_group=symmetry) else: - self._symmetry_group_warning() - res = self.__class__(self._data / other) + res = self.__class__(self._data / other, symmetry_group=symmetry) return res From dd52df6ff9bd950d61ee92a0400a825b30505535 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Feb 2024 15:44:32 -0500 Subject: [PATCH 08/24] Use _result_symmetry to handle all warnings. --- src/aspire/volume/volume.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 1e42ccc57a..d27971fec6 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -190,15 +190,18 @@ def _symmetry_group_warning(self): f" effect the symmetry (or symmetric alignment) of the volume. To reset the" f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." ) - warnings.warn(msg, UserWarning, stacklevel=3) + warnings.warn(msg, UserWarning, stacklevel=4) - def _result_symmetry(self, other): + def _result_symmetry(self, other=None): """ Check if `other` will alter symmetry of `self` and return appropriate symmetry_group. """ - result_symmetry = self.symmetry_group + if other is None: + self._symmetry_group_warning() + return IdentitySymmetryGroup(dtype=self.dtype) - # Warn and set to Identity if incompatible symmetries or non-saclar. + result_symmetry = self.symmetry_group + # Warn and set to IdentitySymmetryGroup if incompatible symmetries or non-scalar. if isinstance(other, Volume) and str(self.symmetry_group) != str( other.symmetry_group ): @@ -400,13 +403,13 @@ def transpose(self): :return: Volume instance. """ - self._symmetry_group_warning("transpose") + symmetry = self._result_symmetry() original_stack_shape = self.stack_shape v = self._data.reshape(-1, *self._data.shape[-3:]) vt = np.transpose(v, (0, -1, -2, -3)) vt = vt.reshape(*original_stack_shape, *self._data.shape[-3:]) - return self.__class__(vt) + return self.__class__(vt, symmetry_group=symmetry) @property def T(self): @@ -436,7 +439,7 @@ def flip(self, axis=-3): :return: Volume instance. """ - self._symmetry_group_warning("flip") + symmetry = self._result_symmetry() # Convert integer to tuple, so we can always loop. if isinstance(axis, int): @@ -449,7 +452,7 @@ def flip(self, axis=-3): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?" ) - return self.__class__(np.flip(self._data, axis)) + return self.__class__(np.flip(self._data, axis), symmetry_group=symmetry) def downsample(self, ds_res, mask=None): """ @@ -495,7 +498,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): :return: `Volume` instance. """ - self._symmetry_group_warning("rotate") + symmetry = self._result_symmetry() if self.stack_ndim > 1: raise NotImplementedError( "`rotation` is currently limited to 1D Volume stacks." @@ -549,7 +552,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1))) ) - return self.__class__(vol) + return self.__class__(vol, symmetry_group=symmetry) def denoise(self): raise NotImplementedError From b39326e39e6ead238d3ea1be5c4db59cbddb4c41 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 14 Feb 2024 15:50:00 -0500 Subject: [PATCH 09/24] missing stacklevel in warn. --- tests/test_volume.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index a489c48feb..576ca44754 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -771,7 +771,7 @@ def test_symmetry_group_reset_warning(): _ = vol_r.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)) # Throw single test warning. - warnings.warn("test", Warning) + warnings.warn("test", Warning, stacklevel=2) # Should only have test warning assert len(record) == 1 From e6b2e1d06afdbfcd4c31d1d9ed7e1d4368eb003b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 15 Feb 2024 15:23:39 -0500 Subject: [PATCH 10/24] add tests for algebraic ops warnings. --- tests/test_volume.py | 89 +++++++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 30 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 576ca44754..176a10c3e5 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -13,6 +13,7 @@ from aspire.utils import Rotation, anorm, grid_2d, powerset, utest_tolerance from aspire.volume import ( AsymmetricVolume, + CnSymmetricVolume, CnSymmetryGroup, SymmetryGroup, TSymmetryGroup, @@ -88,6 +89,13 @@ def asym_vols(res, dtype): return vols +@pytest.fixture(scope="module") +def symmetric_vols(res, dtype): + vol_c3 = CnSymmetricVolume(L=res, C=1, order=3, dtype=dtype, seed=0).generate() + vol_c4 = CnSymmetricVolume(L=res, C=1, order=4, dtype=dtype, seed=0).generate() + return vol_c3, vol_c4 + + @pytest.fixture(scope="module") def vols_hot_cold(res, dtype): L = res @@ -727,36 +735,37 @@ def test_symmetry_group_set_get(sym_group, sym_string): _ = Volume(data, symmetry_group=123, dtype=dtype) -def test_symmetry_group_pass_through(): - sym_group = "C5" - vol = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")), - symmetry_group=sym_group, - ) +def test_symmetry_group_pass_through(symmetric_vols): + vol_c3, _ = symmetric_vols + sym_group = str(vol_c3.symmetry_group) + assert sym_group == "C3" # Check symmetry_group pass-through for various transformations. - assert str(vol.astype(np.float64).symmetry_group) == sym_group # astype - assert str(vol[0].symmetry_group) == sym_group # getitem - assert str(vol.stack_reshape((1, 1)).symmetry_group) == sym_group # stack_reshape + assert str(vol_c3.astype(np.float64).symmetry_group) == sym_group # astype + assert str(vol_c3[0].symmetry_group) == sym_group # getitem assert ( - str(vol.downsample(vol.resolution // 2).symmetry_group) == sym_group + str(vol_c3.stack_reshape((1, 1)).symmetry_group) == sym_group + ) # stack_reshape + assert ( + str(vol_c3.downsample(vol_c3.resolution // 2).symmetry_group) == sym_group ) # downsample -def test_symmetry_group_reset_warning(): - sym_group = "C5" - vol = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")), - symmetry_group=sym_group, - ) +def test_transformation_symmetry_warnings(symmetric_vols): + """ + A warning should be emitted (once) for transpose, flip, rotate, add, sub, mult, div. + """ + vol_c3, _ = symmetric_vols + sym_group = str(vol_c3.symmetry_group) + assert sym_group == "C3" # Check we get warning on first transformation. with pytest.warns( UserWarning, match=r".*`symmetry_group` attribute is being set to `C1`.*" ) as record: - vol_t = vol.T - vol_f = vol.flip() - vol_r = vol.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)) + vol_t = vol_c3.T + vol_f = vol_c3.flip() + vol_r = vol_c3.rotate(Rotation.about_axis("x", np.pi, dtype=vol_c3.dtype)) assert len(record) == 3 # Check symmetry_group has been set to C1. @@ -764,21 +773,41 @@ def test_symmetry_group_reset_warning(): assert str(vol_f.symmetry_group) == "C1" assert str(vol_r.symmetry_group) == "C1" - # Check we get no warnings on second transformation. - with pytest.warns() as record: - _ = vol_t.T - _ = vol_f.flip() - _ = vol_r.rotate(Rotation.about_axis("x", np.pi, dtype=vol.dtype)) - - # Throw single test warning. - warnings.warn("test", Warning, stacklevel=2) - # Should only have test warning assert len(record) == 1 assert str(record[0].message) == "test" - # Check original volume has retained C5 symmetry. - assert str(vol.symmetry_group) == "C5" + # Check original volume has retained C3 symmetry. + assert str(vol_c3.symmetry_group) == "C3" + + +def test_aglebraic_ops_symmetry_warnings(symmetric_vols): + vol_c3, vol_c4 = symmetric_vols + + # Compatible symmetry should retain symmetry_group. + assert (vol_c3 + vol_c3).symmetry_group == vol_c3.symmetry_group + assert (vol_c3 - vol_c3).symmetry_group == vol_c3.symmetry_group + assert (vol_c3 * vol_c3).symmetry_group == vol_c3.symmetry_group + assert ( + vol_c3 / (vol_c3 + 1) + ).symmetry_group == vol_c3.symmetry_group # plus 1 to avoid division by 0. + + # Incompatible symmetry should warn and set symmetry_group to C1. + with pytest.warns( + UserWarning, match=r".*`symmetry_group` attribute is being set to `C1`.*" + ) as record: + vols_sum = vol_c3 + vol_c4 + vol_array_diff = vol_c3 - vol_c4.asnumpy() + vols_mult = vol_c3 * vol_c4 + vol_array_div = vol_c3 / (vol_c4.asnumpy() + 1) + + assert str(vols_sum.symmetry_group) == "C1" + assert str(vol_array_diff.symmetry_group) == "C1" + assert str(vols_mult.symmetry_group) == "C1" + assert str(vol_array_div.symmetry_group) == "C1" + + # Should have 4 warnings on record. + assert len(record) == 4 def test_volume_load_with_symmetry(): From d317fbbbef1416f044261144cca2466de1270315 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 15 Feb 2024 15:29:01 -0500 Subject: [PATCH 11/24] remove unused import --- tests/test_volume.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 176a10c3e5..4536dd6d06 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,7 +1,6 @@ import logging import os import tempfile -import warnings from itertools import product import numpy as np From 5a357aa77374003ea61e9cf3ccef3e45307459a7 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 15 Feb 2024 15:58:14 -0500 Subject: [PATCH 12/24] remove deprecated line --- tests/test_volume.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 4536dd6d06..ed898288a6 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -772,10 +772,6 @@ def test_transformation_symmetry_warnings(symmetric_vols): assert str(vol_f.symmetry_group) == "C1" assert str(vol_r.symmetry_group) == "C1" - # Should only have test warning - assert len(record) == 1 - assert str(record[0].message) == "test" - # Check original volume has retained C3 symmetry. assert str(vol_c3.symmetry_group) == "C3" From 2a8bff0cc46023e9fac0ec214fd199d3173a1773 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 16 Feb 2024 09:51:26 -0500 Subject: [PATCH 13/24] Don't check for C1. --- src/aspire/volume/volume.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index d27971fec6..bd71f7a8fa 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -181,27 +181,35 @@ def _set_symmetry_group(self, value): def _symmetry_group_warning(self): """ - Warn when a transformation has the potential to change the symmetry + Warn when an operation has the potential to change the symmetry of a volume or the alignment of symmetry axes. """ - if str(self.symmetry_group) != "C1": - msg = ( - f"`symmetry_group` attribute is being set to `C1`. This transformation may" - f" effect the symmetry (or symmetric alignment) of the volume. To reset the" - f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." - ) - warnings.warn(msg, UserWarning, stacklevel=4) + msg = ( + f"`symmetry_group` attribute is being set to `C1`. This operation may" + f" effect the symmetry (or symmetric alignment) of the volume. To reset the" + f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." + ) + warnings.warn(msg, UserWarning, stacklevel=4) def _result_symmetry(self, other=None): """ - Check if `other` will alter symmetry of `self` and return appropriate symmetry_group. + Check if `other` will alter symmetry of `self` and return resulting symmetry_group. + + :return: SymmetryGroup instance. """ + # No need to check for C1. + if str(self.symmetry_group) == "C1": + return self.symmetry_group + + # Handle case of transformations of `self`. if other is None: self._symmetry_group_warning() return IdentitySymmetryGroup(dtype=self.dtype) + # Handle binary operations of `self` and `other`. result_symmetry = self.symmetry_group - # Warn and set to IdentitySymmetryGroup if incompatible symmetries or non-scalar. + + # Warn and set to IdentitySymmetryGroup if `other` is symmetrically incompatible. if isinstance(other, Volume) and str(self.symmetry_group) != str( other.symmetry_group ): From 99d89ea06514a04b0f226138567ad7192b0360f0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 16 Feb 2024 10:06:20 -0500 Subject: [PATCH 14/24] supress warnings in test. --- tests/test_synthetic_volume.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_synthetic_volume.py b/tests/test_synthetic_volume.py index 57fdc25116..8860b3532d 100644 --- a/tests/test_synthetic_volume.py +++ b/tests/test_synthetic_volume.py @@ -119,6 +119,8 @@ def test_compact_support(vol_fixture): assert (vol.asnumpy()[0][inside] > 0).all() +# Supress expected warnings due to rotation of symmetric volume. +@pytest.mark.filterwarnings("ignore:`symmetry_group` attribute is being set to `C1`") def test_volume_symmetry(vol_fixture): """Test that volumes have intended symmetry.""" vol = vol_fixture.generate() From d0f4112f1b05eb04b4c4c8570b3217e3fc1c64a1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 16 Feb 2024 11:24:43 -0500 Subject: [PATCH 15/24] refactor result_symmetry. --- src/aspire/volume/volume.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index bd71f7a8fa..90dc91722f 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -198,27 +198,20 @@ def _result_symmetry(self, other=None): :return: SymmetryGroup instance. """ # No need to check for C1. - if str(self.symmetry_group) == "C1": + if isinstance(self.symmetry_group, IdentitySymmetryGroup): return self.symmetry_group - # Handle case of transformations of `self`. - if other is None: - self._symmetry_group_warning() - return IdentitySymmetryGroup(dtype=self.dtype) - - # Handle binary operations of `self` and `other`. + # Warn and set to IdentitySymmetryGroup if `other` is symmetrically incompatible. result_symmetry = self.symmetry_group - # Warn and set to IdentitySymmetryGroup if `other` is symmetrically incompatible. - if isinstance(other, Volume) and str(self.symmetry_group) != str( + # Conditions of incompatibility. + self_transformation = other is None + incompat_syms = isinstance(other, Volume) and str(self.symmetry_group) != str( other.symmetry_group - ): - self._symmetry_group_warning() - result_symmetry = IdentitySymmetryGroup(dtype=self.dtype) + ) + arbitrary_array = not isinstance(other, Volume) and hasattr(other, "__len__") - elif not isinstance(other, Volume) and hasattr( - other, "__len__" - ): # ie. is not a scalar + if any([self_transformation, incompat_syms, arbitrary_array]): self._symmetry_group_warning() result_symmetry = IdentitySymmetryGroup(dtype=self.dtype) From 1ec33f371500e20c604421954388ef733d7493fb Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 16 Feb 2024 12:03:48 -0500 Subject: [PATCH 16/24] Override SymmetryGroup equivalence. --- src/aspire/volume/symmetry_groups.py | 5 +++++ tests/test_symmetry_groups.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index 95b1954af9..171b77605e 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -26,6 +26,11 @@ def generate_rotations(self): Method for generating a Rotation object for the symmetry group. """ + def __eq__(self, other): + if isinstance(other, self.__class__): + return str(self) == str(other) + return False + @property def matrices(self): return self.rotations.matrices diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index d55d37c566..34f64543ab 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -66,6 +66,14 @@ def test_group_str(group_fixture): logger.debug(f"String for {group_fixture}: {sym_string}.") +def test_group_equivalence(group_fixture): + C2_symmetry_group = CnSymmetryGroup(order=2, dtype=group_fixture.dtype) + if str(group_fixture) == "C2": + assert C2_symmetry_group == group_fixture + else: + assert C2_symmetry_group != group_fixture + + def test_group_rotations(group_fixture): rotations = group_fixture.rotations assert isinstance(rotations, Rotation) From 2f94889e7656e3d22db4f73453fa3f671e16de66 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 16 Feb 2024 12:23:23 -0500 Subject: [PATCH 17/24] Use symmetry_group eq. --- src/aspire/volume/volume.py | 4 ++-- src/aspire/volume/volume_synthesis.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 90dc91722f..ee7a15d093 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -206,8 +206,8 @@ def _result_symmetry(self, other=None): # Conditions of incompatibility. self_transformation = other is None - incompat_syms = isinstance(other, Volume) and str(self.symmetry_group) != str( - other.symmetry_group + incompat_syms = ( + isinstance(other, Volume) and self.symmetry_group != other.symmetry_group ) arbitrary_array = not isinstance(other, Volume) and hasattr(other, "__len__") diff --git a/src/aspire/volume/volume_synthesis.py b/src/aspire/volume/volume_synthesis.py index c10c466081..b9514df5ea 100644 --- a/src/aspire/volume/volume_synthesis.py +++ b/src/aspire/volume/volume_synthesis.py @@ -8,6 +8,7 @@ from aspire.volume import ( CnSymmetryGroup, DnSymmetryGroup, + IdentitySymmetryGroup, OSymmetryGroup, TSymmetryGroup, Volume, @@ -247,6 +248,9 @@ def _check_order(self): f"An {self.__class__.__name__} must have order=1. Provided order was {self.order}" ) + def _set_symmetry_group(self): + self._symmetry_group = IdentitySymmetryGroup(dtype=self.dtype) + def _symmetrize_gaussians(self, Q, D, mu): return Q, D, mu From aa3177b6e99cfbee550c0abfa8f3862ad8fbc893 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 19 Feb 2024 10:02:46 -0500 Subject: [PATCH 18/24] ensure no warnings emitted for compatible symmetries. --- tests/test_volume.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index ed898288a6..63972c2d72 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,6 +1,7 @@ import logging import os import tempfile +import warnings from itertools import product import numpy as np @@ -779,13 +780,15 @@ def test_transformation_symmetry_warnings(symmetric_vols): def test_aglebraic_ops_symmetry_warnings(symmetric_vols): vol_c3, vol_c4 = symmetric_vols - # Compatible symmetry should retain symmetry_group. - assert (vol_c3 + vol_c3).symmetry_group == vol_c3.symmetry_group - assert (vol_c3 - vol_c3).symmetry_group == vol_c3.symmetry_group - assert (vol_c3 * vol_c3).symmetry_group == vol_c3.symmetry_group - assert ( - vol_c3 / (vol_c3 + 1) - ).symmetry_group == vol_c3.symmetry_group # plus 1 to avoid division by 0. + # Compatible symmetry should retain symmetry_group and emit no warning. + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert (vol_c3 + vol_c3).symmetry_group == vol_c3.symmetry_group + assert (vol_c3 - vol_c3).symmetry_group == vol_c3.symmetry_group + assert (vol_c3 * vol_c3).symmetry_group == vol_c3.symmetry_group + assert ( + vol_c3 / (vol_c3 + 1) + ).symmetry_group == vol_c3.symmetry_group # plus 1 to avoid division by 0. # Incompatible symmetry should warn and set symmetry_group to C1. with pytest.warns( From d80121c2c9b0bd735f3384699eaf1337023cb58a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 19 Feb 2024 10:36:29 -0500 Subject: [PATCH 19/24] use getattr in place of hasattr. --- src/aspire/volume/volume.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index ee7a15d093..0660bf2d03 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -209,7 +209,7 @@ def _result_symmetry(self, other=None): incompat_syms = ( isinstance(other, Volume) and self.symmetry_group != other.symmetry_group ) - arbitrary_array = not isinstance(other, Volume) and hasattr(other, "__len__") + arbitrary_array = not isinstance(other, Volume) and getattr(other, size, 1) > 1 if any([self_transformation, incompat_syms, arbitrary_array]): self._symmetry_group_warning() From 2799bd002e96af8cef6ebad1841fb2303f6ccb30 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 19 Feb 2024 11:51:34 -0500 Subject: [PATCH 20/24] Fix test docstrings. Fix warning stacklevel for T and transpose(). --- src/aspire/volume/volume.py | 12 +++++++----- tests/test_volume.py | 7 +++++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 0660bf2d03..2b6b1c392e 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -209,7 +209,7 @@ def _result_symmetry(self, other=None): incompat_syms = ( isinstance(other, Volume) and self.symmetry_group != other.symmetry_group ) - arbitrary_array = not isinstance(other, Volume) and getattr(other, size, 1) > 1 + arbitrary_array = not isinstance(other, Volume) and getattr(other, 'size', 1) > 1 if any([self_transformation, incompat_syms, arbitrary_array]): self._symmetry_group_warning() @@ -398,13 +398,15 @@ def from_vec(cls, vec): return cls(data) - def transpose(self): + def transpose(self, symmetry=None): """ Returns a new Volume instance with volume data axes transposed. :return: Volume instance. """ - symmetry = self._result_symmetry() + # Ensures warning stacklevel is the same for `vol.T` and `vol.transpose()`. + if symmetry is None: + symmetry = self._result_symmetry() original_stack_shape = self.stack_shape v = self._data.reshape(-1, *self._data.shape[-3:]) @@ -419,8 +421,8 @@ def T(self): :return: Volume instance. """ - - return self.transpose() + symmetry = self._result_symmetry() + return self.transpose(symmetry=symmetry) def flatten(self): """ diff --git a/tests/test_volume.py b/tests/test_volume.py index 63972c2d72..b058dcc162 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -753,13 +753,13 @@ def test_symmetry_group_pass_through(symmetric_vols): def test_transformation_symmetry_warnings(symmetric_vols): """ - A warning should be emitted (once) for transpose, flip, rotate, add, sub, mult, div. + A warning should be emitted for transpose, flip, and rotate. """ vol_c3, _ = symmetric_vols sym_group = str(vol_c3.symmetry_group) assert sym_group == "C3" - # Check we get warning on first transformation. + # Check we get warning for each transformation. with pytest.warns( UserWarning, match=r".*`symmetry_group` attribute is being set to `C1`.*" ) as record: @@ -778,6 +778,9 @@ def test_transformation_symmetry_warnings(symmetric_vols): def test_aglebraic_ops_symmetry_warnings(symmetric_vols): + """ + A warning should be emitted for add, sub, mult, and div. + """ vol_c3, vol_c4 = symmetric_vols # Compatible symmetry should retain symmetry_group and emit no warning. From 34eb26d45832f10501737ecaf1ba22d5d4bbd659 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 19 Feb 2024 14:23:24 -0500 Subject: [PATCH 21/24] black --- src/aspire/volume/volume.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 2b6b1c392e..8f28de9a7b 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -209,7 +209,9 @@ def _result_symmetry(self, other=None): incompat_syms = ( isinstance(other, Volume) and self.symmetry_group != other.symmetry_group ) - arbitrary_array = not isinstance(other, Volume) and getattr(other, 'size', 1) > 1 + arbitrary_array = ( + not isinstance(other, Volume) and getattr(other, "size", 1) > 1 + ) if any([self_transformation, incompat_syms, arbitrary_array]): self._symmetry_group_warning() From 93ff85f1310900ebab8423ff339aa3d8c1151c64 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 19 Feb 2024 15:52:01 -0500 Subject: [PATCH 22/24] Test we only warn once per line --- tests/test_volume.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_volume.py b/tests/test_volume.py index b058dcc162..a0aaa8c4b0 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -810,6 +810,12 @@ def test_aglebraic_ops_symmetry_warnings(symmetric_vols): # Should have 4 warnings on record. assert len(record) == 4 + # Check that warning occurs only once per line. + with warnings.catch_warnings(record=True) as record: + for _ in range(5): + vol_c3 + vol_c4 + assert len(record) == 1 + def test_volume_load_with_symmetry(): # Check we can load a Volume with symmetry_group. From 6d1a17c254faac0188693dce579c707d0f3c6ba6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 6 Mar 2024 12:22:38 -0500 Subject: [PATCH 23/24] Handle transpose warning stacklevel with private flag. --- src/aspire/volume/volume.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 8f28de9a7b..e60405c895 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -111,6 +111,9 @@ def __init__(self, data, dtype=None, symmetry_group=None): # Set symmetry_group. If None, default to 'C1'. self._set_symmetry_group(symmetry_group) + # Flag to ensure T and transpose have the same symmetry_group warning stacklevel. + self._called_from_T = False + # Numpy interop # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol self.__array_interface__ = self._data.__array_interface__ @@ -179,22 +182,27 @@ def _set_symmetry_group(self, value): ) self._symmetry_group = value - def _symmetry_group_warning(self): + def _symmetry_group_warning(self, stacklevel): """ Warn when an operation has the potential to change the symmetry of a volume or the alignment of symmetry axes. + + :param stacklevel: Warning stacklevel. """ msg = ( f"`symmetry_group` attribute is being set to `C1`. This operation may" f" effect the symmetry (or symmetric alignment) of the volume. To reset the" f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." ) - warnings.warn(msg, UserWarning, stacklevel=4) + warnings.warn(msg, UserWarning, stacklevel=stacklevel) - def _result_symmetry(self, other=None): + def _result_symmetry(self, other=None, stacklevel=4): """ Check if `other` will alter symmetry of `self` and return resulting symmetry_group. + :param other: Other operand to self. Default is None for self transformations. + :param stacklevel: Warning stacklevel. Default of 4 indicates line of code where + operation was performed. :return: SymmetryGroup instance. """ # No need to check for C1. @@ -214,7 +222,7 @@ def _result_symmetry(self, other=None): ) if any([self_transformation, incompat_syms, arbitrary_array]): - self._symmetry_group_warning() + self._symmetry_group_warning(stacklevel=stacklevel) result_symmetry = IdentitySymmetryGroup(dtype=self.dtype) return result_symmetry @@ -400,15 +408,18 @@ def from_vec(cls, vec): return cls(data) - def transpose(self, symmetry=None): + def transpose(self): """ Returns a new Volume instance with volume data axes transposed. :return: Volume instance. """ # Ensures warning stacklevel is the same for `vol.T` and `vol.transpose()`. - if symmetry is None: - symmetry = self._result_symmetry() + stacklevel = 4 + if self._called_from_T: + stacklevel = 5 + self._called_from_T = False + symmetry = self._result_symmetry(stacklevel=stacklevel) original_stack_shape = self.stack_shape v = self._data.reshape(-1, *self._data.shape[-3:]) @@ -423,8 +434,8 @@ def T(self): :return: Volume instance. """ - symmetry = self._result_symmetry() - return self.transpose(symmetry=symmetry) + self._called_from_T = True + return self.transpose() def flatten(self): """ From 96c37423b124cabb62ef11edd646070dab05186d Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 8 Mar 2024 08:38:09 -0500 Subject: [PATCH 24/24] change variable name. --- src/aspire/volume/volume.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index e60405c895..22b8cc8624 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -213,7 +213,7 @@ def _result_symmetry(self, other=None, stacklevel=4): result_symmetry = self.symmetry_group # Conditions of incompatibility. - self_transformation = other is None + axes_altering_transformation = other is None incompat_syms = ( isinstance(other, Volume) and self.symmetry_group != other.symmetry_group ) @@ -221,7 +221,7 @@ def _result_symmetry(self, other=None, stacklevel=4): not isinstance(other, Volume) and getattr(other, "size", 1) > 1 ) - if any([self_transformation, incompat_syms, arbitrary_array]): + if any([axes_altering_transformation, incompat_syms, arbitrary_array]): self._symmetry_group_warning(stacklevel=stacklevel) result_symmetry = IdentitySymmetryGroup(dtype=self.dtype)