Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@
logger.info("Perform phase flip to input images.")
src = src.phase_flip().cache()

# Legacy MATLAB cropped the images to an odd resolution.
src = src.crop_pad(src.L - 1).cache()

# Downsample the images.
logger.info(f"Set the resolution to {img_size} X {img_size}")
src = src.downsample(img_size).cache()
src = src.legacy_downsample(img_size).cache()

# Normalize the background of the images.
src = src.normalize_background().cache()
src = src.legacy_normalize_background().cache()

# Estimate the noise and whiten based on the estimated noise.
src = src.legacy_whiten().cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@
logger.info("Perform phase flip to input images.")
src = src.phase_flip().cache()

# Legacy MATLAB cropped the images to an odd resolution.
src = src.crop_pad(src.L - 1).cache()

# Downsample the images.
logger.info(f"Set the resolution to {img_size} X {img_size}")
src = src.downsample(img_size).cache()
src = src.legacy_downsample(img_size).cache()

# Normalize the background of the images.
src = src.normalize_background().cache()
src = src.legacy_normalize_background().cache()

# Estimate the noise and whiten based on the estimated noise.
src = src.legacy_whiten().cache()
Expand Down
74 changes: 37 additions & 37 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,33 @@
logger = logging.getLogger(__name__)


def normalize_bg(imgs, bg_radius=1.0, do_ramp=True, legacy=False):
def normalize_bg(imgs, bg_radius=1.0, do_ramp=True, shifted=False, ddof=0):
"""
Normalize backgrounds and apply to a stack of images
Normalize backgrounds and apply to a stack of images.

To recreate legacy MATLAB workflow results, review parameters used in
`ImageSource.legacy_normalize_background`.

:param imgs: A stack of images in N-by-L-by-L array
:param bg_radius: Radius cutoff to be considered as background (in image size)
:param do_ramp: When it is `True`, fit a ramping background to the data
and subtract. Namely perform normalization based on values from each image.
Otherwise, a constant background level from all images is used.
:param legacy: Option to match Matlab legacy normalize_background. Default, False,
uses ASPIRE-Python implementation. When True, ramping is disabled, a shifted
2d grid and alternative `bg_radius` is used to generate the background mask,
and standard deviation is computed using N - 1 degrees of freedom.
and subtract. Namely perform normalization based on values from each image.
Otherwise, a constant background level from all images is used.
:param shifted: Optionally shifts 2d grid by 1/2 pixel for even
resolution to replicate MATLAB.
:param ddof: Degrees of freedom for standard deviation.
:return: The modified images
"""
if imgs.ndim > 3:
raise NotImplementedError(
"`normalize_bg` is currently limited to 1D image stacks."
)
L = imgs.shape[-1]

# Make adjustments for legacy mode
shifted = False
ddof = 0 # Degrees of freedom for standard deviation
if legacy:
do_ramp = False
shifted = True # Shifts 2d grid by 1/2 pixel for even resolution
bg_radius = 2 * (L // 2) / L
ddof = 1
input_dtype = imgs.dtype

# Generate background mask
input_dtype = imgs.dtype
grid = grid_2d(L, shifted=shifted, indexing="yx", dtype=input_dtype)
grid_dtype = np.float64 # Use doubles for accuracy and MATLAB repro
grid = grid_2d(L, shifted=shifted, indexing="yx", dtype=grid_dtype)
mask = grid["r"] > bg_radius

if do_ramp:
Expand All @@ -66,14 +60,14 @@ def normalize_bg(imgs, bg_radius=1.0, do_ramp=True, legacy=False):
(
grid["x"][mask].flatten(),
grid["y"][mask].flatten(),
np.ones(grid["y"][mask].flatten().size, dtype=input_dtype),
np.ones(grid["y"][mask].flatten().size, dtype=grid_dtype),
)
).T
ramp_all = np.vstack(
(
grid["x"].flatten(),
grid["y"].flatten(),
np.ones(L * L, dtype=input_dtype),
np.ones(L * L, dtype=grid_dtype),
)
).T
mask_reshape = mask.reshape((L * L))
Expand All @@ -85,10 +79,14 @@ def normalize_bg(imgs, bg_radius=1.0, do_ramp=True, legacy=False):
imgs = imgs.reshape((-1, L, L))

# Apply mask images and calculate mean and std values of background
mean = np.mean(imgs[:, mask], axis=1)
std = np.std(imgs[:, mask], ddof=ddof, axis=1)
# These should be computed and normalized as doubles
bg_pixels = imgs[:, mask].astype(np.float64, copy=False)
mean = np.mean(bg_pixels, axis=1)
std = np.std(bg_pixels, ddof=ddof, axis=1)
imgs = (imgs - mean[:, None, None]) / std[:, None, None]

return (imgs - mean[:, None, None]) / std[:, None, None]
# Restore input dtype
return imgs.astype(input_dtype, copy=False)


def load_mrc(filepath):
Expand Down Expand Up @@ -490,10 +488,11 @@ def legacy_whiten(self, psd, delta):
else:
slc = slice(k - L_half - 1, k + L_half - 1)

# Note these computations should be in double precision
for i, proj in enumerate(self.asnumpy()):

# Zero pad the image to twice the size
padded_proj[slc, slc] = xp.asarray(proj)
padded_proj[slc, slc] = xp.asarray(proj, dtype=np.float64)

# Take the Fourier Transform of the padded image.
fpadded_proj = fft.centered_fft2(padded_proj)
Expand All @@ -515,21 +514,22 @@ def legacy_whiten(self, psd, delta):

filtered_proj = filtered_proj[slc, slc].real

# Assign the resulting image.
res[i] = xp.asnumpy(filtered_proj)
# Assign the resulting image, cast if required.
res[i] = xp.asnumpy(filtered_proj).astype(res.dtype, copy=False)

return Image(res)

def downsample(self, ds_res, zero_nyquist=True, legacy=False):
def downsample(self, ds_res, zero_nyquist=True, centered_fft=True):
"""
Downsample Image to a specific resolution. This method returns a new Image.

:param ds_res: int - new resolution, should be <= the current resolution
of this Image
:param zero_nyquist: Option to keep or remove Nyquist frequency for even
resolution (boolean). Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param legacy: Option to match legacy Matlab downsample method (boolean).
Default of False uses `centered_fft` to maintain ASPIRE-Python centering conventions.
:param centered_fft: Default of True uses `centered_fft` to
maintain ASPIRE-Python centering conventions.

:return: The downsampled Image object.
"""

Expand All @@ -540,25 +540,25 @@ def downsample(self, ds_res, zero_nyquist=True, legacy=False):
# because all of the subsequent calls until `asnumpy` are GPU
# when xp and fft in `cupy` mode.

if legacy:
fx = fft.fftshift(fft.fft2(xp.asarray(im._data)))
else:
if centered_fft:
# compute FT with centered 0-frequency
fx = fft.centered_fft2(xp.asarray(im._data))
else:
fx = fft.fftshift(fft.fft2(xp.asarray(im._data)))

# crop 2D Fourier transform for each image
crop_fx = crop_pad_2d(fx, ds_res)

# If downsampled resolution is even, optionally zero out the nyquist frequency.
if ds_res % 2 == 0 and zero_nyquist and not legacy:
if ds_res % 2 == 0 and zero_nyquist:
crop_fx[:, 0, :] = 0
crop_fx[:, :, 0] = 0

# take back to real space, discard complex part, and scale
if legacy:
out = fft.ifft2(fft.ifftshift(crop_fx))
else:
if centered_fft:
out = fft.centered_ifft2(crop_fx)
else:
out = fft.ifft2(fft.ifftshift(crop_fx))

# The parenths are required because dtype casting semantics
# differs between Numpy 1, 2, and CuPy.
Expand Down
40 changes: 33 additions & 7 deletions src/aspire/image/xform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from joblib import Memory

from aspire.image import Image
from aspire.utils import crop_pad_2d

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,33 +200,58 @@ class Downsample(LinearXform):
A Xform that downsamples an Image object to a resolution specified by this Xform's resolution.
"""

def __init__(self, resolution, zero_nyquist=True, legacy=False):
def __init__(self, resolution, zero_nyquist=True, centered_fft=True):
"""
Initialize Xform to downsample Image to a specific resolution.

:param resolution: int - new resolution, should be <= the current resolution
of this Image
:param zero_nyquist: Option to keep or remove Nyquist frequency for even
resolution (boolean). Defaults to zero_nyquist=True, removing the Nyquist frequency.
:param legacy: Option to match legacy Matlab downsample method (boolean).
Default of False uses `centered_fft` to maintain ASPIRE-Python centering conventions.
:param centered_fft: Default of True uses `centered_fft` to
maintain ASPIRE-Python centering conventions.
"""
self.resolution = resolution
self.zero_nyquist = zero_nyquist
self.legacy = legacy
self.centered_fft = centered_fft
super().__init__()

def _forward(self, im, indices):
return im.downsample(
self.resolution, zero_nyquist=self.zero_nyquist, legacy=self.legacy
self.resolution,
zero_nyquist=self.zero_nyquist,
centered_fft=self.centered_fft,
)

def _adjoint(self, im, indices):
# TODO: Implement up-sampling with zero-padding
raise NotImplementedError("Adjoint of downsampling not implemented yet.")

def __str__(self):
return f"Downsample (Resolution {self.resolution})"
return f"Downsample (resolution={self.resolution}, zero_nyquist={self.zero_nyquist}, centered_fft={self.centered_fft}) Xform"


class CropPad(Xform):
"""
A Xform that crops or pads an Image object to a specified size.
"""

def __init__(self, L, fill_value=0):
"""
Initialize Xform to crop Image to a specific size.

:param L: int - new size
:param fill_value: Optional value for padding, default 0.
"""
self.L = L
self.fill_value = fill_value
super().__init__()

def _forward(self, im, indices):
return crop_pad_2d(im, self.L, self.fill_value)

def __str__(self):
return f"CropPad({self.L}, {self.fill_value}) Xform"


class LegacyWhiten(Xform):
Expand Down Expand Up @@ -256,7 +282,7 @@ def _forward(self, im, indices):
return im.legacy_whiten(self.psd, self.delta)

def __str__(self):
return "Legacy Whitening Xform."
return "LegacyWhiten() Xform"


class FilterXform(SymmetricXform):
Expand Down
Loading
Loading