Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4a6976e
Use only half_theta in PolarBasis2D.
j-c-c Jul 31, 2023
8bef843
refactor _evaluate and testPolarBasis2DAdjoint.
j-c-c Jul 31, 2023
e176624
Check that x is real-valued input of _evaluate_t.
j-c-c Jul 31, 2023
525f47e
Initial refactor from PolarBasis2D ~~> PolarFT
j-c-c Aug 2, 2023
871b307
Add rots_to_clmatrix method.
j-c-c Aug 2, 2023
df93e8a
Add test_orient_sync_voting.py. Add test_build_clmatrix.
j-c-c Aug 2, 2023
f5110ec
test_estimated_rotations
j-c-c Aug 2, 2023
be3f4c8
test_estimate_shifts
j-c-c Aug 2, 2023
b3cecae
additional sync_voting tests
j-c-c Aug 2, 2023
1e7e6f2
testing for PolarFT.
j-c-c Aug 3, 2023
2dae2f7
tox
j-c-c Aug 3, 2023
f50af6b
Fix generate_shift_phases. Properly notate common line Low Pass Filte…
j-c-c Aug 9, 2023
c93059b
Move offsets param into correct argument.
j-c-c Aug 9, 2023
ff61545
refactor common_line_from_rots. Small edits to estimate_shifts.
j-c-c Aug 10, 2023
e10948f
Add offsets to tests. Revert common_line_from_rots.
j-c-c Aug 10, 2023
7a88b82
Small cleanup.
j-c-c Aug 10, 2023
ffb3b23
test for mean_angular_distance.
j-c-c Aug 14, 2023
418040d
refactor test parametrization.
j-c-c Aug 14, 2023
9c33ac7
typo
j-c-c Aug 14, 2023
97a0155
update test tolerance.
j-c-c Aug 14, 2023
f0fa11c
update code comments.
j-c-c Aug 14, 2023
0b7316c
switch to high pass filter as in matlab.
j-c-c Aug 15, 2023
d234cf2
adjust max_shift for 0/random offsets.
j-c-c Aug 16, 2023
4a90c90
Increase resolution for tighter test tols.
j-c-c Aug 16, 2023
0722f12
Seed the Volume AAAAHHH.
j-c-c Aug 16, 2023
3a0b103
Use mean_angular_distance method.
j-c-c Aug 16, 2023
08676e2
test radial modes of pft.
j-c-c Aug 17, 2023
df0b9c2
evaluate_t ~~> transform
j-c-c Aug 17, 2023
21a17dc
remove evaluate
j-c-c Aug 17, 2023
168d7f2
Move polar_ft.py to operators dir.
j-c-c Aug 17, 2023
b255fa1
Fix imports.
j-c-c Aug 17, 2023
b911610
Missed import.
j-c-c Aug 17, 2023
d6db07f
angular_radial_mode test.
j-c-c Aug 18, 2023
69905f1
Set CLSyncVoting shift_step to less than 1. Improves commonline detec…
j-c-c Aug 18, 2023
ab9cc6c
Use proper reshape instead of order F with transpose.
j-c-c Aug 23, 2023
fce7d7b
clarifying comment.
j-c-c Aug 23, 2023
ce158c7
evaluate_t ~~> transform in warning message.
j-c-c Aug 23, 2023
d21ce71
Test complex image error.
j-c-c Aug 23, 2023
8899b36
Clean up rots_to_clmatrix.
j-c-c Aug 23, 2023
33abd7e
Clean up test_orient_symmetric.py.
j-c-c Aug 23, 2023
1d7da2e
typo
j-c-c Aug 23, 2023
66bb935
remove angular_radial_mode test.
j-c-c Aug 24, 2023
18eaf92
remove breakpoint()
j-c-c Aug 24, 2023
8a260ce
Clean up _precomp()
j-c-c Aug 24, 2023
6cf8295
Add PolarFT.full() method and tests.
j-c-c Aug 24, 2023
c0ad31a
change warnings to raises. Add tests.
j-c-c Aug 24, 2023
7babb9d
mark random offsets test case as expensive.
j-c-c Aug 24, 2023
cd562a9
remove unnecessary reshape of pf.
j-c-c Aug 25, 2023
978ed44
address review comments for polar_ft
j-c-c Aug 28, 2023
c8ef88c
Vectorize Rotation.angle_dist() to accept sets of rotations.
j-c-c Aug 28, 2023
507c337
tox
j-c-c Aug 28, 2023
6148daa
full ~~> half_to_full.
j-c-c Aug 29, 2023
6bfb1d0
Oooops forgot to change the test.
j-c-c Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions src/aspire/abinitio/commonline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import scipy.sparse as sparse

from aspire.basis import PolarBasis2D
from aspire.operators import PolarFT
from aspire.utils import common_line_from_rots
from aspire.utils.random import choice

Expand Down Expand Up @@ -69,27 +69,16 @@ def _build(self):

imgs = self.src.images[:]

# Obtain coefficients in polar Fourier basis for input 2D images
self.basis = PolarBasis2D(
# Obtain coefficients of polar Fourier transform for input 2D images
self.pft = PolarFT(
(self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype
)
self.pf = self.basis.evaluate_t(imgs)
self.pf = self.pf.reshape(self.n_img, self.n_theta, self.n_rad)
self.pf = self.pft.transform(imgs)

n_theta_half = self.n_theta // 2

# The last two dimension of pf is of size n_theta x n_rad. We will convert pf
# into an array of size (n_theta/2) x (n_rad-1), that is, take half of each ray
# through the origin except the DC part, and also take the angles only up to PI.
# This is due to the fact that the original images are real, and thus each ray
# is conjugate symmetric. We therefore gain nothing by taking longer correlations
# (of length 2*n_rad-1 instead of n_rad-1). In the Matlab version, pf is converted to
# the size of (n_theta/2) x (2*n_rad-1) but most of the calculations of build_clmatrix
# and estimate_shifts below only use the size of (n_theta/2) x (n_rad-1). In the
# Python version we will use the size of (n_theta/2) x (n_rad-1) directly and make
# sure every part is using it. By taking shorter correlations we can speed the
# computation by a factor of two.
self.pf = np.flip(self.pf[:, n_theta_half:, 1:], 2)
# We remove the DC the component. pf has size (n_img) x (n_theta/2) x (n_rad-1),
# with pf[:, :, 0] containing low frequency content and pf[:, :, -1] containing
# high frequency content.
self.pf = self.pf[:, :, 1:]

def estimate_rotations(self):
"""
Expand Down Expand Up @@ -241,7 +230,7 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000):
show = True
# Negative sign comes from using -i conversion of Fourier transformation
est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0]
est_shifts = est_shifts.reshape((2, self.n_img), order="F")
est_shifts = est_shifts.reshape((self.n_img, 2))

return est_shifts

Expand Down Expand Up @@ -273,6 +262,10 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000):

n_theta_half = self.n_theta // 2
n_img = self.n_img

# `estimate_shifts()` requires that rotations have already been estimated.
if self.rotations is None:
self.estimate_rotations()
rotations = self.rotations

pf = self.pf.copy()
Expand Down Expand Up @@ -444,14 +437,15 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step):
# Number of shifts to try
n_shifts = int(np.ceil(2 * max_shift / shift_step + 1))

# only half of ray
rk = np.arange(-r_max, 0)
# only half of ray, excluding the DC component.
rk = np.arange(1, r_max + 1)

# Generate all shift phases
shifts = -max_shift + shift_step * np.arange(n_shifts)
shift_phases = np.exp(np.outer(shifts, -2 * np.pi * 1j * rk / (2 * r_max + 1)))
# Set filter for common-line detection
h = np.sqrt(np.abs(rk)) * np.exp(-np.square(rk) / (2 * (r_max / 4) ** 2))

return shifts, shift_phases, h

def _generate_index_pairs(self, n_equations):
Expand Down Expand Up @@ -512,7 +506,10 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h):
# Note if we'd rather not have the dtype and casting args,
# we can control h.dtype instead.
np.einsum(subscripts, pf, h, out=pf, dtype=pf.dtype, casting="same_kind")
pf[..., r_max - 1 : r_max + 2] = 0

# This is a high pass filter, cutting out the lowest frequency
# (DC has already been removed).
pf[..., 0] = 0
pf /= np.linalg.norm(pf, axis=-1)[..., np.newaxis]

return pf
5 changes: 3 additions & 2 deletions src/aspire/abinitio/commonline_c3_c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numpy.linalg import eigh, norm, svd

from aspire.abinitio import CLOrient3D, SyncVotingMixin
from aspire.operators import PolarFT
from aspire.utils import (
J_conjugate,
Rotation,
Expand Down Expand Up @@ -294,7 +295,7 @@ def _estimate_inplane_rotations(self, vis):

# Reconstruct the full polar Fourier for use in correlation. self.pf only consists of
# rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1).
pf = np.concatenate((pf, np.conj(pf)), axis=1)
pf = PolarFT.half_to_full(pf)

# Normalize rays.
pf /= norm(pf, axis=-1)[..., np.newaxis]
Expand Down Expand Up @@ -444,7 +445,7 @@ def _self_clmatrix_c3_c4(self):

# Reconstruct the full polar Fourier for use in correlation. self.pf only consists of
# rays in the range [180, 360), with shape (n_img, n_theta//2, n_rad-1).
pf_full = np.concatenate((pf, np.conj(pf)), axis=1)
pf_full = PolarFT.half_to_full(pf)

# The self-common-lines matrix holds two indices per image that represent
# the two self common-lines in the image.
Expand Down
3 changes: 2 additions & 1 deletion src/aspire/abinitio/commonline_cn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numpy.linalg import norm

from aspire.abinitio import CLSymmetryC3C4
from aspire.operators import PolarFT
from aspire.utils import (
J_conjugate,
Rotation,
Expand Down Expand Up @@ -123,7 +124,7 @@ def _estimate_relative_viewing_directions(self):

# Reconstruct full polar Fourier for use in correlation.
pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray.
pf_full = np.concatenate((pf, np.conj(pf)), axis=1)
pf_full = PolarFT.half_to_full(pf)

# Pre-compute shifted pf's.
pf_shifted = (pf * shift_phases[:, None, None]).swapaxes(0, 1)
Expand Down
1 change: 0 additions & 1 deletion src/aspire/basis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@
from .fpswf_2d import FPSWFBasis2D
from .fpswf_3d import FPSWFBasis3D
from .fspca import FSPCABasis
from .polar_2d import PolarBasis2D
from .pswf_2d import PSWFBasis2D
from .pswf_3d import PSWFBasis3D
1 change: 1 addition & 0 deletions src/aspire/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
ZeroFilter,
evaluate_src_filters_on_grid,
)
from .polar_ft import PolarFT
from .wemd import wemd_embed, wemd_norm
113 changes: 61 additions & 52 deletions src/aspire/basis/polar_2d.py → src/aspire/operators/polar_ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,43 @@

import numpy as np

from aspire.basis import Basis
from aspire.nufft import anufft, nufft
from aspire.image import Image
from aspire.nufft import nufft
from aspire.utils import complex_type

logger = logging.getLogger(__name__)


class PolarBasis2D(Basis):
class PolarFT:
"""
Define a derived class for polar Fourier representation for 2D images
"""

def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32):
"""
Initialize an object for the 2D polar Fourier grid class
Initialize an object for the polar Fourier transform class. `PolarFT` expects that
images are real and uses only half of the `ntheta` values.

:param size: The shape of the vectors for which to define the grid.
:param size: The shape of the vectors for which to define the transform.
May be a 2-tuple or an integer, in which case a square basis is assumed.
Currently only square images are supported.
:param nrad: The number of points in the radial dimension. Default is resolution // 2.
:param ntheta: The number of points in the angular dimension. Default is 8 * nrad.
:param dtype: dtype used to compute a polar frequency grid for evaluating the transform..
"""
if isinstance(size, int):
size = (size, size)
ndim = len(size)
assert ndim == 2, "Only two-dimensional grids are supported."
assert len(set(size)) == 1, "Only square domains are supported."

self.ndim = ndim
self.sz = size
self.nrad = nrad
self.ntheta = ntheta
self.dtype = dtype

super().__init__(size, dtype=dtype)
self._build()

# this basis has complex coefficients
self.coefficient_dtype = complex_type(self.dtype)
Expand All @@ -45,7 +50,7 @@ def _build(self):
logger.info("Represent 2D image in a polar Fourier grid")

if self.nrad is None:
self.nrad = self.nres // 2
self.nrad = self.sz[0] // 2

if self.ntheta is None:
# try to use the same number as Fast FB basis
Expand All @@ -56,57 +61,30 @@ def _build(self):
logger.error(msg)
raise NotImplementedError(msg)

self.count = self.nrad * self.ntheta
self.count = self.nrad * (self.ntheta // 2)
self._sz_prod = self.sz[0] * self.sz[1]

# precompute the basis functions in 2D grids
self.freqs = self._precomp()

def _precomp(self):
"""
Precomute the polar Fourier grid
Precompute the polar Fourier grid.
"""
omega0 = 2 * np.pi / (2 * self.nrad - 1)
dtheta = 2 * np.pi / self.ntheta

# only need half size of ntheta
freqs = np.zeros((2, self.nrad * self.ntheta // 2), dtype=self.dtype)
freqs = np.zeros((2, self.ntheta // 2, self.nrad), dtype=self.dtype)
for i in range(self.ntheta // 2):
freqs[0, i * self.nrad : (i + 1) * self.nrad] = np.arange(
self.nrad
) * np.cos(i * dtheta)
freqs[1, i * self.nrad : (i + 1) * self.nrad] = np.arange(
self.nrad
) * np.sin(i * dtheta)
freqs[0, i] = np.cos(i * dtheta)
freqs[1, i] = np.sin(i * dtheta)

freqs *= omega0
return freqs
freqs *= omega0 * np.arange(self.nrad)

def _evaluate(self, v):
"""
Evaluate coefficients in standard 2D coordinate basis from those in polar Fourier basis
return freqs.reshape(2, -1)

:param v: A coefficient vector (or an array of coefficient vectors)
in polar Fourier basis to be evaluated. The last dimension must equal to
`self.count`.
:return x: Image instance in standard 2D coordinate basis with
resolution of `self.sz`.
"""
v = v.reshape(-1, self.ntheta, self.nrad)

nimgs = v.shape[0]

half_size = self.ntheta // 2

v = v[:, :half_size, :] + v[:, half_size:, :].conj()

v = v.reshape(nimgs, self.nrad * half_size)

x = anufft(v, self.freqs, self.sz, real=True)

return x

def _evaluate_t(self, x):
def transform(self, x):
"""
Evaluate coefficient in polar Fourier grid from those in standard 2D coordinate basis

Expand All @@ -116,15 +94,46 @@ def _evaluate_t(self, x):
Fourier grid. This is an array of vectors whose first dimension
corresponds to `x.shape[0]`, and last dimension equals `self.count`.
"""
nimgs = x.shape[0]

half_size = self.ntheta // 2

pf = nufft(x, self.freqs)
if x.dtype != self.dtype:
raise TypeError(
f"{self.__class__.__name__}.transform"
f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}"
)

if not isinstance(x, Image):
raise TypeError(
f"{self.__class__.__name__}.transform"
f" passed numpy array instead of {Image}."
)
else:
x = x.asnumpy()

# Flatten stack
stack_shape = x.shape[: -self.ndim]
x = x.reshape(-1, *x.shape[-self.ndim :])

# We expect the Image `x` to be real in order to take advantage of the conjugate
# symmetry of the Fourier transform of a real valued image.
if not np.isreal(x).all():
raise TypeError(
f"The Image `x` must be real valued. Found dtype {x.dtype}."
)

resolution = x.shape[-1]

pf = nufft(x, self.freqs) / resolution**2

return pf.reshape(*stack_shape, self.ntheta // 2, self.nrad)

@staticmethod
def half_to_full(pf):
"""
Use the conjugate symmetry of pf to construct the full polar Fourier transform
over all rays in [0, 360).

pf = pf.reshape((nimgs, self.nrad, half_size))
v = np.concatenate((pf, pf.conj()), axis=1)
:param pf: The precomputed half polar Fourier transform
with shape (*stack_shape, ntheta//2, nrad)
:return: The full polar Fourier transform with shape (*stack_shape, ntheta, nrad)
"""

# return v coefficients with the last dimension size of self.count
v = v.reshape(nimgs, -1)
return v
return np.concatenate((pf, np.conj(pf)), axis=-2)
1 change: 1 addition & 0 deletions src/aspire/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
grid_2d,
grid_3d,
register_rotations,
rots_to_clmatrix,
uniform_random_angles,
)

Expand Down
34 changes: 34 additions & 0 deletions src/aspire/utils/coor_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,40 @@ def common_line_from_rots(r1, r2, ell):
return ell_ij, ell_ji


def rots_to_clmatrix(rots, n_theta):
"""
Compute the common lines matrix induced by all pairs of rotation
matrices, `rots`, provided.

:param rots: n_rotsx3x3 array of rotation matrices.
:param n_theta: Number of theta values for common lines indices.

:return: n_rots x n_rots common lines matrix.
"""
n_rots = len(rots)
cl_matrix = -np.ones((n_rots, n_rots))
for i in range(n_rots):
Ri = rots[i]
Ri3 = Ri[:, 2]
for j in range(i + 1, n_rots):
Rj = rots[j]
Rj3 = Rj[:, 2]
common_axis = np.cross(Ri3, Rj3) / np.linalg.norm(np.cross(Ri3, Rj3))
xij = Ri.T @ common_axis
xji = Rj.T @ common_axis
theta_ij = np.rad2deg(np.arctan2(xij[1], xij[0])) % 360
theta_ji = np.rad2deg(np.arctan2(xji[1], xji[0])) % 360

if theta_ij > 180:
theta_ij -= 180
theta_ji -= 180

cl_matrix[i, j] = round(theta_ij * n_theta / 360)
cl_matrix[j, i] = round((theta_ji % 360) * n_theta / 360)

return cl_matrix


def crop_pad_2d(im, size, fill_value=0):
"""
:param im: A 2-dimensional numpy array
Expand Down
Loading