Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,16 @@ def legacy_whiten(self, psd, delta):

return Image(res)

def downsample(self, ds_res, zero_nyquist=True):
def downsample(self, ds_res, zero_nyquist=True, legacy=False):
"""
Downsample Image to a specific resolution. This method returns a new Image.

:param ds_res: int - new resolution, should be <= the current resolution
of this Image
:param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution.
Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param zero_nyquist: Option to keep or remove Nyquist frequency for even
resolution (boolean). Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param legacy: Option to match legacy Matlab downsample method (boolean).
Default of False uses `centered_fft` to maintain ASPIRE-Python centering conventions.
:return: The downsampled Image object.
"""

Expand All @@ -528,19 +530,26 @@ def downsample(self, ds_res, zero_nyquist=True):
# because all of the subsequent calls until `asnumpy` are GPU
# when xp and fft in `cupy` mode.

# compute FT with centered 0-frequency
fx = fft.centered_fft2(xp.asarray(im._data))
if legacy:
fx = fft.fftshift(fft.fft2(xp.asarray(im._data)))
else:
# compute FT with centered 0-frequency
fx = fft.centered_fft2(xp.asarray(im._data))

# crop 2D Fourier transform for each image
crop_fx = crop_pad_2d(fx, ds_res)

# If downsampled resolution is even, optionally zero out the nyquist frequency.
if ds_res % 2 == 0 and zero_nyquist is True:
if ds_res % 2 == 0 and zero_nyquist and not legacy:
crop_fx[:, 0, :] = 0
crop_fx[:, :, 0] = 0

# take back to real space, discard complex part, and scale
out = fft.centered_ifft2(crop_fx).real * (ds_res**2 / self.resolution**2)
out = xp.asnumpy(out)
if legacy:
out = fft.ifft2(fft.ifftshift(crop_fx))
else:
out = fft.centered_ifft2(crop_fx)
out = xp.asnumpy(out.real * ds_res**2 / self.resolution**2)

# Optionally scale pixel size
ds_pixel_size = self.pixel_size
Expand Down
18 changes: 16 additions & 2 deletions src/aspire/image/xform.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,26 @@ class Downsample(LinearXform):
A Xform that downsamples an Image object to a resolution specified by this Xform's resolution.
"""

def __init__(self, resolution):
def __init__(self, resolution, zero_nyquist=True, legacy=False):
"""
Initialize Xform to downsample Image to a specific resolution.

:param resolution: int - new resolution, should be <= the current resolution
of this Image
:param zero_nyquist: Option to keep or remove Nyquist frequency for even
resolution (boolean). Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param legacy: Option to match legacy Matlab downsample method (boolean).
Default of False uses `centered_fft` to maintain ASPIRE-Python centering conventions.
"""
self.resolution = resolution
self.zero_nyquist = zero_nyquist
self.legacy = legacy
super().__init__()

def _forward(self, im, indices):
return im.downsample(self.resolution)
return im.downsample(
self.resolution, zero_nyquist=self.zero_nyquist, legacy=self.legacy
)

def _adjoint(self, im, indices):
# TODO: Implement up-sampling with zero-padding
Expand Down
6 changes: 4 additions & 2 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,14 +769,16 @@ def _images(self, indices):
"""

@_as_copy
def downsample(self, L):
def downsample(self, L, zero_nyquist=True, legacy=False):
if L > self.L:
raise ValueError(
"Max desired resolution {L} should be less than the current resolution {self.L}."
)
logger.info(f"Setting max. resolution of source = {L}")

self.generation_pipeline.add_xform(Downsample(resolution=L))
self.generation_pipeline.add_xform(
Downsample(resolution=L, zero_nyquist=zero_nyquist, legacy=legacy)
)

ds_factor = self.L / L
self.unique_filters = [f.scale(ds_factor) for f in self.unique_filters]
Expand Down
21 changes: 15 additions & 6 deletions src/aspire/volume/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,27 +517,33 @@ def flip(self, axis=-3):
symmetry_group=symmetry,
)

def downsample(self, ds_res, mask=None, zero_nyquist=True):
def downsample(self, ds_res, mask=None, zero_nyquist=True, legacy=False):
"""
Downsample each volume to a desired resolution (only cubic supported).

:param ds_res: Desired resolution.
:param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution.
Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param zero_nyquist: Option to keep or remove Nyquist frequency for even
resolution (boolean). Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param mask: Optional NumPy array mask to multiply in Fourier space.
:param legacy: Option to match legacy Matlab downsample method (boolean).
Default of False uses `centered_fft` to maintain ASPIRE-Python centering conventions.
:return: The downsampled Volume object.
"""

original_stack_shape = self.stack_shape
v = self.stack_reshape(-1)

# take 3D Fourier transform of each volume in the stack
fx = fft.centered_fftn(xp.asarray(v._data))
if legacy:
fx = fft.fftshift(fft.fftn(xp.asarray(v._data)))
else:
fx = fft.centered_fftn(xp.asarray(v._data))

# crop each volume to the desired resolution in frequency space
fx = crop_pad_3d(fx, ds_res)

# If downsample resolution is even, optionally zero out the nyquist frequency.
if ds_res % 2 == 0 and zero_nyquist is True:
if ds_res % 2 == 0 and zero_nyquist and not legacy:
fx[:, 0, :, :] = 0
fx[:, :, 0, :] = 0
fx[:, :, :, 0] = 0
Expand All @@ -547,7 +553,10 @@ def downsample(self, ds_res, mask=None, zero_nyquist=True):
fx = fx * xp.asarray(mask)

# inverse Fourier transform of each volume
out = fft.centered_ifftn(fx)
if legacy:
out = fft.ifftn(fft.ifftshift(fx))
else:
out = fft.centered_ifftn(fx)
out = out.real * (ds_res**3 / self.resolution**3)

# Optionally scale pixel size
Expand Down
57 changes: 51 additions & 6 deletions tests/test_downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_integer_offsets():
DTYPES = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)]
RES = [65, 66]
RES_DS = [32, 33]
LEGACY = [True, False]


@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module")
Expand All @@ -141,6 +142,11 @@ def res_ds(request):
return request.param


@pytest.fixture(params=LEGACY, ids=lambda x: f"legacy={x}", scope="module")
def legacy(request):
return request.param


@pytest.fixture(scope="module")
def emdb_vol():
return emdb_2660()
Expand All @@ -153,20 +159,59 @@ def volume(emdb_vol, res, dtype):
return vol


def test_downsample_project(volume, res_ds):
def test_downsample_project(volume, res_ds, legacy):
"""
Test that vol.downsample.project == vol.project.downsample.
"""
rot = np.eye(3, dtype=volume.dtype) # project along z-axis
im_ds_proj = volume.downsample(res_ds).project(rot)
im_proj_ds = volume.project(rot).downsample(res_ds)
im_ds_proj = volume.downsample(res_ds, legacy=legacy).project(rot)
im_proj_ds = volume.project(rot).downsample(res_ds, legacy=legacy)

tol = 1e-09
if volume.dtype == np.float32:
tol = 1e-07
if legacy:
# project does not enforce legacy centering convention,
# so this property will not hold up to allclose tolerance.
tol = 1e-03

tol = 1e-07
if volume.dtype == np.float64:
tol = 1e-09
np.testing.assert_allclose(im_ds_proj, im_proj_ds, atol=tol)


def test_downsample_legacy(volume, res_ds):
"""
The legay Matlab downsample method differs from ASPIRE-Python
downsample in that is uses a different centering convention,
off by a half pixel for odd images, and does not zero out
the nyquist frequency. By making these alterations to the
ASPIRE-Python downsampled images we can match legacy downsample
upt to `allclose`.
"""
n_img = 10
dtype = volume.dtype
src = Simulation(
n=n_img,
vols=volume,
amplitudes=1,
dtype=dtype,
seed=1980,
)
ims = src.images[:]

# Legacy downsampled images.
ims_ds_legacy = ims.downsample(res_ds, legacy=True)

# ASPIRE-Python downsample with centering adjustments for odd resolution images.
shifts = 0.5 * np.ones((n_img, 2), dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we replace this with the formula that we found for the shifts? I'm worried that this is only half-pixel in certain special cases like this one.

Copy link
Collaborator Author

@j-c-c j-c-c Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My notes from our meeting are a little incomplete, but what I have working locally for each case is this.

even -> even: no shift needed
even -> odd: shift by 1/2
odd -> odd: shift by 1/2 - res_ds/(2 * res_og)
odd -> even: shift by -res_ds/(2 * res_og)

I'm not getting allclose on the odd -> even case, so maybe I have something wrong here. For what it's worth, the original test works for different downsample ratios since the shifting occurs before and after downsampling depending on where the odd resolution is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But shouldn't the amount of shifting required differ depending on the downsampling ratio (per the formulas you give above)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Edit Above: I had a typo in the formulas above... was missing a factor of 2 in the denominators.]

Yes, but the 1/2 pixel shift (on line 205) which happens prior to downsampling is converted in the downsampled image into a shift of res_ds/(2 * res_og) pixels by the act of downsampling. So the ratio ends up being accounted for. I believe this is correct and have confirmed it for several different downsample ratios, but it's possible I'm misunderstanding something.

I suspect that I am not getting allclose for the last case due to some difference in our Image.shift method on even vs odd images.

if src.L % 2 == 1:
ims = ims.shift(shifts)
ims_ds_py = ims.downsample(res_ds, zero_nyquist=False)
if res_ds % 2 == 1:
ims_ds_py = ims_ds_py.shift(-shifts)

np.testing.assert_allclose(ims_ds_legacy, ims_ds_py, atol=1e-08)


def test_simulation_relion_downsample():
"""
Test that Simulation.downsample corresponds to RelionSource.downsample
Expand Down
30 changes: 30 additions & 0 deletions tests/test_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numpy import pi
from pytest import raises, skip

from aspire.downloader import emdb_2660
from aspire.source import _LegacySimulation
from aspire.utils import Rotation, anorm, grid_2d, powerset, utest_tolerance
from aspire.volume import (
Expand All @@ -32,6 +33,7 @@ def res_id(params):


RES = [42, 43]
RES_DS = [32, 33]
TEST_PX_SZ = 4.56


Expand All @@ -40,6 +42,11 @@ def res(request):
return request.param


@pytest.fixture(params=RES_DS, ids=lambda x: f"resolution_ds={x}", scope="module")
def res_ds(request):
return request.param


def dtype_id(params):
return f"dtype={params}"

Expand Down Expand Up @@ -122,6 +129,18 @@ def vols_hot_cold(res, dtype):
return vols, hot_cold_locs, vol_center


@pytest.fixture(scope="module")
def emdb_vol():
return emdb_2660()


@pytest.fixture(scope="module")
def volume(emdb_vol, res, dtype):
vol = emdb_vol.astype(dtype, copy=False)
vol = vol.downsample(res)
return vol


@pytest.fixture
def random_data(res, dtype):
return np.random.randn(res, res, res).astype(dtype)
Expand Down Expand Up @@ -673,6 +692,17 @@ def test_downsample(res):
)


def test_downsample_legacy(volume, res_ds):
"""
Smoke test for the downsample legacy flag.
"""
# Legacy downsampled images.
vol_ds = volume.downsample(res_ds, legacy=True)

# Check downsampled volume resolution.
np.testing.assert_equal(vol_ds.resolution, res_ds)


def test_shape(vols_1, res):
assert vols_1.shape == (N, res, res, res)
assert vols_1.stack_shape == (N,)
Expand Down
Loading