## Facetted subgrids - Implementation

This notebook is about implementation of the algorithm sketched out in [facet-subgrid.ipynb](facet-subgrid.ipynb).

In [None]:
%matplotlib inline

from matplotlib import pylab
import matplotlib.patches as patches
import matplotlib.path as path

from ipywidgets import interact
import numpy
import sys
import random
import itertools
import time
import scipy.special
import math
pylab.rcParams['figure.figsize'] = 16, 10
pylab.rcParams['image.cmap'] = 'viridis'

try:
    sys.path.append('../..')
    from crocodile.synthesis import *
    from util.visualize import *
    print("Crocodile mode")
except ImportError:
    print("Stand-alone mode")
    # Convolution and FFT helpers
    def conv(a, b): return ifft(fft(a) * fft(b))
    def coordinates(N):
        return numpy.fft.fftshift(numpy.fft.fftfreq(N))
    def fft(a):
        if len(a.shape) == 1: return numpy.fft.fftshift(numpy.fft.fft(numpy.fft.ifftshift(a)))
        elif len(a.shape) == 2: return numpy.fft.fftshift(numpy.fft.fft2(numpy.fft.ifftshift(a)))
    def ifft(a):
        if len(a.shape) == 1: return numpy.fft.fftshift(numpy.fft.ifft(numpy.fft.ifftshift(a)))
        elif len(a.shape) == 2: return numpy.fft.fftshift(numpy.fft.ifft2(numpy.fft.ifftshift(a)))
    def pad_mid(a, N):
        N0 = a.shape[0]
        assert N >= N0
        return numpy.pad(a, len(a.shape) * [(N//2-N0//2, (N+1)//2-(N0+1)//2)], mode='constant', constant_values=0.0)
    def extract_mid(a, N):
        assert N <= a.shape[0]
        cx = a.shape[0] // 2
        s = N // 2
        if N % 2 == 0:
            return a[len(a.shape) * [slice(cx - s, cx + s)]]
        else:
            return a[len(a.shape) * [slice(cx - s, cx + s + 1)]]
    def anti_aliasing_function(shape, m, c):
        if len(numpy.array(shape).shape) == 0:
            mult = 2 - 1/shape/4
            return scipy.special.pro_ang1(m, m, c, mult*coordinates(shape))[0]
        return numpy.outer(anti_aliasing_function(shape[0], m, c),
                           anti_aliasing_function(shape[1], m, c))
    def coordinates2(N):
        N2 = N // 2
        if N % 2 == 0:
            return numpy.mgrid[-N2:N2, -N2:N2][::-1] / N
        else:
            return numpy.mgrid[-N2:N2+1, -N2:N2+1][::-1] / N
    def _show(a, name, scale, axes):
        size = a.shape[0]
        if size % 2 == 0:
            low,high = -0.5, 0.5 * (size - 2) / size
        else:
            low,high = -0.5 * (size - 1) / size, 0.5 * (size - 1) / size
        low = (low - 1/size/2) * scale
        high = (high - 1/size/2) * scale
        cax=axes.imshow(a, extent=(low,high,low,high)); axes.set_title(name);
        axes.figure.colorbar(cax,shrink=.4,pad=0.025)
    def show_grid(grid, name, theta, axes):
        return _show(grid, name, theta, axes)
    def show_image(img, name, theta, axes):
        return _show(img, name, img.shape[0] / theta, axes)
    def extract_oversampled(a, Qpx, N):
        result = numpy.empty((Qpx, N), dtype=complex)
        for xf in range(Qpx):
            # Determine start offset.
            mx = a.shape[0]//2 - Qpx*(N//2) + xf
            # Extract every Qpx-th pixel
            result[xf] = a[mx : mx+Qpx*N : Qpx]
        return result
    def kernel_oversample(ff, Qpx, s=None):
        # Pad the far field to the required pixel size
        N = ff.shape[0]
        if s is None: s = N
        padff = pad_mid(ff, N*Qpx)
        # Obtain oversampled uv-grid
        af = fft(padff)
        # Extract kernels
        return extract_oversampled(af, Qpx, s)

# Helper for marking ranges in a graph
def mark_range(lbl, x0, x1, y0=None, y1=None, ax=None):
    if ax is None: ax = pylab.gca()
    if y0 is None: y0 = ax.get_ylim()[1]
    if y1 is None: y1 = ax.get_ylim()[0]
    wdt = ax.get_xlim()[1] - ax.get_xlim()[0]
    ax.add_patch(patches.PathPatch(patches.Path([(x0,y0), (x0,y1)]), linestyle="dashed"))
    ax.add_patch(patches.PathPatch(patches.Path([(x1,y0), (x1,y1)]), linestyle="dashed"))
    if pylab.gca().get_yscale() == 'linear':
        lbl_y = (y0*7+y1) / 8
    else: # Some type of log scale
        lbl_y = (y0**7*y1)**(1/8)
    ax.annotate(lbl, (x1+wdt/200, lbl_y))

In [None]:
def error_approx(yB, yN, xN, alpha=0, dim=1, hexagon=False):
    # gridding error
    assert yB < yN
    pswf = anti_aliasing_function(int(yN)*2, alpha, 2*numpy.pi*yN*xN)
    pswf /= numpy.prod(numpy.arange(2*alpha-1,0,-2, dtype=float)) # double factorial    
    grid_error = numpy.abs(numpy.sum(pswf[::2] - pswf[1::2]))
    # correction error
    b_error = numpy.abs(pswf[int(yN) + int(yB)])
    if dim >= 2 and hexagon:
        b_error *= numpy.abs(pswf[int(yN) + int(yB/2)])**(dim-1)
    else:
        b_error **= dim
    return numpy.abs(grid_error) / (2*xM) / b_error

## Parametrisation

Let us a have another look at the parameters that go into our algorithm. How general can we be, and how can we optimise values to give us more speed and flexibility?

Firstly, note that we do not actually care too much about $x_A$ and $y_B$ here: They need to be bounds on $A_i$ and $B_j$ respectively, but they can always be higher. There is also no reason for all $A_i$ and $B_j$ to be the same pattern shifted in image/grid space: They can be entirely arbitrary.

This means that given the sampling rate $N$, we can simply round up the sizes used to represent $A$ and $B$, as well as the number of facets/subgrids needed to cover the space:

In [None]:
import math
N = 512
yB = N / 8
xA = 1 / 8
xA_size = int(math.ceil(xA*2*N))
yB_size = int(math.ceil(yB*2))
print("image_size=%d, xA_size=%d, yB_size=%d" % (N, xA_size, yB_size))

### Choose "gridding" function extents

On the other hand, $x_N$ and $x_M$ are a bit more critical. The accuracy of our approximation depends on:

1. The $2\pi x_Ny_N$ product, which becomes the prolate spheroidal wave function parameter. Too low, and we get a bad approximation, too large, and we hit numeric instabilities
2. The $\frac{x_N}{x_B}$ margin, which decides how much we allow $b$ to magnify up our error

This is further complicated by the fact that $x_M \ge x_A + 2x_N$ must:

1. be as small as possible, as it directly scales the exchanged data size and multiple FFTs,
2. satisfy $\frac 1{2x_M} \in \mathbb N$ to allow us to down-sample cheaply, and
3. also satisfy $2x_MN \in \mathbb N$, as otherwise the set of allowable subgrid offsets will shrink significantly (see below)

The easiest way to explore the parameter space with this many restrictions is to settle on a given overhead ($x_My_N$ product) and enumerate the options:

In [None]:
overhead = 2.2
alpha = 0
max_par = 100
err_best = 10000
xM_step_best = xM_best = xN_best = yN_best = None
for xM_step in numpy.arange(int(numpy.ceil(1 / 2 / xA)), 1, -1):
    if N % xM_step != 0:
        continue
    xM = 1 / 2 / xM_step
    xN = (xM - xA) / 2
    yN = numpy.floor(overhead * xA * yB / xM)
    par = 2 * numpy.pi * xN * yN
    if xN < 1 / N:
        continue
    print("xM = 1/%d, xN = %.4f, yN = %d, par = %.1f" % (xM_step*2, xN, yN, par), flush=True, end="")
    if yN <= yB or par > max_par:
        print(", par too high")
        break
    err = error_approx(yB, yN, xN, alpha=alpha)
    print(", err = %g" % (err))
    if err < err_best:
        err_best = err; xM_step_best = xM_step; xM_best = xM; xN_best = xN; yN_best = yN
xM_step = xM_step_best; xM = xM_best; xN = xN_best; yN = yN_best
print("Chose xM=1/%d" % (xM_step*2))

xM_size = int(2*xM*N)
xM_yN_size = int(numpy.ceil(xM*2*yN*2))
yN_size = xM_yN_size * xM_step
print("xM_size=%d, xM_yN_size=%d, yN_size=%d" % (xM_size, xM_yN_size, yN_size))
print("xM_step=%d" % xM_step)

print(yB, yN)
print("(Parameters for Sze-Tan: x0=%f, R=%d)" % (yB / yN / 2, int(numpy.floor(xN*2*yN))))

### Placing facets and subgrids

The next step is to decide where to place facets and subgrids. All facet/subgrid centre locations must align with the grids on both sides. This means that with facet/grid centres $x_{0,i}$ and $y_{0,j}$ we want:

$$\forall i,j: \quad x_{0,i} y_{0,j} \in \mathbb Z$$

What we need to do here is to separate the sampling rate $N$ into two factors $N_xN_y = N$ with $N_x,N_y \in \mathbb N$. This allows us to define a "grid" of safe facet and subgrid centres:

$$\frac{Nx_{0,i}}{N_x} = N_yx_{0,i} \in \mathbb Z, \quad \frac{y_{0,i}}{N_y} \in \mathbb Z$$

as clearly the product of those two terms is in $\mathbb Z$ and equal to $x_{0,i}y_{0,i}$. An additional side condition here is that we want
$$N_yx_M \in \mathbb N $$
This is saying that we only permit facet offset steps $N_y$ that lie on the coarser $x_M$ grid. What this does is selecting for parameters where the grid size reduction (same as convolution with comb function) becomes a simple selection operator in image space. This will permit an optimisation later.

What we therefore have to do is find subgrid/facet locations satisfying these conditions by rounding them to the next permissable centre for a number of $N_x$ and $N_y$ options. We have to keep in mind that for the purpose of this exercise we still want all facets and subgrids to cover the entire image / grid collectively. This might even require us to increase the number of facets/subgrids past the mathematically minimum number, introducing further "rounding" overhead for unfriendly configurations. Note that for real applications we might actually have a lot more wiggle room here, as we could skip known-zero parts of the image/grid.

In [None]:
nsubgrid = int(math.ceil(N / xA_size))
nfacet = int(math.ceil(N / yB_size))

best_maxdxdy = N; best_subgrid_off = None; best_facet_off = None; best_ny = N
while best_subgrid_off is None:

    print("Trying %d sub-grids, %d facets (%.1f%% overhead):" 
         % (nsubgrid, nfacet, 100 * (1 / (N / nsubgrid / xA_size * N / nfacet / yB_size) - 1)))
    facet_y0s = numpy.arange(nfacet) * N / nfacet
    subgrid_x0s = numpy.arange(nsubgrid) / nsubgrid

    warnx_count = 0; warny_count = 0
    for Ny in xM_step * numpy.arange(1, N // xM_step):
        if N % Ny != 0:
            continue
        Nx = N // Ny
        subgrid_off = Nx * numpy.round(N * subgrid_x0s / Nx)
        facet_off = Ny * numpy.round(facet_y0s / Ny)
        maxdx = numpy.max(numpy.abs(subgrid_off - N * subgrid_x0s))
        maxdy = numpy.max(numpy.abs(facet_off - facet_y0s))
        warnx = warny = ""
        if maxdx > (xA_size - N / nsubgrid) / 2:
            warnx = " (> %.1f!)" % ((xA_size - N / nsubgrid) / 2)
            warnx_count+=1
        if maxdy > (yB_size - N / nfacet) / 2:
            warny = " (> %.1f!)" % ((yB_size - N / nfacet) / 2)
            warny_count+=1
        print("Nx=%d, Ny=%d, maxdx=%.1f%s, maxdy=%.1f%s" % (Nx, Ny, maxdx, warnx, maxdy, warny))
        # Select 
        if warnx == "" and warny == "" and best_ny > Ny:
            # best_maxdxdy > max(maxdx, maxdy):
            best_maxdxdy = max(maxdx, maxdy)
            best_subgrid_off = subgrid_off.astype(int)
            best_facet_off = facet_off.astype(int)
            best_ny = Ny
    # No solution found? Crudely use number of warnings as indicator what we need more of
    if best_subgrid_off is None:
        if warnx_count >= warny_count:
            nsubgrid += 1
        else:
            nfacet += 1
assert best_maxdxdy != N
subgrid_off = best_subgrid_off
facet_off = best_facet_off
Nx = N // best_ny; Ny = best_ny;
print ("Chose Nx=%d, Ny=%d, N*x0s=%s, y0s=%s" % (N // Ny, Ny, subgrid_off, facet_off))
def whole(xs): return numpy.all(numpy.abs(xs - numpy.around(xs)) < 1e-13)
assert whole(numpy.outer(subgrid_off, facet_off) / N)
assert whole(facet_off*xM_size/N)

In [None]:
pylab.figure(figsize=(16, 1)); pylab.title("Subgrid Centre Offsets")
pylab.plot(((subgrid_off+N//2)%N-N//2)/N,numpy.zeros_like(subgrid_off), "b|", markersize=30);
pylab.plot(coordinates(N//Nx), numpy.zeros(N//Nx), "g|")
pylab.xlim(-.5,.5); pylab.yticks([])
for i, x in enumerate(subgrid_x0s):
    mark_range("$x_{0,%d}$"%i, 0, (x + .5) % 1 - .5)
pylab.figure(figsize=(16, 1)); pylab.title("Facet Centre Offsets")
pylab.plot((facet_off+N//2)%N-N//2,numpy.zeros_like(facet_off), "b|", markersize=30);
pylab.plot(coordinates(N//Ny)*N, numpy.zeros(N//Ny), "g|")
pylab.xlim(-N/2,N/2); pylab.yticks([])
for i, y in enumerate(facet_y0s):
    mark_range("$y_{0,%d}$"%i, 0, (y + N//2) % N - N//2)
pylab.show();

### Sizing the intermediate grid

Next, we need to choose $y_P$. We are relatively free in doing so, we just need to make sure that

1. we satisfy $y_P \ge y_B + \frac 12 y_B$ and
2. $2x_My_P \in \mathbb N$ so we can down-sample easily.

However note that down the road, solving FFTs of size $y_PN$ and $x_My_P$ is going to represent most of our computation. Therefore we want to make sure that we do not choose a size here that is either much larger than we need - or gets factored into large prime factors.

In [None]:
def greatest_prime_factor(x):
    i = 2
    while i * i <= x:
        while i < x and x % i == 0:
            x //= i
        i += 1
    return x

yP_size_options = (int(numpy.ceil( int(yB+yN*2) / Ny )) + numpy.arange(0,12)) * Ny 
yP_size_primes = numpy.vectorize(greatest_prime_factor)(yP_size_options)
print(", ".join(["%d: %d" % yp for yp in zip(yP_size_options, yP_size_primes)]))
yP_size = yP_size_options[numpy.argmin(yP_size_primes)]
print("Chose yP_size = %d (%.1f %% overhead)" % (yP_size, 100 * yP_size / (yB_size/2+yN_size) - 100))

assert whole(subgrid_off*yP_size/N)

xM_yP_size = int(xM*2*yP_size)
xMxN_yP_size = xM_yP_size + 2*int(numpy.ceil(xN*yP_size)) # same margin both sides
print("yP_size=%d, xM_yP_size=%d, xMxN_yP_size=%d" % (yP_size, xM_yP_size, xMxN_yP_size))

In [None]:
G = numpy.random.rand(N)-0.5
FG = fft(G)

subgrid = numpy.empty((nsubgrid, xA_size), dtype=complex)
subgrid_A = numpy.zeros_like(subgrid, dtype=int)
subgrid_border = (subgrid_off + numpy.hstack([subgrid_off[1:],[N]])) // 2
print(subgrid_border)
for i in range(nsubgrid):
    left = (subgrid_border[i-1] - subgrid_off[i] + xA_size//2) % N
    right = subgrid_border[i] - subgrid_off[i] + xA_size//2
    assert left >= 0 and right <= xA_size, "xA not large enough to cover subgrids!"
    subgrid_A[i,left:right] = 1
    subgrid[i] = subgrid_A[i] * extract_mid(numpy.roll(G, -subgrid_off[i]), xA_size)

facet = numpy.empty((nfacet, yB_size), dtype=complex)
facet_B = numpy.zeros_like(facet, dtype=bool)
facet_split = numpy.array_split(range(N), nfacet)
facet_border = (facet_off + numpy.hstack([facet_off[1:],[N]])) // 2
print(facet_border)
for j in range(nfacet):
    left = (facet_border[j-1] - facet_off[j] + yB_size//2) % N
    right = facet_border[j] - facet_off[j] + yB_size//2
    assert left >= 0 and right <= yB_size, "yB not large enough to cover facets!"
    facet_B[j,left:right] = 1
    facet[j] = facet_B[j] * extract_mid(numpy.roll(FG, -facet_off[j]), yB_size)

We need a bunch of array constants derived from the gridding function:
 * $\mathcal Fb$ ($y_B$ size)
 * $\mathcal Fn$ ($y_N$ size, sampled at $x_M$ rate), as well as 
 * $\mathcal Fm' = \mathcal Fn\mathcal Fm$ term ($y_P$ size, sampled at $x_M+x_N$).
 
For the convolution with $b$, $n$, and cheap multiplication with $m$ at $y_P$ image size respectively.

In [None]:
pswf = anti_aliasing_function(yN_size, alpha, 2*numpy.pi*yN*xN).real
pylab.semilogy(extract_mid(numpy.abs(fft(pswf)), int(numpy.floor(xN*2*yN*2))), "s")
Fb = 1/extract_mid(pswf, yB_size)
Fn = pswf[(yN_size//2)%int(1/2/xM)::int(1/2/xM)]
facet_m0_trunc = pswf * numpy.sinc(coordinates(yN_size)*xM_size/N*yN_size)
facet_m0_trunc = xM_size*yP_size/N * extract_mid(ifft(pad_mid(facet_m0_trunc, yP_size)), xMxN_yP_size).real

## Facet $\rightarrow$ Subgrid

With a few more slight optimisations we arrive at a compact representation for our algorithm:

In [None]:
print(N / 5, yP_size / 5, yN_size/ 5)
xN_yP_size = xMxN_yP_size - xM_yP_size
RNjMiBjFj = numpy.empty((nsubgrid, nfacet, xM_yN_size), dtype=complex)
for j in range(nfacet):
    BjFj = ifft(pad_mid(facet[j] * Fb, yP_size))
    for i in range(nsubgrid):
        MiBjFj = facet_m0_trunc * extract_mid(numpy.roll(BjFj, -subgrid_off[i]*yP_size//N), xMxN_yP_size)
        MiBjFj_sum = numpy.array(extract_mid(MiBjFj, xM_yP_size))
        MiBjFj_sum[:xN_yP_size//2] += MiBjFj[-xN_yP_size//2:]
        MiBjFj_sum[-xN_yP_size//2:] += MiBjFj[:xN_yP_size//2:]
        RNjMiBjFj[i,j] = Fn * extract_mid(fft(MiBjFj_sum), xM_yN_size)

# - redistribution of RNjMiBjFj here -

fig = pylab.figure(figsize=(16, 8))
ax1, ax2 = fig.add_subplot(211), fig.add_subplot(212)
err_sum = err_sum_img = 0
for i in range(nsubgrid):
    approx = numpy.zeros(xM_size, dtype=complex)
    for j in range(nfacet):
        approx += numpy.roll(pad_mid(RNjMiBjFj[i,j], xM_size), facet_off[j]*xM_size//N)
    approx = subgrid_A[i] * extract_mid(ifft(approx), xA_size)
    
    ax1.semilogy(xA*2*coordinates(xA_size), numpy.abs( approx - subgrid[i] ))
    ax2.semilogy(N*coordinates(xA_size), numpy.abs( fft(approx - subgrid[i]) ))    
    err_sum += numpy.abs(approx - subgrid[i])**2
    err_sum_img += numpy.abs(fft(approx - subgrid[i]))**2
mark_range("$x_A$", -xA, xA, ax=ax1); mark_range("$N/2$", -N/2, N/2, ax=ax2)
print("RMSE:", numpy.sqrt(numpy.mean(err_sum)), "(image:", numpy.sqrt(numpy.mean(err_sum_img)), ")")

## Subgrid $\rightarrow$ facet

The other direction works similarly, now we want:
$$F_j \approx b_j \ast \sum_i m_i (n_j \ast S_i)$$

We run into a very similar problem with $m$ as when reconstructing subgrids, except this time it happens because we want to construct:
$$ b_j \left( m_i (n_j \ast S_i)\right)
  = b_j \left( \mathcal F^{-1}\left[\Pi_{2y_P} \mathcal F m_i\right] (n_j \ast S_i)\right)$$

As usual, this is entirely dual: In the previous case we had a signal limited by $y_B$ and needed the result of the convolution up to $y_N$, whereas now we have a signal bounded by $y_N$, but need the convolution result up to $y_B$. This cancels out - therefore we are okay with the same choice of $y_P$.

In [None]:
FNjSi = numpy.empty((nsubgrid, nfacet, xM_yN_size), dtype=complex)
for i in range(nsubgrid):
    FSi = fft(pad_mid(subgrid[i], xM_size))
    for j in range(nfacet):
        FNjSi[i,j] = extract_mid(numpy.roll(FSi, -facet_off[j]*xM_size//N), xM_yN_size)

# - redistribution of FNjSi here -

fig = pylab.figure(figsize=(16, 8))
ax1, ax2 = fig.add_subplot(211), fig.add_subplot(212)
err_sum = err_sum_img = 0
for j in range(nfacet):
    approx = numpy.zeros(yB_size, dtype=complex)
    for i in range(nsubgrid):
        NjSi = numpy.zeros(xMxN_yP_size, dtype=complex)
        NjSi_mid = extract_mid(NjSi, xM_yP_size)
        NjSi_mid[:] = ifft(pad_mid(Fn * FNjSi[i,j], xM_yP_size)) # updates NjSi_tile via reference!
        NjSi[-xN_yP_size//2:] = NjSi_mid[:xN_yP_size//2]
        NjSi[:xN_yP_size//2:] = NjSi_mid[-xN_yP_size//2:]
        FMiNjSi = fft(numpy.roll(pad_mid(facet_m0_trunc * NjSi, yP_size), subgrid_off[i]*yP_size//N))
        approx += extract_mid(FMiNjSi, yB_size)
    approx *= Fb * facet_B[j]

    err_sum += numpy.abs(ifft(approx - facet[j]))**2
    err_sum_img += numpy.abs(approx - facet[j])**2
    ax1.semilogy(coordinates(yB_size), numpy.abs(ifft(facet[j] - approx)))
    ax2.semilogy(yB_size*coordinates(yB_size), numpy.abs(facet[j] - approx))
print("RMSE:", numpy.sqrt(numpy.mean(err_sum)), "(image:", numpy.sqrt(numpy.mean(err_sum_img)), ")")
mark_range("$x_A$", -xA, xA, ax=ax1)
mark_range("$x_M$", -xM, xM, ax=ax1)
mark_range("$y_B$", -yB, yB, ax=ax2)
mark_range("$0.5$", -.5, .5, ax=ax1)
pylab.show(fig)

## 2D case

All of this generalises to two dimensions in the way you would expect. Let us set up test data:

In [None]:
print(nsubgrid,"x",nsubgrid,"subgrids,",nfacet,"x", nfacet,"facets")
subgrid_2 = numpy.empty((nsubgrid, nsubgrid, xA_size, xA_size), dtype=complex)
facet_2 = numpy.empty((nfacet, nfacet, yB_size, yB_size), dtype=complex)

G_2 = numpy.exp(2j*numpy.pi*numpy.random.rand(N,N))*numpy.random.rand(N,N)/2
for i0,i1 in itertools.product(range(nsubgrid), range(nsubgrid)):
    subgrid_2[i0,i1] = extract_mid(numpy.roll(G_2, (-subgrid_off[i0], -subgrid_off[i1]), (0,1)), xA_size)
    subgrid_2[i0,i1] *= numpy.outer(subgrid_A[i0], subgrid_A[i1])
FG_2 = fft(G_2)
for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    facet_2[j0,j1] = extract_mid(numpy.roll(FG_2, (-facet_off[j0], -facet_off[j1]), (0,1)), yB_size)
    facet_2[j0,j1] *= numpy.outer(facet_B[j0], facet_B[j1])

Given that the amount of data has been squared, performance is a bit more of a concern now. Fortunately, the entire procedure is completely separable, so let us first re-define the operations to work on one array axis exclusively:

In [None]:
def slice_a(fill_val, axis_val, dims, axis):
    return [ axis_val if i == axis else fill_val for i in range(dims) ]
def pad_mid_a(a, N, axis):
    N0 = a.shape[axis]
    if N == N0: return a
    pad = slice_a((0,0), (N//2-N0//2, (N+1)//2-(N0+1)//2), 
                  len(a.shape), axis)    
    return numpy.pad(a, pad, mode='constant', constant_values=0.0)
def extract_mid_a(a, N, axis):
    assert N <= a.shape[axis]
    cx = a.shape[axis] // 2
    if N % 2 != 0:
        slc = slice(cx - N // 2, cx + N // 2 + 1)
    else:
        slc = slice(cx - N // 2, cx + N // 2)
    return a[slice_a(slice(None), slc, len(a.shape), axis)]
def fft_a(a, axis):
    return numpy.fft.fftshift(numpy.fft.fft(numpy.fft.ifftshift(a, axis),axis=axis),axis)
def ifft_a(a, axis):
    return numpy.fft.fftshift(numpy.fft.ifft(numpy.fft.ifftshift(a, axis),axis=axis),axis)
def broadcast_a(a, dims, axis):
    slc = [numpy.newaxis] * dims
    slc[axis] = slice(None)
    return a[slc]
def broadcast_a(a, dims, axis):
    return a[slice_a(numpy.newaxis, slice(None), dims, axis)]

This allows us to define the two fundamental operations - going from $F$ to $b\ast F$ and from $b\ast F$ to $n\ast m(b\ast F)$ separately:

In [None]:
def prepare_facet(facet, axis):
    BF = pad_mid_a(facet * broadcast_a(Fb, len(facet.shape), axis), yP_size, axis)
    return ifft_a(BF, axis)
def extract_subgrid(BF, i, axis):
    dims = len(BF.shape)
    BF_mid = extract_mid_a(numpy.roll(BF, -subgrid_off[i]*yP_size//N, axis), xMxN_yP_size, axis)
    MBF = broadcast_a(facet_m0_trunc,dims,axis) * BF_mid
    MBF_sum = numpy.array(extract_mid_a(MBF, xM_yP_size, axis))
    xN_yP_size = xMxN_yP_size - xM_yP_size
    # [:xN_yP_size//2] / [-xN_yP_size//2:] for axis, [:] otherwise
    slc1 = slice_a(slice(None), slice(xN_yP_size//2), dims, axis)
    slc2 = slice_a(slice(None), slice(-xN_yP_size//2,None), dims, axis)
    MBF_sum[slc1] += MBF[slc2]
    MBF_sum[slc2] += MBF[slc1]
    return broadcast_a(Fn,len(BF.shape),axis) * \
           extract_mid_a(fft_a(MBF_sum, axis), xM_yN_size, axis)

Having those operations separately means that we can shuffle things around quite a bit without affecting the result. The obvious first choice might be to do all facet-preparation up-front, as this allows us to share the computation across all subgrids:

In [None]:
t = time.time()
NMBF_NMBF = numpy.empty((nsubgrid, nsubgrid, nfacet, nfacet, xM_yN_size, xM_yN_size), dtype=complex)
for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    BF_F = prepare_facet(facet_2[j0,j1], 0)
    BF_BF = prepare_facet(BF_F, 1)
    for i0 in range(nsubgrid):
        NMBF_BF = extract_subgrid(BF_BF, i0, 0)
        for i1 in range(nsubgrid):
            NMBF_NMBF[i0,i1,j0,j1] = extract_subgrid(NMBF_BF, i1, 1)
print(time.time() - t, "s")

However, remember that `prepare_facet` increases the amount of data involved, which in turn means that we need to shuffle more data through subsequent computations.

Therefore it is actually more efficient to first do the subgrid-specific reduction, and *then* continue with the (constant) facet preparation along the other axis. We can tackle both axes in whatever order we like, it doesn't make a difference for the result:

In [None]:
t = time.time()
for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    BF_F = prepare_facet(facet_2[j0,j1], 0)
    for i0 in range(nsubgrid):
        NMBF_F = extract_subgrid(BF_F, i0, 0)
        NMBF_BF = prepare_facet(NMBF_F, 1)
        for i1 in range(nsubgrid):
            NMBF_NMBF[i0,i1,j0,j1] = extract_subgrid(NMBF_BF, i1, 1)
print(time.time() - t, "s")

In [None]:
t = time.time()
for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    F_BF = prepare_facet(facet_2[j0,j1], 1)
    for i1 in range(nsubgrid):
        F_NMBF = extract_subgrid(F_BF, i1, 1)
        BF_NMBF = prepare_facet(F_NMBF, 0)
        for i0 in range(nsubgrid):
            NMBF_NMBF[i0,i1,j0,j1] = extract_subgrid(BF_NMBF, i0, 0)
print(time.time() - t, "s")

In [None]:
pylab.rcParams['figure.figsize'] = 16, 8
err_sum = err_sum_img = 0
for i0,i1 in itertools.product(range(nsubgrid), range(nsubgrid)):
    approx = numpy.zeros((xM_size, xM_size), dtype=complex)
    for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
        approx += numpy.roll(pad_mid(NMBF_NMBF[i0,i1,j0,j1], xM_size),
                             (facet_off[j0]*xM_size//N, facet_off[j1]*xM_size//N), (0,1))
    approx = extract_mid(ifft(approx), xA_size)
    approx *= numpy.outer(subgrid_A[i0], subgrid_A[i1])
    err_sum += numpy.abs(approx - subgrid_2[i0,i1])**2 / nsubgrid**2
    err_sum_img += numpy.abs(fft(approx - subgrid_2[i0,i1]))**2 / nsubgrid**2
pylab.imshow(numpy.log(numpy.sqrt(err_sum)) / numpy.log(10)); pylab.colorbar(); pylab.show()
pylab.imshow(numpy.log(numpy.sqrt(err_sum_img)) / numpy.log(10)); pylab.colorbar(); pylab.show()
print("RMSE:", numpy.sqrt(numpy.mean(err_sum)), "(image:", numpy.sqrt(numpy.mean(err_sum_img)), ")")

## Degridding

To use this for radio astronomy, our goal in this context is to (de)grid visibilities from subgrids. This uses very similar machinery - in fact, what we described so far can simply be re-expressed as gridding or degridding all points of a sub-grid using facets. Difference being that our method is a lot faster and requires less data movement.

However, this does similarity does not actually buy us much: While for the recombination we consider the fields of view of facets, for gridding visibilities we are interested in the "global" field of view. Therefore we need a different grid correction and gridder that gets applied before and after we have done the combination, respectively.

The full size of the considered image is fixed to $N$, therefore our effective image size is $2x_0N$:

In [None]:
pylab.rcParams['figure.figsize'] = 16, 4

gc_alpha = 0; xGp = 5/N; gc_x0 = 0.35
gc_support = int(2*xGp*N)
print("parameter:", numpy.pi*gc_support/2, "x0:", gc_x0)
x0_size = int(N*gc_x0*2)
gc_pswf = anti_aliasing_function(N, gc_alpha, numpy.pi*gc_support/2)
gc = pad_mid(extract_mid(1 / gc_pswf, x0_size), N)
pylab.semilogy(x0_size*coordinates(x0_size), numpy.abs(extract_mid(gc, x0_size))); pylab.legend(["F[n]"]);
pylab.xlim((-N/1.8, N/1.8))
mark_range("$x_0N$", -gc_x0*N,gc_x0*N);
mark_range("$N/2$", -N/2,N/2); pylab.title("Grid correction"); pylab.show();

From this we derive the new $\mathcal F G$ that we are going to feed to the recombination algorithm:

In [None]:
FG_2_gc = FG_2 * numpy.outer(gc, gc)
show_image(numpy.log(numpy.maximum(1e-15, numpy.abs(FG_2_gc))) / numpy.log(10), "FG_2_cropped", N)
G_2_gc = ifft(FG_2_gc)
crop = pad_mid(numpy.ones(x0_size), N)
G_2_cropped = ifft(FG_2 * numpy.outer(crop,crop))

Which in turn leads to new facets. Note how the grid correction pattern is clearly larger than any individual facet.

The other thing to notice here is that due to the grid correction margin a significant portion of the image is now zero, which translates to entire facets being zero. Due to the linearity of the method this means we could simply skip those. For the purpose of this notebook we do not use this, but it is a good optimisation to keep in mind.

In [None]:
subgrid_2 = numpy.empty((nsubgrid, nsubgrid, xA_size, xA_size), dtype=complex)
facet_2 = numpy.empty((nfacet, nfacet, yB_size, yB_size), dtype=complex)
for i0,i1 in itertools.product(range(nsubgrid), range(nsubgrid)):
    subgrid_2[i0,i1] = extract_mid(numpy.roll(G_2_gc, (-subgrid_off[i0], -subgrid_off[i1]), (0,1)), xA_size)
    subgrid_2[i0,i1] *= numpy.outer(subgrid_A[i0], subgrid_A[i1])
fig = pylab.figure(figsize=(32,32))
for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    facet_2[j0,j1] = extract_mid(numpy.roll(FG_2_gc, (-facet_off[j0], -facet_off[j1]), (0,1)), yB_size)
    facet_2[j0,j1] *= numpy.outer(facet_B[j0], facet_B[j1])
    show_image(numpy.log(numpy.maximum(1e-15, numpy.abs(facet_2[j0,j1]))) / numpy.log(10),
               "facet_%d%d" % (j0,j1), N, axes=fig.add_subplot(nfacet,nfacet,j1+(nfacet-j0-1)*nfacet+1),
              norm=(-15,8))
pylab.show(fig)

The recombination algorithm again, using the new data.

In [None]:
NMBF_NMBF = numpy.empty((nsubgrid, nsubgrid, nfacet, nfacet, xM_yN_size, xM_yN_size), dtype=complex)
for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    F_BF = prepare_facet(facet_2[j0,j1], 1)
    for i1 in range(nsubgrid):
        F_NMBF = extract_subgrid(F_BF, i1, 1)
        BF_NMBF = prepare_facet(F_NMBF, 0)
        for i0 in range(nsubgrid):
            NMBF_NMBF[i0,i1,j0,j1] = extract_subgrid(BF_NMBF, i0, 0)

from pylru import lrudecorator
@lrudecorator(100)
def make_approx_subgrid(i0,i1):
    approx = numpy.zeros((xM_size, xM_size), dtype=complex)
    for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
        approx += numpy.roll(pad_mid(NMBF_NMBF[i0,i1,j0,j1], xM_size),
                             (facet_off[j0]*xM_size//N, facet_off[j1]*xM_size//N), (0,1))
    # Extract region that is set in subgrid for comparison
    approx_compare = extract_mid(ifft(approx), xA_size)
    approx_compare *= numpy.outer(subgrid_A[i0], subgrid_A[i1])
    rmse = numpy.sqrt(numpy.mean(numpy.abs(approx_compare - subgrid_2[i0,i1])**2 / nsubgrid**2))
    # Return full approximation. We degrid from it, so bounds don't matter
    return ifft(approx), rmse / numpy.mean(numpy.abs(approx_compare))

In order to obtain visibilities at non-integer positions we need an oversampled gridding function, as usual:

In [None]:
oversample = 2**14
print("grid support:", gc_support)
print("oversampling:", oversample)
kernel = kernel_oversample(gc_pswf, oversample, gc_support).real
kernel /= numpy.sum(kernel[0])
r = numpy.arange(-oversample*(gc_support//2), oversample*((gc_support+1)//2)) / oversample
pylab.semilogy(r, numpy.transpose(kernel).flatten().real); mark_range("$Nx_G$", -N*xGp,N*xGp);
pylab.title("Gridding kernel (oversampled x%d)" % oversample); pylab.show();

In [None]:
@interact(iu=(0, N, 0.01),iv=(0, N, 0.01))
def test_degrid_accuracy(iu,iv, show_subgrid=False):
    u = (iu - N//2) / N; v = (iv - N//2) / N
    su = numpy.sum((iu+N//2)%N >= subgrid_border) % nsubgrid
    sv = numpy.sum((iv+N//2)%N >= subgrid_border) % nsubgrid
    siu = iu + xA_size//2-(subgrid_off[su] + N//2) % N
    siv = iv + xA_size//2-(subgrid_off[sv] + N//2) % N
    
    dAM = (xM_size - xA_size) // 2
    deg = conv_predict(N, 1, numpy.array([(u,v,0)]), None, G_2_gc, kernel)[0]
    if whole(iu) and whole(iv):
        actual = G_2_cropped[int(iv),int(iu)]
        print("actual:       ", actual)
        print("degridded:    ", deg)
        print("degrid error: ", numpy.abs(deg-actual))
    else:
        print("degridded:    ", deg)
    
    approx_subgrid, rmse = make_approx_subgrid(sv, su)
    print("subgrid:       (%d/%d), rmse: %g" % (su, sv, rmse))
    
    sou = (((subgrid_off[su] + N//2) % N) - N//2) / N
    sov = (((subgrid_off[sv] + N//2) % N) - N//2) / N
    deg_ap = conv_predict(N, 2*xM, numpy.array([(u-sou,v-sov,0)]), None, approx_subgrid, kernel)[0]
    print("recomb+degrid:", deg_ap);
    print("recomb error: ", numpy.abs(deg_ap-deg))
    if whole(iu) and whole(iv):
        print("total error:  ", numpy.abs(deg_ap-actual))
        
    if show_subgrid:
        fig = pylab.figure()
        ax = fig.add_subplot(111)
        show_grid(numpy.abs(numpy.log(approx_subgrid) / numpy.log(10)), "subgrid_%d%d" % (su,sv), N, axes=ax)
        ax.add_patch(patches.Rectangle((u-sou-gc_support//2/N, v-sov-gc_support//2/N),
                                       gc_support/N, gc_support/N, fill=False))
        pylab.show(fig)

# Write out results

In [None]:
import os.path
import h5py

out_prefix = "../../data/grid/T05_"

with open(out_prefix + "pswf.in", "w") as f:
    numpy.fft.ifftshift(pswf).tofile(f)

for j0,j1 in itertools.product(range(nfacet), range(nfacet)):
    
    with open(out_prefix + "facet%d%d.in" % (j0,j1), "w") as f:
        numpy.fft.ifftshift(facet_2[j0,j1]).tofile(f)
    for i0,i1 in itertools.product(range(nsubgrid), range(nsubgrid)):
        with open(out_prefix + "nmbf%d%d%d%d.in" % (i0,i1,j0,j1), "w") as f:
            numpy.fft.ifftshift(NMBF_NMBF[i0,i1,j0,j1]).tofile(f)

for i0,i1 in itertools.product(range(nsubgrid), range(nsubgrid)):
    with open(out_prefix + "approx%d%d.in" % (i0,i1), "w") as f:
        numpy.fft.ifftshift(make_approx_subgrid(i0, i1)[0]).tofile(f)
        
with h5py.File(out_prefix + "kernel.h5",'w') as f:
    f['sepkern/kern'] = kernel