Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def from_fits(
noise_map_hdu=0,
uv_wavelengths_hdu=0,
transformer_class=TransformerNUFFT,
raise_error_dft_visibilities_limit: bool = True,
):
"""
Load an interferometer dataset from multiple .fits files.
Expand Down Expand Up @@ -176,6 +177,11 @@ def from_fits(
transformer_class
The class of the Fourier Transform which maps images from real space to Fourier space
visibilities. Defaults to `TransformerNUFFT` for efficiency with large datasets.
raise_error_dft_visibilities_limit
If True (default), raise a `DatasetException` when ``transformer_class`` is
`TransformerDFT` and the dataset has more than 10,000 visibilities. Set to False to
opt into the slow DFT path at ALMA-scale (e.g. when profiling the JAX-traceable
DFT path before a JIT-friendly NUFFT is available).

Returns
-------
Expand All @@ -199,6 +205,7 @@ def from_fits(
noise_map=noise_map,
uv_wavelengths=uv_wavelengths,
transformer_class=transformer_class,
raise_error_dft_visibilities_limit=raise_error_dft_visibilities_limit,
)

def apply_sparse_operator(
Expand Down
46 changes: 46 additions & 0 deletions test_autoarray/dataset/interferometer/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,52 @@ def test__dirty_signal_to_noise_map__shape_native_matches_real_space_mask(
).all()


def test__from_fits__raise_error_dft_visibilities_limit__threads_kwarg(
tmp_path, mask_2d_7x7
):
"""``from_fits`` must forward ``raise_error_dft_visibilities_limit`` to the
``Interferometer`` constructor so callers loading large DFT-based datasets can opt out
of the >10,000-visibility safety check (e.g. for profiling the JAX-traceable DFT path)."""
from astropy.io import fits

n_visibilities = 10_001
visibilities = np.ones((n_visibilities, 2), dtype=np.float64)
noise_map = np.ones((n_visibilities, 2), dtype=np.float64)
uv_wavelengths = np.zeros((n_visibilities, 2), dtype=np.float64)

data_path = tmp_path / "data.fits"
noise_map_path = tmp_path / "noise_map.fits"
uv_path = tmp_path / "uv_wavelengths.fits"

for arr, path in (
(visibilities, data_path),
(noise_map, noise_map_path),
(uv_wavelengths, uv_path),
):
fits.PrimaryHDU(data=arr).writeto(path, overwrite=True)

with pytest.raises(aa.exc.DatasetException):
aa.Interferometer.from_fits(
data_path=data_path,
noise_map_path=noise_map_path,
uv_wavelengths_path=uv_path,
real_space_mask=mask_2d_7x7,
transformer_class=transformer.TransformerDFT,
)

dataset = aa.Interferometer.from_fits(
data_path=data_path,
noise_map_path=noise_map_path,
uv_wavelengths_path=uv_path,
real_space_mask=mask_2d_7x7,
transformer_class=transformer.TransformerDFT,
raise_error_dft_visibilities_limit=False,
)

assert dataset.uv_wavelengths.shape[0] == n_visibilities
assert type(dataset.transformer) == transformer.TransformerDFT


def test__from_fits__all_files_in_one_fits__load_using_different_hdus(mask_2d_7x7):
dataset = aa.Interferometer.from_fits(
real_space_mask=mask_2d_7x7,
Expand Down
Loading