# Dask bag-based imaging demonstration

This notebook explores the use of dask bags for parallelisation. 

See imaging-dask notebook for processing with dask delayed

We create the visibility and fill in values with the transform of a number of point sources. 

In [None]:
%matplotlib inline

import os
import sys

from dask import delayed, bag
from distributed import Client

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.operations import apply_gaintable
from arl.data.polarisation import PolarisationFrame
from arl.visibility.base import create_visibility, copy_visibility
from arl.visibility.operations import concatenate_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.operations import show_image, qa_image, create_empty_image_like,\
    pad_image, show_image
from arl.image.deconvolution import deconvolve_cube, restore_cube
from arl.util.testing_support import create_named_configuration, create_test_image
from arl.image.operations import create_empty_image_like, create_image_from_array
from arl.imaging import create_image_from_visibility, predict_skycomponent_visibility, \
    advise_wide_field, predict_2d, invert_2d, normalize_sumwt
from arl.imaging.imaging_context import imaging_context
from arl.imaging.wstack import predict_wstack_single, invert_wstack_single
from arl.imaging.timeslice import predict_timeslice_single, invert_timeslice_single
from arl.visibility.gather_scatter import visibility_gather_w, visibility_scatter_w
from arl.visibility.gather_scatter import visibility_gather_time, visibility_scatter_time
from arl.imaging.weighting import weight_visibility
from arl.graphs.dask_init import get_dask_Client

from arl.graphs.bags import predict_bag, predict_record, scatter_record, invert_bag, invert_record,\
    invert_record_add, create_empty_image_record, create_empty_visibility_record, folded_to_image_record, \
    folded_to_visibility_record, deconvolve_bag

import pprint
pp = pprint.PrettyPrinter(indent=4)
import logging

log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

Create a Dask client to perform the computation of the bags

In [None]:
c=get_dask_Client()

Define a function to create the visibilities

In [None]:
def ingest_visibility(freq=[1e8],
                      chan_width=[1e6],
                      times=[0.0],
                      reffrequency=[1e8],
                      npixel=512,
                      init=False):
    lowcore = create_named_configuration('LOWBD2-CORE')

    frequency = numpy.array(freq)
    channel_bandwidth = numpy.array(chan_width)

    phasecentre = SkyCoord(
        ra=+15.0 * u.deg, dec=-26.7 * u.deg, frame='icrs', equinox='J2000')
    vt = create_visibility(
        lowcore,
        times,
        frequency,
        channel_bandwidth=channel_bandwidth,
        weight=1.0,
        phasecentre=phasecentre,
        polarisation_frame=PolarisationFrame("stokesI"))
    if init:
        cellsize = 0.001
        model = create_image_from_visibility(
            vt,
            npixel=npixel,
            cellsize=cellsize,
            npol=1,
            frequency=reffrequency,
            polarisation_frame=PolarisationFrame("stokesI"))
        flux = numpy.array([[100.0]])
        facets = 4

        spacing_pixels = npixel // facets
        spacing = 180.0 * cellsize * spacing_pixels / numpy.pi
        centers = -1.5, -0.5, +0.5, +1.5
        comps = list()
        for iy in centers:
            for ix in centers:
                pra = int(round(npixel // 2 + ix * spacing_pixels - 1))
                pdec = int(round(npixel // 2 + iy * spacing_pixels - 1))
                sc = pixel_to_skycoord(pra, pdec, model.wcs)
                comps.append(
                    create_skycomponent(
                        flux=flux,
                        frequency=reffrequency,
                        direction=sc,
                        polarisation_frame=PolarisationFrame("stokesI")))
        predict_skycomponent_visibility(vt, comps)

    return vt

Now make five of these spanning 800MHz to 1200MHz and put them into a Dask bag. We do this with varying levels of aggregation but will only use f_vis_bag.

The visibilities are wrapped in meta-data required to enable the bag operations.

In [None]:
nfreqwin = 5
ntimes = 5
times = numpy.linspace(-numpy.pi / 4, numpy.pi / 4, ntimes)
frequencies = numpy.linspace(0.8e8, 1.2e8, nfreqwin)
chan_width=numpy.array(nfreqwin*[4e7/nfreqwin])

all_vis_bag = bag.from_sequence([{'vis': ingest_visibility(frequencies, chan_width=chan_width,
                                                           times=times)}])
pp.pprint('All data:')
pp.pprint(list(all_vis_bag))

f_vis_bag = bag.from_sequence([{'freqwin': f,
                                'vis': ingest_visibility([freq], chan_width=[chan_width[f]],
                                                         times=times)}
                               for f, freq in enumerate(frequencies)])

pp.pprint('Per frequency:')
pp.pprint(list(f_vis_bag))

t_vis_bag = bag.from_sequence([{'timewin': t,
                                'vis': ingest_visibility(frequencies, chan_width=chan_width,
                                                         times=[time])}
                               for t, time in enumerate(times)])
pp.pprint('Per time: ')
pp.pprint(list(t_vis_bag))

ft_vis_bag = bag.from_sequence([{'freqwin': f,
                                 'timewin': t,
                                 'vis': ingest_visibility([freq], chan_width=[chan_width[f]],
                                                          times=[time])}
                                for f, freq in enumerate(frequencies)
                                for t, time in enumerate(times)])
pp.pprint('Per frequency and time: ')
pp.pprint(list(ft_vis_bag))

Make a model image.

In [None]:
npixel = 512

def get_LSM(vt, cellsize=0.001, reffrequency=[1e8], npixel=512):
    model = pad_image(
        create_test_image(
            vt,
            cellsize=cellsize,
            frequency=reffrequency,
            phasecentre=vt.phasecentre,
            polarisation_frame=PolarisationFrame("stokesI")),
        shape=[1, 1, npixel, npixel])
    return model

future = c.compute(all_vis_bag)

all_vis_bag = bag.from_sequence(future.result())
advice = advise_wide_field(all_vis_bag.compute()[0]['vis'], guard_band_image=4.0)
model = get_LSM(all_vis_bag.compute()[0]['vis'], npixel=npixel, reffrequency=[frequencies[0]])

For the processing to come, we need to attach some meta data to aid in selecting the data.

Now we can set up the prediction of the visibility from the model. We scatter over w and then apply the wstack for a single w plane. Then we concatenate the visibilities back together.

In [None]:
future=c.compute(predict_bag(f_vis_bag, model, context='wstack_single', vis_slices=101))
vis_bag=bag.from_sequence(future.result())

We will form a dirty images for each frequency from all times. 

In [None]:
dirty_bag = invert_bag(vis_bag, model, context='wstack_single')    
psf_bag = invert_bag(vis_bag, model, dopsf=True, context='wstack_single')

In [None]:
future = c.compute(dirty_bag)
result = future.result()

The bags contain the Image objects wrapped in meta data.

In [None]:
pp.pprint(result)

In [None]:
future = c.compute(dirty_bag)
result = future.result()
dirty, sumwt = result[0]['image']
show_image(dirty)
plt.show()


future = c.compute(psf_bag)
result = future.result()
psf, sumwt = result[0]['image']
show_image(psf)
plt.show()

Now we can do the deconvolution for each frequency window

In [None]:
comp_bag=deconvolve_bag(dirty_bag, psf_bag, model, niter=1000, threshold=0.01, 
                    fracthresh=0.1, window_shape='quarter',
                    gain=0.7, scales=[0, 3, 10])

future=c.compute(comp_bag)
comp = comp_bag.compute()
for clean in comp:
    fig=show_image(clean)
    plt.show()

comp_bag=bag.from_sequence(comp)