diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index 648750ac01..cffe6d544e 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -571,7 +571,7 @@ def noise_function(x, y): # Generate several CTFs. ctf_filters = [ - RadialCTFFilter(pixel_size=5, defocus=d) + RadialCTFFilter(pixel_size=vol_ds.pixel_size, defocus=d) for d in np.linspace(defocus_min, defocus_max, defocus_ct) ] diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 8910436de2..77d304b156 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -63,7 +63,7 @@ defocus_ct = 7 ctf_filters = [ - RadialCTFFilter(pixel_size=5, defocus=d) + RadialCTFFilter(pixel_size=vol.pixel_size, defocus=d) for d in np.linspace(defocus_min, defocus_max, defocus_ct) ] diff --git a/gallery/tutorials/tutorials/cov3d_simulation.py b/gallery/tutorials/tutorials/cov3d_simulation.py index 5fced70fbb..741a47de99 100644 --- a/gallery/tutorials/tutorials/cov3d_simulation.py +++ b/gallery/tutorials/tutorials/cov3d_simulation.py @@ -47,7 +47,9 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)], + unique_filters=[ + RadialCTFFilter(pixel_size=10, defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + ], dtype=dtype, ) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index cbab268cf6..1bbe2ba160 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -80,7 +80,7 @@ def load_mrc(filepath): Load raw data from `.mrc` into an array. :param filepath: File path (string). - :return: numpy array of image data. + :return: (numpy array of image data, pixel_size) """ # mrcfile tends to yield many warnings about EMPIAR datasets being corrupt @@ -92,6 +92,7 @@ def load_mrc(filepath): with mrcfile.open(filepath, mode="r", permissive=True) as mrc: im = mrc.data + pixel_size = Image._vx_array_to_size(mrc.voxel_size) # Log each mrcfile warning to debug log, noting the associated file for w in ws: @@ -110,19 +111,29 @@ def load_mrc(filepath): f" Will attempt to continue processing {filepath}" ) - return im + return im, pixel_size def load_tiff(filepath): """ Load raw data from `.tiff` into an array. + Note, TIFF does not natively provide equivalent to pixel/voxel_size, + so users of TIFF files may need to manually assign `pixel_size` to + `Image` instances when required. Defaults to `pixel_size=None`. + :param filepath: File path (string). - :return: numpy array of image data. + :return: (numpy array of image data, pixel_size=None) """ + # Use PIL to open `filepath` + img = PILImage.open(filepath) + + # Future todo, extract `voxel_size` if available in TIFF tags (custom tag?) + # For now, default to `None`. + pixel_size = None - # Use PIL to open `filepath` and cast to numpy array. - return np.array(PILImage.open(filepath)) + # Cast image data as numpy array + return np.array(img), pixel_size class Image: @@ -133,7 +144,7 @@ class Image: ".tiff": load_tiff, } - def __init__(self, data, dtype=None): + def __init__(self, data, pixel_size=None, dtype=None): """ A stack of one or more images. @@ -149,6 +160,10 @@ def __init__(self, data, dtype=None): :param data: Numpy array containing image data with shape `(..., resolution, resolution)`. + :param pixel_size: Optional pixel size in angstroms. + When provided will be saved with `mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param dtype: Optionally cast `data` to this dtype. Defaults to `data.dtype`. @@ -180,6 +195,9 @@ def __init__(self, data, dtype=None): self.stack_shape = self._data.shape[:-2] self.n_images = np.prod(self.stack_shape) self.resolution = self._data.shape[-1] + self.pixel_size = None + if pixel_size is not None: + self.pixel_size = float(pixel_size) # Numpy interop # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol @@ -238,7 +256,7 @@ def _check_key_dims(self, key): def __getitem__(self, key): self._check_key_dims(key) - return self.__class__(self._data[key]) + return self.__class__(self._data[key], pixel_size=self.pixel_size) def __setitem__(self, key, value): self._check_key_dims(key) @@ -266,31 +284,34 @@ def stack_reshape(self, *args): f"Number of images {self.n_images} cannot be reshaped to {shape}." ) - return self.__class__(self._data.reshape(*shape, *self._data.shape[-2:])) + return self.__class__( + self._data.reshape(*shape, *self._data.shape[-2:]), + pixel_size=self.pixel_size, + ) def __add__(self, other): if isinstance(other, Image): other = other._data - return self.__class__(self._data + other) + return self.__class__(self._data + other, pixel_size=self.pixel_size) def __sub__(self, other): if isinstance(other, Image): other = other._data - return self.__class__(self._data - other) + return self.__class__(self._data - other, pixel_size=self.pixel_size) def __mul__(self, other): if isinstance(other, Image): other = other._data - return self.__class__(self._data * other) + return self.__class__(self._data * other, pixel_size=self.pixel_size) def __neg__(self): - return self.__class__(-self._data) + return self.__class__(-self._data, pixel_size=self.pixel_size) def sqrt(self): - return self.__class__(np.sqrt(self._data)) + return self.__class__(np.sqrt(self._data), pixel_size=self.pixel_size) @property def T(self): @@ -312,7 +333,9 @@ def transpose(self): im = self.stack_reshape(-1) imt = np.transpose(im._data, (0, -1, -2)) - return self.__class__(imt).stack_reshape(original_stack_shape) + return self.__class__(imt, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def flip(self, axis=-2): """ @@ -335,11 +358,15 @@ def flip(self, axis=-2): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-3}?" ) - return self.__class__(np.flip(self._data, axis)) + return self.__class__(np.flip(self._data, axis), pixel_size=self.pixel_size) def __repr__(self): + px_msg = "." + if self.pixel_size is not None: + px_msg = f" with pixel_size={self.pixel_size} angstroms." + msg = f"{self.n_images} {self.dtype} images arranged as a {self.stack_shape} stack" - msg += f" each of size {self.resolution}x{self.resolution}." + msg += f" each of size {self.resolution}x{self.resolution}{px_msg}" return msg def asnumpy(self): @@ -355,7 +382,7 @@ def asnumpy(self): return view def copy(self): - return self.__class__(self._data.copy()) + return self.__class__(self._data.copy(), pixel_size=self.pixel_size) def shift(self, shifts): """ @@ -412,7 +439,14 @@ def downsample(self, ds_res, zero_nyquist=True): out = fft.centered_ifft2(crop_fx).real * (ds_res**2 / self.resolution**2) out = xp.asnumpy(out) - return self.__class__(out).stack_reshape(original_stack_shape) + # Optionally scale pixel size + ds_pixel_size = self.pixel_size + if ds_pixel_size is not None: + ds_pixel_size *= self.resolution / ds_res + + return self.__class__(out, pixel_size=ds_pixel_size).stack_reshape( + original_stack_shape + ) def filter(self, filter): """ @@ -441,7 +475,9 @@ def filter(self, filter): im = xp.asnumpy(im.real) - return self.__class__(im).stack_reshape(original_stack_shape) + return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def rotate(self): raise NotImplementedError @@ -453,6 +489,9 @@ def save(self, mrcs_filepath, overwrite=False): with mrcfile.new(mrcs_filepath, overwrite=overwrite) as mrc: # original input format (the image index first) mrc.set_data(self._data.astype(np.float32)) + # Note assigning voxel_size must come after `set_data` + if self.pixel_size is not None: + mrc.voxel_size = self.pixel_size @staticmethod def load(filepath, dtype=None): @@ -477,14 +516,14 @@ def load(filepath, dtype=None): ) # Call the appropriate file reader - im = Image.extensions[ext](filepath) + im, pixel_size = Image.extensions[ext](filepath) # Attempt casting when user provides dtype if dtype is not None: im = im.astype(dtype, copy=False) # Return as Image instance - return Image(im) + return Image(im, pixel_size=pixel_size) def _im_translate(self, shifts): """ @@ -535,7 +574,9 @@ def _im_translate(self, shifts): im_translated = xp.asnumpy(im_translated.real) # Reshape to stack shape - return self.__class__(im_translated).stack_reshape(stack_shape) + return self.__class__(im_translated, pixel_size=self.pixel_size).stack_reshape( + stack_shape + ) def norm(self): return anorm(self._data) @@ -602,7 +643,9 @@ def backproject(self, rot_matrices, symmetry_group=None, zero_nyquist=True): vol /= L - return aspire.volume.Volume(vol, symmetry_group=symmetry_group) + return aspire.volume.Volume( + vol, pixel_size=self.pixel_size, symmetry_group=symmetry_group + ) def show(self, columns=5, figsize=(20, 10), colorbar=True): """ @@ -645,7 +688,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): + def frc(self, other, cutoff=None, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -663,8 +706,6 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): Default `None` implies `cutoff=1` and excludes plotting cutoff line. - :param pixel_size: Pixel size in angstrom. Default `None` - implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -684,7 +725,7 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): frc = FourierRingCorrelation( a=self.asnumpy(), b=other.asnumpy(), - pixel_size=pixel_size, + pixel_size=self.pixel_size, method=method, ) @@ -695,6 +736,32 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): return frc.analyze_correlations(cutoff), frc.correlations + @staticmethod + def _vx_array_to_size(vx): + """ + Utility to convert from several possible `mrcfile.voxel_size` + representations to a single (float) value or None. + """ + + # Convert from recarray to single values, + # checks uniformity. + if isinstance(vx, np.recarray): + if vx.x != vx.y: + raise ValueError(f"Voxel sizes are not uniform: {vx}") + vx = vx.x + + # Convert `0` to `None` + if ( + isinstance(vx, int) or isinstance(vx, float) or isinstance(vx, np.ndarray) + ) and vx == 0: + vx = None + + # Consistently return a `float` when not None + if vx is not None: + vx = float(vx) + + return vx + class CartesianImage(Image): def expand(self, basis): diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index bb7491c780..e75187fb4a 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -403,7 +403,7 @@ def __init__(self, dim=None): class CTFFilter(Filter): def __init__( self, - pixel_size=10, + pixel_size=1, voltage=200, defocus_u=15000, defocus_v=15000, @@ -415,7 +415,7 @@ def __init__( """ A CTF (Contrast Transfer Function) Filter - :param pixel_size: Pixel size in angstrom + :param pixel_size: Pixel size in angstrom, default 1. :param voltage: Electron voltage in kV :param defocus_u: Defocus depth along the u-axis in angstrom :param defocus_v: Defocus depth along the v-axis in angstrom @@ -425,7 +425,7 @@ def __init__( :param B: Envelope decay in inverse square angstrom (default 0) """ super().__init__(dim=2, radial=defocus_u == defocus_v) - self.pixel_size = pixel_size + self.pixel_size = float(pixel_size) self.voltage = voltage self.wavelength = voltage_to_wavelength(self.voltage) self.defocus_u = defocus_u @@ -482,7 +482,7 @@ def scale(self, c=1): class RadialCTFFilter(CTFFilter): def __init__( - self, pixel_size=10, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0 + self, pixel_size=1, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0 ): super().__init__( pixel_size=pixel_size, diff --git a/src/aspire/source/coordinates.py b/src/aspire/source/coordinates.py index dca7aaf873..299422df70 100644 --- a/src/aspire/source/coordinates.py +++ b/src/aspire/source/coordinates.py @@ -490,7 +490,9 @@ def _images(self, indices): cropped = self._crop_micrograph(arr, next(coord)) im[i] = cropped # Finally, apply transforms to resulting Image - return self.generation_pipeline.forward(Image(im), indices) + return self.generation_pipeline.forward( + Image(im, pixel_size=self.pixel_size), indices + ) @staticmethod def _is_number(text): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index fa5be1f7f7..4d256bd414 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -150,7 +150,14 @@ class ImageSource(ABC): _mutable = True def __init__( - self, L, n, dtype="double", metadata=None, memory=None, symmetry_group=None + self, + L, + n, + dtype="double", + metadata=None, + memory=None, + symmetry_group=None, + pixel_size=None, ): """ A cryo-EM ImageSource object that supplies images along with other parameters for image manipulation. @@ -163,6 +170,7 @@ def __init__( The path of the base directory to use as a data store or None. If None is given, no caching is performed. :param symmetry_group: A SymmetryGroup instance or string indicating the underlying symmetry of the molecule. Defaults to the `IdentitySymmetryGroup`, which represents an asymmetric particle, if none provided. + :param pixel_size: Pixel size of the images in angstroms, default `None`. """ # Instantiate the accessor for the `images` property @@ -172,6 +180,9 @@ def __init__( self._n = None self.n = n self.dtype = np.dtype(dtype) + if pixel_size is not None: + pixel_size = float(pixel_size) + self.pixel_size = pixel_size # The private attribute '_cached_im' can be populated by calling this object's cache() method explicitly self._cached_im = None @@ -736,7 +747,7 @@ def _apply_filters( f"_apply_filters() passed {type(im_orig)} instead of Image instance" ) # for now just convert it - im_orig = Image(im_orig) + im_orig = Image(im_orig, pixel_size=self.pixel_size) im = im_orig.copy() @@ -1481,6 +1492,7 @@ def __init__(self, src, indices, memory=None): dtype=src.dtype, metadata=metadata, memory=memory, + pixel_size=src.pixel_size, ) # Create filter indices, these are required to pass unharmed through filter eval code @@ -1650,7 +1662,9 @@ class ArrayImageSource(ImageSource): if available, is consulted directly by the parent class, bypassing `_images`. """ - def __init__(self, im, metadata=None, angles=None, symmetry_group=None): + def __init__( + self, im, metadata=None, angles=None, symmetry_group=None, pixel_size=None + ): """ Initialize from an `Image` object. @@ -1659,12 +1673,13 @@ def __init__(self, im, metadata=None, angles=None, symmetry_group=None): :param metadata: A Dataframe of metadata information corresponding to this ImageSource's images :param angles: Optional n-by-3 array of rotation angles corresponding to `im`. :param symmetry_group: A SymmetryGroup instance or string indicating the underlying symmetry of the molecule. + :param pixel_size: Pixel size of the images in angstroms, default `None`. """ if not isinstance(im, Image): logger.info("Attempting to create an Image object from Numpy array.") try: - im = Image(im) + im = Image(im, pixel_size=pixel_size) except Exception as e: raise RuntimeError( "Creating Image object from Numpy array failed." @@ -1678,6 +1693,7 @@ def __init__(self, im, metadata=None, angles=None, symmetry_group=None): metadata=metadata, memory=None, symmetry_group=symmetry_group, + pixel_size=im.pixel_size, ) self._cached_im = im diff --git a/src/aspire/source/micrograph.py b/src/aspire/source/micrograph.py index 182133d982..2d654401b5 100644 --- a/src/aspire/source/micrograph.py +++ b/src/aspire/source/micrograph.py @@ -17,11 +17,14 @@ class MicrographSource(ABC): - def __init__(self, micrograph_count, micrograph_size, dtype): + def __init__(self, micrograph_count, micrograph_size, dtype, pixel_size=None): """ """ self.micrograph_count = int(micrograph_count) self.micrograph_size = int(micrograph_size) self.dtype = np.dtype(dtype) + if pixel_size is not None: + pixel_size = float(pixel_size) + self.pixel_size = pixel_size self._images_accessor = _ImageAccessor(self._images, self.micrograph_count) @@ -85,7 +88,7 @@ def show(self, *args, **kwargs): """ Helper function to display micrograph. See Image.show(). """ - Image(self.asnumpy()).show(*args, **kwargs) + Image(self.asnumpy(), pixel_size=self.pixel_size).show(*args, **kwargs) @property def images(self): @@ -107,7 +110,7 @@ def _images(self, indices): class ArrayMicrographSource(MicrographSource): - def __init__(self, micrographs, dtype=None): + def __init__(self, micrographs, dtype=None, pixel_size=None): """ Instantiate a `MicrographSource` with `micrographs`. @@ -119,6 +122,7 @@ def __init__(self, micrographs, dtype=None): Currently only `float32` and `float64` are supported. Note, due to limitations of common MRC implementations, saving is limited to single precision. + :param pixel_size: Pixel size of the images in angstroms, default `None`. """ # Check micrographs is an array @@ -140,6 +144,7 @@ def __init__(self, micrographs, dtype=None): micrograph_count=micrographs.shape[0], micrograph_size=micrographs.shape[-1], dtype=dtype or micrographs.dtype, + pixel_size=pixel_size, ) # We're already backed by an array, access it directly. @@ -152,11 +157,11 @@ def _images(self, indices): :param indices: A 1-D Numpy array of integer indices. :return: An array backed `MicrographSource` object representing the micrographs for `indices`. """ - return Image(self._data[indices]) + return Image(self._data[indices], pixel_size=self.pixel_size) class DiskMicrographSource(MicrographSource): - def __init__(self, micrographs_path, dtype=None): + def __init__(self, micrographs_path, dtype=None, pixel_size=None): """ Instantiate a `MicrographSource` with `micrographs_path`. @@ -190,11 +195,16 @@ def __init__(self, micrographs_path, dtype=None): # Load the first micrograph to infer shape/type # Size will be checked during on-the-fly loading of subsequent micrographs. micrograph0 = Image.load(self.micrograph_files[0]) + if micrograph0.pixel_size is not None and micrograph0.pixel_size != pixel_size: + raise ValueError( + f"Mismatched pixel size. {micrograph0.pixel_size} angstroms defined in {self.micrograph_files[0]}, but provided {pixel_size} angstroms." + ) super().__init__( micrograph_count=len(self.micrograph_files), micrograph_size=micrograph0.resolution, dtype=dtype or micrograph0.dtype, + pixel_size=pixel_size, ) # Prepare accessor to load files from disk on the fly. @@ -262,8 +272,16 @@ def _images(self, indices): ) # Assign to array, implicitly performs casting to dtype micrographs[i] = micrograph.asnumpy() + # Assert pixel_size + if ( + micrograph.pixel_size is not None + and micrograph.pixel_size != self.pixel_size + ): + raise ValueError( + f"Mismatched pixel size. {micrograph.pixel_size} angstroms defined in {self.micrograph_files[ind]}, but provided {self.pixel_size} angstroms." + ) - return Image(micrographs) + return Image(micrographs, pixel_size=self.pixel_size) class MicrographSimulation(MicrographSource): @@ -557,7 +575,7 @@ def _clean_images(self, indices): self.pad : self.micrograph_size + self.pad, self.pad : self.micrograph_size + self.pad, ] - return Image(clean_micrograph) + return Image(clean_micrograph, pixel_size=self.pixel_size) def get_micrograph_index(self, particle_index): """ diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 99907cbf6a..bd6d660dd3 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -59,7 +59,6 @@ def __init__( self.filepath = filepath self.data_folder = data_folder - self.pixel_size = pixel_size self.B = B self.n_workers = n_workers self.max_rows = max_rows @@ -112,6 +111,7 @@ def __init__( metadata=metadata, symmetry_group=symmetry_group, memory=memory, + pixel_size=pixel_size, ) # CTF estimation parameters coming from Relion @@ -272,4 +272,6 @@ def load_single_mrcs(filepath, indices): logger.debug(f"Loading {len(indices)} images complete") # Finally, apply transforms to resulting Image - return self.generation_pipeline.forward(Image(im), indices) + return self.generation_pipeline.forward( + Image(im, pixel_size=self.pixel_size), indices + ) diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index 304d5be56d..331b86e442 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -50,6 +50,7 @@ def __init__( memory=None, noise_adder=None, symmetry_group=None, + pixel_size=None, ): """ A `Simulation` object that supplies images along with other parameters for image manipulation. @@ -79,6 +80,7 @@ def __init__( :param noise_adder: Optionally append instance of `NoiseAdder` to generation pipeline. :param symmetry_group: A SymmetryGroup instance or string indicating symmetry of the molecule. + :param pixel_size: Pixel size of the images in angstroms, default `None`. :return: A Simulation object. """ @@ -91,6 +93,7 @@ def __init__( self.vols = AsymmetricVolume( L=L or 8, C=C, + pixel_size=pixel_size, seed=self.seed, dtype=dtype or np.float32, ).generate() @@ -122,6 +125,7 @@ def __init__( dtype=self.vols.dtype, memory=memory, symmetry_group=symmetry_group, + pixel_size=self.vols.pixel_size, ) # If a user provides both `L` and `vols`, resolution should match. @@ -153,6 +157,7 @@ def __init__( if unique_filters is None: unique_filters = [] self.unique_filters = unique_filters + self._check_filter_pixel_size(unique_filters) # sim_filters must be a deep copy so that it is not changed # when unique_filters is changed self.sim_filters = copy.deepcopy(unique_filters) @@ -231,6 +236,29 @@ def _populate_ctf_metadata(self, filter_indices): filter_values, ) + def _check_filter_pixel_size(self, unique_filters): + """ + Private method to ensure user provided filters match `Simulation` pixel size. + + When `Simulation.pixel_size` is not `None`, any + `unique_filters` having a non-matching `pixel_size` attribute + will raise. + """ + + # Skip when Simulation pixel_size is not explicitly provided. + if self.pixel_size is None: + return + + for f in unique_filters: + f_pixel_size = getattr(f, "pixel_size", None) + if f_pixel_size is not None and not np.isclose( + f_pixel_size, self.pixel_size + ): + raise ValueError( + f"`Simulation.pixel_size` {self.pixel_size} does not match filter {f} pixel size {f_pixel_size}." + "Ensure provided `pixel_size` attributes match." + ) + @property def projections(self): """ @@ -260,7 +288,7 @@ def _projections(self, indices): im_k = self.vols[k - 1].project(rot_matrices=rot) im[idx_k, :, :] = im_k.asnumpy() - return Image(im) + return Image(im, pixel_size=self.pixel_size) @property def clean_images(self): diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 7e131a0190..5e2212e958 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -58,7 +58,7 @@ class Volume: Volume is an (N1 x ...) x L x L x L array, along with associated utility methods. """ - def __init__(self, data, dtype=None, symmetry_group=None): + def __init__(self, data, dtype=None, pixel_size=None, symmetry_group=None): """ A stack of one or more volumes. @@ -76,6 +76,10 @@ def __init__(self, data, dtype=None, symmetry_group=None): `(..., resolution, resolution, resolution)`. :param dtype: Optionally cast `data` to this dtype. Defaults to `data.dtype`. + :param pixel_size: Optional voxel_size in angstroms. + When provided will be saved with `map`/`mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param symmetry_group: A SymmetryGroup instance or string indicating symmetry of the Volume. :return: A Volume instance holding `data`. @@ -107,6 +111,9 @@ def __init__(self, data, dtype=None, symmetry_group=None): self.n_vols = np.prod(self.stack_shape) self.resolution = self._data.shape[-1] self.size = self._data.size + self.pixel_size = None + if pixel_size is not None: + self.pixel_size = float(pixel_size) # Set symmetry_group. If None, default to 'C1'. self._set_symmetry_group(symmetry_group) @@ -140,7 +147,9 @@ def astype(self, dtype, copy=True): :return: Volume instance """ return self.__class__( - self.asnumpy().astype(dtype, copy=copy), symmetry_group=self.symmetry_group + self.asnumpy().astype(dtype, copy=copy), + pixel_size=self.pixel_size, + symmetry_group=self.symmetry_group, ) def _check_key_dims(self, key): @@ -151,7 +160,11 @@ def _check_key_dims(self, key): def __getitem__(self, key): self._check_key_dims(key) - return self.__class__(self._data[key], symmetry_group=self.symmetry_group) + return self.__class__( + self._data[key], + pixel_size=self.pixel_size, + symmetry_group=self.symmetry_group, + ) def __setitem__(self, key, value): self._check_key_dims(key) @@ -242,14 +255,19 @@ def stack_reshape(self, *args): return self.__class__( self._data.reshape(*shape, *self._data.shape[-3:]), + pixel_size=self.pixel_size, symmetry_group=self.symmetry_group, ) def __repr__(self): + px_msg = "." + if self.pixel_size is not None: + px_msg = f" with pixel_size={self.pixel_size} angstroms." + msg = ( f"{self.n_vols} {self.dtype} volumes arranged as a {self.stack_shape} stack" ) - msg += f" each of size {self.resolution}x{self.resolution}x{self.resolution}." + msg += f" each of size {self.resolution}x{self.resolution}x{self.resolution}{px_msg}" return msg def __len__(self): @@ -258,9 +276,15 @@ def __len__(self): def __add__(self, other): symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data + other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data + other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data + other, symmetry_group=symmetry) + res = self.__class__( + self._data + other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res @@ -270,21 +294,37 @@ def __radd__(self, otherL): def __sub__(self, other): symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data - other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data - other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data - other, symmetry_group=symmetry) + res = self.__class__( + self._data - other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res def __rsub__(self, otherL): - return self.__class__(otherL - self._data) + return self.__class__( + otherL - self._data, + pixel_size=self.pixel_size, + symmetry_group=self.symmetry_group, + ) def __mul__(self, other): symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data * other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data * other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data * other, symmetry_group=symmetry) + res = self.__class__( + self._data * other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res @@ -297,9 +337,15 @@ def __truediv__(self, other): """ symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data / other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data / other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data / other, symmetry_group=symmetry) + res = self.__class__( + self._data / other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res @@ -307,7 +353,10 @@ def __rtruediv__(self, otherL): """ Right scalar division, follows numpy semantics. """ - return otherL * Volume(1.0 / self._data) + return otherL * Volume( + 1.0 / self._data, + pixel_size=self.pixel_size, + ) def project(self, rot_matrices, zero_nyquist=True): """ @@ -374,8 +423,7 @@ def project(self, rot_matrices, zero_nyquist=True): im_f[:, :, 0] = 0 im_f = fft.centered_ifft2(im_f) - - return aspire.image.Image(xp.asnumpy(im_f.real)) + return aspire.image.Image(xp.asnumpy(im_f.real), pixel_size=self.pixel_size) def to_vec(self): """Returns an N x resolution ** 3 array.""" @@ -419,7 +467,7 @@ def transpose(self): v = self._data.reshape(-1, *self._data.shape[-3:]) vt = np.transpose(v, (0, -1, -2, -3)) vt = vt.reshape(*original_stack_shape, *self._data.shape[-3:]) - return self.__class__(vt, symmetry_group=symmetry) + return self.__class__(vt, pixel_size=self.pixel_size, symmetry_group=symmetry) @property def T(self): @@ -462,7 +510,11 @@ def flip(self, axis=-3): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?" ) - return self.__class__(np.flip(self._data, axis), symmetry_group=symmetry) + return self.__class__( + np.flip(self._data, axis), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) def downsample(self, ds_res, mask=None, zero_nyquist=True): """ @@ -497,9 +549,15 @@ def downsample(self, ds_res, mask=None, zero_nyquist=True): out = fft.centered_ifftn(fx) out = out.real * (ds_res**3 / self.resolution**3) + # Optionally scale pixel size + ds_pixel_size = self.pixel_size + if ds_pixel_size is not None: + ds_pixel_size *= self.resolution / ds_res + # returns a new Volume object return self.__class__( xp.asnumpy(out), + pixel_size=ds_pixel_size, symmetry_group=self.symmetry_group, ).stack_reshape(original_stack_shape) @@ -572,7 +630,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1))) ) - return self.__class__(vol, symmetry_group=symmetry) + return self.__class__(vol, pixel_size=self.pixel_size, symmetry_group=symmetry) def denoise(self): raise NotImplementedError @@ -593,6 +651,9 @@ def save(self, filename, overwrite=False): with mrcfile.new(filename, overwrite=overwrite) as mrc: mrc.set_data(self._data.astype(np.float32)) + # Note assigning voxel_size must come after `set_data` + if self.pixel_size is not None: + mrc.voxel_size = self.pixel_size if self.dtype != np.float32: logger.info(f"Volume with dtype {self.dtype} saved with dtype float32") @@ -612,6 +673,7 @@ def load(cls, filename, permissive=True, dtype=None, symmetry_group=None): """ with mrcfile.open(filename, permissive=permissive) as mrc: loaded_data = mrc.data + pixel_size = Volume._vx_array_to_size(mrc.voxel_size) # FINUFFT work around if loaded_data.dtype == np.float32: @@ -622,9 +684,14 @@ def load(cls, filename, permissive=True, dtype=None, symmetry_group=None): if loaded_data.dtype != dtype: logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") - return cls(loaded_data, symmetry_group=symmetry_group, dtype=dtype) + return cls( + loaded_data, + pixel_size=pixel_size, + symmetry_group=symmetry_group, + dtype=dtype, + ) - def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): + def fsc(self, other, cutoff=None, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -641,8 +708,6 @@ def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): :param cutoff: Cutoff value, traditionally `.143`. Default `None` implies `cutoff=1` and excludes plotting cutoff line. - :param pixel_size: Pixel size in angstrom. Default `None` - implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -662,7 +727,7 @@ def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): fsc = FourierShellCorrelation( a=self.asnumpy(), b=other.asnumpy(), - pixel_size=pixel_size, + pixel_size=self.pixel_size, method=method, ) @@ -681,7 +746,7 @@ def empty_like(v): :param v: Volume instance :return: Volume instance """ - return Volume(np.empty(v.shape, dtype=v.dtype)) + return Volume(np.empty(v.shape, dtype=v.dtype), pixel_size=v.pixel_size) @staticmethod def zeros_like(v): @@ -691,7 +756,33 @@ def zeros_like(v): :param v: Volume instance :return: Volume instance """ - return Volume(np.zeros(v.shape, dtype=v.dtype)) + return Volume(np.zeros(v.shape, dtype=v.dtype), pixel_size=v.pixel_size) + + @staticmethod + def _vx_array_to_size(vx): + """ + Utility to convert from several possible `mrcfile.voxel_size` + representations to a single (float) value or None. + """ + + # Convert from recarray to single values, + # checks uniformity. + if isinstance(vx, np.recarray): + if vx.x != vx.y or vx.x != vx.z: + raise ValueError(f"Voxel sizes are not uniform: {vx}") + vx = vx.x + + # Convert `0` to `None` + if ( + isinstance(vx, int) or isinstance(vx, float) or isinstance(vx, np.ndarray) + ) and vx == 0: + vx = None + + # Consistently return a `float` when not None + if vx is not None: + vx = float(vx) + + return vx class CartesianVolume(Volume): diff --git a/src/aspire/volume/volume_synthesis.py b/src/aspire/volume/volume_synthesis.py index b9514df5ea..43f794bfaf 100644 --- a/src/aspire/volume/volume_synthesis.py +++ b/src/aspire/volume/volume_synthesis.py @@ -16,11 +16,12 @@ class SyntheticVolumeBase(abc.ABC): - def __init__(self, L, C, seed=None, dtype=np.float64): + def __init__(self, L, C, pixel_size=None, seed=None, dtype=np.float64): self.L = L self.C = C self.seed = seed self.dtype = dtype + self.pixel_size = pixel_size @abc.abstractmethod def generate(self): @@ -39,18 +40,24 @@ class GaussianBlobsVolume(SyntheticVolumeBase): A base class for all volumes which are generated with randomized 3D Gaussians. """ - def __init__(self, L, C, K=16, alpha=1, seed=None, dtype=np.float64): + def __init__( + self, L, C, K=16, alpha=1, pixel_size=None, seed=None, dtype=np.float64 + ): """ :param L: Resolution of the Volume(s) in pixels. :param C: Number of Volumes to generate. :param K: Number of Gaussian blobs used to construct the Volume(s). :param alpha: Scaling factor for variance of Gaussian blobs. Default=1. + :param pixel_size: Optional voxel_size in angstroms. + When provided will be saved with `map`/`mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param seed: Random seed for generating random Gaussian blobs. :param dtype: dtype for Volume(s) """ self.K = int(K) self.alpha = float(alpha) - super().__init__(L=L, C=C, seed=seed, dtype=dtype) + super().__init__(L=L, C=C, pixel_size=pixel_size, seed=seed, dtype=dtype) self._set_symmetry_group() @abc.abstractproperty @@ -75,7 +82,11 @@ def generate(self): """ vol = self._gaussian_blob_vols() bump_mask = bump_3d(self.L, spread=5, dtype=self.dtype) - return Volume(bump_mask * vol, symmetry_group=self.symmetry_group) + return Volume( + bump_mask * vol, + symmetry_group=self.symmetry_group, + pixel_size=self.pixel_size, + ) def _gaussian_blob_vols(self): """ @@ -168,18 +179,26 @@ class CnSymmetricVolume(GaussianBlobsVolume): A Volume object with cyclically symmetric volumes constructed of random 3D Gaussian blobs. """ - def __init__(self, L, C, order, K=16, alpha=1, seed=None, dtype=np.float64): + def __init__( + self, L, C, order, K=16, alpha=1, pixel_size=None, seed=None, dtype=np.float64 + ): """ :param L: Resolution of the Volume(s) in pixels. :param C: Number of Volumes to generate. :param order: An integer representing the cyclic order of the Volume(s). :param K: Number of Gaussian blobs used to construct the Volume(s). + :param pixel_size: Optional voxel_size in angstroms. + When provided will be saved with `map`/`mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param seed: Random seed for generating random Gaussian blobs. :param dtype: dtype for Volume(s) """ self.order = int(order) self._check_order() - super().__init__(L=L, C=C, K=K, alpha=alpha, seed=seed, dtype=dtype) + super().__init__( + L=L, C=C, K=K, alpha=alpha, pixel_size=pixel_size, seed=seed, dtype=dtype + ) def _check_order(self): if self.order < 2: @@ -239,8 +258,10 @@ class AsymmetricVolume(CnSymmetricVolume): An asymmetric Volume constructed of random 3D Gaussian blobs with compact support in the unit sphere. """ - def __init__(self, L, C, K=64, seed=None, dtype=np.float64): - super().__init__(L=L, C=C, K=K, order=1, seed=seed, dtype=dtype) + def __init__(self, L, C, K=64, pixel_size=None, seed=None, dtype=np.float64): + super().__init__( + L=L, C=C, K=K, order=1, pixel_size=pixel_size, seed=seed, dtype=dtype + ) def _check_order(self): if self.order != 1: @@ -260,8 +281,8 @@ class LegacyVolume(AsymmetricVolume): An asymmetric Volume object used for testing of legacy code. """ - def __init__(self, L, C=2, K=16, seed=0, dtype=np.float64): - super().__init__(L=L, C=C, K=K, seed=seed, dtype=dtype) + def __init__(self, L, C=2, K=16, pixel_size=None, seed=0, dtype=np.float64): + super().__init__(L=L, C=C, K=K, pixel_size=pixel_size, seed=seed, dtype=dtype) def generate(self): """ @@ -272,4 +293,4 @@ def generate(self): # Swap axes to retain Legacy xyz-indexing. vols = np.swapaxes(vols, 1, 3) - return Volume(vols) + return Volume(vols, pixel_size=self.pixel_size) diff --git a/tests/test_anisotropic_noise.py b/tests/test_anisotropic_noise.py index 2fd1d13ca8..caaedc4aff 100644 --- a/tests/test_anisotropic_noise.py +++ b/tests/test_anisotropic_noise.py @@ -20,7 +20,9 @@ def setUp(self): n=1024, vols=self.vol, unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + # Set legacy pixel size + RadialCTFFilter(pixel_size=10, defocus=d) + for d in np.linspace(1.5e4, 2.5e4, 7) ], dtype=self.dtype, ) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 305f6e4ec4..4c990212ce 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -90,6 +90,9 @@ def test_downsample_2d_case(L, L_ds): assert (N, L_ds, L_ds) == imgs_ds.shape # check center points for all images assert checkCenterPoint(imgs_org, imgs_ds) + # Confirm default `pixel_size` + assert imgs_org.pixel_size is None + assert imgs_ds.pixel_size is None @pytest.mark.parametrize("L", [65, 66]) @@ -103,6 +106,9 @@ def test_downsample_3d_case(L, L_ds): assert checkCenterPoint(vols_org, vols_ds) # check signal energy is conserved assert checkSignalEnergy(vols_org, vols_ds) + # Confirm default `pixel_size` + assert vols_org.pixel_size is None + assert vols_ds.pixel_size is None def test_integer_offsets(): @@ -155,3 +161,25 @@ def test_downsample_project(volume, res_ds): if volume.dtype == np.float64: tol = 1e-09 np.testing.assert_allclose(im_ds_proj, im_proj_ds, atol=tol) + + +def test_pixel_size(): + """ + Test downsampling is rescaling the `pixel_size` attribute. + """ + # Image sizes in pixels + L = 8 # original + dsL = 5 # downsampled + + # Construct a small test Image + img = Image(np.random.random((1, L, L)).astype(DTYPE, copy=False), pixel_size=1.23) + + # Downsample the image + result = img.downsample(dsL) + + # Confirm the pixel size is scaled + np.testing.assert_approx_equal( + result.pixel_size, + img.pixel_size * L / dsL, + err_msg="Incorrect pixel size.", + ) diff --git a/tests/test_filters.py b/tests/test_filters.py index 911e3b347b..b0b23bb74f 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -118,7 +118,8 @@ def testRadialCTFFilter(self): self.assertEqual(result.shape, (256,)) def testRadialCTFFilterGrid(self): - filter = RadialCTFFilter(defocus=2.5e4) + # Set legacy pixel size + filter = RadialCTFFilter(pixel_size=10, defocus=2.5e4) result = filter.evaluate_grid(8, dtype=self.dtype) self.assertEqual(result.shape, (8, 8)) @@ -218,7 +219,10 @@ def testRadialCTFFilterGrid(self): ) def testRadialCTFFilterMultiplierGrid(self): - filter = RadialCTFFilter(defocus=2.5e4) * RadialCTFFilter(defocus=2.5e4) + # Set legacy pixel size + filter = RadialCTFFilter(pixel_size=10, defocus=2.5e4) * RadialCTFFilter( + pixel_size=10, defocus=2.5e4 + ) result = filter.evaluate_grid(8, dtype=self.dtype) self.assertEqual(result.shape, (8, 8)) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 282089ce12..79240572f8 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -115,7 +115,7 @@ def volume_fixture(img_size, dtype): def test_frc_id(image_fixture, method): img, _, _ = image_fixture - frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0], 2, rtol=0.02) assert np.allclose(frc, 1, rtol=0.01) @@ -123,14 +123,14 @@ def test_frc_id(image_fixture, method): def test_frc_trunc(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype - frc_resolution, frc = img_a.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img_a.frc(img_b, cutoff=0.143, method=method) assert frc_resolution[0] > 3.0 def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture - frc_resolution, frc = img_a.frc(img_n, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img_a.frc(img_n, cutoff=0.143, method=method) assert frc_resolution[0] > 3.5 @@ -142,13 +142,13 @@ def test_frc_img_plot(image_fixture): # Plot to screen with matplotlib_no_gui(): - _ = img_a.frc(img_n, pixel_size=1, cutoff=0.143, plot=True) + _ = img_a.frc(img_n, cutoff=0.143, plot=True) # Plot to file # Also tests `cutoff=None` with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_a.frc(img_n, pixel_size=1, cutoff=None, plot=file_path) + img_a.frc(img_n, cutoff=None, plot=file_path) assert os.path.exists(file_path) @@ -160,9 +160,7 @@ def test_frc_plot(image_fixture, method): """ img_a, img_b, _ = image_fixture - frc = FourierRingCorrelation( - img_a.asnumpy(), img_b.asnumpy(), pixel_size=1, method=method - ) + frc = FourierRingCorrelation(img_a.asnumpy(), img_b.asnumpy(), method=method) with matplotlib_no_gui(): frc.plot(cutoff=0.5) @@ -178,7 +176,7 @@ def test_frc_plot(image_fixture, method): def test_fsc_id(volume_fixture, method): vol, _ = volume_fixture - fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) + fsc_resolution, fsc = vol.fsc(vol, cutoff=0.143, method=method) assert np.isclose(fsc_resolution[0], 2, rtol=0.02) assert np.allclose(fsc, 1, rtol=0.01) @@ -186,11 +184,11 @@ def test_fsc_id(volume_fixture, method): def test_fsc_trunc(volume_fixture, method): vol_a, vol_b = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) + fsc_resolution, fsc = vol_a.fsc(vol_b, cutoff=0.143, method=method) assert fsc_resolution[0] > 3.0 # The follow should correspond to the test_fsc_plot below. - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) + fsc_resolution, fsc = vol_a.fsc(vol_b, cutoff=0.5, method=method) assert fsc_resolution[0] > 3.9 @@ -202,13 +200,13 @@ def test_fsc_vol_plot(volume_fixture): # Plot to screen with matplotlib_no_gui(): - _ = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, plot=True) + _ = vol_a.fsc(vol_b, cutoff=0.5, plot=True) # Plot to file # Also tests `cutoff=None` with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "vol_fsc_curve.png") - vol_a.fsc(vol_b, pixel_size=1, cutoff=None, plot=file_path) + vol_a.fsc(vol_b, cutoff=None, plot=file_path) assert os.path.exists(file_path) @@ -218,9 +216,7 @@ def test_fsc_plot(volume_fixture, method): """ vol_a, vol_b = volume_fixture - fsc = FourierShellCorrelation( - vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, method=method - ) + fsc = FourierShellCorrelation(vol_a.asnumpy(), vol_b.asnumpy(), method=method) with matplotlib_no_gui(): fsc.plot(cutoff=0.5) @@ -306,7 +302,7 @@ def test_img_type_mismatch(): b = a.asnumpy() with pytest.raises(TypeError, match=r"`other` image must be an `Image` instance"): - _ = a.frc(b, pixel_size=1, cutoff=0.143) + _ = a.frc(b, cutoff=0.143) def test_vol_type_mismatch(): @@ -314,7 +310,7 @@ def test_vol_type_mismatch(): b = a.asnumpy() with pytest.raises(TypeError, match=r"`other` volume must be an `Volume` instance"): - _ = a.fsc(b, pixel_size=1, cutoff=0.143) + _ = a.fsc(b, cutoff=0.143) # Broadcasting @@ -329,7 +325,7 @@ def test_frc_id_bcast(image_fixture, method): k = 3 img_b = Image(np.tile(img, (3, 1, 1))) - frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img_b, cutoff=0.143, method=method) assert np.allclose( frc_resolution, [ @@ -344,7 +340,7 @@ def test_frc_id_bcast(image_fixture, method): # (1) x (1,3) img_b = img_b.stack_reshape(1, 3) - frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img_b, cutoff=0.143, method=method) assert np.allclose( frc_resolution, [ @@ -359,7 +355,7 @@ def test_frc_id_bcast(image_fixture, method): # (1) x (3,1) img_b = img_b.stack_reshape(3, 1) - frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img_b, cutoff=0.143, method=method) assert np.allclose( frc_resolution, [ @@ -378,7 +374,7 @@ def test_fsc_id_bcast(volume_fixture, method): k = 3 vol_b = Volume(np.tile(vol.asnumpy(), (3, 1, 1, 1))) - fsc_resolution, fsc = vol.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) + fsc_resolution, fsc = vol.fsc(vol_b, cutoff=0.143, method=method) assert np.allclose( fsc_resolution, [ @@ -400,12 +396,12 @@ def test_frc_img_plot_bcast(image_fixture): # Plot to screen, one:many with matplotlib_no_gui(): - _ = img_a.frc(img_b, pixel_size=1, cutoff=0.143, plot=True) + _ = img_a.frc(img_b, cutoff=0.143, plot=True) # Plot to file, many elementwise with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_b.frc(img_b, pixel_size=1, cutoff=0.143, plot=file_path) + img_b.frc(img_b, cutoff=0.143, plot=file_path) assert os.path.exists(file_path) diff --git a/tests/test_image.py b/tests/test_image.py index 688d4169ec..887e726c0d 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -24,8 +24,22 @@ n = 3 mdim = 2 +PARITY = [0, 1] +DTYPES = [np.float32, np.float64] -def get_images(parity=0, dtype=np.float32): + +@pytest.fixture(params=PARITY, ids=lambda x: f"parity={x}", scope="module") +def parity(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def get_images(parity, dtype): size = 768 - parity # numpy array for top-level functions that directly expect it im_np = face(gray=True).astype(dtype)[np.newaxis, :size, :size] @@ -33,36 +47,40 @@ def get_images(parity=0, dtype=np.float32): im_np /= denom # Normalize test image data to 0,1 # Independent Image object for testing Image methods - im = Image(im_np.copy()) + im = Image(im_np.copy(), pixel_size=1.23) return im_np, im -def get_stacks(parity=0, dtype=np.float32): - im_np, im = get_images(parity, dtype) +@pytest.fixture(scope="module") +def get_stacks(get_images, dtype): + im_np, im = get_images # Construct a simple stack of Images - ims_np = np.empty((n, *im_np.shape[1:]), dtype=dtype) + ims_np = np.empty((n, *im_np.shape[1:]), dtype=im_np.dtype) for i in range(n): ims_np[i] = im_np * (i + 1) / float(n) # Independent Image stack object for testing Image methods - ims = Image(ims_np) + ims = Image(ims_np.copy()) return ims_np, ims -def get_mdim_images(parity=0, dtype=np.float32): - ims_np, im = get_stacks(parity, dtype) +# Note that `get_mdim_images` is mutated by some tests, +# force per function scope. +@pytest.fixture(scope="function") +def get_mdim_images(get_stacks): + ims_np, im = get_stacks # Multi dimensional stack Image object mdim = 2 mdim_ims_np = np.concatenate([ims_np] * mdim).reshape(mdim, *ims_np.shape) # Independent multidimensional Image stack object for testing Image methods - mdim_ims = Image(mdim_ims_np) + mdim_ims = Image(mdim_ims_np.copy()) return mdim_ims_np, mdim_ims -def testRepr(): - _, mdim_ims = get_mdim_images() +def testRepr(get_mdim_images): + _, mdim_ims = get_mdim_images r = repr(mdim_ims) logger.info(f"Image repr:\n{r}") @@ -73,9 +91,8 @@ def testNonSquare(): _ = Image(np.empty((4, 5))) -@pytest.mark.parametrize("parity,dtype", params) -def testImShift(parity, dtype): - im_np, im = get_images(parity, dtype) +def testImShift(get_images, dtype): + im_np, im = get_images # Note that the _im_translate method can handle float input shifts, as it # computes the shifts in Fourier space, rather than performing a roll # However, NumPy's roll() only accepts integer inputs @@ -101,10 +118,8 @@ def testImShift(parity, dtype): np.testing.assert_allclose(im0.asnumpy()[0, :, :], im3, atol=atol) -@pytest.mark.parametrize("parity,dtype", params) -def testImShiftStack(parity, dtype): - ims_np, ims = get_stacks(parity, dtype) - +def testImShiftStack(get_stacks, dtype): + ims_np, ims = get_stacks # test stack of shifts (same number as Image.num_img) # mix of odd and even shifts = np.array([[100, 200], [203, 150], [55, 307]]) @@ -131,8 +146,8 @@ def testImShiftStack(parity, dtype): np.testing.assert_allclose(im0.asnumpy(), im3, atol=atol) -def testImageShiftErrors(): - _, im = get_images(0, np.float32) +def testImageShiftErrors(get_images): + _, im = get_images # test bad shift shape with pytest.raises(ValueError, match="Input shifts must be of shape"): _ = im.shift(np.array([100, 100, 100])) @@ -141,18 +156,16 @@ def testImageShiftErrors(): _ = im.shift(np.array([[100, 200], [100, 200]])) -@pytest.mark.parametrize("parity,dtype", params) -def testImageSqrt(parity, dtype): - im_np, im = get_images(parity, dtype) - ims_np, ims = get_stacks(parity, dtype) +def testImageSqrt(get_images, get_stacks): + im_np, im = get_images + ims_np, ims = get_stacks assert np.allclose(im.sqrt().asnumpy(), np.sqrt(im_np)) assert np.allclose(ims.sqrt().asnumpy(), np.sqrt(ims_np)) -@pytest.mark.parametrize("parity,dtype", params) -def testImageTranspose(parity, dtype): - im_np, im = get_images(parity, dtype) - ims_np, ims = get_stacks(parity, dtype) +def testImageTranspose(get_images, get_stacks): + im_np, im = get_images + ims_np, ims = get_stacks # test method and abbreviation assert np.allclose(im.T.asnumpy(), np.transpose(im_np, (0, 2, 1))) assert np.allclose(im.transpose().asnumpy(), np.transpose(im_np, (0, 2, 1))) @@ -163,10 +176,9 @@ def testImageTranspose(parity, dtype): assert np.allclose(ims.transpose()[i], ims_np[i].T) -@pytest.mark.parametrize("parity,dtype", params) -def testImageFlip(parity, dtype): - im_np, im = get_images(parity, dtype) - ims_np, ims = get_stacks(parity, dtype) +def testImageFlip(get_images, get_stacks): + im_np, im = get_images + ims_np, ims = get_stacks for axis in powerset(range(1, 3)): if not axis: # test default @@ -188,31 +200,31 @@ def testImageFlip(parity, dtype): _ = im.flip(axis) -def testShape(): - ims_np, ims = get_stacks() +def testShape(get_stacks): + ims_np, ims = get_stacks assert ims.shape == ims_np.shape assert ims.stack_shape == ims_np.shape[:-2] assert ims.stack_ndim == 1 -def testMultiDimShape(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimShape(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images assert mdim_ims.shape == mdim_ims_np.shape assert mdim_ims.stack_shape == mdim_ims_np.shape[:-2] assert mdim_ims.stack_ndim == mdim assert mdim_ims.n_images == mdim * ims.n_images -def testBadKey(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testBadKey(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images with pytest.raises(ValueError, match="slice length must be"): _ = mdim_ims[tuple(range(mdim_ims.ndim + 1))] -def testMultiDimGets(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimGets(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images for X in mdim_ims: assert np.allclose(ims_np, X) @@ -220,9 +232,9 @@ def testMultiDimGets(): assert np.allclose(mdim_ims[:, 1:], ims[1:]) -def testMultiDimSets(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimSets(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images mdim_ims[0, 1] = 123 # Check the values changed assert np.allclose(mdim_ims[0, 1], 123) @@ -232,9 +244,9 @@ def testMultiDimSets(): assert np.allclose(mdim_ims[1, :], ims_np) -def testMultiDimSetsSlice(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimSetsSlice(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images # Test setting a slice mdim_ims[0, 1:] = 456 # Check the values changed @@ -244,9 +256,9 @@ def testMultiDimSetsSlice(): assert np.allclose(mdim_ims[1, :], ims_np) -def testMultiDimReshape(): +def testMultiDimReshape(get_mdim_images): # Try mdim reshape - mdim_ims_np, mdim_ims = get_mdim_images() + mdim_ims_np, mdim_ims = get_mdim_images X = mdim_ims.stack_reshape(*mdim_ims.stack_shape[::-1]) assert X.stack_shape == mdim_ims.stack_shape[::-1] # Compare with direct np.reshape of axes of ndarray @@ -254,22 +266,22 @@ def testMultiDimReshape(): assert np.allclose(X.asnumpy(), mdim_ims_np.reshape(shape)) -def testMultiDimFlattens(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimFlattens(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Try flattening X = mdim_ims.stack_reshape(mdim_ims.n_images) assert X.stack_shape, (mdim_ims.n_images,) -def testMultiDimFlattensTrick(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimFlattensTrick(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Try flattening with -1 X = mdim_ims.stack_reshape(-1) assert X.stack_shape == (mdim_ims.n_images,) -def testMultiDimReshapeTuples(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimReshapeTuples(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Try flattening with (-1,) X = mdim_ims.stack_reshape((-1,)) assert X.stack_shape, (mdim_ims.n_images,) @@ -279,8 +291,8 @@ def testMultiDimReshapeTuples(): assert X.stack_shape == mdim_ims.stack_shape[::-1] -def testMultiDimBadReshape(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimBadReshape(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Incorrect flat shape with pytest.raises(ValueError, match="Number of images"): _ = mdim_ims.stack_reshape(8675309) @@ -290,11 +302,11 @@ def testMultiDimBadReshape(): _ = mdim_ims.stack_reshape(42, 8675309) -def testMultiDimBroadcast(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimBroadcast(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images X = mdim_ims + ims - assert np.allclose(X[0], 2 * ims.asnumpy()) + np.testing.assert_allclose(X[0], 2 * ims.asnumpy()) @matplotlib_dry_run @@ -306,12 +318,12 @@ def testShow(): im.show() -def test_backproject_symmetry_group(): +def test_backproject_symmetry_group(dtype): """ Test backproject SymmetryGroup pass through and error message. """ ary = np.random.random((5, 8, 8)) - im = Image(ary) + im = Image(ary, dtype=dtype) rots = Rotation.generate_random_rotations(5).matrices # Attempt backproject with bad symmetry group. @@ -324,9 +336,7 @@ def test_backproject_symmetry_group(): assert isinstance(vol.symmetry_group, CnSymmetryGroup) # Symmetry from instance. - vol = im.backproject( - rots, symmetry_group=CnSymmetryGroup(order=3, dtype=np.float32) - ) + vol = im.backproject(rots, symmetry_group=CnSymmetryGroup(order=3, dtype=dtype)) assert isinstance(vol.symmetry_group, CnSymmetryGroup) @@ -381,7 +391,7 @@ def test_load_bad_ext(): _ = Image.load("bad.ext") -def test_load_mrc(): +def test_load_mrc(dtype): """ Test `Image.load` round-trip. """ @@ -390,27 +400,19 @@ def test_load_mrc(): filepath = os.path.join(DATA_DIR, "sample.mrc") # Load data from file - im = Image.load(filepath) - im_64 = Image.load(filepath, dtype=np.float64) + im = Image.load(filepath, dtype=dtype) with tempfile.TemporaryDirectory() as tmpdir_name: # tmp filename test_filepath = os.path.join(tmpdir_name, "test.mrc") - test_filepath_64 = os.path.join(tmpdir_name, "test_64.mrc") im.save(test_filepath) - im_64.save(test_filepath_64) - im2 = Image.load(test_filepath) - im2_64 = Image.load(test_filepath_64, dtype=np.float64) + im2 = Image.load(test_filepath, dtype) # Check the single precision round-trip assert np.array_equal(im, im2) - assert im2.dtype == np.float32 - - # check the double precision round-trip - assert np.array_equal(im_64, im2_64) - assert im2_64.dtype == np.float64 + assert im2.dtype == dtype def test_load_tiff(): @@ -436,3 +438,30 @@ def test_load_tiff(): # Check contents assert np.array_equal(im, im2) + + +def test_save_load_pixel_size(get_images, dtype): + """ + Test saving and loading an MRC with pixel size attribute + """ + + im_np, im = get_images + + with tempfile.TemporaryDirectory() as tmpdir_name: + # tmp filename + test_filepath = os.path.join(tmpdir_name, "test.mrc") + + # Save image to file + im.save(test_filepath) + + # Load image from file + im2 = Image.load(test_filepath, dtype) + + # Check we've loaded the image data + np.testing.assert_allclose(im2, im) + # Check we've loaded the image dtype + assert im2.dtype == im.dtype, "Image dtype mismatched on save-load" + # Check we've loaded the pixel size + np.testing.assert_almost_equal( + im2.pixel_size, im.pixel_size, err_msg="Image pixel_size incorrect save-load" + ) diff --git a/tests/test_mean_estimator_boosting.py b/tests/test_mean_estimator_boosting.py index 9251dee09e..6eac159115 100644 --- a/tests/test_mean_estimator_boosting.py +++ b/tests/test_mean_estimator_boosting.py @@ -122,7 +122,7 @@ def weighted_source(weighted_volume): def test_fsc(source, estimated_volume): """Compare estimated volume to source volume with FSC.""" # Fourier Shell Correlation - fsc_resolution, fsc = source.vols.fsc(estimated_volume, pixel_size=1, cutoff=0.5) + fsc_resolution, fsc = source.vols.fsc(estimated_volume, cutoff=0.5) # Check that resolution is less than 2.1 pixels. np.testing.assert_array_less(fsc_resolution, 2.1) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index ad8a7ff4e1..92a29e225e 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -21,15 +21,19 @@ class SingleSimTestCase(TestCase): """Test we can construct a length 1 Sim.""" def setUp(self): - self.sim = Simulation( - n=1, - L=8, - ) + self._pixel_size = 1.23 # Test value + + self.sim = Simulation(n=1, L=8, pixel_size=self._pixel_size) def testImage(self): """Test we can get an Image from a length 1 Sim.""" _ = self.sim.images[0] + def testPixelSize(self): + """Test pixel_size is passing through Simulation.""" + self.assertTrue(self.sim.pixel_size == self._pixel_size) + self.assertTrue(self.sim.pixel_size == self.sim.vols.pixel_size) + @matplotlib_dry_run def testImageShow(self): self.sim.images[:].show() @@ -106,9 +110,12 @@ def setUp(self): self.n = 1024 self.L = 8 self.dtype = np.float32 + # Set legacy pixel_size + self._pixel_size = 10 self.vols = LegacyVolume( L=self.L, + pixel_size=self._pixel_size, dtype=self.dtype, ).generate() @@ -117,7 +124,8 @@ def setUp(self): L=self.L, vols=self.vols, unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + RadialCTFFilter(pixel_size=self._pixel_size, defocus=d) + for d in np.linspace(1.5e4, 2.5e4, 7) ], noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, @@ -168,7 +176,9 @@ def testSimulationCached(self): vols=self.vols, offsets=self.sim.offsets, unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + # Set legacy pixel size + RadialCTFFilter(pixel_size=self._pixel_size, defocus=d) + for d in np.linspace(1.5e4, 2.5e4, 7) ], noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, @@ -659,3 +669,15 @@ def test_cached_image_accessors(): np.testing.assert_allclose(cached_src.projections[:], src.projections[:]) np.testing.assert_allclose(cached_src.images[:], src.images[:]) np.testing.assert_allclose(cached_src.clean_images[:], src.clean_images[:]) + + +def test_mismatched_pixel_size(): + """ + Confirm raises error when explicit Simulation and CTFFilter pixel sizes mismatch. + """ + # Create a CTF with a pixel_size + filts = [RadialCTFFilter(pixel_size=5)] + + # Try to create a Simulation with a different pixel_size + with raises(ValueError, match=r"pixel_size.*does not match filter.*"): + _ = Simulation(L=8, n=1, C=1, pixel_size=10, unique_filters=filts) diff --git a/tests/test_synthetic_volume.py b/tests/test_synthetic_volume.py index ddcdcbcab5..fec7591764 100644 --- a/tests/test_synthetic_volume.py +++ b/tests/test_synthetic_volume.py @@ -20,6 +20,9 @@ # dtype fixture to pass into volume fixture. DTYPES = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +# Pixel sized used to test assignment +PXSZ = 3.0 + @pytest.fixture(params=DTYPES) def dtype_fixture(request): @@ -85,6 +88,10 @@ def vol_fixture(request, dtype_fixture): if len(params) > 2: vol_kwargs["order"] = params[2] + # Assign some volumes a pixel_size, leave others as default. + if res % 2: + vol_kwargs["pixel_size"] = PXSZ + return vol_class(**vol_kwargs) @@ -96,8 +103,15 @@ def test_volume_repr(vol_fixture): def test_volume_generate(vol_fixture): - """Test that a volume is generated""" - _ = vol_fixture.generate() + """ + Test that a volume is generated + and stores pixel_size when provided. + """ + v = vol_fixture.generate() + + # In vol_fixture, we assign pixel_size to volumes having odd voxel sizes. + if vol_fixture.L % 2: + np.testing.assert_approx_equal(v.pixel_size, PXSZ) def test_simulation_init(vol_fixture): diff --git a/tests/test_volume.py b/tests/test_volume.py index 991cc3288e..ac86c4096b 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -30,6 +30,7 @@ def res_id(params): RES = [42, 43] +TEST_PX_SZ = 4.56 @pytest.fixture(params=RES, ids=res_id, scope="module") @@ -75,7 +76,7 @@ def vols_1(data_1): @pytest.fixture def vols_2(data_2): - return Volume(data_2) + return Volume(data_2, pixel_size=TEST_PX_SZ) @pytest.fixture @@ -291,6 +292,39 @@ def test_save_load(vols_1): assert np.allclose(vols_1, vols_loaded_single) assert isinstance(vols_loaded_double, Volume) assert np.allclose(vols_1, vols_loaded_double) + assert vols_loaded_single.pixel_size is None, "Pixel size should be None" + assert vols_loaded_double.pixel_size is None, "Pixel size should be None" + + +def test_volume_pixel_size(vols_2): + """ + Test volume is storing pixel_size attribute. + """ + assert np.isclose(TEST_PX_SZ, vols_2.pixel_size), "Incorrect Volume pixel_size" + + +def test_save_load_pixel_size(vols_2): + # Create a tmpdir in a context. It will be cleaned up on exit. + with tempfile.TemporaryDirectory() as tmpdir: + # Save the Volume object into an MRC files + mrcs_filepath = os.path.join(tmpdir, "test.mrc") + vols_2.save(mrcs_filepath) + + # Load saved MRC file as a Volume of dtypes single and double. + vols_loaded_single = Volume.load(mrcs_filepath, dtype=np.float32) + vols_loaded_double = Volume.load(mrcs_filepath, dtype=np.float64) + + # Confirm the pixel size is loaded + np.testing.assert_approx_equal( + vols_loaded_single.pixel_size, + vols_2.pixel_size, + err_msg="Incorrect pixel size in singles.", + ) + np.testing.assert_approx_equal( + vols_loaded_double.pixel_size, + vols_2.pixel_size, + err_msg="Incorrect pixel size in doubles.", + ) def test_project(vols_hot_cold): @@ -545,11 +579,20 @@ def test_flip(vols_1, data_1): def test_downsample(res): - vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy"))) + vols = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")), pixel_size=1.23 + ) result = vols.downsample(res) og_res = vols.resolution ds_res = result.resolution + # Confirm the pixel size is scaled + np.testing.assert_approx_equal( + result.pixel_size, + vols.pixel_size * og_res / ds_res, + err_msg="Incorrect pixel size.", + ) + # check signal energy np.testing.assert_allclose( anorm(vols.asnumpy(), axes=(1, 2, 3)) / og_res,