Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1ed4de6
add pixel_size attribute to Volume class
garrettwrong Jul 18, 2024
f4ff4db
volume load pixel size
garrettwrong Jul 18, 2024
aa93106
volume load pixel size
garrettwrong Jul 18, 2024
f0ad627
volume load pixel size
garrettwrong Jul 18, 2024
5195b2d
rm old pixel_size attr from FSC calls
garrettwrong Jul 18, 2024
3b2a6a6
Add Image.voxel_size handling
garrettwrong Jul 18, 2024
78adb4a
cleanup image tests round 1
garrettwrong Jul 22, 2024
6a75a5f
add image save-load test
garrettwrong Jul 22, 2024
98d72d9
add Image.downsample pixel_size test
garrettwrong Jul 22, 2024
da4befe
Minimally add pixel_size to sources
garrettwrong Jul 22, 2024
3d86861
tox caught incorrect var
garrettwrong Jul 22, 2024
60d6b26
lint
garrettwrong Jul 26, 2024
8fbdc86
Change default CTFFilter pixel_size from 10 to 1
garrettwrong Jul 26, 2024
f3620a9
add pixel_size synth volume classes
garrettwrong Jul 29, 2024
1e9543f
self review cleanup
garrettwrong Jul 29, 2024
dc60bed
add units to error messages
garrettwrong Aug 1, 2024
112f4ac
add units to error messages
garrettwrong Aug 1, 2024
794b57a
add pixel_size to image repr
garrettwrong Aug 1, 2024
525b407
correct synth vol doc strings
garrettwrong Aug 1, 2024
96e271f
add direct volume pixel_size attr test
garrettwrong Aug 1, 2024
5f01347
Use ValueError for px sz mismatch
garrettwrong Aug 26, 2024
bfd4476
Test px sz assignment with assert_approx_equal
garrettwrong Aug 26, 2024
76ac0b8
merge conflict lint
garrettwrong Aug 26, 2024
393b089
merge conflict test
garrettwrong Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gallery/tutorials/aspire_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def noise_function(x, y):

# Generate several CTFs.
ctf_filters = [
RadialCTFFilter(pixel_size=5, defocus=d)
RadialCTFFilter(pixel_size=vol_ds.pixel_size, defocus=d)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]

Expand Down
2 changes: 1 addition & 1 deletion gallery/tutorials/pipeline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
defocus_ct = 7

ctf_filters = [
RadialCTFFilter(pixel_size=5, defocus=d)
RadialCTFFilter(pixel_size=vol.pixel_size, defocus=d)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]

Expand Down
4 changes: 3 additions & 1 deletion gallery/tutorials/tutorials/cov3d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)],
unique_filters=[
RadialCTFFilter(pixel_size=10, defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)
],
dtype=dtype,
)

Expand Down
121 changes: 94 additions & 27 deletions src/aspire/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def load_mrc(filepath):
Load raw data from `.mrc` into an array.

:param filepath: File path (string).
:return: numpy array of image data.
:return: (numpy array of image data, pixel_size)
"""

# mrcfile tends to yield many warnings about EMPIAR datasets being corrupt
Expand All @@ -92,6 +92,7 @@ def load_mrc(filepath):

with mrcfile.open(filepath, mode="r", permissive=True) as mrc:
im = mrc.data
pixel_size = Image._vx_array_to_size(mrc.voxel_size)

# Log each mrcfile warning to debug log, noting the associated file
for w in ws:
Expand All @@ -110,19 +111,29 @@ def load_mrc(filepath):
f" Will attempt to continue processing {filepath}"
)

return im
return im, pixel_size


def load_tiff(filepath):
"""
Load raw data from `.tiff` into an array.

Note, TIFF does not natively provide equivalent to pixel/voxel_size,
so users of TIFF files may need to manually assign `pixel_size` to
`Image` instances when required. Defaults to `pixel_size=None`.

:param filepath: File path (string).
:return: numpy array of image data.
:return: (numpy array of image data, pixel_size=None)
"""
# Use PIL to open `filepath`
img = PILImage.open(filepath)

# Future todo, extract `voxel_size` if available in TIFF tags (custom tag?)
# For now, default to `None`.
pixel_size = None

# Use PIL to open `filepath` and cast to numpy array.
return np.array(PILImage.open(filepath))
# Cast image data as numpy array
return np.array(img), pixel_size


class Image:
Expand All @@ -133,7 +144,7 @@ class Image:
".tiff": load_tiff,
}

def __init__(self, data, dtype=None):
def __init__(self, data, pixel_size=None, dtype=None):
"""
A stack of one or more images.

Expand All @@ -149,6 +160,10 @@ def __init__(self, data, dtype=None):

:param data: Numpy array containing image data with shape
`(..., resolution, resolution)`.
:param pixel_size: Optional pixel size in angstroms.
When provided will be saved with `mrc` metadata.
Default of `None` will not write to file,
but will be considered unit pixels (1) for FSC.
:param dtype: Optionally cast `data` to this dtype.
Defaults to `data.dtype`.

Expand Down Expand Up @@ -180,6 +195,9 @@ def __init__(self, data, dtype=None):
self.stack_shape = self._data.shape[:-2]
self.n_images = np.prod(self.stack_shape)
self.resolution = self._data.shape[-1]
self.pixel_size = None
if pixel_size is not None:
self.pixel_size = float(pixel_size)

# Numpy interop
# https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol
Expand Down Expand Up @@ -238,7 +256,7 @@ def _check_key_dims(self, key):

def __getitem__(self, key):
self._check_key_dims(key)
return self.__class__(self._data[key])
return self.__class__(self._data[key], pixel_size=self.pixel_size)

def __setitem__(self, key, value):
self._check_key_dims(key)
Expand Down Expand Up @@ -266,31 +284,34 @@ def stack_reshape(self, *args):
f"Number of images {self.n_images} cannot be reshaped to {shape}."
)

return self.__class__(self._data.reshape(*shape, *self._data.shape[-2:]))
return self.__class__(
self._data.reshape(*shape, *self._data.shape[-2:]),
pixel_size=self.pixel_size,
)

def __add__(self, other):
if isinstance(other, Image):
other = other._data

return self.__class__(self._data + other)
return self.__class__(self._data + other, pixel_size=self.pixel_size)

def __sub__(self, other):
if isinstance(other, Image):
other = other._data

return self.__class__(self._data - other)
return self.__class__(self._data - other, pixel_size=self.pixel_size)

def __mul__(self, other):
if isinstance(other, Image):
other = other._data

return self.__class__(self._data * other)
return self.__class__(self._data * other, pixel_size=self.pixel_size)

def __neg__(self):
return self.__class__(-self._data)
return self.__class__(-self._data, pixel_size=self.pixel_size)

def sqrt(self):
return self.__class__(np.sqrt(self._data))
return self.__class__(np.sqrt(self._data), pixel_size=self.pixel_size)

@property
def T(self):
Expand All @@ -312,7 +333,9 @@ def transpose(self):
im = self.stack_reshape(-1)
imt = np.transpose(im._data, (0, -1, -2))

return self.__class__(imt).stack_reshape(original_stack_shape)
return self.__class__(imt, pixel_size=self.pixel_size).stack_reshape(
original_stack_shape
)

def flip(self, axis=-2):
"""
Expand All @@ -335,11 +358,15 @@ def flip(self, axis=-2):
f"Cannot flip axis {ax}: stack axis. Did you mean {ax-3}?"
)

return self.__class__(np.flip(self._data, axis))
return self.__class__(np.flip(self._data, axis), pixel_size=self.pixel_size)

def __repr__(self):
px_msg = "."
if self.pixel_size is not None:
px_msg = f" with pixel_size={self.pixel_size} angstroms."

msg = f"{self.n_images} {self.dtype} images arranged as a {self.stack_shape} stack"
msg += f" each of size {self.resolution}x{self.resolution}."
msg += f" each of size {self.resolution}x{self.resolution}{px_msg}"
return msg

def asnumpy(self):
Expand All @@ -355,7 +382,7 @@ def asnumpy(self):
return view

def copy(self):
return self.__class__(self._data.copy())
return self.__class__(self._data.copy(), pixel_size=self.pixel_size)

def shift(self, shifts):
"""
Expand Down Expand Up @@ -412,7 +439,14 @@ def downsample(self, ds_res, zero_nyquist=True):
out = fft.centered_ifft2(crop_fx).real * (ds_res**2 / self.resolution**2)
out = xp.asnumpy(out)

return self.__class__(out).stack_reshape(original_stack_shape)
# Optionally scale pixel size
ds_pixel_size = self.pixel_size
if ds_pixel_size is not None:
ds_pixel_size *= self.resolution / ds_res

return self.__class__(out, pixel_size=ds_pixel_size).stack_reshape(
original_stack_shape
)

def filter(self, filter):
"""
Expand Down Expand Up @@ -441,7 +475,9 @@ def filter(self, filter):

im = xp.asnumpy(im.real)

return self.__class__(im).stack_reshape(original_stack_shape)
return self.__class__(im, pixel_size=self.pixel_size).stack_reshape(
original_stack_shape
)

def rotate(self):
raise NotImplementedError
Expand All @@ -453,6 +489,9 @@ def save(self, mrcs_filepath, overwrite=False):
with mrcfile.new(mrcs_filepath, overwrite=overwrite) as mrc:
# original input format (the image index first)
mrc.set_data(self._data.astype(np.float32))
# Note assigning voxel_size must come after `set_data`
if self.pixel_size is not None:
mrc.voxel_size = self.pixel_size

@staticmethod
def load(filepath, dtype=None):
Expand All @@ -477,14 +516,14 @@ def load(filepath, dtype=None):
)

# Call the appropriate file reader
im = Image.extensions[ext](filepath)
im, pixel_size = Image.extensions[ext](filepath)

# Attempt casting when user provides dtype
if dtype is not None:
im = im.astype(dtype, copy=False)

# Return as Image instance
return Image(im)
return Image(im, pixel_size=pixel_size)

def _im_translate(self, shifts):
"""
Expand Down Expand Up @@ -535,7 +574,9 @@ def _im_translate(self, shifts):
im_translated = xp.asnumpy(im_translated.real)

# Reshape to stack shape
return self.__class__(im_translated).stack_reshape(stack_shape)
return self.__class__(im_translated, pixel_size=self.pixel_size).stack_reshape(
stack_shape
)

def norm(self):
return anorm(self._data)
Expand Down Expand Up @@ -602,7 +643,9 @@ def backproject(self, rot_matrices, symmetry_group=None, zero_nyquist=True):

vol /= L

return aspire.volume.Volume(vol, symmetry_group=symmetry_group)
return aspire.volume.Volume(
vol, pixel_size=self.pixel_size, symmetry_group=symmetry_group
)

def show(self, columns=5, figsize=(20, 10), colorbar=True):
"""
Expand Down Expand Up @@ -645,7 +688,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True):

plt.show()

def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False):
def frc(self, other, cutoff=None, method="fft", plot=False):
r"""
Compute the Fourier ring correlation between two images.

Expand All @@ -663,8 +706,6 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False):
Default `None` implies `cutoff=1` and excludes
plotting cutoff line.

:param pixel_size: Pixel size in angstrom. Default `None`
implies unit in pixels, equivalent to pixel_size=1.
:param method: Selects either 'fft' (on cartesian grid),
or 'nufft' (on polar grid). Defaults to 'fft'.
:param plot: Optionally plot to screen or file.
Expand All @@ -684,7 +725,7 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False):
frc = FourierRingCorrelation(
a=self.asnumpy(),
b=other.asnumpy(),
pixel_size=pixel_size,
pixel_size=self.pixel_size,
method=method,
)

Expand All @@ -695,6 +736,32 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False):

return frc.analyze_correlations(cutoff), frc.correlations

@staticmethod
def _vx_array_to_size(vx):
"""
Utility to convert from several possible `mrcfile.voxel_size`
representations to a single (float) value or None.
"""

# Convert from recarray to single values,
# checks uniformity.
if isinstance(vx, np.recarray):
if vx.x != vx.y:
raise ValueError(f"Voxel sizes are not uniform: {vx}")
vx = vx.x

# Convert `0` to `None`
if (
isinstance(vx, int) or isinstance(vx, float) or isinstance(vx, np.ndarray)
) and vx == 0:
vx = None

# Consistently return a `float` when not None
if vx is not None:
vx = float(vx)

return vx


class CartesianImage(Image):
def expand(self, basis):
Expand Down
8 changes: 4 additions & 4 deletions src/aspire/operators/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __init__(self, dim=None):
class CTFFilter(Filter):
def __init__(
self,
pixel_size=10,
pixel_size=1,
voltage=200,
defocus_u=15000,
defocus_v=15000,
Expand All @@ -415,7 +415,7 @@ def __init__(
"""
A CTF (Contrast Transfer Function) Filter

:param pixel_size: Pixel size in angstrom
:param pixel_size: Pixel size in angstrom, default 1.
:param voltage: Electron voltage in kV
:param defocus_u: Defocus depth along the u-axis in angstrom
:param defocus_v: Defocus depth along the v-axis in angstrom
Expand All @@ -425,7 +425,7 @@ def __init__(
:param B: Envelope decay in inverse square angstrom (default 0)
"""
super().__init__(dim=2, radial=defocus_u == defocus_v)
self.pixel_size = pixel_size
self.pixel_size = float(pixel_size)
self.voltage = voltage
self.wavelength = voltage_to_wavelength(self.voltage)
self.defocus_u = defocus_u
Expand Down Expand Up @@ -482,7 +482,7 @@ def scale(self, c=1):

class RadialCTFFilter(CTFFilter):
def __init__(
self, pixel_size=10, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0
self, pixel_size=1, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0
):
super().__init__(
pixel_size=pixel_size,
Expand Down
4 changes: 3 additions & 1 deletion src/aspire/source/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def _images(self, indices):
cropped = self._crop_micrograph(arr, next(coord))
im[i] = cropped
# Finally, apply transforms to resulting Image
return self.generation_pipeline.forward(Image(im), indices)
return self.generation_pipeline.forward(
Image(im, pixel_size=self.pixel_size), indices
)

@staticmethod
def _is_number(text):
Expand Down
Loading