# Illustration of w-stacking

In [None]:
%matplotlib inline

import sys
sys.path.append('../..')

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = 16, 10

import functools
import numpy
import scipy
import scipy.special
import time

from crocodile.clean import *
from crocodile.synthesis import *
from crocodile.simulate import *
from util.visualize import *
from arl.test_support import create_named_configuration

Generate baseline coordinates for an observation with the VLA over 6 hours, with a visibility recorded every 10 minutes. The phase center is fixed at a declination of 45 degrees. We assume that the imaged sky says at that position over the course of the observation.

Note how this gives rise to fairly large $w$-values.

In [None]:
vlas = create_named_configuration('VLAA')
ha_range = numpy.arange(numpy.radians(0),
                        numpy.radians(90),
                        numpy.radians(90 / 36))
dec = numpy.radians(45)
vobs = xyz_to_baselines(vlas.data['xyz'], ha_range, dec)

# Wavelength: 5 metres 
wvl=5
uvw = vobs / wvl

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
ax = plt.figure().add_subplot(121, projection='3d')
ax.scatter(uvw[:,0], uvw[:,1] , uvw[:,2])
max_uvw = numpy.amax(uvw)
ax.set_xlabel('U [$\lambda$]'); ax.set_xlim((-max_uvw, max_uvw))
ax.set_ylabel('V [$\lambda$]'); ax.set_ylim((-max_uvw, max_uvw))
ax.set_zlabel('W [$\lambda$]'); ax.set_zlim((-max_uvw, max_uvw))
ax.view_init(20, 20)
pylab.show()

## We can now generate visibilities for these baselines by simulation. We place three sources.

In [None]:
import itertools
vis = numpy.zeros(len(uvw), dtype=complex)
for u,v in itertools.product(range(-3, 4), range(-3, 4)):
    vis += 1.0*simulate_point(uvw, 0.010*u, 0.010*v)
plt.clf()
uvdist=numpy.sqrt(uvw[:,0]**2+uvw[:,1]**2)
plt.plot(uvdist, numpy.abs(vis), '.', color='r')

Using imaging, we can now reconstruct the image. We split the visibilities into a number of w-bins:

In [None]:
# Imaging parameterisation
theta = 2*0.05
lam = 18000
wstep = 100
npixkern = 31
grid_size = int(numpy.ceil(theta*lam))

In [None]:
# Determine weights (globally)
wt = doweight(theta, lam, uvw, numpy.ones(len(uvw)))

# Depending on algorithm we are going to prefer different uvw-distributions,
# so make decision about conjugation of visibilities flexible.
def flip_conj(where):
    # Conjugate visibility. This does not change its meaning.
    uvw[where] = -uvw[where]
    vis[where] = numpy.conj(vis[where])
    # Determine w-planes
    wplane = numpy.around(uvw[:,2] / wstep).astype(int)
    return uvw, vis, numpy.arange(numpy.min(wplane), numpy.max(wplane)+1), wplane

## Simple w-stacking

Now we can image each w-plane separately, and divide the w-term out in the image plane. This method requires us to do a lot of FFTs:

In [None]:
image_sum = numpy.zeros((grid_size, grid_size), dtype=complex)
w_grids = {}
uvw,vis,wplanes,wplane = flip_conj(uvw[:,2] < 0.0)
start_time = time.time()
for wp in wplanes:
    
    # Filter out w-plane
    puvw = uvw[wplane == wp]
    if len(puvw) == 0: continue
    pvis = vis[wplane == wp]
    pwt = wt[wplane == wp]
    midw = numpy.mean(puvw[:,2])
    print("w-plane %d: %d visibilities, %.1f average w" % (wp, len(puvw), midw))
    
    # Translate w-coordinate (not needed for simple imaging though)
    #puvw = numpy.array(puvw)
    #puvw[:,2] -= midw
    
    src = numpy.ndarray((len(pvis), 0))
    
    # Make image
    cdrt = simple_imaging(theta, lam, puvw, src, pvis * pwt)
    l,m = theta*coordinates2(grid_size)
    
    # Multiply by Fresnel pattern in image space, add
    wkern = w_kernel_function(l, m, midw)
    w_grids[wp] = ifft(cdrt) / wkern
    image_sum += w_grids[wp]

print("Done in %.1fs" % (time.time() - start_time))

# We only used half of the visibilities, so the image is not going to
# end up real-valued. However, we can easily just remove the unused imaginary
# parts and multiply by 2 to arrive at the correct result.
show_image(2.0*numpy.real(image_sum), "image", theta)

This was the easiest version of w-stacking. Clearly a lot of w-planes are mostly empty, which is wasteful both in terms of FFT complexity and especially in terms of memory (bandwidth).

## Optimised w-planes

We can actually reduce the size of these w-planes: Use a grid that has just enough size to contain the visibility and the w-pattern, but FFT it back into grid space and add it into the w=0 grid with an offset. This means two FFTs, but if the sub-grid size is small enough this is worth it.

In [None]:
start_time = time.time()
uvw,vis,wplanes,wplane = flip_conj(uvw[:,1] < 0.0)
grid_sum = numpy.zeros((grid_size, grid_size), dtype=complex)
for wp in wplanes:
    
    # Filter out w-plane
    puvw = uvw[wplane == wp]
    if len(puvw) == 0: continue
    pvis = vis[wplane == wp]
    pwt = wt[wplane == wp]
    midw = numpy.mean(puvw[:,2])
    
    # w=0 plane? Just grid directly - skip Fresnel pattern (guaranteed to be =1) + FFTs
    if abs(midw) < wstep / 2:
        grid_sum += simple_imaging(theta, lam, puvw, src, pvis * pwt)
        continue
    
    # Determine uv bounds, round to grid cell
    xy_min = numpy.floor(numpy.amin(puvw[:,:2], axis=0) * theta).astype(int)
    xy_max = numpy.ceil(numpy.amax(puvw[:,:2], axis=0) * theta).astype(int)
    
    # Make sure we have enough space for convolution.
    xy_min -= (npixkern + 1) // 2
    xy_max += npixkern // 2
    xy_size = numpy.max(xy_max - xy_min)
    print("w-plane %d: %d visibilities, %.1f average w, %dx%d cells" %
          (wp, len(puvw), midw, xy_size, xy_size))
    
    # Force quadratic - TODO: unneeded, strictly speaking
    xy_maxq = numpy.amax([xy_max, xy_min + xy_size], axis=0)
    
    # Determine the uvw size and mid-point
    uvw_size = xy_size / theta
    uvw_mid = numpy.hstack([(xy_maxq + xy_min) // 2 / theta, midw])

    # Grid
    pgrid = simple_imaging(theta, uvw_size, puvw - uvw_mid, src, pvis * pwt)
    
    # Generate Fresnel pattern
    l,m = theta*coordinates2(xy_size)
    wkern = w_kernel_function(l, m, midw)
    
    # Divide Fresnel pattern in image plane, then FFT right back
    pgrid_w = fft(ifft(pgrid) / wkern)
    
    # Add to original grid at offset
    mid = int(lam*theta)//2
    x0, y0 = mid + xy_min
    x1, y1 = mid + xy_max
    grid_sum[y0:y1, x0:x1] += pgrid_w[0:y1-y0, 0:x1-x0]

image_sum = ifft(grid_sum)
print("Done in %.1fs" % (time.time() - start_time))
show_image(2.0*numpy.real(image_sum), "image", theta)

As you might notice, this is actually slower overall, because for lower w doing two FFTs per w-plane adds quite a bit of extra work.

## Choosing uv-bins with w-stacking

However, it should now be clear that we can choose what parts of the w-planes to generate entirely independently, so we can especially choose to generate the same uv-chunks on all w-planes. This not only allows us to share the FFT back to the w=0 grid, but also makes the FFT cheaper once we are considering large grids.

In [None]:
uvbin_size = 256 - npixkern # Choose it so we get a nice 2^x size below
start_time = time.time()
uvw,vis,wplanes,wplane = flip_conj(uvw[:,1] < 0.0)
grid_sum = numpy.zeros((grid_size, grid_size), dtype=complex)
ubin = numpy.floor(uvw[:,0]*theta/uvbin_size).astype(int)
vbin = numpy.floor(uvw[:,1]*theta/uvbin_size).astype(int)

# Generate Fresnel pattern for shifting between two w-planes
# As this is the same between all w-planes, we can share it
# between the whole loop.
l,m = theta*coordinates2(uvbin_size + npixkern)
wkern = w_kernel_function(l, m, wstep)

for ub in range(numpy.min(ubin), numpy.max(ubin)+1):
    for vb in range(numpy.min(vbin), numpy.max(vbin)+1):
        
        # Find visibilities
        bin_sel = numpy.logical_and(ubin == ub, vbin == vb)
        if not numpy.any(bin_sel):
            continue
        
        # Determine bin dimensions
        xy_min = uvbin_size * numpy.array([ub, vb], dtype=int)
        xy_max = uvbin_size * numpy.array([ub+1, vb+1], dtype=int)
        uv_min = xy_min / theta
        uv_max = xy_min / theta
        uv_mid = (xy_max + xy_min) // 2 / theta

        # Make sure we have enough space for convolution.
        xy_min -= (npixkern + 1) // 2
        xy_max += npixkern // 2
        assert(numpy.all(numpy.max(xy_max - xy_min) == uvbin_size+npixkern))
        uvw_size = (uvbin_size+npixkern) / theta

        # Make grid for uv-bin
        bin_image_sum = numpy.zeros((uvbin_size+npixkern, uvbin_size+npixkern), dtype=complex)
        nvis = 0; midws = []
        last_wp = wplanes[0]
        for wp in wplanes:

            # Filter out visibilities for u/v-bin and w-plane
            slc = numpy.logical_and(bin_sel, wplane == wp)
            puvw = uvw[slc]
            if len(puvw) == 0: continue
            pvis = vis[slc]
            pwt = wt[slc]
            
            # Statistics
            nvis += len(puvw)
            midws.append(wp*wstep)
            
            # w=0 plane? Just grid directly, as before
            if wp == 0:
                grid_sum += simple_imaging(theta, lam, puvw, src, pvis * pwt)
                continue

            # Bring image sum into this w-plane
            if last_wp != wplanes[0]:
                bin_image_sum *= wkern**(wp-last_wp)
            last_wp = wp
            
            # Grid relative to mid-point
            uvw_mid = numpy.hstack([uv_mid, [wp*wstep]])
            pgrid = simple_imaging(theta, uvw_size, puvw - uvw_mid, src, pvis * pwt)

            # Add to bin grid
            bin_image_sum += ifft(pgrid)

        # No visibilities? Skip
        if nvis == 0: continue

        # Transfer into w=0 plane, FFT image sum
        print("uv-bin %d,%d: %d visibilities, %s w-bins" % (ub, vb, nvis, numpy.array(midws, dtype=int)))
        bin_image_sum /= wkern**last_wp
        bin_grid = fft(bin_image_sum)

        # Add to grid, keeping bounds in mind
        mid = int(lam*theta)//2
        x0, y0 = mid + xy_min
        x1, y1 = mid + xy_max
        x0b, y0b = numpy.amax([[x0, y0], [0,0]], axis=0)
        x1b, y1b = numpy.amin([[x1, y1], [grid_size,grid_size]], axis=0)
        grid_sum[y0b:y1b, x0b:x1b] += \
           bin_grid[y0b-y0:y1b-y0, x0b-x0:x1b-x0]

image_sum = ifft(grid_sum)
print("Done in %.1fs" % (time.time() - start_time))
show_image(2.0 * numpy.real(image_sum), "image", theta)

By zooming in we can confirm output quality:

In [None]:
image_show = numpy.real(image_sum)
step=int(grid_size/10)
def zoom(x, y=step): pylab.matshow(image_show[y:y+2*step,x:x+2*step]) ; pylab.colorbar(shrink=.4,pad=0.025);  pylab.show()
from ipywidgets import interact
interact(zoom, x=(0,image_show.shape[0]-2*step,step), y=(0,image_show.shape[1]-2*step,step));