From cb85ed0e10e95838ccfe5a747f83bc207d17f84c Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 26 Sep 2025 14:01:34 +0100 Subject: [PATCH 1/6] Added dask class for 2D HiPS, and added round-tripping test via Dask class and fix a number of issues that came up including in the generation of HiPS --- reproject/hips/_dask_array.py | 133 +++++++++++++++++ reproject/hips/high_level.py | 187 ++++++++++++------------ reproject/hips/tests/test_dask_array.py | 56 +++++++ reproject/hips/utils.py | 102 ++++++++++++- reproject/utils.py | 6 +- 5 files changed, 387 insertions(+), 97 deletions(-) create mode 100644 reproject/hips/_dask_array.py create mode 100644 reproject/hips/tests/test_dask_array.py diff --git a/reproject/hips/_dask_array.py b/reproject/hips/_dask_array.py new file mode 100644 index 000000000..558f1ed7d --- /dev/null +++ b/reproject/hips/_dask_array.py @@ -0,0 +1,133 @@ +import os +import struct +import urllib +import uuid +import functools + +import numpy as np +from astropy import units as u +from astropy.io import fits +from astropy.wcs import WCS +from astropy_healpix import HEALPix, level_to_nside +from dask import array as da +from astropy.utils.data import download_file +from astropy.wcs.utils import celestial_frame_to_wcs + +from .utils import is_url, load_properties, tile_filename, tile_header, map_header +from .high_level import VALID_COORD_SYSTEM + +__all__ = ['hips_as_dask_and_wcs'] + + +class HiPSArray: + + def __init__(self, directory_or_url, level=None): + + self._directory_or_url = directory_or_url + + self._is_url = is_url(directory_or_url) + + self._properties = load_properties(directory_or_url) + + self._tile_width = int(self._properties["hips_tile_width"]) + self._order = int(self._properties["hips_order"]) + self._level = self._order if level is None else level + self._tile_format = self._properties["hips_tile_format"] + self._frame_str = self._properties["hips_frame"] + self._frame = VALID_COORD_SYSTEM[self._frame_str] + + self._hp = HEALPix(nside=level_to_nside(self._level), frame=self._frame, order="nested") + + self._header = map_header(level=self._level, frame=self._frame, tile_size=self._tile_width) + + self.wcs = WCS(self._header) + self.shape = self.wcs.array_shape + + self.dtype = float + self.ndim = 2 + + self.chunksize = (self._tile_width, self._tile_width) + + self._nan = np.nan * np.ones(self.chunksize, dtype=self.dtype) + + self._blank = np.broadcast_to(np.nan, self.shape) + + def __getitem__(self, item): + + if item[0].start == item[0].stop or item[1].start == item[1].stop: + return self._blank[item] + + # For now assume item is a list of slices. Find + + # imid = (item[0].start + item[0].stop) // 2 + # jmid = (item[1].start + item[1].stop) // 2 + + istart = item[0].start + irange = item[0].stop - item[0].start + imid = np.array([istart + 0.25 * irange, istart + 0.75 * irange]) + + jstart = item[1].start + jrange = item[1].stop - item[1].start + jmid = np.array([jstart + 0.25 * jrange, jstart + 0.75 * jrange]) + + # Convert pixel coordinates to HEALPix indices + + coord = self.wcs.pixel_to_world(jmid, imid) + + if self._frame_str == "equatorial": + lon, lat = coord.ra.deg, coord.dec.deg + elif self._frame_str == "galactic": + lon, lat = coord.l.deg, coord.b.deg + else: + raise NotImplementedError() + + if np.all(np.isnan(lon) | np.isnan(lat)): + return self._nan + + index = self._hp.skycoord_to_healpix(coord) + + if np.all(index == -1): + return self._nan + + index = np.max(index) + + return self._get_tile(level=self._level, index=index) + + @functools.lru_cache(maxsize=128) + def _get_tile(self, *, level, index): + + filename_or_url = tile_filename( + level=self._level, + index=index, + output_directory=self._directory_or_url, + extension="fits", + ) + + if self._is_url: + try: + filename = download_file(filename_or_url, cache=True) + except urllib.error.HTTPError: + return self._nan + elif not os.path.exists(filename_or_url): + return self._nan + else: + filename = filename_or_url + + with fits.open(filename) as hdulist: + hdu = hdulist[0] + # data = hdu.data[::-1] + data = hdu.data + + # return np.ones(hdu.data.shape) * index + + return data + + +def hips_as_dask_and_wcs(directory_or_url, *, level=None): + array_wrapper = HiPSArray(directory_or_url, level=level) + return da.from_array( + array_wrapper, + chunks=array_wrapper.chunksize, + name=str(uuid.uuid4()), + meta=np.array([], dtype=float) + ), array_wrapper.wcs diff --git a/reproject/hips/high_level.py b/reproject/hips/high_level.py index 2ca9b7d52..64ca89221 100644 --- a/reproject/hips/high_level.py +++ b/reproject/hips/high_level.py @@ -10,7 +10,6 @@ import numpy as np from astropy.coordinates import ICRS, BarycentricTrueEcliptic, Galactic from astropy.io import fits -from astropy.nddata import block_reduce from astropy_healpix import ( HEALPix, level_to_nside, @@ -22,7 +21,9 @@ from ..utils import as_transparent_rgb, is_jpeg, is_png, parse_input_data from ..wcs_utils import has_celestial, pixel_scale from .utils import ( + load_properties, make_tile_folders, + save_properties, tile_filename, tile_header, ) @@ -202,6 +203,8 @@ def reproject_to_hips( # Determine center of image and radius to furthest corner, to determine # which HiPS tiles need to be generated + # TODO: this will fail for e.g. allsky maps + ny, nx = array_in.shape[-2:] cen_x, cen_y = (nx - 1) / 2, (ny - 1) / 2 @@ -212,7 +215,7 @@ def reproject_to_hips( cen_world = wcs_in.pixel_to_world(cen_x, cen_y) cor_world = wcs_in.pixel_to_world(cor_x, cor_y) - radius = cor_world.separation(cen_world).max() + radius = cor_world.separation(cen_world).max() * 2 # TODO: in future if astropy-healpix implements polygon searches, we could # use that instead @@ -222,7 +225,9 @@ def reproject_to_hips( nside = level_to_nside(level) hp = HEALPix(nside=nside, order="nested", frame=frame) - indices = hp.cone_search_skycoord(cen_world, radius=radius) + # indices = hp.cone_search_skycoord(cen_world, radius=radius) + + indices = np.arange(hp.npix) logger.info(f"Found {len(indices)} tiles (at most) to generate at level {level}") @@ -234,12 +239,28 @@ def reproject_to_hips( # Iterate over the tiles and generate them def process(index): - header = tile_header(level=level, index=index, frame=frame, tile_size=tile_size) if hasattr(wcs_in, "deepcopy"): wcs_in_copy = wcs_in.deepcopy() else: wcs_in_copy = deepcopy(wcs_in) - array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs) + + header = tile_header(level=level, index=index, frame=frame, tile_size=tile_size) + + if isinstance(header, tuple): + array_out1, footprint1 = reproject_function( + (array_in, wcs_in_copy), header[0], **kwargs + ) + array_out2, footprint2 = reproject_function( + (array_in, wcs_in_copy), header[1], **kwargs + ) + array_out = ( + np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2 + ) / (footprint1 + footprint2) + footprint = (footprint1 + footprint2) / 2 + header = header[0] + else: + array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs) + if tile_format != "png": array_out[np.isnan(array_out)] = 0.0 if np.all(footprint == 0): @@ -253,6 +274,7 @@ def process(index): extension=EXTENSION[tile_format], ), array_out, + header, ) else: if tile_format == "png": @@ -287,76 +309,76 @@ def process(index): indices = np.array(generated_indices) - # Iterate over higher levels and compute lower resolution tiles - for ilevel in range(level - 1, -1, -1): - - # Find index of tiles to produce at lower-resolution levels - indices = np.sort(np.unique(indices // 4)) - - make_tile_folders(level=ilevel, indices=indices, output_directory=output_directory) - - for index in indices: - - header = tile_header(level=ilevel, index=index, frame=frame, tile_size=tile_size) - - if tile_format == "fits": - array = np.zeros((tile_size, tile_size)) - elif tile_format == "png": - array = np.zeros((tile_size, tile_size, 4), dtype=np.uint8) - else: - array = np.zeros((tile_size, tile_size, 3), dtype=np.uint8) - - for subindex in range(4): - - current_index = 4 * index + subindex - subtile_filename = tile_filename( - level=ilevel + 1, - index=current_index, - output_directory=output_directory, - extension=EXTENSION[tile_format], - ) - - if os.path.exists(subtile_filename): - - if tile_format == "fits": - data = block_reduce(fits.getdata(subtile_filename), 2, func=np.mean) - else: - data = block_reduce( - np.array(Image.open(subtile_filename))[::-1], (2, 2, 1), func=np.mean - ) - - if subindex == 0: - array[256:, :256] = data - elif subindex == 2: - array[256:, 256:] = data - elif subindex == 1: - array[:256, :256] = data - elif subindex == 3: - array[:256, 256:] = data - - if tile_format == "fits": - fits.writeto( - tile_filename( - level=ilevel, - index=index, - output_directory=output_directory, - extension=EXTENSION[tile_format], - ), - array, - header, - ) - else: - image = as_transparent_rgb(array.transpose(2, 0, 1)) - if tile_format == "jpeg": - image = image.convert("RGB") - image.save( - tile_filename( - level=ilevel, - index=index, - output_directory=output_directory, - extension=EXTENSION[tile_format], - ) - ) + # # Iterate over higher levels and compute lower resolution tiles + # for ilevel in range(level - 1, -1, -1): + + # # Find index of tiles to produce at lower-resolution levels + # indices = np.sort(np.unique(indices // 4)) + + # make_tile_folders(level=ilevel, indices=indices, output_directory=output_directory) + + # for index in indices: + + # header = tile_header(level=ilevel, index=index, frame=frame, tile_size=tile_size) + + # if tile_format == "fits": + # array = np.zeros((tile_size, tile_size)) + # elif tile_format == "png": + # array = np.zeros((tile_size, tile_size, 4), dtype=np.uint8) + # else: + # array = np.zeros((tile_size, tile_size, 3), dtype=np.uint8) + + # for subindex in range(4): + + # current_index = 4 * index + subindex + # subtile_filename = tile_filename( + # level=ilevel + 1, + # index=current_index, + # output_directory=output_directory, + # extension=EXTENSION[tile_format], + # ) + + # if os.path.exists(subtile_filename): + + # if tile_format == "fits": + # data = block_reduce(fits.getdata(subtile_filename), 2, func=np.mean) + # else: + # data = block_reduce( + # np.array(Image.open(subtile_filename))[::-1], (2, 2, 1), func=np.mean + # ) + + # if subindex == 0: + # array[256:, :256] = data + # elif subindex == 2: + # array[256:, 256:] = data + # elif subindex == 1: + # array[:256, :256] = data + # elif subindex == 3: + # array[:256, 256:] = data + + # if tile_format == "fits": + # fits.writeto( + # tile_filename( + # level=ilevel, + # index=index, + # output_directory=output_directory, + # extension=EXTENSION[tile_format], + # ), + # array, + # header, + # ) + # else: + # image = as_transparent_rgb(array.transpose(2, 0, 1)) + # if tile_format == "jpeg": + # image = image.convert("RGB") + # image.save( + # tile_filename( + # level=ilevel, + # index=index, + # output_directory=output_directory, + # extension=EXTENSION[tile_format], + # ) + # ) # Generate properties file @@ -403,21 +425,6 @@ def save_index(directory): f.write(INDEX_HTML) -def save_properties(directory, properties): - with open(os.path.join(directory, "properties"), "w") as f: - for key, value in properties.items(): - f.write(f"{key:20s} = {value}\n") - - -def load_properties(directory): - properties = {} - with open(os.path.join(directory, "properties")) as f: - for line in f: - key, value = line.split("=") - properties[key.strip()] = value.strip() - return properties - - def coadd_hips(input_directories, output_directory): """ Given multiple HiPS directories, combine these into a single HiPS. diff --git a/reproject/hips/tests/test_dask_array.py b/reproject/hips/tests/test_dask_array.py new file mode 100644 index 000000000..ee6027aa5 --- /dev/null +++ b/reproject/hips/tests/test_dask_array.py @@ -0,0 +1,56 @@ +import pytest +import numpy as np + +from astropy.wcs import WCS +from astropy.io import fits + +from reproject import reproject_interp +from reproject.hips import reproject_to_hips +from reproject.hips._dask_array import hips_as_dask_and_wcs +from astropy.utils.data import get_pkg_data_filename + +class TestHIPSDaskArray: + + def setup_method(self): + + hdu = fits.open(get_pkg_data_filename('allsky/allsky_rosat.fits'))[0] + self.original_header = hdu.header + self.original_wcs = WCS(hdu.header) + self.original_array = hdu.data.size + np.arange(hdu.data.size).reshape(hdu.data.shape) + + @pytest.mark.parametrize('frame', ('galactic', 'equatorial')) + @pytest.mark.parametrize('level', (0, 1)) + def test_roundtrip(self, tmp_path, frame, level): + + self.output_directory = tmp_path / 'roundtrip' + + reproject_to_hips( + (self.original_array, self.original_wcs), + coord_system_out=frame, + level=level, + reproject_function=reproject_interp, + output_directory=self.output_directory, + ) + + dask_array, wcs = hips_as_dask_and_wcs(self.output_directory, level=level) + + final_array, footprint = reproject_interp((dask_array, wcs), self.original_wcs, shape_out=self.original_array.shape) + + # FIXME: Due to boundary effects and the fact there are NaN values in + # the whole-map dask array, there are a few NaN pixels in the image in + # the end. For now, we tolerate a small fraction of NaN pixels, and to + # fix this we should modify the dask array so that in empty tiles + # adjacent to non-empty tiles, we set the values to the boundaries of + # the non-empty neighbouring tiles so that interpolation doesn't run + # into any issues. In theory there should be around 90500 pixels inside + # the valid region of the image, so we require at least 90400 valid + # values. + + valid = ~np.isnan(final_array) + + assert np.sum(valid) > 90400 + + np.testing.assert_allclose(final_array[valid], self.original_array[valid], rtol=0.01) + + + # VALIDATE LEVEL diff --git a/reproject/hips/utils.py b/reproject/hips/utils.py index ff3ab588a..60d07798e 100644 --- a/reproject/hips/utils.py +++ b/reproject/hips/utils.py @@ -1,4 +1,6 @@ import os +import urllib +from pathlib import Path import numpy as np from astropy.wcs.utils import celestial_frame_to_wcs @@ -7,7 +9,48 @@ level_to_nside, ) -__all__ = ["tile_header", "tile_filename", "make_tile_folders"] +__all__ = [ + "map_header", + "tile_header", + "tile_filename", + "make_tile_folders", + "is_url", + "load_properties", + "save_properties", +] + + +def map_header(*, level, frame, tile_size): + """ + Return the WCS for a whole map stored as a 2D array in HPX projection + """ + + nside = level_to_nside(level) + + # Determine image size + image_size = 5 * nside * tile_size + + map_wcs = celestial_frame_to_wcs(frame, projection="HPX") + map_wcs.wcs.crval = 0.0, 0.0 + + # Determine map resolution + res = 45 / tile_size / 2**level + map_wcs.wcs.cd = [[-res, -res], [res, -res]] + + # Set PV parameters to default values + map_wcs.wcs.set_pv([(2, 1, 4), (2, 2, 3)]) + + # Set origin to center of the image + map_wcs.wcs.crpix = image_size / 2 + 0.5, image_size / 2 + 0.5 + + # Construct header + header = map_wcs.to_header() + + header["NAXIS"] = 2 + header["NAXIS1"] = image_size + header["NAXIS2"] = image_size + + return header def tile_header(*, level, index, frame, tile_size): @@ -20,9 +63,13 @@ def tile_header(*, level, index, frame, tile_size): # and then just update values for each tile for performance. nside = level_to_nside(level) + + # threshold = np.hypot(image_size, image_size) / 2. + hp = HEALPix(nside=nside, order="nested", frame=frame) tile_wcs = celestial_frame_to_wcs(frame, projection="HPX") + tile_wcs.wcs.crval = 0.0, 0.0 # Determine tile resolution res = 45 / tile_size / 2**level # degrees @@ -33,10 +80,18 @@ def tile_header(*, level, index, frame, tile_size): # Determine CRPIX values by determining the position of the relevant corner # relative to the origin of the projection. - offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=1, dy=0)) + offset_x, offset_y = tile_wcs.world_to_pixel( + hp.healpix_to_skycoord(index, dx=[0.5, 0.9, 0.9, 0.1, 0.1], dy=[0.5, 0.1, 0.9, 0.9, 0.1]) + ) + border_tile = ( + np.max(np.hypot(offset_x[1:] - offset_x[0], offset_y[1:] - offset_y[0])) > tile_size + ) + + # offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=1, dy=0)) + offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=0.75, dy=0.25)) - tile_wcs.wcs.crpix[0] = -offset_x - 0.5 - tile_wcs.wcs.crpix[1] = -offset_y - 0.5 + tile_wcs.wcs.crpix[0] = -offset_x - 0.5 + tile_size / 4 + tile_wcs.wcs.crpix[1] = -offset_y - 0.5 + tile_size / 4 # Construct header header = tile_wcs.to_header() @@ -47,7 +102,13 @@ def tile_header(*, level, index, frame, tile_size): header["NAXIS1"] = tile_size header["NAXIS2"] = tile_size - return header + if border_tile: + header2 = header.copy() + header2["CRPIX1"] = -header["CRPIX2"] + tile_size + 1 + header2["CRPIX2"] = -header["CRPIX1"] + tile_size + 1 + return header, header2 + else: + return header def _rounded_index(index): @@ -72,3 +133,34 @@ def make_tile_folders(*, level, indices, output_directory): ) if not os.path.exists(dirname): os.makedirs(dirname) + + +def is_url(directory): + if isinstance(directory, Path): + return False + else: + return directory.startswith("http://") or directory.startswith("https://") + + +def save_properties(directory, properties): + with open(os.path.join(directory, "properties"), "w") as f: + for key, value in properties.items(): + f.write(f"{key:20s} = {value}\n") + + +def load_properties(directory_or_url): + + if is_url(directory_or_url): + properties_filename, _ = urllib.request.urlretrieve(f"{directory_or_url}/properties") + else: + properties_filename = os.path.join(directory_or_url, "properties") + + properties = {} + with open(properties_filename) as f: + for line in f: + if line.startswith("#") or line.strip() == "": + continue + key, value = line.split("=", 1) + properties[key.strip()] = value.strip() + + return properties diff --git a/reproject/utils.py b/reproject/utils.py index fc4eef086..02046a871 100644 --- a/reproject/utils.py +++ b/reproject/utils.py @@ -31,7 +31,7 @@ def _dask_to_numpy_memmap(dask_array, tmp_dir): # Sometimes compute() has to be called twice to return a Numpy array, # so we need to check here if this is the case and call the first compute() - if isinstance(dask_array.ravel()[0].compute(), da.Array): + if isinstance(dask_array[(slice(0, 0),) * dask_array.ndim].compute(), da.Array): dask_array = dask_array.compute() # Cast the dask array to regular float for two reasons - first, zarr 3.0.0 @@ -83,7 +83,9 @@ def hdu_to_numpy_memmap(hdu): """ if ( - hdu.header.get("BSCALE", 1) != 1 + getattr(hdu, "_orig_bscale", 1) != 1 + or getattr(hdu, "_orig_bzero", 0) != 0 + or hdu.header.get("BSCALE", 1) != 1 or hdu.header.get("BZERO", 0) != 0 or hdu.fileinfo() is None or hdu._data_replaced From 3db64703c109e928a70d0c3429e680f64441f3ce Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 26 Sep 2025 14:20:29 +0100 Subject: [PATCH 2/6] Added comments and added a test for level validation --- reproject/hips/_dask_array.py | 43 ++++++++++------ reproject/hips/high_level.py | 9 ++-- reproject/hips/tests/test_dask_array.py | 67 ++++++++++++++++++------- 3 files changed, 84 insertions(+), 35 deletions(-) diff --git a/reproject/hips/_dask_array.py b/reproject/hips/_dask_array.py index 558f1ed7d..35deb229e 100644 --- a/reproject/hips/_dask_array.py +++ b/reproject/hips/_dask_array.py @@ -1,22 +1,19 @@ +import functools import os -import struct import urllib import uuid -import functools import numpy as np -from astropy import units as u from astropy.io import fits +from astropy.utils.data import download_file from astropy.wcs import WCS from astropy_healpix import HEALPix, level_to_nside from dask import array as da -from astropy.utils.data import download_file -from astropy.wcs.utils import celestial_frame_to_wcs -from .utils import is_url, load_properties, tile_filename, tile_header, map_header from .high_level import VALID_COORD_SYSTEM +from .utils import is_url, load_properties, map_header, tile_filename -__all__ = ['hips_as_dask_and_wcs'] +__all__ = ["hips_as_dask_and_wcs"] class HiPSArray: @@ -31,6 +28,17 @@ def __init__(self, directory_or_url, level=None): self._tile_width = int(self._properties["hips_tile_width"]) self._order = int(self._properties["hips_order"]) + if level is None: + self._level = self._order + else: + if level > self._order: + raise ValueError( + f"HiPS dataset at {directory_or_url} does not contain level {level} data" + ) + elif level < 0: + raise ValueError("level should be positive") + else: + self._level = int(level) self._level = self._order if level is None else level self._tile_format = self._properties["hips_tile_format"] self._frame_str = self._properties["hips_frame"] @@ -81,8 +89,12 @@ def __getitem__(self, item): else: raise NotImplementedError() - if np.all(np.isnan(lon) | np.isnan(lat)): + invalid = np.isnan(lon) | np.isnan(lat) + + if np.all(invalid): return self._nan + elif np.any(invalid): + coord = coord[~invalid] index = self._hp.skycoord_to_healpix(coord) @@ -125,9 +137,12 @@ def _get_tile(self, *, level, index): def hips_as_dask_and_wcs(directory_or_url, *, level=None): array_wrapper = HiPSArray(directory_or_url, level=level) - return da.from_array( - array_wrapper, - chunks=array_wrapper.chunksize, - name=str(uuid.uuid4()), - meta=np.array([], dtype=float) - ), array_wrapper.wcs + return ( + da.from_array( + array_wrapper, + chunks=array_wrapper.chunksize, + name=str(uuid.uuid4()), + meta=np.array([], dtype=float), + ), + array_wrapper.wcs, + ) diff --git a/reproject/hips/high_level.py b/reproject/hips/high_level.py index 64ca89221..99e4c8a95 100644 --- a/reproject/hips/high_level.py +++ b/reproject/hips/high_level.py @@ -253,10 +253,11 @@ def process(index): array_out2, footprint2 = reproject_function( (array_in, wcs_in_copy), header[1], **kwargs ) - array_out = ( - np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2 - ) / (footprint1 + footprint2) - footprint = (footprint1 + footprint2) / 2 + with np.errstate(invalid="ignore"): + array_out = ( + np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2 + ) / (footprint1 + footprint2) + footprint = (footprint1 + footprint2) / 2 header = header[0] else: array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs) diff --git a/reproject/hips/tests/test_dask_array.py b/reproject/hips/tests/test_dask_array.py index ee6027aa5..0309a88b6 100644 --- a/reproject/hips/tests/test_dask_array.py +++ b/reproject/hips/tests/test_dask_array.py @@ -1,40 +1,50 @@ -import pytest import numpy as np - -from astropy.wcs import WCS +import pytest from astropy.io import fits +from astropy.utils.data import get_pkg_data_filename +from astropy.wcs import WCS from reproject import reproject_interp from reproject.hips import reproject_to_hips from reproject.hips._dask_array import hips_as_dask_and_wcs -from astropy.utils.data import get_pkg_data_filename + class TestHIPSDaskArray: def setup_method(self): - - hdu = fits.open(get_pkg_data_filename('allsky/allsky_rosat.fits'))[0] - self.original_header = hdu.header + # We use an all-sky WCS image as input since this will test all parts + # of the HiPS projection (some issues happen around boundaries for instance) + hdu = fits.open(get_pkg_data_filename("allsky/allsky_rosat.fits"))[0] self.original_wcs = WCS(hdu.header) self.original_array = hdu.data.size + np.arange(hdu.data.size).reshape(hdu.data.shape) - @pytest.mark.parametrize('frame', ('galactic', 'equatorial')) - @pytest.mark.parametrize('level', (0, 1)) + @pytest.mark.parametrize("frame", ("galactic", "equatorial")) + @pytest.mark.parametrize("level", (0, 1)) def test_roundtrip(self, tmp_path, frame, level): - self.output_directory = tmp_path / 'roundtrip' + output_directory = tmp_path / "roundtrip" + # Note that we always use level=1 to generate, but use a variable level + # to construct the dask array - this is deliberate and ensure that the + # dask array has a proper separation of maximum and current level. reproject_to_hips( (self.original_array, self.original_wcs), coord_system_out=frame, - level=level, + level=1, reproject_function=reproject_interp, - output_directory=self.output_directory, + output_directory=output_directory, + tile_size=256, ) - dask_array, wcs = hips_as_dask_and_wcs(self.output_directory, level=level) + # Represent the HiPS as a dask array + dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=level) - final_array, footprint = reproject_interp((dask_array, wcs), self.original_wcs, shape_out=self.original_array.shape) + # Reproject back to the original WCS + final_array, footprint = reproject_interp( + (dask_array, wcs), + self.original_wcs, + shape_out=self.original_array.shape, + ) # FIXME: Due to boundary effects and the fact there are NaN values in # the whole-map dask array, there are a few NaN pixels in the image in @@ -47,10 +57,33 @@ def test_roundtrip(self, tmp_path, frame, level): # values. valid = ~np.isnan(final_array) - assert np.sum(valid) > 90400 - np.testing.assert_allclose(final_array[valid], self.original_array[valid], rtol=0.01) + def test_level_validation(self, tmp_path): + + output_directory = tmp_path / "levels" + + reproject_to_hips( + (self.original_array, self.original_wcs), + coord_system_out="equatorial", + level=1, + reproject_function=reproject_interp, + output_directory=output_directory, + tile_size=32, + ) + + dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=0) + assert dask_array.shape == (160, 160) + + dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=1) + assert dask_array.shape == (320, 320) + + dask_array, wcs = hips_as_dask_and_wcs(output_directory) + assert dask_array.shape == (320, 320) + + with pytest.raises(Exception, match=r"does not contain level 2 data"): + hips_as_dask_and_wcs(output_directory, level=2) - # VALIDATE LEVEL + with pytest.raises(Exception, match=r"should be positive"): + hips_as_dask_and_wcs(output_directory, level=-1) From 517c360e35429bc8fc0ce5271c619859e77e7e31 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 26 Sep 2025 14:54:04 +0100 Subject: [PATCH 3/6] Generalize logic of determining field of view that works for all-sky images and re-instate generation of lower resolution tiles --- reproject/hips/_dask_array.py | 2 +- reproject/hips/high_level.py | 180 ++++++++++++++---------- reproject/hips/tests/test_dask_array.py | 4 +- 3 files changed, 109 insertions(+), 77 deletions(-) diff --git a/reproject/hips/_dask_array.py b/reproject/hips/_dask_array.py index 35deb229e..898321437 100644 --- a/reproject/hips/_dask_array.py +++ b/reproject/hips/_dask_array.py @@ -105,7 +105,7 @@ def __getitem__(self, item): return self._get_tile(level=self._level, index=index) - @functools.lru_cache(maxsize=128) + @functools.lru_cache(maxsize=128) # noqa: B019 def _get_tile(self, *, level, index): filename_or_url = tile_filename( diff --git a/reproject/hips/high_level.py b/reproject/hips/high_level.py index 99e4c8a95..e00092b37 100644 --- a/reproject/hips/high_level.py +++ b/reproject/hips/high_level.py @@ -8,8 +8,10 @@ from pathlib import Path import numpy as np +from astropy import units as u from astropy.coordinates import ICRS, BarycentricTrueEcliptic, Galactic from astropy.io import fits +from astropy.nddata import block_reduce from astropy_healpix import ( HEALPix, level_to_nside, @@ -215,7 +217,30 @@ def reproject_to_hips( cen_world = wcs_in.pixel_to_world(cen_x, cen_y) cor_world = wcs_in.pixel_to_world(cor_x, cor_y) - radius = cor_world.separation(cen_world).max() * 2 + separations = cor_world.separation(cen_world) + + if np.any(np.isnan(separations)): + + # At least one of the corners is outside of the region of validity of + # the WCS, so we use a different approach where we randomly sample a + # number of positions in the image and then check the maximum + # separation between any pair of points. + + n_ran = 1000 + ran_x = np.random.uniform(-0.5, nx - 0.5, n_ran) + ran_y = np.random.uniform(-0.5, nx - 0.5, n_ran) + + ran_world = wcs_in.pixel_to_world(ran_x, ran_y) + + separations = ran_world[:, None].separation(ran_world[None, :]) + + max_separation = np.nanmax(separations) + + else: + + max_separation = separations.max() + + radius = 1.5 * max_separation # TODO: in future if astropy-healpix implements polygon searches, we could # use that instead @@ -225,9 +250,10 @@ def reproject_to_hips( nside = level_to_nside(level) hp = HEALPix(nside=nside, order="nested", frame=frame) - # indices = hp.cone_search_skycoord(cen_world, radius=radius) - - indices = np.arange(hp.npix) + if radius > 120 * u.deg: + indices = np.arange(hp.npix) + else: + indices = hp.cone_search_skycoord(cen_world, radius=radius) logger.info(f"Found {len(indices)} tiles (at most) to generate at level {level}") @@ -310,76 +336,82 @@ def process(index): indices = np.array(generated_indices) - # # Iterate over higher levels and compute lower resolution tiles - # for ilevel in range(level - 1, -1, -1): - - # # Find index of tiles to produce at lower-resolution levels - # indices = np.sort(np.unique(indices // 4)) - - # make_tile_folders(level=ilevel, indices=indices, output_directory=output_directory) - - # for index in indices: - - # header = tile_header(level=ilevel, index=index, frame=frame, tile_size=tile_size) - - # if tile_format == "fits": - # array = np.zeros((tile_size, tile_size)) - # elif tile_format == "png": - # array = np.zeros((tile_size, tile_size, 4), dtype=np.uint8) - # else: - # array = np.zeros((tile_size, tile_size, 3), dtype=np.uint8) - - # for subindex in range(4): - - # current_index = 4 * index + subindex - # subtile_filename = tile_filename( - # level=ilevel + 1, - # index=current_index, - # output_directory=output_directory, - # extension=EXTENSION[tile_format], - # ) - - # if os.path.exists(subtile_filename): - - # if tile_format == "fits": - # data = block_reduce(fits.getdata(subtile_filename), 2, func=np.mean) - # else: - # data = block_reduce( - # np.array(Image.open(subtile_filename))[::-1], (2, 2, 1), func=np.mean - # ) - - # if subindex == 0: - # array[256:, :256] = data - # elif subindex == 2: - # array[256:, 256:] = data - # elif subindex == 1: - # array[:256, :256] = data - # elif subindex == 3: - # array[:256, 256:] = data - - # if tile_format == "fits": - # fits.writeto( - # tile_filename( - # level=ilevel, - # index=index, - # output_directory=output_directory, - # extension=EXTENSION[tile_format], - # ), - # array, - # header, - # ) - # else: - # image = as_transparent_rgb(array.transpose(2, 0, 1)) - # if tile_format == "jpeg": - # image = image.convert("RGB") - # image.save( - # tile_filename( - # level=ilevel, - # index=index, - # output_directory=output_directory, - # extension=EXTENSION[tile_format], - # ) - # ) + # Iterate over higher levels and compute lower resolution tiles + + half_tile_size = tile_size // 2 + + for ilevel in range(level - 1, -1, -1): + + # Find index of tiles to produce at lower-resolution levels + indices = np.sort(np.unique(indices // 4)) + + make_tile_folders(level=ilevel, indices=indices, output_directory=output_directory) + + for index in indices: + + header = tile_header(level=ilevel, index=index, frame=frame, tile_size=tile_size) + + if isinstance(header, tuple): + header = header[0] + + if tile_format == "fits": + array = np.zeros((tile_size, tile_size)) + elif tile_format == "png": + array = np.zeros((tile_size, tile_size, 4), dtype=np.uint8) + else: + array = np.zeros((tile_size, tile_size, 3), dtype=np.uint8) + + for subindex in range(4): + + current_index = 4 * index + subindex + subtile_filename = tile_filename( + level=ilevel + 1, + index=current_index, + output_directory=output_directory, + extension=EXTENSION[tile_format], + ) + + if os.path.exists(subtile_filename): + + if tile_format == "fits": + data = block_reduce(fits.getdata(subtile_filename), 2, func=np.mean) + else: + data = block_reduce( + np.array(Image.open(subtile_filename))[::-1], (2, 2, 1), func=np.mean + ) + + if subindex == 0: + array[half_tile_size:, :half_tile_size] = data + elif subindex == 2: + array[half_tile_size:, half_tile_size:] = data + elif subindex == 1: + array[:half_tile_size, :half_tile_size] = data + elif subindex == 3: + array[:half_tile_size, half_tile_size:] = data + + if tile_format == "fits": + fits.writeto( + tile_filename( + level=ilevel, + index=index, + output_directory=output_directory, + extension=EXTENSION[tile_format], + ), + array, + header, + ) + else: + image = as_transparent_rgb(array.transpose(2, 0, 1)) + if tile_format == "jpeg": + image = image.convert("RGB") + image.save( + tile_filename( + level=ilevel, + index=index, + output_directory=output_directory, + extension=EXTENSION[tile_format], + ) + ) # Generate properties file diff --git a/reproject/hips/tests/test_dask_array.py b/reproject/hips/tests/test_dask_array.py index 0309a88b6..60be12396 100644 --- a/reproject/hips/tests/test_dask_array.py +++ b/reproject/hips/tests/test_dask_array.py @@ -53,11 +53,11 @@ def test_roundtrip(self, tmp_path, frame, level): # adjacent to non-empty tiles, we set the values to the boundaries of # the non-empty neighbouring tiles so that interpolation doesn't run # into any issues. In theory there should be around 90500 pixels inside - # the valid region of the image, so we require at least 90400 valid + # the valid region of the image, so we require at least 90000 valid # values. valid = ~np.isnan(final_array) - assert np.sum(valid) > 90400 + assert np.sum(valid) > 90000 np.testing.assert_allclose(final_array[valid], self.original_array[valid], rtol=0.01) def test_level_validation(self, tmp_path): From 8334f1e9120a6394e2a9eb82a57b954a24412ac8 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 26 Sep 2025 14:56:23 +0100 Subject: [PATCH 4/6] Cleanup --- reproject/hips/__init__.py | 1 + reproject/hips/_dask_array.py | 12 +++++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/reproject/hips/__init__.py b/reproject/hips/__init__.py index 68b722972..5e206c3ee 100644 --- a/reproject/hips/__init__.py +++ b/reproject/hips/__init__.py @@ -1 +1,2 @@ from .high_level import * # noqa +from ._dask_array import hips_as_dask_and_wcs diff --git a/reproject/hips/_dask_array.py b/reproject/hips/_dask_array.py index 898321437..f5ae8c942 100644 --- a/reproject/hips/_dask_array.py +++ b/reproject/hips/_dask_array.py @@ -65,10 +65,8 @@ def __getitem__(self, item): if item[0].start == item[0].stop or item[1].start == item[1].stop: return self._blank[item] - # For now assume item is a list of slices. Find - - # imid = (item[0].start + item[0].stop) // 2 - # jmid = (item[1].start + item[1].stop) // 2 + # We use two points in different parts of the image because in some + # cases using the exact center or corners can cause issues. istart = item[0].start irange = item[0].stop - item[0].start @@ -127,15 +125,15 @@ def _get_tile(self, *, level, index): with fits.open(filename) as hdulist: hdu = hdulist[0] - # data = hdu.data[::-1] data = hdu.data - # return np.ones(hdu.data.shape) * index - return data def hips_as_dask_and_wcs(directory_or_url, *, level=None): + """ + Return a dask array and WCS that represent a HiPS dataset at a particular level. + """ array_wrapper = HiPSArray(directory_or_url, level=level) return ( da.from_array( From ae9770d0b7ec301b7d1fd2dd32bbe15fe9cf159c Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 26 Sep 2025 14:57:16 +0100 Subject: [PATCH 5/6] More cleanup --- reproject/hips/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/reproject/hips/utils.py b/reproject/hips/utils.py index 60d07798e..b21ed5588 100644 --- a/reproject/hips/utils.py +++ b/reproject/hips/utils.py @@ -64,8 +64,6 @@ def tile_header(*, level, index, frame, tile_size): nside = level_to_nside(level) - # threshold = np.hypot(image_size, image_size) / 2. - hp = HEALPix(nside=nside, order="nested", frame=frame) tile_wcs = celestial_frame_to_wcs(frame, projection="HPX") @@ -87,7 +85,6 @@ def tile_header(*, level, index, frame, tile_size): np.max(np.hypot(offset_x[1:] - offset_x[0], offset_y[1:] - offset_y[0])) > tile_size ) - # offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=1, dy=0)) offset_x, offset_y = tile_wcs.world_to_pixel(hp.healpix_to_skycoord(index, dx=0.75, dy=0.25)) tile_wcs.wcs.crpix[0] = -offset_x - 0.5 + tile_size / 4 From fdc25ea679d21712de5e221fb50c16d72f6c9fad Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Tue, 30 Sep 2025 09:56:35 +0100 Subject: [PATCH 6/6] Renamed function and fixed pre-commit --- reproject/hips/__init__.py | 2 +- reproject/hips/_dask_array.py | 4 ++-- reproject/hips/tests/test_dask_array.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/reproject/hips/__init__.py b/reproject/hips/__init__.py index 5e206c3ee..ab029c63e 100644 --- a/reproject/hips/__init__.py +++ b/reproject/hips/__init__.py @@ -1,2 +1,2 @@ from .high_level import * # noqa -from ._dask_array import hips_as_dask_and_wcs +from ._dask_array import hips_as_dask_array # noqa diff --git a/reproject/hips/_dask_array.py b/reproject/hips/_dask_array.py index f5ae8c942..0fde086fa 100644 --- a/reproject/hips/_dask_array.py +++ b/reproject/hips/_dask_array.py @@ -13,7 +13,7 @@ from .high_level import VALID_COORD_SYSTEM from .utils import is_url, load_properties, map_header, tile_filename -__all__ = ["hips_as_dask_and_wcs"] +__all__ = ["hips_as_dask_array"] class HiPSArray: @@ -130,7 +130,7 @@ def _get_tile(self, *, level, index): return data -def hips_as_dask_and_wcs(directory_or_url, *, level=None): +def hips_as_dask_array(directory_or_url, *, level=None): """ Return a dask array and WCS that represent a HiPS dataset at a particular level. """ diff --git a/reproject/hips/tests/test_dask_array.py b/reproject/hips/tests/test_dask_array.py index 60be12396..556a36548 100644 --- a/reproject/hips/tests/test_dask_array.py +++ b/reproject/hips/tests/test_dask_array.py @@ -6,7 +6,7 @@ from reproject import reproject_interp from reproject.hips import reproject_to_hips -from reproject.hips._dask_array import hips_as_dask_and_wcs +from reproject.hips._dask_array import hips_as_dask_array class TestHIPSDaskArray: @@ -37,7 +37,7 @@ def test_roundtrip(self, tmp_path, frame, level): ) # Represent the HiPS as a dask array - dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=level) + dask_array, wcs = hips_as_dask_array(output_directory, level=level) # Reproject back to the original WCS final_array, footprint = reproject_interp( @@ -73,17 +73,17 @@ def test_level_validation(self, tmp_path): tile_size=32, ) - dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=0) + dask_array, wcs = hips_as_dask_array(output_directory, level=0) assert dask_array.shape == (160, 160) - dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=1) + dask_array, wcs = hips_as_dask_array(output_directory, level=1) assert dask_array.shape == (320, 320) - dask_array, wcs = hips_as_dask_and_wcs(output_directory) + dask_array, wcs = hips_as_dask_array(output_directory) assert dask_array.shape == (320, 320) with pytest.raises(Exception, match=r"does not contain level 2 data"): - hips_as_dask_and_wcs(output_directory, level=2) + hips_as_dask_array(output_directory, level=2) with pytest.raises(Exception, match=r"should be positive"): - hips_as_dask_and_wcs(output_directory, level=-1) + hips_as_dask_array(output_directory, level=-1)