In [None]:
import dask.array as da
import healpy as hp
import numpy as np

import healpix_convolution as hc

In [None]:
import matplotlib.pyplot as plt

## experimentation

In [None]:
resolution = 4
cell_ids = np.arange(12 * 4**resolution)
indexing_scheme = "nested"

In [None]:
truncate = 4
sigma = 0.1  # in radians
cell_distance = hp.nside2resol(2**resolution, arcmin=False)
ring = int((truncate * sigma / cell_distance) // 2)

neighbours = hc.neighbours(
    cell_ids, resolution=resolution, indexing_scheme=indexing_scheme, ring=ring
)
distances = hc.angular_distances(
    neighbours, resolution=resolution, indexing_scheme=indexing_scheme
)
mask = neighbours == -1

sigma2 = sigma * sigma
phi_x = np.where(mask, 0, np.exp(-0.5 / sigma2 * distances**2))
kernel = phi_x / phi_x.sum(axis=-1)[:, None]
kernel.shape

In [None]:
import sparse

In [None]:
mask = np.reshape(neighbours, -1) != -1
coords = np.reshape(
    np.stack(
        [
            np.repeat(cell_ids[:, None], repeats=neighbours.shape[-1], axis=-1),
            neighbours,
        ],
        axis=0,
    ),
    (2, -1),
)

kernel_ = np.reshape(kernel, -1)[mask]
coords_ = np.reshape(coords, (2, -1))[:, mask]

In [None]:
kernel_matrix = sparse.COO(
    data=kernel_, coords=coords_, shape=(cell_ids.size, cell_ids.size), fill_value=0
)
kernel_matrix

## dask awareness

In [None]:
resolution = 4
kernel_size = 3
indexing_scheme = "ring"
sigma = 0.1

cell_ids = da.arange(12 * 4**resolution, chunks=(1000,))
cell_ids

In [None]:
cell_ids_ = np.reshape(cell_ids, (-1,))

# TODO: figure out whether there is a better way of defining the units of `sigma`
if kernel_size is not None:
    ring = int(kernel_size / 2)
else:
    cell_distance = hp.nside2resol(2**resolution, arcmin=False)
    ring = int((truncate * sigma / cell_distance) // 2)

nb = hc.neighbours(
    cell_ids_, resolution=resolution, indexing_scheme=indexing_scheme, ring=ring
)
d = hc.angular_distances(nb, resolution=resolution, indexing_scheme=indexing_scheme)

sigma2 = sigma * sigma
phi_x = np.exp(-0.5 / sigma2 * d**2)
masked = np.where(nb == -1, 0, phi_x)
normalized = masked / np.sum(masked, axis=1, keepdims=True)
normalized

In [None]:
import sparse

In [None]:
cell_ids__ = np.repeat(cell_ids_[:, None], axis=-1, repeats=nb.shape[1])
cell_ids__

In [None]:
?da.map_blocks

In [None]:
cell_ids__.chunks

In [None]:
shape = (1000, cell_ids.size)
matrix = da.map_blocks(
    hc.kernels.common.create_sparse,
    cell_ids__,
    nb,
    normalized,
    shape=shape,
    meta=sparse.COO.from_numpy(np.array((), dtype="float64")),
    drop_axis=1,
    new_axis=1,
    chunks=(cell_ids__.chunks[0], cell_ids.size),
)
matrix

In [None]:
display(cell_ids__, nb, normalized)

In [None]:
matrix.compute()

## module version

In [None]:
resolution = 3
cell_ids = np.arange(12 * 4**resolution)
indexing_scheme = "nested"
sigma = 0.1
truncate = 4.0

In [None]:
%%time
kernel = hc.kernels.gaussian_kernel(
    cell_ids,
    resolution=resolution,
    indexing_scheme=indexing_scheme,
    sigma=sigma,
    truncate=truncate,
)
kernel

### verification

In [None]:
norm = np.sum(kernel, axis=1).todense()
norm

In [None]:
fig, ax = plt.subplots(figsize=(14, 14))

mappable = ax.imshow(kernel.todense())
fig.colorbar(mappable)

In [None]:
kernel_ = kernel[0, :].todense()

In [None]:
import healpy as hp

In [None]:
hp.newvisufunc.projview(kernel_, nest=True)

- subdomain convolution
- image pyramid (up/downgrading)
- neighbour ordering
- chunked kernel