Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/aspire/volume/symmetry_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def generate_rotations(self):
Method for generating a Rotation object for the symmetry group.
"""

def __eq__(self, other):
if isinstance(other, self.__class__):
return str(self) == str(other)
return False

@property
def matrices(self):
return self.rotations.matrices
Expand Down
109 changes: 90 additions & 19 deletions src/aspire/volume/volume.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings

import mrcfile
import numpy as np
Expand Down Expand Up @@ -110,6 +111,9 @@ def __init__(self, data, dtype=None, symmetry_group=None):
# Set symmetry_group. If None, default to 'C1'.
self._set_symmetry_group(symmetry_group)

# Flag to ensure T and transpose have the same symmetry_group warning stacklevel.
self._called_from_T = False

# Numpy interop
# https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol
self.__array_interface__ = self._data.__array_interface__
Expand All @@ -135,7 +139,9 @@ def astype(self, dtype, copy=True):
Defaults to True.
:return: Volume instance
"""
return self.__class__(self.asnumpy().astype(dtype, copy=copy))
return self.__class__(
self.asnumpy().astype(dtype, copy=copy), symmetry_group=self.symmetry_group
)

def _check_key_dims(self, key):
if isinstance(key, tuple) and (len(key) > self._data.ndim):
Expand All @@ -145,7 +151,7 @@ def _check_key_dims(self, key):

def __getitem__(self, key):
self._check_key_dims(key)
return self.__class__(self._data[key])
return self.__class__(self._data[key], symmetry_group=self.symmetry_group)

def __setitem__(self, key, value):
self._check_key_dims(key)
Expand Down Expand Up @@ -176,6 +182,51 @@ def _set_symmetry_group(self, value):
)
self._symmetry_group = value

def _symmetry_group_warning(self, stacklevel):
"""
Warn when an operation has the potential to change the symmetry
of a volume or the alignment of symmetry axes.

:param stacklevel: Warning stacklevel.
"""
msg = (
f"`symmetry_group` attribute is being set to `C1`. This operation may"
f" effect the symmetry (or symmetric alignment) of the volume. To reset the"
f" symmetry group run `self._set_symmetry_group('{self.symmetry_group}')`."
)
warnings.warn(msg, UserWarning, stacklevel=stacklevel)

def _result_symmetry(self, other=None, stacklevel=4):
"""
Check if `other` will alter symmetry of `self` and return resulting symmetry_group.

:param other: Other operand to self. Default is None for self transformations.
:param stacklevel: Warning stacklevel. Default of 4 indicates line of code where
operation was performed.
:return: SymmetryGroup instance.
"""
# No need to check for C1.
if isinstance(self.symmetry_group, IdentitySymmetryGroup):
return self.symmetry_group

# Warn and set to IdentitySymmetryGroup if `other` is symmetrically incompatible.
result_symmetry = self.symmetry_group

# Conditions of incompatibility.
axes_altering_transformation = other is None
incompat_syms = (
isinstance(other, Volume) and self.symmetry_group != other.symmetry_group
)
arbitrary_array = (
not isinstance(other, Volume) and getattr(other, "size", 1) > 1
)

if any([axes_altering_transformation, incompat_syms, arbitrary_array]):
self._symmetry_group_warning(stacklevel=stacklevel)
result_symmetry = IdentitySymmetryGroup(dtype=self.dtype)

return result_symmetry

def stack_reshape(self, *args):
"""
Reshape the stack axis.
Expand All @@ -198,7 +249,10 @@ def stack_reshape(self, *args):
f"Number of volumes {self.n_vols} cannot be reshaped to {shape}."
)

return self.__class__(self._data.reshape(*shape, *self._data.shape[-3:]))
return self.__class__(
self._data.reshape(*shape, *self._data.shape[-3:]),
symmetry_group=self.symmetry_group,
)

def __repr__(self):
msg = (
Expand All @@ -211,32 +265,35 @@ def __len__(self):
return self.n_vols

def __add__(self, other):
symmetry = self._result_symmetry(other)
if isinstance(other, Volume):
res = self.__class__(self._data + other.asnumpy())
res = self.__class__(self._data + other.asnumpy(), symmetry_group=symmetry)
else:
res = self.__class__(self._data + other)
res = self.__class__(self._data + other, symmetry_group=symmetry)

return res

def __radd__(self, otherL):
return self + otherL

def __sub__(self, other):
symmetry = self._result_symmetry(other)
if isinstance(other, Volume):
res = self.__class__(self._data - other.asnumpy())
res = self.__class__(self._data - other.asnumpy(), symmetry_group=symmetry)
else:
res = self.__class__(self._data - other)
res = self.__class__(self._data - other, symmetry_group=symmetry)

return res

def __rsub__(self, otherL):
return self.__class__(otherL - self._data)

def __mul__(self, other):
symmetry = self._result_symmetry(other)
if isinstance(other, Volume):
res = self.__class__(self._data * other.asnumpy())
res = self.__class__(self._data * other.asnumpy(), symmetry_group=symmetry)
else:
res = self.__class__(self._data * other)
res = self.__class__(self._data * other, symmetry_group=symmetry)

return res

Expand All @@ -247,10 +304,11 @@ def __truediv__(self, other):
"""
Scalar division, follows numpy semantics.
"""
symmetry = self._result_symmetry(other)
if isinstance(other, Volume):
res = self.__class__(self._data / other.asnumpy())
res = self.__class__(self._data / other.asnumpy(), symmetry_group=symmetry)
else:
res = self.__class__(self._data / other)
res = self.__class__(self._data / other, symmetry_group=symmetry)

return res

Expand Down Expand Up @@ -352,14 +410,22 @@ def from_vec(cls, vec):

def transpose(self):
"""
Returns a new Volume instance with volume data axes tranposed.
Returns a new Volume instance with volume data axes transposed.

:return: Volume instance.
"""
# Ensures warning stacklevel is the same for `vol.T` and `vol.transpose()`.
stacklevel = 4
if self._called_from_T:
stacklevel = 5
self._called_from_T = False
symmetry = self._result_symmetry(stacklevel=stacklevel)

original_stack_shape = self.stack_shape
v = self.stack_reshape(-1)
vt = np.transpose(v._data, (0, -1, -2, -3))
return self.__class__(vt).stack_reshape(original_stack_shape)
v = self._data.reshape(-1, *self._data.shape[-3:])
vt = np.transpose(v, (0, -1, -2, -3))
vt = vt.reshape(*original_stack_shape, *self._data.shape[-3:])
return self.__class__(vt, symmetry_group=symmetry)

@property
def T(self):
Expand All @@ -368,7 +434,7 @@ def T(self):

:return: Volume instance.
"""

self._called_from_T = True
return self.transpose()

def flatten(self):
Expand All @@ -389,6 +455,8 @@ def flip(self, axis=-3):

:return: Volume instance.
"""
symmetry = self._result_symmetry()

# Convert integer to tuple, so we can always loop.
if isinstance(axis, int):
axis = (axis,)
Expand All @@ -400,7 +468,7 @@ def flip(self, axis=-3):
f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?"
)

return self.__class__(np.flip(self._data, axis))
return self.__class__(np.flip(self._data, axis), symmetry_group=symmetry)

def downsample(self, ds_res, mask=None):
"""
Expand All @@ -427,7 +495,9 @@ def downsample(self, ds_res, mask=None):
ds_res**3 / self.resolution**3
)
# returns a new Volume object
return self.__class__(np.real(out)).stack_reshape(original_stack_shape)
return self.__class__(
np.real(out), symmetry_group=self.symmetry_group
).stack_reshape(original_stack_shape)

def shift(self):
raise NotImplementedError
Expand All @@ -444,6 +514,7 @@ def rotate(self, rot_matrices, zero_nyquist=True):

:return: `Volume` instance.
"""
symmetry = self._result_symmetry()
if self.stack_ndim > 1:
raise NotImplementedError(
"`rotation` is currently limited to 1D Volume stacks."
Expand Down Expand Up @@ -497,7 +568,7 @@ def rotate(self, rot_matrices, zero_nyquist=True):
np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1)))
)

return self.__class__(vol)
return self.__class__(vol, symmetry_group=symmetry)

def denoise(self):
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions src/aspire/volume/volume_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aspire.volume import (
CnSymmetryGroup,
DnSymmetryGroup,
IdentitySymmetryGroup,
OSymmetryGroup,
TSymmetryGroup,
Volume,
Expand Down Expand Up @@ -247,6 +248,9 @@ def _check_order(self):
f"An {self.__class__.__name__} must have order=1. Provided order was {self.order}"
)

def _set_symmetry_group(self):
self._symmetry_group = IdentitySymmetryGroup(dtype=self.dtype)

def _symmetrize_gaussians(self, Q, D, mu):
return Q, D, mu

Expand Down
8 changes: 8 additions & 0 deletions tests/test_symmetry_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def test_group_str(group_fixture):
logger.debug(f"String for {group_fixture}: {sym_string}.")


def test_group_equivalence(group_fixture):
C2_symmetry_group = CnSymmetryGroup(order=2, dtype=group_fixture.dtype)
if str(group_fixture) == "C2":
assert C2_symmetry_group == group_fixture
else:
assert C2_symmetry_group != group_fixture


def test_group_rotations(group_fixture):
rotations = group_fixture.rotations
assert isinstance(rotations, Rotation)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_synthetic_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def test_compact_support(vol_fixture):
assert (vol.asnumpy()[0][inside] > 0).all()


# Supress expected warnings due to rotation of symmetric volume.
@pytest.mark.filterwarnings("ignore:`symmetry_group` attribute is being set to `C1`")
def test_volume_symmetry(vol_fixture):
"""Test that volumes have intended symmetry."""
vol = vol_fixture.generate()
Expand Down
Loading