Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/aspire/source/relion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
max_rows=None,
symmetry_group=None,
memory=None,
dtype=None,
):
"""
Load STAR file at given filepath
Expand All @@ -50,6 +51,9 @@ def __init__(
:param symmetry_group: A `SymmetryGroup` object or string corresponding to the symmetry of the molecule.
:param memory: str or None
The path of the base directory to use as a data store or None. If None is given, no caching is performed.
:param dtype: Optional datatype override.
Default `None` infers dtype from underlying image (MRC) files.
Can be used to upcast STAR files for processing in double precision.
"""
logger.info(f"Creating ImageSource from STAR file at path {filepath}")

Expand All @@ -72,12 +76,19 @@ def __init__(

# Get the 'mode' (data type) - TODO: There's probably a more direct way to do this.
mode = int(mrc.header.mode)
dtypes = {0: "int8", 1: "int16", 2: "float32", 6: "uint16"}
mrc_dtypes = {0: "int8", 1: "int16", 2: "float32", 6: "uint16"}
assert (
mode in dtypes
), f"Only modes={list(dtypes.keys())} in MRC files are supported for now."
mode in mrc_dtypes
), f"Only modes={list(mrc_dtypes.keys())} in MRC files are supported for now."

dtype = dtypes[mode]
mrc_dtype = mrc_dtypes[mode]
# Potentially over ride the inferred data type.
if dtype is not None and dtype != np.dtype(mrc_dtype):
logger.warning(
f"Overriding MRC datatype {mrc_dtype} with user supplied {dtype}."
)
elif dtype is None:
dtype = mrc_dtype

shape = mrc.data.shape
# the code below accounts for the case where the first MRCS image in the STAR file has one image
Expand Down
34 changes: 33 additions & 1 deletion tests/test_starfile_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@


class StarFileTestCase(TestCase):
# Default dtype (inferred)
_dtype = None

def setUpStarFile(self, starfile_name):
# set up RelionSource object for tests
with importlib_path(tests.saved_test_data, starfile_name) as starfile:
self.src = RelionSource(starfile, data_folder=DATA_DIR, max_rows=12)
self.src = RelionSource(
starfile, data_folder=DATA_DIR, max_rows=12, dtype=self._dtype
)

def setUp(self):
# this method is used by StarFileMainCase but overridden by StarFileOneImage
Expand Down Expand Up @@ -94,3 +99,30 @@ def testMRCSWithOneParticle(self):
# where there is only one image in the mrcs
single_image = self.src.images[0].asnumpy()[0]
self.assertEqual(single_image.shape, (200, 200))


class StarFileDtypeOverrideCase64(StarFileMainCase):
# Override RelionSource dtype
_dtype = np.float64

def testSourceDtype(self):
"""Test source identifies as _dtype."""
self.assertTrue(self.src.dtype == self._dtype)

def testImageDtype(self):
"""Test image returned as _dtype."""
self.assertTrue(self.src.images[0].dtype == self._dtype)

def testRotationsDtype(self):
"""Test image returned as _dtype."""
self.assertTrue(self.src.rotations.dtype == self._dtype)

def testImageDownsampleDtype(self):
"""Test downsample pipeline operation returns _dtype."""
_src = self.src.downsample(16)
self.assertTrue(_src.images[0].dtype == self._dtype)


class StarFileDtypeOverrideCase32(StarFileMainCase):
# Override RelionSource dtype
_dtype = np.float32