diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 57b32041cd..757362870d 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -12,7 +12,13 @@ import aspire.volume from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp -from aspire.utils import FourierRingCorrelation, anorm, crop_pad_2d, grid_2d +from aspire.utils import ( + FourierRingCorrelation, + anorm, + crop_pad_2d, + grid_2d, + rename_with_timestamp, +) from aspire.volume import SymmetryGroup logger = logging.getLogger(__name__) @@ -484,10 +490,25 @@ def filter(self, filter): def rotate(self): raise NotImplementedError - def save(self, mrcs_filepath, overwrite=False): + def save(self, mrcs_filepath, overwrite=None): + """ + Save Image to disk as mrcs file + + :param filename: Filepath where Image will be saved. + :param overwrite: Options to control overwrite behavior (default is None): + - True: Overwrites the existing file if it exists. + - False: Raises an error if the file exists. + - None: Renames the old file by appending a time/date stamp. + """ if self.stack_ndim > 1: raise NotImplementedError("`save` is currently limited to 1D image stacks.") + if overwrite is None and os.path.exists(mrcs_filepath): + # If the file exists, append a timestamp to the old file and rename it + _ = rename_with_timestamp(mrcs_filepath) + elif overwrite is None: + overwrite = False + with mrcfile.new(mrcs_filepath, overwrite=overwrite) as mrc: # original input format (the image index first) mrc.set_data(self._data.astype(np.float32)) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 84cf26e1f9..0287e9875c 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -28,7 +28,7 @@ PowerFilter, ) from aspire.storage import MrcStats, StarFile -from aspire.utils import Rotation, grid_2d, support_mask, trange +from aspire.utils import Rotation, grid_2d, rename_with_timestamp, support_mask, trange from aspire.volume import IdentitySymmetryGroup, SymmetryGroup logger = logging.getLogger(__name__) @@ -976,7 +976,7 @@ def save( starfile_filepath, batch_size=512, save_mode=None, - overwrite=False, + overwrite=None, ): """ Save the output metadata to STAR file and/or images to MRCS file. @@ -988,10 +988,29 @@ def save( while `batch_size>=1` implies stack MRC extension `.mrcs`. :param save_mode: Whether to save all images in a `single` or multiple files in batch size. Default is multiple, supply `'single'` for single mode. - :param overwrite: Option to overwrite the output MRC files. + :param overwrite: Options to control overwrite behavior (default is None): + - True: Overwrites the existing file if it exists. + - False: Raises an error if the file exists. + - None: Renames the old file by appending a time/date stamp. :return: A dictionary containing "starfile"--the path to the saved starfile-- and "mrcs", a list of the saved particle stack MRC filenames. """ + if overwrite is None and os.path.exists(starfile_filepath): + # If the file exists, append the timestamp to the old file and rename it + renamed_filepath = rename_with_timestamp(starfile_filepath, move=False) + + # Retrieve original ImageSource and save with new starfile name. + from aspire.source import RelionSource + + src = RelionSource(starfile_filepath) + src.save(renamed_filepath, overwrite=False) + + # Allow overwriting old files. + overwrite = True + + elif overwrite is None: + overwrite = False + logger.info("save metadata into STAR file") filename_indices = self.save_metadata( starfile_filepath, diff --git a/src/aspire/source/micrograph.py b/src/aspire/source/micrograph.py index 2d654401b5..5aa35e389c 100644 --- a/src/aspire/source/micrograph.py +++ b/src/aspire/source/micrograph.py @@ -10,7 +10,7 @@ from aspire.source import Simulation from aspire.source.image import _ImageAccessor from aspire.storage import StarFile -from aspire.utils import Random, grid_2d +from aspire.utils import Random, grid_2d, rename_with_timestamp from aspire.volume import Volume logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def __len__(self): """ return self.micrograph_count - def save(self, path, name_prefix="micrograph", overwrite=True): + def save(self, path, name_prefix="micrograph", overwrite=None): """ Save micrographs to `path`. @@ -54,11 +54,18 @@ def save(self, path, name_prefix="micrograph", overwrite=True): :param path: Directory to save data. :param name_prefix: Optional, name prefix string for micrograph files. - :param overwrite: Optional, bool. Allow writing to existing directory, - and overwriting existing files. + :param overwrite: Options to control overwrite behavior (default is None): + - True: Overwrites the existing path if it exists. + - False: Raises an error if the path exists. + - None: Renames the old path by appending a time/date stamp. :return: List of saved `.mrc` files. """ + if overwrite is None and os.path.exists(path): + # If the directory exists, append a timestamp to existing directory. + _ = rename_with_timestamp(path) + overwrite = True + # Make dir if does not exist. Path(path).mkdir(parents=True, exist_ok=overwrite) diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index e691e0ba5e..c4b46bcdec 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -28,6 +28,7 @@ inverse_r, J_conjugate, powerset, + rename_with_timestamp, sha256sum, support_mask, fuzzy_mask, diff --git a/src/aspire/utils/misc.py b/src/aspire/utils/misc.py index d0c30b9f90..a3f9917024 100644 --- a/src/aspire/utils/misc.py +++ b/src/aspire/utils/misc.py @@ -5,7 +5,9 @@ import hashlib import importlib.resources import logging +import os import sys +from datetime import datetime from itertools import chain, combinations import numpy as np @@ -48,6 +50,31 @@ def importlib_path(package, resource): return p +def rename_with_timestamp(filepath, move=True): + """ + Rename a file by appending a timestamp to the end of the filename. + + :param filepath: Filepath to rename. + :param move: Option to rename the file on disk. + + :return: filepath with timestamp appended. + """ + base, ext = os.path.splitext(filepath) + timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + renamed_filepath = f"{base}_{timestamp}{ext}" + logger.info(f"Renaming {filepath} as {renamed_filepath}.") + + # Rename the existing file by appending the timestamp. + if move: + try: + os.rename(filepath, renamed_filepath) + except FileNotFoundError: + logger.warning(f"File '{filepath}' not found, could not rename.") + return None + + return renamed_filepath + + def abs2(x): """ Compute complex modulus squared. diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 0ed31b8b74..eae985a794 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -1,4 +1,5 @@ import logging +import os import warnings import mrcfile @@ -16,6 +17,7 @@ grid_2d, grid_3d, mat_to_vec, + rename_with_timestamp, vec_to_mat, ) from aspire.volume import IdentitySymmetryGroup, SymmetryGroup @@ -635,20 +637,27 @@ def rotate(self, rot_matrices, zero_nyquist=True): def denoise(self): raise NotImplementedError - def save(self, filename, overwrite=False): + def save(self, filename, overwrite=None): """ Save volume to disk as mrc file - :param filename: Filepath where volume will be saved - - :param overwrite: Option to overwrite file when set to True. - Defaults to overwrite=False. + :param filename: Filepath where volume will be saved. + :param overwrite: Options to control overwrite behavior (default is None): + - True: Overwrites the existing file if it exists. + - False: Raises an error if the file exists. + - None: Renames the old file by appending a time/date stamp. """ if self.stack_ndim > 1: raise NotImplementedError( "`save` is currently limited to 1D Volume stacks." ) + if overwrite is None and os.path.exists(filename): + # If the file exists, append a timestamp to the old file and rename it + _ = rename_with_timestamp(filename) + elif overwrite is None: + overwrite = False + with mrcfile.new(filename, overwrite=overwrite) as mrc: mrc.set_data(self._data.astype(np.float32)) # Note assigning voxel_size must come after `set_data` diff --git a/tests/test_image.py b/tests/test_image.py index 89fbde4a84..593f781e7d 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,6 +1,8 @@ import logging import os.path import tempfile +from datetime import datetime +from unittest import mock import mrcfile import numpy as np @@ -353,6 +355,69 @@ def test_asnumpy_readonly(): vw[0, 0, 0] = 123 +def test_save_overwrite(caplog): + """ + Test that the overwrite flag behaves as expected. + - overwrite=True: Overwrites the existing file. + - overwrite=False: Raises an error if the file exists. + - overwrite=None: Renames the existing file and saves the new one. + """ + im1 = Image(np.ones((1, 8, 8), dtype=np.float32)) + im2 = Image(2 * np.ones((1, 8, 8), dtype=np.float32)) + im3 = Image(3 * np.ones((1, 8, 8), dtype=np.float32)) + + # Create a tmp dir for this test output + with tempfile.TemporaryDirectory() as tmpdir_name: + # tmp filename + mrc_path = os.path.join(tmpdir_name, "og.mrc") + base, ext = os.path.splitext(mrc_path) + + # Create and save the first image + im1.save(mrc_path, overwrite=True) + + # Case 1: overwrite=True (should overwrite the existing file) + im2.save(mrc_path, overwrite=True) + + # Load and check if im2 has overwritten im1 + im2_loaded = Image.load(mrc_path) + np.testing.assert_allclose(im2.asnumpy(), im2_loaded.asnumpy()) + + # Case 2: overwrite=False (should raise an overwrite error) + with pytest.raises( + ValueError, + match="File '.*' already exists; set overwrite=True to overwrite it", + ): + im3.save(mrc_path, overwrite=False) + + # Case 3: overwrite=None (should rename the existing file and save im3 with original filename) + # Mock datetime to return a fixed timestamp. + mock_datetime_value = datetime(2024, 10, 18, 12, 0, 0) + with mock.patch("aspire.utils.misc.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_datetime_value + mock_datetime.strftime = datetime.strftime + + with caplog.at_level(logging.INFO): + im3.save(mrc_path, overwrite=None) + + # Check that the existing file was renamed and logged + assert f"Renaming {mrc_path}" in caplog.text + + # Construct the expected renamed filename using the mock timestamp + mock_timestamp = mock_datetime_value.strftime("%y%m%d_%H%M%S") + renamed_file = f"{base}_{mock_timestamp}{ext}" + + # Assert that the renamed file exists + assert os.path.exists(renamed_file), "Renamed file not found" + + # Load and check that im3 was saved to the original path + im3_loaded = Image.load(mrc_path) + np.testing.assert_allclose(im3.asnumpy(), im3_loaded.asnumpy()) + + # Also check that the renamed file still contains im2's data + im2_loaded_renamed = Image.load(renamed_file) + np.testing.assert_allclose(im2.asnumpy(), im2_loaded_renamed.asnumpy()) + + def test_corrupt_mrc_load(caplog): """ Test that corrupt mrc files are logged as expected. diff --git a/tests/test_micrograph_simulation.py b/tests/test_micrograph_simulation.py index 58f77869a9..14f1212e0d 100644 --- a/tests/test_micrograph_simulation.py +++ b/tests/test_micrograph_simulation.py @@ -2,6 +2,8 @@ import logging import os import tempfile +from datetime import datetime +from unittest import mock import numpy as np import pytest @@ -303,6 +305,81 @@ def test_sim_save(): ) +def test_save_overwrite(caplog): + """ + Tests MicrographSimulation.save functionality. + + Specifically tests interoperability with CentersCoordinateSource + """ + + v = AsymmetricVolume(L=16, C=1, dtype=np.float64).generate() + ctfs = [ + RadialCTFFilter( + pixel_size=4, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0 + ) + ] + + mg_sim = MicrographSimulation( + volume=v, + particles_per_micrograph=3, + interparticle_distance=v.resolution, + micrograph_count=2, + micrograph_size=512, + ctf_filters=ctfs, + ) + + mg_sim_new = MicrographSimulation( + volume=v, + particles_per_micrograph=4, + interparticle_distance=v.resolution, + micrograph_count=3, + micrograph_size=512, + ctf_filters=ctfs, + ) + + with tempfile.TemporaryDirectory() as tmp_output_dir: + path = os.path.join(tmp_output_dir, "test") + + # Write MRC and STAR files + save_paths_1 = mg_sim.save(path, overwrite=True) + + # Case 1: overwrite=True (should overwrite the existing file) + save_paths_2 = mg_sim.save(path, overwrite=True) + np.testing.assert_array_equal(save_paths_1, save_paths_2) + + # Case2: overwrite=False (should raise error) + with pytest.raises(FileExistsError): + _ = mg_sim.save(path, overwrite=False) + + # Case 3: overwrite=None (should rename the existing directory) + mock_datetime_value = datetime(2024, 10, 18, 12, 0, 0) + with mock.patch("aspire.utils.misc.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_datetime_value + mock_datetime.strftime = datetime.strftime + + with caplog.at_level(logging.INFO): + _ = mg_sim_new.save(path, overwrite=None) + + # Check that the existing directory was renamed and logged + assert f"Renaming {path}" in caplog.text + assert os.path.exists(path), "Directory not found" + + # Construct the expected renamed directory using the mock timestamp + mock_timestamp = mock_datetime_value.strftime("%y%m%d_%H%M%S") + renamed_dir = f"{path}_{mock_timestamp}" + + # Assert that the renamed file exists + assert os.path.exists(renamed_dir), "Renamed directory not found" + + # Load renamed directory and check images against orignal sim. + mg_src = DiskMicrographSource(renamed_dir) + np.testing.assert_allclose(mg_src.asnumpy(), mg_sim.asnumpy()) + + # Load new directory and check images against orignal sim. + mg_src_new = DiskMicrographSource(path) + np.testing.assert_allclose(mg_src_new.asnumpy(), mg_sim_new.asnumpy()) + + def test_bad_amplitudes(vol_fixture): """ Test incorrect `particle_amplitudes` argument raises. diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 2331339265..944e2e7c06 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -666,6 +666,96 @@ def test_cached_image_accessors(): ) +def test_save_overwrite(caplog): + """ + Test that the overwrite flag behaves as expected. + - overwrite=True: Overwrites the existing file. + - overwrite=False: Raises an error if the file exists. + - overwrite=None: Renames the existing file and saves the new one. + """ + sim1 = Simulation(seed=1) + sim2 = Simulation(seed=2) + sim3 = Simulation(seed=3) + + # Create a tmp dir for this test output + with tempfile.TemporaryDirectory() as tmpdir_name: + # tmp filename + starfile = os.path.join(tmpdir_name, "og.star") + base, ext = os.path.splitext(starfile) + + sim1.save(starfile, overwrite=True) + + # Case 1: overwrite=True (should overwrite the existing file) + sim2.save(starfile, overwrite=True) + + # Load and check if sim2 has overwritten sim1 + sim2_loaded = RelionSource(starfile) + np.testing.assert_allclose( + sim2.images[:].asnumpy(), + sim2_loaded.images[:].asnumpy(), + atol=utest_tolerance(sim2.dtype), + ) + + # Check that metadata is unchanged. + check_metadata(sim2, sim2_loaded) + + # Case 2: overwrite=False (should raise an overwrite error) + with raises( + ValueError, + match="File '.*' already exists; set overwrite=True to overwrite it", + ): + sim2.save(starfile, overwrite=False) + + # case 3: overwrite=None (should rename the existing file and save im3 with original filename) + with caplog.at_level(logging.INFO): + sim3.save(starfile, overwrite=None) + + # Check that the existing file was renamed and logged + assert f"Renaming {starfile}" in caplog.text + + # Find the renamed file by checking the directory contents + renamed_file = None + for filename in os.listdir(tmpdir_name): + if filename.startswith("og_") and filename.endswith(".star"): + renamed_file = os.path.join(tmpdir_name, filename) + break + + assert renamed_file is not None, "Renamed file not found" + + # Load and check that sim3 was saved to the original path + sim3_loaded = RelionSource(starfile) + np.testing.assert_allclose( + sim3.images[:].asnumpy(), + sim3_loaded.images[:].asnumpy(), + atol=utest_tolerance(sim3.dtype), + ) + check_metadata(sim3, sim3_loaded) + + # Also check that the renamed file still contains sim2's data + sim2_loaded_renamed = RelionSource(renamed_file) + np.testing.assert_allclose( + sim2.images[:].asnumpy(), + sim2_loaded_renamed.images[:].asnumpy(), + atol=utest_tolerance(sim2.dtype), + ) + check_metadata(sim2, sim2_loaded_renamed) + + +def check_metadata(sim_src, relion_src): + """ + Helper function to test if metadata fields in a Simulation match + those in a RelionSource. + """ + for k, v in sim_src._metadata.items(): + try: + np.testing.assert_array_equal(v, relion_src._metadata[k]) + except AssertionError: + # Loaded metadata might be strings so recast. + np.testing.assert_allclose( + v, np.array(relion_src._metadata[k]).astype(type(v[0])) + ) + + def test_mismatched_pixel_size(): """ Confirm raises error when explicit Simulation and CTFFilter pixel sizes mismatch. diff --git a/tests/test_utils.py b/tests/test_utils.py index 040f427758..a81ccaf11a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,8 @@ import tempfile import warnings from contextlib import contextmanager +from datetime import datetime +from unittest import mock import matplotlib import numpy as np @@ -20,6 +22,7 @@ num_procs_suggestion, physical_core_cpu_suggestion, powerset, + rename_with_timestamp, utest_tolerance, virtual_core_cpu_suggestion, ) @@ -112,6 +115,55 @@ def test_get_full_version_unexpected(monkeypatch): assert get_full_version() == __version__ + ".x" +def test_rename_with_timestamp(caplog): + with tempfile.TemporaryDirectory() as tmpdir_name: + filepath = os.path.join(tmpdir_name, "test_file.name") + base, ext = os.path.splitext(filepath) + + # Create file on disk. + with open(filepath, "w") as f: + f.write("Test file") + + # Mock datetime to return a fixed timestamp. + mock_datetime_value = datetime(2024, 10, 18, 12, 0, 0) + mock_timestamp = mock_datetime_value.strftime("%y%m%d_%H%M%S") + + with mock.patch("aspire.utils.misc.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_datetime_value + mock_datetime.strftime = datetime.strftime + + # Case 1: move=False should return the new file name with appended timestamp. + renamed_file = rename_with_timestamp(filepath, move=False) + assert renamed_file == f"{base}_{mock_timestamp}{ext}" + + # Case 2: move=True (default) should rename file on disk. + with caplog.at_level(logging.INFO): + renamed_file = rename_with_timestamp(filepath) + + # Check log for renaming operation. + assert f"Renaming {filepath} as {renamed_file}" in caplog.text + + # Check that the original file no longer exists. + assert not os.path.exists(filepath) + + # Check that the new file exists on disk with the expected name. + assert os.path.exists(renamed_file) + + # Case 3: Test when the file does not exist. + non_existent_file = os.path.join(tmpdir_name, "non_existent_file.name") + with caplog.at_level(logging.WARNING): + result = rename_with_timestamp(non_existent_file) + + # Check that None is returned since the file doesn't exist. + assert result is None + + # Check log for the warning about file not found. + assert ( + f"File '{non_existent_file}' not found, could not rename." + in caplog.text + ) + + def test_power_set(): ref = sorted([(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]) s = range(1, 4) diff --git a/tests/test_volume.py b/tests/test_volume.py index 1f55645e10..d220860784 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -2,7 +2,9 @@ import os import tempfile import warnings +from datetime import datetime from itertools import product +from unittest import mock import numpy as np import pytest @@ -296,6 +298,69 @@ def test_save_load(vols_1): assert vols_loaded_double.pixel_size is None, "Pixel size should be None" +def test_save_overwrite(caplog): + """ + Test that the overwrite flag behaves as expected. + - overwrite=True: Overwrites the existing file. + - overwrite=False: Raises an error if the file exists. + - overwrite=None: Renames the existing file and saves the new one. + """ + vol1 = Volume(np.ones((1, 8, 8, 8), dtype=np.float32)) + vol2 = Volume(2 * np.ones((1, 8, 8, 8), dtype=np.float32)) + vol3 = Volume(3 * np.ones((1, 8, 8, 8), dtype=np.float32)) + + # Create a tmp dir for this test output + with tempfile.TemporaryDirectory() as tmpdir_name: + # tmp filename + mrc_path = os.path.join(tmpdir_name, "og.mrc") + base, ext = os.path.splitext(mrc_path) + + # Create and save the first image + vol1.save(mrc_path, overwrite=True) + + # Case 1: overwrite=True (should overwrite the existing file) + vol2.save(mrc_path, overwrite=True) + + # Load and check if vol2 has overwritten vol1 + vol2_loaded = Volume.load(mrc_path) + np.testing.assert_allclose(vol2.asnumpy(), vol2_loaded.asnumpy()) + + # Case 2: overwrite=False (should raise an overwrite error) + with pytest.raises( + ValueError, + match="File '.*' already exists; set overwrite=True to overwrite it", + ): + vol3.save(mrc_path, overwrite=False) + + # Case 3: overwrite=None (should rename the existing file and save vol3 with original filename) + # Mock datetime to return a fixed timestamp + mock_datetime_value = datetime(2024, 10, 18, 12, 0, 0) + with mock.patch("aspire.utils.misc.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_datetime_value + mock_datetime.strftime = datetime.strftime + + with caplog.at_level(logging.INFO): + vol3.save(mrc_path, overwrite=None) + + # Check that the existing file was renamed and logged + assert f"Renaming {mrc_path}" in caplog.text + + # Construct the expected renamed filename using the mock timestamp + mock_timestamp = mock_datetime_value.strftime("%y%m%d_%H%M%S") + renamed_file = f"{base}_{mock_timestamp}{ext}" + + # Assert that the renamed file exists + assert os.path.exists(renamed_file), "Renamed file not found" + + # Load and check that vol3 was saved to the original path + vol3_loaded = Volume.load(mrc_path) + np.testing.assert_allclose(vol3.asnumpy(), vol3_loaded.asnumpy()) + + # Also check that the renamed file still contains vol2's data + vol2_loaded_renamed = Volume.load(renamed_file) + np.testing.assert_allclose(vol2.asnumpy(), vol2_loaded_renamed.asnumpy()) + + def test_volume_pixel_size(vols_2): """ Test volume is storing pixel_size attribute.