diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 02b8ab9363..99907cbf6a 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -34,6 +34,7 @@ def __init__( max_rows=None, symmetry_group=None, memory=None, + dtype=None, ): """ Load STAR file at given filepath @@ -50,6 +51,9 @@ def __init__( :param symmetry_group: A `SymmetryGroup` object or string corresponding to the symmetry of the molecule. :param memory: str or None The path of the base directory to use as a data store or None. If None is given, no caching is performed. + :param dtype: Optional datatype override. + Default `None` infers dtype from underlying image (MRC) files. + Can be used to upcast STAR files for processing in double precision. """ logger.info(f"Creating ImageSource from STAR file at path {filepath}") @@ -72,12 +76,19 @@ def __init__( # Get the 'mode' (data type) - TODO: There's probably a more direct way to do this. mode = int(mrc.header.mode) - dtypes = {0: "int8", 1: "int16", 2: "float32", 6: "uint16"} + mrc_dtypes = {0: "int8", 1: "int16", 2: "float32", 6: "uint16"} assert ( - mode in dtypes - ), f"Only modes={list(dtypes.keys())} in MRC files are supported for now." + mode in mrc_dtypes + ), f"Only modes={list(mrc_dtypes.keys())} in MRC files are supported for now." - dtype = dtypes[mode] + mrc_dtype = mrc_dtypes[mode] + # Potentially over ride the inferred data type. + if dtype is not None and dtype != np.dtype(mrc_dtype): + logger.warning( + f"Overriding MRC datatype {mrc_dtype} with user supplied {dtype}." + ) + elif dtype is None: + dtype = mrc_dtype shape = mrc.data.shape # the code below accounts for the case where the first MRCS image in the STAR file has one image diff --git a/tests/test_starfile_stack.py b/tests/test_starfile_stack.py index 886bceb8f1..c7729fb18f 100644 --- a/tests/test_starfile_stack.py +++ b/tests/test_starfile_stack.py @@ -14,10 +14,15 @@ class StarFileTestCase(TestCase): + # Default dtype (inferred) + _dtype = None + def setUpStarFile(self, starfile_name): # set up RelionSource object for tests with importlib_path(tests.saved_test_data, starfile_name) as starfile: - self.src = RelionSource(starfile, data_folder=DATA_DIR, max_rows=12) + self.src = RelionSource( + starfile, data_folder=DATA_DIR, max_rows=12, dtype=self._dtype + ) def setUp(self): # this method is used by StarFileMainCase but overridden by StarFileOneImage @@ -94,3 +99,30 @@ def testMRCSWithOneParticle(self): # where there is only one image in the mrcs single_image = self.src.images[0].asnumpy()[0] self.assertEqual(single_image.shape, (200, 200)) + + +class StarFileDtypeOverrideCase64(StarFileMainCase): + # Override RelionSource dtype + _dtype = np.float64 + + def testSourceDtype(self): + """Test source identifies as _dtype.""" + self.assertTrue(self.src.dtype == self._dtype) + + def testImageDtype(self): + """Test image returned as _dtype.""" + self.assertTrue(self.src.images[0].dtype == self._dtype) + + def testRotationsDtype(self): + """Test image returned as _dtype.""" + self.assertTrue(self.src.rotations.dtype == self._dtype) + + def testImageDownsampleDtype(self): + """Test downsample pipeline operation returns _dtype.""" + _src = self.src.downsample(16) + self.assertTrue(_src.images[0].dtype == self._dtype) + + +class StarFileDtypeOverrideCase32(StarFileMainCase): + # Override RelionSource dtype + _dtype = np.float32