# Dask bag-based imaging demonstration

This notebook explores the use of dask bags for parallelisation. For the most part we work with the bags directly. Much of this can be hidden in standard functions.

See imaging-dask notebook for processing with dask delayed

We create the visibility and fill in values with the transform of a number of point sources. 

In [None]:
%matplotlib inline

import os
import sys

from dask import delayed, bag
from distributed import Client

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.operations import apply_gaintable
from arl.data.polarisation import PolarisationFrame
from arl.visibility.base import create_visibility, copy_visibility
from arl.visibility.operations import concatenate_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.operations import show_image, qa_image, create_empty_image_like,\
    pad_image
from arl.image.deconvolution import deconvolve_cube, restore_cube
from arl.util.testing_support import create_named_configuration, create_test_image
from arl.imaging import create_image_from_visibility, predict_skycomponent_visibility, \
    advise_wide_field, predict_2d, invert_2d, normalize_sumwt
from arl.imaging.wstack import predict_wstack_single, invert_wstack_single
from arl.imaging.timeslice import predict_timeslice_single, invert_timeslice_single
from arl.visibility.gather_scatter import visibility_gather_w, visibility_scatter_w
from arl.visibility.gather_scatter import visibility_gather_time, visibility_scatter_time
from arl.imaging.weighting import weight_visibility
from arl.graphs.dask_init import get_dask_Client
from arl.pipelines.graphs import create_continuum_imaging_pipeline_graph
from arl.graphs.bags import safe_invert_list, safe_predict_list, sum_invert_bag_results, deconvolve_bag

import logging

log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))

Define a function to create the visibilities

In [None]:
def ingest_visibility(freq=1e8, chan_width=1e6, reffrequency=[1e8], npixel=512,
                      init=False):
    lowcore = create_named_configuration('LOWBD2-CORE')
    times = numpy.linspace(-numpy.pi / 4, numpy.pi / 4, 7)
    frequency = numpy.array([freq])
    channel_bandwidth = numpy.array([chan_width])

    phasecentre = SkyCoord(
        ra=+15.0 * u.deg, dec=-26.7 * u.deg, frame='icrs', equinox='J2000')
    vt = create_visibility(
        lowcore,
        times,
        frequency,
        channel_bandwidth=channel_bandwidth,
        weight=1.0,
        phasecentre=phasecentre,
        polarisation_frame=PolarisationFrame("stokesI"))
    if init:
        cellsize = 0.001
        model = create_image_from_visibility(
            vt,
            npixel=npixel,
            cellsize=cellsize,
            npol=1,
            frequency=reffrequency,
            polarisation_frame=PolarisationFrame("stokesI"))
        flux = numpy.array([[100.0]])
        facets = 4

        spacing_pixels = npixel // facets
        spacing = 180.0 * cellsize * spacing_pixels / numpy.pi
        centers = -1.5, -0.5, +0.5, +1.5
        comps = list()
        for iy in centers:
            for ix in centers:
                pra = int(round(npixel // 2 + ix * spacing_pixels - 1))
                pdec = int(round(npixel // 2 + iy * spacing_pixels - 1))
                sc = pixel_to_skycoord(pra, pdec, model.wcs)
                comps.append(
                    create_skycomponent(
                        flux=flux,
                        frequency=reffrequency,
                        direction=sc,
                        polarisation_frame=PolarisationFrame("stokesI")))
        predict_skycomponent_visibility(vt, comps)

    return vt

Now make seven of these spanning 800MHz to 1200MHz and put them into a Dask bag.

In [None]:
nfreqwin=7
vis_bag=bag.from_sequence([ingest_visibility(freq) 
                           for freq in numpy.linspace(0.8e8,1.2e8,nfreqwin)])
print(vis_bag)

We need to compute the bag in order to use it. First we just need a representative data set to calculate imaging parameters.

In [None]:
npixel=512
facets=4
def get_LSM(vt, cellsize=0.001, reffrequency=[1e8], npixel=512):
    model = pad_image(create_test_image(vt, cellsize=cellsize, frequency=reffrequency, 
                                        phasecentre=vt.phasecentre,
                                        polarisation_frame=PolarisationFrame("stokesI")),
                                        shape=[1, 1, 512, 512])
    return model

vis_bag = list(vis_bag)
model = get_LSM(vis_bag[0])
advice=advise_wide_field(vis_bag[0], guard_band_image=4.0)
vis_slices=11

Now we can set up the prediction of the visibility from the model. We scatter over w and then apply the wstack for a single w plane. Then we concatenate the visibilities back together.

To save recomputing this, we compute it now and place it into another bag of the same name.

In [None]:
vis_bag=bag.from_sequence([ingest_visibility(freq) 
                           for freq in numpy.linspace(0.8e8,1.2e8,nfreqwin)])\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_predict_list, model, predict=predict_wstack_single)\
    .map(concatenate_visibility)
    
vis_bag=bag.from_sequence(vis_bag.compute())

Check out the visibility function. To get the result out of the bag, we do need to compute it but this time it's just a lookup.

In [None]:
vt = vis_bag.compute()[0]

# To check that we got the prediction right, plot the amplitude of the visibility.
uvdist=numpy.sqrt(vt.data['uvw'][:,0]**2+vt.data['uvw'][:,1]**2)
plt.clf()
plt.plot(uvdist, numpy.abs(vt.data['vis']), '.')
plt.xlabel('uvdist')
plt.ylabel('Amp Visibility')
plt.show()

Now we can make the dirty images. As before we will scatter each of the 7 frequency windows (patitions) over w, giving a 2 level nested structure. We make a separate image for each frequency window. The image resolution noticeably improves for the high frequencies.

In [None]:
dirty_bag=vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=False, normalize=True)\
    .map(sum_invert_bag_results)
dirty_bag=bag.from_sequence(dirty_bag.compute())

psf_bag=vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=True, normalize=True)\
    .map(sum_invert_bag_results)
    
psf_bag=bag.from_sequence(psf_bag.compute())
    
for i, dirty in enumerate(dirty_bag.compute()):
    print(qa_image(dirty[0], context='dirty'))
    fig = show_image(dirty[0], title='Dirty image %d, weight %.3f' 
                     % (i, dirty[1]))
    plt.show()

In the next step all these seven images will be deconvolved in parallel. In this case we again need to zip the dirty and psf images and then use a simple adapter function.

In [None]:
def bag_deconvolve(dirty_psf_zip, **kwargs):
    result = deconvolve_cube(dirty_psf_zip[0][0], dirty_psf_zip[1][0], **kwargs)
    return result[0]

comp_bag=bag.zip(dirty_bag, psf_bag).map(bag_deconvolve, niter=1000, threshold=0.001, 
                                         fracthresh=0.01, window_shape='quarter',
                                         gain=0.7, scales=[0, 3, 10, 30])

comp = comp_bag.compute()
fig=show_image(comp[0])

comp_bag=bag.from_sequence(comp)

Now we can calculate the model and residual visibility. To calculate the residual visibility, we will zip the original and model visibilities together and map our adapter across the zipped bag.

In [None]:
model_vis_bag=vis_bag\
    .map(visibility_scatter_w, vis_slices=101)\
    .map(safe_predict_list, comp_bag, predict=predict_wstack_single)\
    .map(concatenate_visibility)
    
model_vis_bag = bag.from_sequence(model_vis_bag.compute())

def subtract_vis(vis_model_zip):
    residual_vis = copy_visibility(vis_model_zip[0])
    residual_vis.data['vis'] -= vis_model_zip[1].data['vis']
    return residual_vis

residual_vis_bag = bag.zip(vis_bag, model_vis_bag)\
    .map(subtract_vis)
    
residual_vis_bag=bag.from_sequence(residual_vis_bag.compute())
    
ovt = vis_bag.compute()[0]
vt = residual_vis_bag.compute()[0]

# To check that we got the prediction right, plot the amplitude of the visibility.
uvdist=numpy.sqrt(vt.data['uvw'][:,0]**2+vt.data['uvw'][:,1]**2)
plt.clf()
plt.plot(uvdist, numpy.abs(ovt.data['vis']), '.', color='b')
plt.plot(uvdist, numpy.abs(vt.data['vis']), '.', color='r')
plt.xlabel('uvdist')
plt.ylabel('Amp Visibility')
plt.show()

Now we can restore the images

In [None]:
residual_bag=residual_vis_bag\
    .map(visibility_scatter_w, vis_slices=11)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=False, normalize=True)\
    .map(sum_invert_bag_results)
    
residual_bag=bag.from_sequence(residual_bag.compute())

def bag_restore(cpr_zip, **kwargs):
    return restore_cube(cpr_zip[0], cpr_zip[1][0], cpr_zip[2][0], **kwargs)

restore_bag = bag.zip(comp_bag, psf_bag, residual_bag)\
    .map(bag_restore)

for i, restored in enumerate(restore_bag.compute()):
    fig = show_image(restored, title='Restored image %d' %i)
    plt.show()