From 80f17f550a3857a7396434e3b621bd011cfc95a3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 30 Jan 2024 14:08:01 -0500 Subject: [PATCH 1/3] Add ability to overide RelionSource dtypes --- src/aspire/source/relion.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 02b8ab9363..b370f4cfe6 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -34,6 +34,7 @@ def __init__( max_rows=None, symmetry_group=None, memory=None, + dtype=None, ): """ Load STAR file at given filepath @@ -50,6 +51,9 @@ def __init__( :param symmetry_group: A `SymmetryGroup` object or string corresponding to the symmetry of the molecule. :param memory: str or None The path of the base directory to use as a data store or None. If None is given, no caching is performed. + :param dtype: Optional datatype override. + Default `None` infers dtype from underlying image (MRC) files. + Can be used to upcast STAR files for processing in double precision. """ logger.info(f"Creating ImageSource from STAR file at path {filepath}") @@ -72,12 +76,19 @@ def __init__( # Get the 'mode' (data type) - TODO: There's probably a more direct way to do this. mode = int(mrc.header.mode) - dtypes = {0: "int8", 1: "int16", 2: "float32", 6: "uint16"} + mrc_dtypes = {0: "int8", 1: "int16", 2: "float32", 6: "uint16"} assert ( - mode in dtypes - ), f"Only modes={list(dtypes.keys())} in MRC files are supported for now." + mode in mrc_dtypes + ), f"Only modes={list(mrc_dtypes.keys())} in MRC files are supported for now." - dtype = dtypes[mode] + mrc_dtype = mrc_dtypes[mode] + # Potentially over ride the inferred data type. + if dtype is not None and dtype != mrc_dtype: + logger.warning( + f"Overriding MRC datatype {mrc_dtype} with user supplied {dtype}." + ) + elif dtype is None: + dtype = mrc_dtype shape = mrc.data.shape # the code below accounts for the case where the first MRCS image in the STAR file has one image From 6a9e5ff493c2c07eb125f2bab7a962e62ddd8838 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 30 Jan 2024 15:22:21 -0500 Subject: [PATCH 2/3] Test some RelionSource features are returning override dtype --- tests/test_starfile_stack.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/test_starfile_stack.py b/tests/test_starfile_stack.py index 886bceb8f1..c7729fb18f 100644 --- a/tests/test_starfile_stack.py +++ b/tests/test_starfile_stack.py @@ -14,10 +14,15 @@ class StarFileTestCase(TestCase): + # Default dtype (inferred) + _dtype = None + def setUpStarFile(self, starfile_name): # set up RelionSource object for tests with importlib_path(tests.saved_test_data, starfile_name) as starfile: - self.src = RelionSource(starfile, data_folder=DATA_DIR, max_rows=12) + self.src = RelionSource( + starfile, data_folder=DATA_DIR, max_rows=12, dtype=self._dtype + ) def setUp(self): # this method is used by StarFileMainCase but overridden by StarFileOneImage @@ -94,3 +99,30 @@ def testMRCSWithOneParticle(self): # where there is only one image in the mrcs single_image = self.src.images[0].asnumpy()[0] self.assertEqual(single_image.shape, (200, 200)) + + +class StarFileDtypeOverrideCase64(StarFileMainCase): + # Override RelionSource dtype + _dtype = np.float64 + + def testSourceDtype(self): + """Test source identifies as _dtype.""" + self.assertTrue(self.src.dtype == self._dtype) + + def testImageDtype(self): + """Test image returned as _dtype.""" + self.assertTrue(self.src.images[0].dtype == self._dtype) + + def testRotationsDtype(self): + """Test image returned as _dtype.""" + self.assertTrue(self.src.rotations.dtype == self._dtype) + + def testImageDownsampleDtype(self): + """Test downsample pipeline operation returns _dtype.""" + _src = self.src.downsample(16) + self.assertTrue(_src.images[0].dtype == self._dtype) + + +class StarFileDtypeOverrideCase32(StarFileMainCase): + # Override RelionSource dtype + _dtype = np.float32 From 20fa66d85a6aba4b30d11c4aafe93a1010b0c83e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 7 Feb 2024 11:32:09 -0500 Subject: [PATCH 3/3] Cast mrc dtype strings --- src/aspire/source/relion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index b370f4cfe6..99907cbf6a 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -83,7 +83,7 @@ def __init__( mrc_dtype = mrc_dtypes[mode] # Potentially over ride the inferred data type. - if dtype is not None and dtype != mrc_dtype: + if dtype is not None and dtype != np.dtype(mrc_dtype): logger.warning( f"Overriding MRC datatype {mrc_dtype} with user supplied {dtype}." )