diff --git a/src/aspire/volume/symmetry_groups.py b/src/aspire/volume/symmetry_groups.py index 95b1954af9..171b77605e 100644 --- a/src/aspire/volume/symmetry_groups.py +++ b/src/aspire/volume/symmetry_groups.py @@ -26,6 +26,11 @@ def generate_rotations(self): Method for generating a Rotation object for the symmetry group. """ + def __eq__(self, other): + if isinstance(other, self.__class__): + return str(self) == str(other) + return False + @property def matrices(self): return self.rotations.matrices diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index cc29710045..22b8cc8624 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -1,4 +1,5 @@ import logging +import warnings import mrcfile import numpy as np @@ -110,6 +111,9 @@ def __init__(self, data, dtype=None, symmetry_group=None): # Set symmetry_group. If None, default to 'C1'. self._set_symmetry_group(symmetry_group) + # Flag to ensure T and transpose have the same symmetry_group warning stacklevel. + self._called_from_T = False + # Numpy interop # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol self.__array_interface__ = self._data.__array_interface__ @@ -135,7 +139,9 @@ def astype(self, dtype, copy=True): Defaults to True. :return: Volume instance """ - return self.__class__(self.asnumpy().astype(dtype, copy=copy)) + return self.__class__( + self.asnumpy().astype(dtype, copy=copy), symmetry_group=self.symmetry_group + ) def _check_key_dims(self, key): if isinstance(key, tuple) and (len(key) > self._data.ndim): @@ -145,7 +151,7 @@ def _check_key_dims(self, key): def __getitem__(self, key): self._check_key_dims(key) - return self.__class__(self._data[key]) + return self.__class__(self._data[key], symmetry_group=self.symmetry_group) def __setitem__(self, key, value): self._check_key_dims(key) @@ -176,6 +182,51 @@ def _set_symmetry_group(self, value): ) self._symmetry_group = value + def _symmetry_group_warning(self, stacklevel): + """ + Warn when an operation has the potential to change the symmetry + of a volume or the alignment of symmetry axes. + + :param stacklevel: Warning stacklevel. + """ + msg = ( + f"`symmetry_group` attribute is being set to `C1`. This operation may" + f" effect the symmetry (or symmetric alignment) of the volume. To reset the" + f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`." + ) + warnings.warn(msg, UserWarning, stacklevel=stacklevel) + + def _result_symmetry(self, other=None, stacklevel=4): + """ + Check if `other` will alter symmetry of `self` and return resulting symmetry_group. + + :param other: Other operand to self. Default is None for self transformations. + :param stacklevel: Warning stacklevel. Default of 4 indicates line of code where + operation was performed. + :return: SymmetryGroup instance. + """ + # No need to check for C1. + if isinstance(self.symmetry_group, IdentitySymmetryGroup): + return self.symmetry_group + + # Warn and set to IdentitySymmetryGroup if `other` is symmetrically incompatible. + result_symmetry = self.symmetry_group + + # Conditions of incompatibility. + axes_altering_transformation = other is None + incompat_syms = ( + isinstance(other, Volume) and self.symmetry_group != other.symmetry_group + ) + arbitrary_array = ( + not isinstance(other, Volume) and getattr(other, "size", 1) > 1 + ) + + if any([axes_altering_transformation, incompat_syms, arbitrary_array]): + self._symmetry_group_warning(stacklevel=stacklevel) + result_symmetry = IdentitySymmetryGroup(dtype=self.dtype) + + return result_symmetry + def stack_reshape(self, *args): """ Reshape the stack axis. @@ -198,7 +249,10 @@ def stack_reshape(self, *args): f"Number of volumes {self.n_vols} cannot be reshaped to {shape}." ) - return self.__class__(self._data.reshape(*shape, *self._data.shape[-3:])) + return self.__class__( + self._data.reshape(*shape, *self._data.shape[-3:]), + symmetry_group=self.symmetry_group, + ) def __repr__(self): msg = ( @@ -211,10 +265,11 @@ def __len__(self): return self.n_vols def __add__(self, other): + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data + other.asnumpy()) + res = self.__class__(self._data + other.asnumpy(), symmetry_group=symmetry) else: - res = self.__class__(self._data + other) + res = self.__class__(self._data + other, symmetry_group=symmetry) return res @@ -222,10 +277,11 @@ def __radd__(self, otherL): return self + otherL def __sub__(self, other): + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data - other.asnumpy()) + res = self.__class__(self._data - other.asnumpy(), symmetry_group=symmetry) else: - res = self.__class__(self._data - other) + res = self.__class__(self._data - other, symmetry_group=symmetry) return res @@ -233,10 +289,11 @@ def __rsub__(self, otherL): return self.__class__(otherL - self._data) def __mul__(self, other): + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data * other.asnumpy()) + res = self.__class__(self._data * other.asnumpy(), symmetry_group=symmetry) else: - res = self.__class__(self._data * other) + res = self.__class__(self._data * other, symmetry_group=symmetry) return res @@ -247,10 +304,11 @@ def __truediv__(self, other): """ Scalar division, follows numpy semantics. """ + symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data / other.asnumpy()) + res = self.__class__(self._data / other.asnumpy(), symmetry_group=symmetry) else: - res = self.__class__(self._data / other) + res = self.__class__(self._data / other, symmetry_group=symmetry) return res @@ -352,14 +410,22 @@ def from_vec(cls, vec): def transpose(self): """ - Returns a new Volume instance with volume data axes tranposed. + Returns a new Volume instance with volume data axes transposed. :return: Volume instance. """ + # Ensures warning stacklevel is the same for `vol.T` and `vol.transpose()`. + stacklevel = 4 + if self._called_from_T: + stacklevel = 5 + self._called_from_T = False + symmetry = self._result_symmetry(stacklevel=stacklevel) + original_stack_shape = self.stack_shape - v = self.stack_reshape(-1) - vt = np.transpose(v._data, (0, -1, -2, -3)) - return self.__class__(vt).stack_reshape(original_stack_shape) + v = self._data.reshape(-1, *self._data.shape[-3:]) + vt = np.transpose(v, (0, -1, -2, -3)) + vt = vt.reshape(*original_stack_shape, *self._data.shape[-3:]) + return self.__class__(vt, symmetry_group=symmetry) @property def T(self): @@ -368,7 +434,7 @@ def T(self): :return: Volume instance. """ - + self._called_from_T = True return self.transpose() def flatten(self): @@ -389,6 +455,8 @@ def flip(self, axis=-3): :return: Volume instance. """ + symmetry = self._result_symmetry() + # Convert integer to tuple, so we can always loop. if isinstance(axis, int): axis = (axis,) @@ -400,7 +468,7 @@ def flip(self, axis=-3): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?" ) - return self.__class__(np.flip(self._data, axis)) + return self.__class__(np.flip(self._data, axis), symmetry_group=symmetry) def downsample(self, ds_res, mask=None): """ @@ -427,7 +495,9 @@ def downsample(self, ds_res, mask=None): ds_res**3 / self.resolution**3 ) # returns a new Volume object - return self.__class__(np.real(out)).stack_reshape(original_stack_shape) + return self.__class__( + np.real(out), symmetry_group=self.symmetry_group + ).stack_reshape(original_stack_shape) def shift(self): raise NotImplementedError @@ -444,6 +514,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): :return: `Volume` instance. """ + symmetry = self._result_symmetry() if self.stack_ndim > 1: raise NotImplementedError( "`rotation` is currently limited to 1D Volume stacks." @@ -497,7 +568,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1))) ) - return self.__class__(vol) + return self.__class__(vol, symmetry_group=symmetry) def denoise(self): raise NotImplementedError diff --git a/src/aspire/volume/volume_synthesis.py b/src/aspire/volume/volume_synthesis.py index c10c466081..b9514df5ea 100644 --- a/src/aspire/volume/volume_synthesis.py +++ b/src/aspire/volume/volume_synthesis.py @@ -8,6 +8,7 @@ from aspire.volume import ( CnSymmetryGroup, DnSymmetryGroup, + IdentitySymmetryGroup, OSymmetryGroup, TSymmetryGroup, Volume, @@ -247,6 +248,9 @@ def _check_order(self): f"An {self.__class__.__name__} must have order=1. Provided order was {self.order}" ) + def _set_symmetry_group(self): + self._symmetry_group = IdentitySymmetryGroup(dtype=self.dtype) + def _symmetrize_gaussians(self, Q, D, mu): return Q, D, mu diff --git a/tests/test_symmetry_groups.py b/tests/test_symmetry_groups.py index d55d37c566..34f64543ab 100644 --- a/tests/test_symmetry_groups.py +++ b/tests/test_symmetry_groups.py @@ -66,6 +66,14 @@ def test_group_str(group_fixture): logger.debug(f"String for {group_fixture}: {sym_string}.") +def test_group_equivalence(group_fixture): + C2_symmetry_group = CnSymmetryGroup(order=2, dtype=group_fixture.dtype) + if str(group_fixture) == "C2": + assert C2_symmetry_group == group_fixture + else: + assert C2_symmetry_group != group_fixture + + def test_group_rotations(group_fixture): rotations = group_fixture.rotations assert isinstance(rotations, Rotation) diff --git a/tests/test_synthetic_volume.py b/tests/test_synthetic_volume.py index 57fdc25116..8860b3532d 100644 --- a/tests/test_synthetic_volume.py +++ b/tests/test_synthetic_volume.py @@ -119,6 +119,8 @@ def test_compact_support(vol_fixture): assert (vol.asnumpy()[0][inside] > 0).all() +# Supress expected warnings due to rotation of symmetric volume. +@pytest.mark.filterwarnings("ignore:`symmetry_group` attribute is being set to `C1`") def test_volume_symmetry(vol_fixture): """Test that volumes have intended symmetry.""" vol = vol_fixture.generate() diff --git a/tests/test_volume.py b/tests/test_volume.py index 9d039c4c9f..a0aaa8c4b0 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,6 +1,7 @@ import logging import os import tempfile +import warnings from itertools import product import numpy as np @@ -12,6 +13,7 @@ from aspire.utils import Rotation, anorm, grid_2d, powerset, utest_tolerance from aspire.volume import ( AsymmetricVolume, + CnSymmetricVolume, CnSymmetryGroup, SymmetryGroup, TSymmetryGroup, @@ -87,6 +89,13 @@ def asym_vols(res, dtype): return vols +@pytest.fixture(scope="module") +def symmetric_vols(res, dtype): + vol_c3 = CnSymmetricVolume(L=res, C=1, order=3, dtype=dtype, seed=0).generate() + vol_c4 = CnSymmetricVolume(L=res, C=1, order=4, dtype=dtype, seed=0).generate() + return vol_c3, vol_c4 + + @pytest.fixture(scope="module") def vols_hot_cold(res, dtype): L = res @@ -726,6 +735,88 @@ def test_symmetry_group_set_get(sym_group, sym_string): _ = Volume(data, symmetry_group=123, dtype=dtype) +def test_symmetry_group_pass_through(symmetric_vols): + vol_c3, _ = symmetric_vols + sym_group = str(vol_c3.symmetry_group) + assert sym_group == "C3" + + # Check symmetry_group pass-through for various transformations. + assert str(vol_c3.astype(np.float64).symmetry_group) == sym_group # astype + assert str(vol_c3[0].symmetry_group) == sym_group # getitem + assert ( + str(vol_c3.stack_reshape((1, 1)).symmetry_group) == sym_group + ) # stack_reshape + assert ( + str(vol_c3.downsample(vol_c3.resolution // 2).symmetry_group) == sym_group + ) # downsample + + +def test_transformation_symmetry_warnings(symmetric_vols): + """ + A warning should be emitted for transpose, flip, and rotate. + """ + vol_c3, _ = symmetric_vols + sym_group = str(vol_c3.symmetry_group) + assert sym_group == "C3" + + # Check we get warning for each transformation. + with pytest.warns( + UserWarning, match=r".*`symmetry_group` attribute is being set to `C1`.*" + ) as record: + vol_t = vol_c3.T + vol_f = vol_c3.flip() + vol_r = vol_c3.rotate(Rotation.about_axis("x", np.pi, dtype=vol_c3.dtype)) + assert len(record) == 3 + + # Check symmetry_group has been set to C1. + assert str(vol_t.symmetry_group) == "C1" + assert str(vol_f.symmetry_group) == "C1" + assert str(vol_r.symmetry_group) == "C1" + + # Check original volume has retained C3 symmetry. + assert str(vol_c3.symmetry_group) == "C3" + + +def test_aglebraic_ops_symmetry_warnings(symmetric_vols): + """ + A warning should be emitted for add, sub, mult, and div. + """ + vol_c3, vol_c4 = symmetric_vols + + # Compatible symmetry should retain symmetry_group and emit no warning. + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert (vol_c3 + vol_c3).symmetry_group == vol_c3.symmetry_group + assert (vol_c3 - vol_c3).symmetry_group == vol_c3.symmetry_group + assert (vol_c3 * vol_c3).symmetry_group == vol_c3.symmetry_group + assert ( + vol_c3 / (vol_c3 + 1) + ).symmetry_group == vol_c3.symmetry_group # plus 1 to avoid division by 0. + + # Incompatible symmetry should warn and set symmetry_group to C1. + with pytest.warns( + UserWarning, match=r".*`symmetry_group` attribute is being set to `C1`.*" + ) as record: + vols_sum = vol_c3 + vol_c4 + vol_array_diff = vol_c3 - vol_c4.asnumpy() + vols_mult = vol_c3 * vol_c4 + vol_array_div = vol_c3 / (vol_c4.asnumpy() + 1) + + assert str(vols_sum.symmetry_group) == "C1" + assert str(vol_array_diff.symmetry_group) == "C1" + assert str(vols_mult.symmetry_group) == "C1" + assert str(vol_array_div.symmetry_group) == "C1" + + # Should have 4 warnings on record. + assert len(record) == 4 + + # Check that warning occurs only once per line. + with warnings.catch_warnings(record=True) as record: + for _ in range(5): + vol_c3 + vol_c4 + assert len(record) == 1 + + def test_volume_load_with_symmetry(): # Check we can load a Volume with symmetry_group. vol = Volume(