Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8d54d43
stub in BFT work from notebook
garrettwrong Mar 27, 2025
ada14e4
fixup mixing with translations
garrettwrong Mar 27, 2025
db24bac
vector fast polar align
garrettwrong Mar 28, 2025
94819c4
shift base image and commute shift
garrettwrong Mar 28, 2025
6fbfd8d
cleanup
garrettwrong Mar 28, 2025
e25502c
hack in gpu code, dirty
garrettwrong Mar 28, 2025
978aec8
factor out the pft
garrettwrong Mar 28, 2025
56fc198
begin batching, two places to broadcast
garrettwrong Mar 28, 2025
2b9b0b5
table broadcast polar cross corr
garrettwrong Mar 28, 2025
4d99d33
table broadcast shifts, resuse arrays, reduce mem cost some speed
garrettwrong Mar 28, 2025
680ba93
Cleanup unit test for broadcast case
garrettwrong Apr 23, 2025
6a57f73
cleanup pft interop
garrettwrong Apr 23, 2025
ef144f9
A little more cleanup
garrettwrong Apr 23, 2025
6c1930e
stash
garrettwrong Apr 23, 2025
cabc95b
add fine interp and optimize methods
garrettwrong Apr 24, 2025
17b8a6b
add BFTAverager2D to test suite
garrettwrong Apr 25, 2025
0835b38
intial add BFT to source wrappers, remove 110
garrettwrong Apr 28, 2025
a91b8e8
tox checks
garrettwrong Apr 29, 2025
b4b66f0
flip bug fix
garrettwrong Apr 29, 2025
24de7d5
update shift grid to return array of tuples
garrettwrong Apr 30, 2025
4459bbb
cleanup
garrettwrong May 1, 2025
d06f41f
reversed the index mapping, whoops
garrettwrong May 1, 2025
b81f69e
copy syntax
garrettwrong May 1, 2025
1b10494
remove interp option from polar cross cor align
garrettwrong May 8, 2025
752ee40
cleanup comment
garrettwrong May 13, 2025
291ca3d
update //16 to //32 in shift search
garrettwrong May 15, 2025
1b14c81
default to self.n_radial
garrettwrong May 15, 2025
21092a6
typo ceates -> creates
garrettwrong May 15, 2025
d002430
docstring updates
garrettwrong May 15, 2025
fa55950
use L//2 for n_radial
garrettwrong May 16, 2025
b3527e8
len(shifts) ~> len(test_shifts)
garrettwrong May 21, 2025
1b62472
cleanup minor review remarks
garrettwrong May 21, 2025
8d55a1e
sub pixel review change bug
garrettwrong May 23, 2025
2d18693
stub in PolarFT shifting
garrettwrong May 27, 2025
ab8ee7b
stub in PolarFT shift test 2d
garrettwrong May 28, 2025
984acc6
add broadcast polar shift test
garrettwrong May 28, 2025
9de7acf
add multiple shift broadcast polar code
garrettwrong May 28, 2025
9f49b49
Use PolarFT.shift in BFT class source
garrettwrong May 28, 2025
a20dbcf
exted PolarFT.shift to xp
garrettwrong May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aspire/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
BFRAverager2D,
BFSRAverager2D,
BFSReddyChatterjiAverager2D,
BFTAverager2D,
EMAverager2D,
FTKAverager2D,
ReddyChatterjiAverager2D,
Expand Down
268 changes: 249 additions & 19 deletions src/aspire/classification/averager2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,29 @@
from aspire.basis import Coef
from aspire.classification.reddy_chatterji import reddy_chatterji_register
from aspire.image import Image, ImageStacker, MeanImageStacker
from aspire.numeric import xp
from aspire.utils import tqdm, trange
from aspire.numeric import fft, xp
from aspire.operators import PolarFT
from aspire.utils import complex_type, tqdm, trange
from aspire.utils.coor_trans import grid_2d

logger = logging.getLogger(__name__)


def commute_shift_rot(shifts, rots):
"""
Rotate `shifts` points by `rots` ccw radians.

:param shifts: Array of shift points shaped (..., 2)
:param rots: Array of rotations (radians)
:returns: Array of rotated shift points shaped (..., 2)
"""
sx = shifts[:, 0]
sy = shifts[:, 1]
x = sx * np.cos(rots) - sy * np.sin(rots)
y = sx * np.sin(rots) + sy * np.cos(rots)
return np.stack((x, y), axis=1)


class Averager2D(ABC):
"""
Base class for 2D Image Averaging methods.
Expand Down Expand Up @@ -234,27 +250,34 @@

return Image(avgs)

def _shift_search_grid(self, L, radius, roll_zero=False):
def _shift_search_grid(self, L, radius, roll_zero=False, sub_pixel=1):
"""
Returns two 1-D arrays representing the X and Y grid points in the defined
shift search space (disc <= self.radius).

:param radius: Disc radius in pixels
:returns: Grid points as 2-tuple of vectors X,Y.
:param roll_zero: Roll (0,0) to zero'th element. Defaults to False.
:param sub_pixel: Sub-pixel decimation . 1 yields 1 pixel, 10 yields 1/10 pixel, etc.
Values will be cast to integers.
:returns: Grid points as array of 2-tuples [(x0,y0),... (xi,yi)].
"""
sub_pixel = int(sub_pixel)

# We'll brute force all shifts in a grid.
g = grid_2d(L, normalized=False)
disc = g["r"] <= radius
g = grid_2d(sub_pixel * L, normalized=False)
disc = g["r"] <= (sub_pixel * radius)
X, Y = g["x"][disc], g["y"][disc]
X, Y = X / sub_pixel, Y / sub_pixel

# Optionally roll arrays so 0 is first.
if roll_zero:
zero_ind = np.argwhere(X * X + Y * Y == 0).flatten()[0]
X, Y = np.roll(X, -zero_ind), np.roll(Y, -zero_ind)
assert (X[0], Y[0]) == (0, 0), (radius, zero_ind, X, Y)

return X, Y
shifts = np.stack((X, Y), axis=1)

return shifts


class BFSRAverager2D(AligningAverager2D):
Expand Down Expand Up @@ -283,7 +306,7 @@

:params n_angles: Number of brute force rotations to attempt, defaults 360.
:param radius: Brute force translation search radius.
Defaults to src.L//16.
Defaults to src.L//32.
"""
super().__init__(
composite_basis,
Expand All @@ -300,7 +323,7 @@
f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method."
)

self.radius = radius if radius is not None else src.L // 16
self.radius = radius if radius is not None else src.L // 32

if self.radius != 0:

Expand Down Expand Up @@ -337,9 +360,7 @@

# Create a search grid and force initial pair to (0,0)
# This is done primarily in case of a tie later, we would take unshifted.
x_shifts, y_shifts = self._shift_search_grid(
self.src.L, self.radius, roll_zero=True
)
test_shifts = self._shift_search_grid(self.src.L, self.radius, roll_zero=True)

for k in trange(n_classes, desc="Rotationally aligning classes"):
# We want to locally cache the original images,
Expand Down Expand Up @@ -370,10 +391,10 @@

# Loop over shift search space, updating best result
for x, y in tqdm(
zip(x_shifts, y_shifts),
total=len(x_shifts),
test_shifts,
total=len(test_shifts),
desc="\tmaximizing over shifts",
disable=len(x_shifts) == 1,
disable=len(test_shifts) == 1,
leave=False,
):
shift = np.array([x, y], dtype=int)
Expand Down Expand Up @@ -439,6 +460,12 @@


class BFRAverager2D(BFSRAverager2D):
"""
Brute Force Rotation only reference implementation.

See BFT with `radius=0` for a more performant implementation using a fast rotational alignment.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, radius=0, **kwargs)

Expand Down Expand Up @@ -660,7 +687,7 @@
dot_products = np.ones(classes.shape, dtype=self.dtype) * -np.inf
shifts = np.zeros((*classes.shape, 2), dtype=int)

X, Y = self._shift_search_grid(self.alignment_src.L, self.radius)
test_shifts = self._shift_search_grid(self.alignment_src.L, self.radius)

def _innerloop(k):
unshifted_images = self._cls_images(classes[k])
Expand All @@ -670,10 +697,10 @@
_shifts = np.zeros((*classes.shape[1:], 2), dtype=int)

for xs, ys in tqdm(
zip(X, Y),
total=len(X),
test_shifts,
total=len(test_shifts),
desc="\tmaximizing over shifts",
disable=len(X) == 1,
disable=len(test_shifts) == 1,
leave=False,
):

Expand Down Expand Up @@ -725,6 +752,209 @@
return AligningAverager2D.average(self, classes, reflections, coefs)


class BFTAverager2D(AligningAverager2D):
"""
This perfoms a Brute Force Translations and fast rotational alignment.

For each shift,
Perform polar Fourier cross correlation based rotational alignment.

Return the rotation and shift yielding the best results.
"""

def __init__(
self,
composite_basis,
src,
alignment_basis=None,
n_angles=360,
n_radial=None,
radius=None,
sub_pixel=10,
batch_size=512,
dtype=None,
):
"""
See AligningAverager2D. Adds `n_angles`, `n_radial`, `radius`, `sub_pixel`.

:params n_angles: Number of PFT angular components, defaults 360.
:param n_radial: Number of PFT radial components, defaults `self.src.L//2`.
:param radius: Brute force translation search radius.
`0` disables translation search, rotations only.
Defaults to `src.L//32`.
:param sub_pixel: Sub-pixel decimation used in brute force shift search.
Defaults to 10 sub-pixel to pixel, ie 0.1 spaced sub-pixel.
"""
super().__init__(
composite_basis,
src,
alignment_basis,
batch_size=batch_size,
dtype=dtype,
)

self.n_angles = n_angles

self.radius = radius if radius is not None else src.L // 32

if self.radius != 0 and not hasattr(self.alignment_basis, "shift"):
raise RuntimeError(

Check warning on line 801 in src/aspire/classification/averager2d.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/classification/averager2d.py#L801

Added line #L801 was not covered by tests
f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method."
)

self.sub_pixel = sub_pixel

# Configure number of radial points
self.n_radial = n_radial or self.src.L // 2

# Setup Polar Transform
self._pft = PolarFT(
self.src.L, ntheta=n_angles, nrad=self.n_radial, dtype=self.dtype
)
self._mask = xp.asarray(grid_2d(self.src.L, normalized=True)["r"] < 1)

def _fast_rotational_alignment(self, pfA, pfB):
"""
Perform fast rotational alignment using Polar Fourier cross correlation.

Note broadcasting is specialized for this problem.
pfA.shape (m, ntheta, nrad)
pfB.shape (n, ntheta, nrad)
yields thetas (m,n), peaks (m,n)

"""

if pfA.ndim == 2:
pfA = pfA[None]

Check warning on line 828 in src/aspire/classification/averager2d.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/classification/averager2d.py#L828

Added line #L828 was not covered by tests
if pfB.ndim == 2:
pfB = pfB[None]

Check warning on line 830 in src/aspire/classification/averager2d.py

View check run for this annotation

Codecov / codecov/patch

src/aspire/classification/averager2d.py#L830

Added line #L830 was not covered by tests

# 2 hats one sum
pfA = fft.fft(pfA, axis=-2)
pfB = fft.fft(pfB, axis=-2)
# Tabulate elements of pfA cross pfB.conj() using broadcast multiply
x = xp.expand_dims(pfA, 1) * xp.expand_dims(pfB.conj(), 0)
angular = xp.sum(xp.abs(fft.ifft2(x)), axis=-1) # sum all radial contributions

# Resolve the angle maximizing the correlation through the angular dimension
inds = xp.argmax(angular, axis=-1)

max_thetas = 2 * np.pi / self._pft.ntheta * inds
peaks = xp.take_along_axis(angular, inds[..., None], axis=-1).squeeze(-1)

return xp.asnumpy(max_thetas), xp.asnumpy(peaks)

def align(self, classes, reflections, basis_coefficients=None):
"""
See `AligningAverager2D.align`
"""

# Admit simple case of single case alignment
classes = np.atleast_2d(classes)
reflections = np.atleast_2d(reflections)

# Result arrays
# These arrays will incrementally store our best alignment.
n_classes, n_nbor = classes.shape
rotations = np.zeros((n_classes, n_nbor), dtype=self.dtype)
dot_products = np.ones((n_classes, n_nbor), dtype=self.dtype) * -np.inf
shifts = np.zeros((*classes.shape, 2), dtype=self.dtype)

# Create a search grid and force initial pair to (0,0)
# This is done primarily in case of a tie later, we would prefer unshifted.
test_shifts = self._shift_search_grid(
self.src.L,
self.radius,
roll_zero=True,
sub_pixel=self.sub_pixel,
)

# Work arrays
bs = min(self.batch_size, len(test_shifts))
_rotations = np.zeros((bs, n_nbor), dtype=self.dtype)
_dot_products = np.ones((bs, n_nbor), dtype=self.dtype) * -np.inf
template_images = xp.empty(
(bs, self._pft.ntheta // 2, self._pft.nrad), dtype=complex_type(self.dtype)
)
_images = xp.empty((n_nbor - 1, self.src.L, self.src.L), dtype=self.dtype)

for k in trange(n_classes, desc="Rotationally aligning classes"):
# We want to locally cache the original images,
# because we will mutate them with shifts in the next loop.
# This avoids recomputing them before each shift
# The coefficient for the base images are also computed here.
if basis_coefficients is None:
original_images = Image(self._cls_images(classes[k], src=self.src))
else:
original_coef = basis_coefficients[classes[k], :]
original_images = self.alignment_basis.evaluate(original_coef)

_img0 = original_images[0].asnumpy().copy()
_images[:] = xp.asarray(original_images[1:].asnumpy().copy())

# Handle reflections
refl = reflections[k][1:] # skips original_image 0
_images[refl] = xp.flip(_images[refl], axis=-2)

# Mask off
_images[:] = _images[:] * self._mask

# Convert to polar Fourier
pf_img0 = self._pft._transform(_img0)
pf_images = self._pft.half_to_full(self._pft._transform(_images))

# Batch over shift search space, updating best results
pbar = tqdm(
total=len(test_shifts),
desc="\tmaximizing over shifts",
disable=len(test_shifts) == 1,
leave=False,
)
for start in range(0, len(test_shifts), self.batch_size):
end = min(start + self.batch_size, len(test_shifts))
bs = end - start # handle a small last batch
batch_shifts = test_shifts[start:end]

# Shift the base, pf_img0, for each shift in this batch
# Note this includes shifting for the zero shift case
template_images[:bs] = xp.asarray(
self._pft.shift(pf_img0, batch_shifts)
)

pf_template_images = self._pft.half_to_full(template_images)

# Compute and assign the best rotation found with this translation
# note offset of 1 for skipped original_image 0
_rotations[:bs, 1:], _dot_products[:bs, 1:] = (
self._fast_rotational_alignment(pf_template_images[:bs], pf_images)
)

# Note, these could be vectorized, but the code block
# wasn't appreciably faster when I compared them for
# current problem sizes.
for i in range(bs):

# Test and update
# Each base-neighbor pair may have a best shift+rot from a different shift iteration.
improved_indices = _dot_products[i] > dot_products[k]
rotations[k, improved_indices] = -_rotations[i, improved_indices]
dot_products[k, improved_indices] = _dot_products[
i, improved_indices
]
# base shifts assigned here, commutation resolved end of loop
shifts[k, improved_indices] = -batch_shifts[i]

pbar.update(bs)

# Completed batching over shifts
pbar.close()

# Commute the rotation and shift (code shifted the base image instead of all class members)
shifts[k] = commute_shift_rot(shifts[k], rotations[k])

return rotations, shifts, dot_products


class EMAverager2D(Averager2D):
"""
Citation needed.
Expand Down
Loading
Loading