diff --git a/reproject/hips/__init__.py b/reproject/hips/__init__.py index 68b722972..ab029c63e 100644 --- a/reproject/hips/__init__.py +++ b/reproject/hips/__init__.py @@ -1 +1,2 @@ from .high_level import * # noqa +from ._dask_array import hips_as_dask_array # noqa diff --git a/reproject/hips/_dask_array.py b/reproject/hips/_dask_array.py new file mode 100644 index 000000000..0fde086fa --- /dev/null +++ b/reproject/hips/_dask_array.py @@ -0,0 +1,146 @@ +import functools +import os +import urllib +import uuid + +import numpy as np +from astropy.io import fits +from astropy.utils.data import download_file +from astropy.wcs import WCS +from astropy_healpix import HEALPix, level_to_nside +from dask import array as da + +from .high_level import VALID_COORD_SYSTEM +from .utils import is_url, load_properties, map_header, tile_filename + +__all__ = ["hips_as_dask_array"] + + +class HiPSArray: + + def __init__(self, directory_or_url, level=None): + + self._directory_or_url = directory_or_url + + self._is_url = is_url(directory_or_url) + + self._properties = load_properties(directory_or_url) + + self._tile_width = int(self._properties["hips_tile_width"]) + self._order = int(self._properties["hips_order"]) + if level is None: + self._level = self._order + else: + if level > self._order: + raise ValueError( + f"HiPS dataset at {directory_or_url} does not contain level {level} data" + ) + elif level < 0: + raise ValueError("level should be positive") + else: + self._level = int(level) + self._level = self._order if level is None else level + self._tile_format = self._properties["hips_tile_format"] + self._frame_str = self._properties["hips_frame"] + self._frame = VALID_COORD_SYSTEM[self._frame_str] + + self._hp = HEALPix(nside=level_to_nside(self._level), frame=self._frame, order="nested") + + self._header = map_header(level=self._level, frame=self._frame, tile_size=self._tile_width) + + self.wcs = WCS(self._header) + self.shape = self.wcs.array_shape + + self.dtype = float + self.ndim = 2 + + self.chunksize = (self._tile_width, self._tile_width) + + self._nan = np.nan * np.ones(self.chunksize, dtype=self.dtype) + + self._blank = np.broadcast_to(np.nan, self.shape) + + def __getitem__(self, item): + + if item[0].start == item[0].stop or item[1].start == item[1].stop: + return self._blank[item] + + # We use two points in different parts of the image because in some + # cases using the exact center or corners can cause issues. + + istart = item[0].start + irange = item[0].stop - item[0].start + imid = np.array([istart + 0.25 * irange, istart + 0.75 * irange]) + + jstart = item[1].start + jrange = item[1].stop - item[1].start + jmid = np.array([jstart + 0.25 * jrange, jstart + 0.75 * jrange]) + + # Convert pixel coordinates to HEALPix indices + + coord = self.wcs.pixel_to_world(jmid, imid) + + if self._frame_str == "equatorial": + lon, lat = coord.ra.deg, coord.dec.deg + elif self._frame_str == "galactic": + lon, lat = coord.l.deg, coord.b.deg + else: + raise NotImplementedError() + + invalid = np.isnan(lon) | np.isnan(lat) + + if np.all(invalid): + return self._nan + elif np.any(invalid): + coord = coord[~invalid] + + index = self._hp.skycoord_to_healpix(coord) + + if np.all(index == -1): + return self._nan + + index = np.max(index) + + return self._get_tile(level=self._level, index=index) + + @functools.lru_cache(maxsize=128) # noqa: B019 + def _get_tile(self, *, level, index): + + filename_or_url = tile_filename( + level=self._level, + index=index, + output_directory=self._directory_or_url, + extension="fits", + ) + + if self._is_url: + try: + filename = download_file(filename_or_url, cache=True) + except urllib.error.HTTPError: + return self._nan + elif not os.path.exists(filename_or_url): + return self._nan + else: + filename = filename_or_url + + with fits.open(filename) as hdulist: + hdu = hdulist[0] + data = hdu.data + + return data + + +def hips_as_dask_array(directory_or_url, *, level=None): + """ + Return a dask array and WCS that represent a HiPS dataset at a particular level. + """ + array_wrapper = HiPSArray(directory_or_url, level=level) + return ( + da.from_array( + array_wrapper, + chunks=array_wrapper.chunksize, + name=str(uuid.uuid4()), + meta=np.array([], dtype=float), + ), + array_wrapper.wcs, + ) diff --git a/reproject/hips/high_level.py b/reproject/hips/high_level.py index 2ca9b7d52..e00092b37 100644 --- a/reproject/hips/high_level.py +++ b/reproject/hips/high_level.py @@ -8,6 +8,7 @@ from pathlib import Path import numpy as np +from astropy import units as u from astropy.coordinates import ICRS, BarycentricTrueEcliptic, Galactic from astropy.io import fits from astropy.nddata import block_reduce @@ -22,7 +23,9 @@ from ..utils import as_transparent_rgb, is_jpeg, is_png, parse_input_data from ..wcs_utils import has_celestial, pixel_scale from .utils import ( + load_properties, make_tile_folders, + save_properties, tile_filename, tile_header, ) @@ -202,6 +205,8 @@ def reproject_to_hips( # Determine center of image and radius to furthest corner, to determine # which HiPS tiles need to be generated + # TODO: this will fail for e.g. allsky maps + ny, nx = array_in.shape[-2:] cen_x, cen_y = (nx - 1) / 2, (ny - 1) / 2 @@ -212,7 +217,30 @@ def reproject_to_hips( cen_world = wcs_in.pixel_to_world(cen_x, cen_y) cor_world = wcs_in.pixel_to_world(cor_x, cor_y) - radius = cor_world.separation(cen_world).max() + separations = cor_world.separation(cen_world) + + if np.any(np.isnan(separations)): + + # At least one of the corners is outside of the region of validity of + # the WCS, so we use a different approach where we randomly sample a + # number of positions in the image and then check the maximum + # separation between any pair of points. + + n_ran = 1000 + ran_x = np.random.uniform(-0.5, nx - 0.5, n_ran) + ran_y = np.random.uniform(-0.5, nx - 0.5, n_ran) + + ran_world = wcs_in.pixel_to_world(ran_x, ran_y) + + separations = ran_world[:, None].separation(ran_world[None, :]) + + max_separation = np.nanmax(separations) + + else: + + max_separation = separations.max() + + radius = 1.5 * max_separation # TODO: in future if astropy-healpix implements polygon searches, we could # use that instead @@ -222,7 +250,10 @@ def reproject_to_hips( nside = level_to_nside(level) hp = HEALPix(nside=nside, order="nested", frame=frame) - indices = hp.cone_search_skycoord(cen_world, radius=radius) + if radius > 120 * u.deg: + indices = np.arange(hp.npix) + else: + indices = hp.cone_search_skycoord(cen_world, radius=radius) logger.info(f"Found {len(indices)} tiles (at most) to generate at level {level}") @@ -234,12 +265,29 @@ def reproject_to_hips( # Iterate over the tiles and generate them def process(index): - header = tile_header(level=level, index=index, frame=frame, tile_size=tile_size) if hasattr(wcs_in, "deepcopy"): wcs_in_copy = wcs_in.deepcopy() else: wcs_in_copy = deepcopy(wcs_in) - array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs) + + header = tile_header(level=level, index=index, frame=frame, tile_size=tile_size) + + if isinstance(header, tuple): + array_out1, footprint1 = reproject_function( + (array_in, wcs_in_copy), header[0], **kwargs + ) + array_out2, footprint2 = reproject_function( + (array_in, wcs_in_copy), header[1], **kwargs + ) + with np.errstate(invalid="ignore"): + array_out = ( + np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2 + ) / (footprint1 + footprint2) + footprint = (footprint1 + footprint2) / 2 + header = header[0] + else: + array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs) + if tile_format != "png": array_out[np.isnan(array_out)] = 0.0 if np.all(footprint == 0): @@ -253,6 +301,7 @@ def process(index): extension=EXTENSION[tile_format], ), array_out, + header, ) else: if tile_format == "png": @@ -288,6 +337,9 @@ def process(index): indices = np.array(generated_indices) # Iterate over higher levels and compute lower resolution tiles + + half_tile_size = tile_size // 2 + for ilevel in range(level - 1, -1, -1): # Find index of tiles to produce at lower-resolution levels @@ -299,6 +351,9 @@ def process(index): header = tile_header(level=ilevel, index=index, frame=frame, tile_size=tile_size) + if isinstance(header, tuple): + header = header[0] + if tile_format == "fits": array = np.zeros((tile_size, tile_size)) elif tile_format == "png": @@ -326,13 +381,13 @@ def process(index): ) if subindex == 0: - array[256:, :256] = data + array[half_tile_size:, :half_tile_size] = data elif subindex == 2: - array[256:, 256:] = data + array[half_tile_size:, half_tile_size:] = data elif subindex == 1: - array[:256, :256] = data + array[:half_tile_size, :half_tile_size] = data elif subindex == 3: - array[:256, 256:] = data + array[:half_tile_size, half_tile_size:] = data if tile_format == "fits": fits.writeto( @@ -403,21 +458,6 @@ def save_index(directory): f.write(INDEX_HTML) -def save_properties(directory, properties): - with open(os.path.join(directory, "properties"), "w") as f: - for key, value in properties.items(): - f.write(f"{key:20s} = {value}\n") - - -def load_properties(directory): - properties = {} - with open(os.path.join(directory, "properties")) as f: - for line in f: - key, value = line.split("=") - properties[key.strip()] = value.strip() - return properties - - def coadd_hips(input_directories, output_directory): """ Given multiple HiPS directories, combine these into a single HiPS. diff --git a/reproject/hips/tests/test_dask_array.py b/reproject/hips/tests/test_dask_array.py new file mode 100644 index 000000000..556a36548 --- /dev/null +++ b/reproject/hips/tests/test_dask_array.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest +from astropy.io import fits +from astropy.utils.data import get_pkg_data_filename +from astropy.wcs import WCS + +from reproject import reproject_interp +from reproject.hips import reproject_to_hips +from reproject.hips._dask_array import hips_as_dask_array + + +class TestHIPSDaskArray: + + def setup_method(self): + # We use an all-sky WCS image as input since this will test all parts + # of the HiPS projection (some issues happen around boundaries for instance) + hdu = fits.open(get_pkg_data_filename("allsky/allsky_rosat.fits"))[0] + self.original_wcs = WCS(hdu.header) + self.original_array = hdu.data.size + np.arange(hdu.data.size).reshape(hdu.data.shape) + + @pytest.mark.parametrize("frame", ("galactic", "equatorial")) + @pytest.mark.parametrize("level", (0, 1)) + def test_roundtrip(self, tmp_path, frame, level): + + output_directory = tmp_path / "roundtrip" + + # Note that we always use level=1 to generate, but use a variable level + # to construct the dask array - this is deliberate and ensure that the + # dask array has a proper separation of maximum and current level. + reproject_to_hips( + (self.original_array, self.original_wcs), + coord_system_out=frame, + level=1, + reproject_function=reproject_interp, + output_directory=output_directory, + tile_size=256, + ) + + # Represent the HiPS as a dask array + dask_array, wcs = hips_as_dask_array(output_directory, level=level) + + # Reproject back to the original WCS + final_array, footprint = reproject_interp( + (dask_array, wcs), + self.original_wcs, + shape_out=self.original_array.shape, + ) + + # FIXME: Due to boundary effects and the fact there are NaN values in + # the whole-map dask array, there are a few NaN pixels in the image in + # the end. For now, we tolerate a small fraction of NaN pixels, and to + # fix this we should modify the dask array so that in empty tiles + # adjacent to non-empty tiles, we set the values to the boundaries of + # the non-empty neighbouring tiles so that interpolation doesn't run + # into any issues. In theory there should be around 90500 pixels inside + # the valid region of the image, so we require at least 90000 valid + # values. + + valid = ~np.isnan(final_array) + assert np.sum(valid) > 90000 + np.testing.assert_allclose(final_array[valid], self.original_array[valid], rtol=0.01) + + def test_level_validation(self, tmp_path): + + output_directory = tmp_path / "levels" + + reproject_to_hips( + (self.original_array, self.original_wcs), + coord_system_out="equatorial", + level=1, + reproject_function=reproject_interp, + output_directory=output_directory, + tile_size=32, + ) + + dask_array, wcs = hips_as_dask_array(output_directory, level=0) + assert dask_array.shape == (160, 160) + + dask_array, wcs = hips_as_dask_array(output_directory, level=1) + assert dask_array.shape == (320, 320) + + dask_array, wcs = hips_as_dask_array(output_directory) + assert dask_array.shape == (320, 320) + + with pytest.raises(Exception, match=r"does not contain level 2 data"): + hips_as_dask_array(output_directory, level=2) + + with pytest.raises(Exception, match=r"should be positive"): + hips_as_dask_array(output_directory, level=-1) diff --git a/reproject/hips/utils.py b/reproject/hips/utils.py index ff3ab588a..b21ed5588 100644 --- a/reproject/hips/utils.py +++ b/reproject/hips/utils.py @@ -1,4 +1,6 @@ import os +import urllib +from pathlib import Path import numpy as np from astropy.wcs.utils import celestial_frame_to_wcs @@ -7,7 +9,48 @@ level_to_nside, ) -__all__ = ["tile_header", "tile_filename", "make_tile_folders"] +__all__ = [ + "map_header", + "tile_header", + "tile_filename", + "make_tile_folders", + "is_url", + "load_properties", + "save_properties", +] + + +def map_header(*, level, frame, tile_size): + """ + Return the WCS for a whole map stored as a 2D array in HPX projection + """ + + nside = level_to_nside(level) + + # Determine image size + image_size = 5 * nside * tile_size + + map_wcs = celestial_frame_to_wcs(frame, projection="HPX") + map_wcs.wcs.crval = 0.0, 0.0 + + # Determine map resolution + res = 45 / tile_size / 2**level + map_wcs.wcs.cd = [[-res, -res], [res, -res]] + + # Set PV parameters to default values + map_wcs.wcs.set_pv([(2, 1, 4), (2, 2, 3)]) + + # Set origin to center of the image + map_wcs.wcs.crpix = image_size / 2 + 0.5, image_size / 2 + 0.5 + + # Construct header + header = map_wcs.to_header() + + header["NAXIS"] = 2 + header["NAXIS1"] = image_size + header["NAXIS2"] = image_size + + return header def tile_header(*, level, index, frame, tile_size): @@ -20,9 +63,11 @@ def tile_header(*, level, index, frame, tile_size): # and then just update values for each tile for performance. nside = level_to_nside(level) + hp = HEALPix(nside=nside, order="nested", frame=frame) tile_wcs = celestial_frame_to_wcs(frame, projection="HPX") + tile_wcs.wcs.crval = 0.0, 0.0 # Determine tile resolution res = 45 / tile_size / 2**level # degrees @@ -33,10 +78,17 @@ def tile_header(*, level, index, frame, tile_size): # Determine CRPIX values by determining the position of the relevant corner # relative to the origin of the projection. - offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=1, dy=0)) + offset_x, offset_y = tile_wcs.world_to_pixel( + hp.healpix_to_skycoord(index, dx=[0.5, 0.9, 0.9, 0.1, 0.1], dy=[0.5, 0.1, 0.9, 0.9, 0.1]) + ) + border_tile = ( + np.max(np.hypot(offset_x[1:] - offset_x[0], offset_y[1:] - offset_y[0])) > tile_size + ) + + offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=0.75, dy=0.25)) - tile_wcs.wcs.crpix[0] = -offset_x - 0.5 - tile_wcs.wcs.crpix[1] = -offset_y - 0.5 + tile_wcs.wcs.crpix[0] = -offset_x - 0.5 + tile_size / 4 + tile_wcs.wcs.crpix[1] = -offset_y - 0.5 + tile_size / 4 # Construct header header = tile_wcs.to_header() @@ -47,7 +99,13 @@ def tile_header(*, level, index, frame, tile_size): header["NAXIS1"] = tile_size header["NAXIS2"] = tile_size - return header + if border_tile: + header2 = header.copy() + header2["CRPIX1"] = -header["CRPIX2"] + tile_size + 1 + header2["CRPIX2"] = -header["CRPIX1"] + tile_size + 1 + return header, header2 + else: + return header def _rounded_index(index): @@ -72,3 +130,34 @@ def make_tile_folders(*, level, indices, output_directory): ) if not os.path.exists(dirname): os.makedirs(dirname) + + +def is_url(directory): + if isinstance(directory, Path): + return False + else: + return directory.startswith("http://") or directory.startswith("https://") + + +def save_properties(directory, properties): + with open(os.path.join(directory, "properties"), "w") as f: + for key, value in properties.items(): + f.write(f"{key:20s} = {value}\n") + + +def load_properties(directory_or_url): + + if is_url(directory_or_url): + properties_filename, _ = urllib.request.urlretrieve(f"{directory_or_url}/properties") + else: + properties_filename = os.path.join(directory_or_url, "properties") + + properties = {} + with open(properties_filename) as f: + for line in f: + if line.startswith("#") or line.strip() == "": + continue + key, value = line.split("=", 1) + properties[key.strip()] = value.strip() + + return properties diff --git a/reproject/utils.py b/reproject/utils.py index fc4eef086..02046a871 100644 --- a/reproject/utils.py +++ b/reproject/utils.py @@ -31,7 +31,7 @@ def _dask_to_numpy_memmap(dask_array, tmp_dir): # Sometimes compute() has to be called twice to return a Numpy array, # so we need to check here if this is the case and call the first compute() - if isinstance(dask_array.ravel()[0].compute(), da.Array): + if isinstance(dask_array[(slice(0, 0),) * dask_array.ndim].compute(), da.Array): dask_array = dask_array.compute() # Cast the dask array to regular float for two reasons - first, zarr 3.0.0 @@ -83,7 +83,9 @@ def hdu_to_numpy_memmap(hdu): """ if ( - hdu.header.get("BSCALE", 1) != 1 + getattr(hdu, "_orig_bscale", 1) != 1 + or getattr(hdu, "_orig_bzero", 0) != 0 + or hdu.header.get("BSCALE", 1) != 1 or hdu.header.get("BZERO", 0) != 0 or hdu.fileinfo() is None or hdu._data_replaced