diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 1bbe2ba160..57b32041cd 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -8,6 +8,7 @@ from PIL import Image as PILImage from scipy.linalg import lstsq +import aspire.sinogram import aspire.volume from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp @@ -221,24 +222,25 @@ def project(self, angles): original_stack = self.stack_shape # 2-D grid - radial_idx = np.fft.rfftfreq(n_points) * np.pi * 2 + radial_idx = fft.rfftfreq(n_points) * xp.pi * 2 n_real_points = len(radial_idx) n_angles = len(angles) + angles = xp.asarray(angles) - pts = np.empty((2, n_angles, n_real_points), dtype=self.dtype) - pts[0] = radial_idx[np.newaxis, :] * np.sin(angles)[:, np.newaxis] - pts[1] = radial_idx[np.newaxis, :] * np.cos(angles)[:, np.newaxis] + pts = xp.empty((2, n_angles, n_real_points), dtype=self.dtype) + pts[0] = radial_idx[xp.newaxis, :] * xp.sin(angles)[:, xp.newaxis] + pts[1] = radial_idx[xp.newaxis, :] * xp.cos(angles)[:, xp.newaxis] pts = pts.reshape(2, n_real_points * n_angles) # compute the polar nufft (NUFFT) - image_ft = nufft(self.stack_reshape(-1)._data, pts).reshape( + image_ft = nufft(xp.asarray(self.stack_reshape(-1)._data), pts).reshape( self.n_images, n_angles, n_real_points ) # Radon transform, output: (stack size, angles, points) - image_rt = np.fft.fftshift(np.fft.irfft(image_ft, n=n_points, axis=-1), axes=-1) + image_rt = fft.fftshift(fft.irfft(image_ft, n=n_points, axis=-1), axes=-1) image_rt = image_rt.reshape(*original_stack, n_angles, n_points) - return image_rt + return aspire.sinogram.Sinogram(xp.asnumpy(image_rt)) @property def res(self): diff --git a/src/aspire/numeric/cupy_fft.py b/src/aspire/numeric/cupy_fft.py index 6ad6a4e9da..b491a0dcd1 100644 --- a/src/aspire/numeric/cupy_fft.py +++ b/src/aspire/numeric/cupy_fft.py @@ -100,3 +100,14 @@ def dct(self, x, **kwargs): @_preserve_host def idct(self, x, **kwargs): return cufft.idct(x, **kwargs) + + def rfftfreq(self, n, **kwargs): + return cufft.rfftfreq(n, **kwargs) + + @_preserve_host + def irfft(self, x, **kwargs): + return cufft.irfft(x, **kwargs) + + @_preserve_host + def rfft(self, x, **kwargs): + return cufft.rfft(x, **kwargs) diff --git a/src/aspire/numeric/scipy_fft.py b/src/aspire/numeric/scipy_fft.py index 3891d45671..0ef5c95f16 100644 --- a/src/aspire/numeric/scipy_fft.py +++ b/src/aspire/numeric/scipy_fft.py @@ -39,3 +39,12 @@ def dct(self, x, **kwargs): def idct(self, x, **kwargs): return sp.fft.idct(x, **kwargs) + + def rfftfreq(self, x, **kwargs): + return sp.fft.rfftfreq(x, **kwargs) + + def irfft(self, x, **kwargs): + return sp.fft.irfft(x, **kwargs) + + def rfft(self, x, **kwargs): + return sp.fft.rfft(x, **kwargs) diff --git a/src/aspire/sinogram/__init__.py b/src/aspire/sinogram/__init__.py new file mode 100644 index 0000000000..98e489eedf --- /dev/null +++ b/src/aspire/sinogram/__init__.py @@ -0,0 +1 @@ +from .sinogram import Sinogram diff --git a/src/aspire/sinogram/sinogram.py b/src/aspire/sinogram/sinogram.py new file mode 100644 index 0000000000..34451a396d --- /dev/null +++ b/src/aspire/sinogram/sinogram.py @@ -0,0 +1,155 @@ +import logging + +import numpy as np + +import aspire.image +from aspire.nufft import anufft +from aspire.numeric import fft, xp + +logger = logging.getLogger(__name__) + + +class Sinogram: + def __init__(self, data, dtype=None): + """ + Initialize a Sinogram Object. This is a stack of one or more line projections or sinograms. + + The stack can be multidimensional with 'self.n' equal to the product + of the stack dimensions. Singletons will be expanded into a stack + with one entry. + + :param data: Numpy array containing image data with shape + `(..., angles, radial points)`. + :param dtype: Optionally cast `data` to this dtype. + Defaults to `data.dtype`. + + :return: Sinogram instance holding `data`. + """ + if dtype is None: + self.dtype = data.dtype + else: + self.dtype = np.dtype(dtype) + + if data.ndim == 2: + data = data[np.newaxis, :, :] + if data.ndim < 3: + raise ValueError( + f"Invalid data shape: {data.shape}. Expected shape: (..., angles, radial_points), where '...' is the stack number." + ) + + self._data = data.astype(self.dtype, copy=False) + self.ndim = self._data.ndim + self.shape = self._data.shape + self.stack_shape = self._data.shape[:-2] + self.stack_n_dim = self._data.ndim - 2 + self.n = np.product(self.stack_shape) + self.n_angles = self._data.shape[-2] + self.n_radial_points = self._data.shape[-1] + + # Numpy interop + # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol + self.__array_interface__ = self._data.__array_interface__ + self.__array__ = self._data + + def _check_key_dims(self, key): + if isinstance(key, tuple) and (len(key) > self._data.ndim): + raise ValueError( + f"Sinogram stack_dim is {self.stack_n_dim}, slice length must be =< {self.n_dim}" + ) + + def __getitem__(self, key): + self._check_key_dims(key) + return self.__class__(self._data[key]) + + def __setitem__(self, key, value): + self._check_key_dims(key) + self._data[key] = value + + def stack_reshape(self, *args): + """ + Reshape the stack axis. + + :*args: Integer(s) or tuple describing the intended shape. + + :return: Sinogram instance + """ + + # If we're passed a tuple, use that + if len(args) == 1 and isinstance(args[0], tuple): + shape = args[0] + else: + # Otherwise use the variadic args + shape = args + + # Sanity check the size + if shape != (-1,) and np.prod(shape) != self.n: + raise ValueError( + f"Number of sinogram images {self.n} cannot be reshaped to {shape}." + ) + + return self.__class__(self._data.reshape(*shape, *self._data.shape[-2:])) + + def asnumpy(self): + """ + Return image data as a (, angles, radians) + read-only array view. + + :return: read-only ndarray view + """ + + view = self._data.view() + view.flags.writeable = False + return view + + def copy(self): + return self.__class__(self._data.copy()) + + def __str__(self): + return f"Sinogram(n_images = {self.n}, n_angles = {self.n_angles}, n_radial_points = {self.n_radial_points})" + + def __repr__(self): + msg = f"Sinogram: {self.n} images of dtype {self.dtype}, " + msg += f"arranged as a stack with shape {self.stack_shape}. " + msg += f"Each image has {self.n_angles} angles and {self.n_radial_points} radial points." + return msg + + def backproject(self, angles): + """ + Backprojection method for a single stack of lines. + + :param angles: np.ndarray + 1D array of angles in radians. Each entry in the array + corresponds to different angles which are used to + reconstruct the image. + :return: An Image object containing the original stack size + with a newly reconstructed numpy array of the images. + Expected return shape should be (..., n_radial_points, n_radial_points) + """ + if len(angles) != self.n_angles: + raise ValueError("Number of angles must match the number of projections.") + + original_stack_shape = self.stack_shape + sinogram = xp.asarray(self.stack_reshape(-1)._data) + L = self.n_radial_points + sinogram = fft.ifftshift(sinogram, axes=-1) + sinogram_ft = fft.rfft(sinogram, axis=-1) + sinogram_ft *= xp.pi # Fix scale to match + sinogram_ft[..., 0] /= 2 # Fix DC + angles = xp.asarray(angles) + + # grid generation with real points + y_idx = fft.rfftfreq(self.n_radial_points) * xp.pi * 2 + n_real_points = len(y_idx) + pts = xp.empty((2, len(angles), n_real_points), dtype=self.dtype) + pts[0] = y_idx[xp.newaxis, :] * xp.sin(angles)[:, xp.newaxis] + pts[1] = y_idx[xp.newaxis, :] * xp.cos(angles)[:, xp.newaxis] + + imgs = anufft( + sinogram_ft.reshape(self.n, -1), + pts.reshape(2, n_real_points * len(angles)), + sz=(L, L), + real=True, + ).reshape(self.n, L, L) + + imgs = imgs / (self.n_radial_points * len(angles)) + return aspire.image.Image(xp.asnumpy(imgs)).stack_reshape(original_stack_shape) diff --git a/tests/test_sinogram.py b/tests/test_sinogram.py index 56aa6776e0..bbf448db97 100644 --- a/tests/test_sinogram.py +++ b/tests/test_sinogram.py @@ -1,14 +1,16 @@ import numpy as np import pytest from skimage import data -from skimage.transform import radon +from skimage.transform import iradon, radon from aspire.image import Image from aspire.utils import grid_2d -# Relative tolerance comparing line projections to scikit -# The same tolerance will be used in all scikit comparisons -SK_TOL = 0.005 +# Relative tolerance comparing sinogram projections to scikit +# The same tolerance will be used in all scikit forward and backward comparisons +SK_TOL_FORWARDPROJECT = 0.005 + +SK_TOL_BACKPROJECT = 0.0025 IMG_SIZES = [ 511, @@ -60,7 +62,7 @@ def masked_image(dtype, img_size): Creates a masked image fixture using camera data from Scikit-Image. """ g = grid_2d(img_size, normalized=True, shifted=True) - mask = g["r"] < 1 + mask = g["r"] < 0.99 image = data.camera().astype(dtype) image = image[:img_size, :img_size] @@ -68,7 +70,7 @@ def masked_image(dtype, img_size): # Image.project and compare results to skimage.radon -def test_image_project(masked_image, num_ang): +def test_project_single(masked_image, num_ang): """ Test Image.project on a single stack of images. Compares project method output with skimage project. """ @@ -86,18 +88,21 @@ def test_image_project(masked_image, num_ang): # the original author of this method. # # Note, transpose sk output to match (angles, points) - reference_sinogram = radon(masked_image._data[0], theta=angles[::-1]).T + # Note, `radon` does not admit read only views, so the slice is copied. + reference_sinogram = radon(masked_image.asnumpy()[0].copy(), theta=angles[::-1]).T assert reference_sinogram.shape == (len(angles), ny), "Incorrect Shape" # compare project method on ski-image reference - nrms = np.sqrt(np.mean((s[0] - reference_sinogram) ** 2, axis=-1)) / np.linalg.norm( - reference_sinogram, axis=-1 - ) + nrms = np.sqrt( + np.mean((s[0].asnumpy() - reference_sinogram) ** 2, axis=-1) + ) / np.linalg.norm(reference_sinogram, axis=-1) - np.testing.assert_array_less(nrms, SK_TOL, "Error in image projections.") + np.testing.assert_array_less( + nrms, SK_TOL_FORWARDPROJECT, err_msg="Error in image projections." + ) -def test_multidim(num_ang): +def test_project_multidim(num_ang): """ Test Image.project on stacks of images. Extension of test_image_project but for multi-dimensional stacks. """ @@ -130,9 +135,138 @@ def test_multidim(num_ang): np.testing.assert_allclose(s[i, j : j + 1], single_sinogram) # Next individually compute sk's radon transform for each image. - reference_sinograms[i, j] = radon(img._data[0], theta=angles[::-1]).T + # Note, `radon` does not admit read only views, so the slice is copied. + reference_sinograms[i, j] = radon( + img.asnumpy()[0].copy(), theta=angles[::-1] + ).T _nrms = np.sqrt(np.mean((s - reference_sinograms) ** 2, axis=-1)) / np.linalg.norm( reference_sinograms, axis=-1 ) - np.testing.assert_array_less(_nrms, SK_TOL, "Error in image projections.") + np.testing.assert_array_less( + _nrms, SK_TOL_FORWARDPROJECT, err_msg="Error in image projections." + ) + + +def test_backproject_single(masked_image, num_ang): + """ + Test Sinogram.backproject on a single stack of line projections (sinograms). + + This test compares the reconstructed image from the `backproject` method to + the skimage method `iradon.` + """ + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + sinogram = masked_image.project(rads) + sinogram_np = sinogram.asnumpy() + back_project = sinogram.backproject(rads) + + assert masked_image.shape == back_project.shape, "The shape must be the same." + + # generate circular mask w/ radius 1 to reconstructed image + # aim to remove discrepencies for the edges of the image + g = grid_2d(back_project.resolution, normalized=True, shifted=True) + mask = g["r"] < 0.99 + our_back_project = back_project.asnumpy()[0] * mask + + # generating sci-kit image backproject method w/ no filter + sk_image_iradon = iradon(sinogram_np[0].T, theta=-angles, filter_name=None) * mask + + # we apply a normalized root mean square error on the images to find relative error to range of ref. image + nrmse = np.sqrt(np.mean((our_back_project - sk_image_iradon) ** 2)) / ( + np.max(sk_image_iradon) - np.min(sk_image_iradon) + ) + np.testing.assert_array_less( + nrmse, + SK_TOL_BACKPROJECT, + err_msg=f"NRMSE is too high: {nrmse}, expected less than {SK_TOL_BACKPROJECT}", + ) + + +def test_backproject_multidim(num_ang): + """ + Test Sinogram.backproject on a stack of line projections. + + Extension of the `backproject_single` test but checks for multi-dimensional stacks. + """ + L = 512 # pixels + n = 3 + m = 2 + + g = grid_2d(L, normalized=True, shifted=True) + mask = g["r"] < 0.99 + + # Generate images + imgs = Image(np.random.random((m, n, L, L))) * mask + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + + # apply a forward project on the image, then backwards + ours_forward = imgs.project(rads) + ours_backward = ours_forward.backproject(rads) + + # Compare + reference_back_projects = np.empty((m, n, L, L)) + for i in range(m): + for j in range(n): + img = imgs[i, j] + # Compute the singleton case, and compare with stack. + single_sinogram = img.project(rads) + back_project = single_sinogram.backproject(rads) + + # These should be allclose up to determinism. + np.testing.assert_allclose(ours_backward[i, j : j + 1], back_project[0]) + + # Next individually compute sk's iradon transform for each image. + reference_back_projects[i, j] = ( + iradon( + single_sinogram.asnumpy()[0].T, theta=-1 * angles, filter_name=None + ) + * mask + ) + + # apply a mask, then find the NRMSE on the collection of images + # similar tolerance level to single project test + nrmse = np.sqrt( + np.mean( + (ours_backward.asnumpy() * mask - reference_back_projects), axis=(-2, -1) + ) + ** 2 + ) / ( + np.max(reference_back_projects, axis=(-2, -1)) + - np.min(reference_back_projects, axis=(-2, -1)) + ) + + np.testing.assert_array_less( + nrmse, SK_TOL_BACKPROJECT, err_msg="Error with the reconstructed images." + ) + + +# testing the str method +def test_sinogram_str_method(masked_image, num_ang): + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + sinogram = masked_image.project(rads) + n_images = sinogram.n + n_angles = sinogram.n_angles + n_radial_points = sinogram.n_radial_points + expected_str = f"Sinogram(n_images = {n_images}, n_angles = {n_angles}, n_radial_points = {n_radial_points})" + assert str(sinogram) == expected_str + + +# testing the repr method +def test_sinogram_repr_method(masked_image, num_ang): + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + sinogram = masked_image.project(rads) + n_images = sinogram.n + dtype = sinogram.dtype + stack_shape = sinogram.stack_shape + n_angles = sinogram.n_angles + n_radial_points = sinogram.n_radial_points + expected_repr = ( + f"Sinogram: {n_images} images of dtype {dtype}, " + f"arranged as a stack with shape {stack_shape}. " + f"Each image has {n_angles} angles and {n_radial_points} radial points." + ) + assert repr(sinogram) == expected_repr