diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 894b979c..e7707e85 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -139,6 +139,7 @@ def from_fits( noise_map_hdu=0, uv_wavelengths_hdu=0, transformer_class=TransformerNUFFT, + raise_error_dft_visibilities_limit: bool = True, ): """ Load an interferometer dataset from multiple .fits files. @@ -176,6 +177,11 @@ def from_fits( transformer_class The class of the Fourier Transform which maps images from real space to Fourier space visibilities. Defaults to `TransformerNUFFT` for efficiency with large datasets. + raise_error_dft_visibilities_limit + If True (default), raise a `DatasetException` when ``transformer_class`` is + `TransformerDFT` and the dataset has more than 10,000 visibilities. Set to False to + opt into the slow DFT path at ALMA-scale (e.g. when profiling the JAX-traceable + DFT path before a JIT-friendly NUFFT is available). Returns ------- @@ -199,6 +205,7 @@ def from_fits( noise_map=noise_map, uv_wavelengths=uv_wavelengths, transformer_class=transformer_class, + raise_error_dft_visibilities_limit=raise_error_dft_visibilities_limit, ) def apply_sparse_operator( diff --git a/test_autoarray/dataset/interferometer/test_dataset.py b/test_autoarray/dataset/interferometer/test_dataset.py index a387c047..3fde438e 100644 --- a/test_autoarray/dataset/interferometer/test_dataset.py +++ b/test_autoarray/dataset/interferometer/test_dataset.py @@ -64,6 +64,52 @@ def test__dirty_signal_to_noise_map__shape_native_matches_real_space_mask( ).all() +def test__from_fits__raise_error_dft_visibilities_limit__threads_kwarg( + tmp_path, mask_2d_7x7 +): + """``from_fits`` must forward ``raise_error_dft_visibilities_limit`` to the + ``Interferometer`` constructor so callers loading large DFT-based datasets can opt out + of the >10,000-visibility safety check (e.g. for profiling the JAX-traceable DFT path).""" + from astropy.io import fits + + n_visibilities = 10_001 + visibilities = np.ones((n_visibilities, 2), dtype=np.float64) + noise_map = np.ones((n_visibilities, 2), dtype=np.float64) + uv_wavelengths = np.zeros((n_visibilities, 2), dtype=np.float64) + + data_path = tmp_path / "data.fits" + noise_map_path = tmp_path / "noise_map.fits" + uv_path = tmp_path / "uv_wavelengths.fits" + + for arr, path in ( + (visibilities, data_path), + (noise_map, noise_map_path), + (uv_wavelengths, uv_path), + ): + fits.PrimaryHDU(data=arr).writeto(path, overwrite=True) + + with pytest.raises(aa.exc.DatasetException): + aa.Interferometer.from_fits( + data_path=data_path, + noise_map_path=noise_map_path, + uv_wavelengths_path=uv_path, + real_space_mask=mask_2d_7x7, + transformer_class=transformer.TransformerDFT, + ) + + dataset = aa.Interferometer.from_fits( + data_path=data_path, + noise_map_path=noise_map_path, + uv_wavelengths_path=uv_path, + real_space_mask=mask_2d_7x7, + transformer_class=transformer.TransformerDFT, + raise_error_dft_visibilities_limit=False, + ) + + assert dataset.uv_wavelengths.shape[0] == n_visibilities + assert type(dataset.transformer) == transformer.TransformerDFT + + def test__from_fits__all_files_in_one_fits__load_using_different_hdus(mask_2d_7x7): dataset = aa.Interferometer.from_fits( real_space_mask=mask_2d_7x7,