# Dask bag-based imaging demonstration

This notebook explores the use of dask bags for parallelisation. 

See imaging-pipelines notebook for pipeline processing with dask.

We create the visibility and fill in values with the transform of a number of point sources. 

In [None]:
%matplotlib inline

import os
import sys

from dask import delayed, bag
from distributed import Client

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.operations import apply_gaintable
from arl.data.polarisation import PolarisationFrame
from arl.visibility.base import create_visibility, copy_visibility
from arl.visibility.operations import concat_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.operations import show_image, qa_image, create_empty_image_like
from arl.util.testing_support import create_named_configuration, create_test_image
from arl.imaging import create_image_from_visibility, predict_skycomponent_visibility, \
    advise_wide_field, predict_2d, invert_2d, normalize_sumwt
from arl.imaging.wstack import predict_wstack_single, invert_wstack_single
from arl.imaging.timeslice import predict_timeslice_single, invert_timeslice_single
from arl.visibility.gather_scatter import visibility_gather_w, visibility_scatter_w
from arl.visibility.gather_scatter import visibility_gather_time, visibility_scatter_time
from arl.imaging.weighting import weight_visibility
from arl.graphs.dask_init import get_dask_Client
from arl.graphs.graphs import sum_invert_results
from arl.pipelines.graphs import create_continuum_imaging_pipeline_graph
from arl.graphs.bags import safe_invert_list, safe_predict_list, sum_invert_results

import logging

log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))

In [None]:
def ingest_visibility(freq=1e8, chan_width=1e6, reffrequency=[1e8],
                      init=False):
    lowcore = create_named_configuration('LOWBD2-CORE')
    times = numpy.linspace(-numpy.pi / 4, numpy.pi / 4, 7)
    frequency = numpy.array([freq])
    channel_bandwidth = numpy.array([chan_width])

    phasecentre = SkyCoord(
        ra=+15.0 * u.deg, dec=-26.7 * u.deg, frame='icrs', equinox='J2000')
    vt = create_visibility(
        lowcore,
        times,
        frequency,
        channel_bandwidth=channel_bandwidth,
        weight=1.0,
        phasecentre=phasecentre,
        polarisation_frame=PolarisationFrame("stokesI"))
    if init:
        npixel = 256
        cellsize = 0.001
        model = create_image_from_visibility(
            vt,
            npixel=npixel,
            cellsize=cellsize,
            npol=1,
            frequency=reffrequency,
            polarisation_frame=PolarisationFrame("stokesI"))
        flux = numpy.array([[100.0]])
        facets = 4

        spacing_pixels = npixel // facets
        spacing = 180.0 * cellsize * spacing_pixels / numpy.pi
        centers = -1.5, -0.5, +0.5, +1.5
        comps = list()
        for iy in centers:
            for ix in centers:
                pra = int(round(npixel // 2 + ix * spacing_pixels - 1))
                pdec = int(round(npixel // 2 + iy * spacing_pixels - 1))
                sc = pixel_to_skycoord(pra, pdec, model.wcs)
                comps.append(
                    create_skycomponent(
                        flux=flux,
                        frequency=reffrequency,
                        direction=sc,
                        polarisation_frame=PolarisationFrame("stokesI")))
        predict_skycomponent_visibility(vt, comps)

    return vt

In [None]:
nfreqwin=7
vis_bag=bag.from_sequence([ingest_visibility(freq) 
                           for freq in numpy.linspace(0.8e8,1.2e8,nfreqwin)])

At this point, all we have is a bag containing instruction on how to construct the visibility

In [None]:
print(vis_bag)

We need to compute the graph in order to use it. First we just need a representative data set to calculate imaging parameters.

In [None]:
npixel=256
facets=4
def get_LSM(vt, cellsize=0.001, reffrequency=[1e8]):
    model = create_test_image(vt, cellsize=cellsize, frequency=reffrequency, 
                              phasecentre=vt.phasecentre,
                              polarisation_frame=PolarisationFrame("stokesI"))
    return model

vis_bag = list(vis_bag)
model = get_LSM(vis_bag[0])
advice=advise_wide_field(vis_bag[0], guard_band_image=4.0)

Now we can set up the prediction of the visibility from the model

In [None]:
predicted_vis_bag=bag.from_sequence([ingest_visibility(freq) for freq in numpy.linspace(0.8e8,1.2e8,nfreqwin)])\
    .map(visibility_scatter_w, vis_slices=11)\
    .map(safe_predict_list, model, predict=predict_wstack_single)\
    .map(concat_visibility)

In [None]:
vt = predicted_vis_bag.compute()[0]

# To check that we got the prediction right, plot the amplitude of the visibility.
uvdist=numpy.sqrt(vt.data['uvw'][:,0]**2+vt.data['uvw'][:,1]**2)
plt.clf()
plt.plot(uvdist, numpy.abs(vt.data['vis']), '.')
plt.xlabel('uvdist')
plt.ylabel('Amp Visibility')
plt.show()

Now we can make the dirty image. For this we will scatter over time.

In [None]:
dirty_bag=predicted_vis_bag\
    .map(visibility_scatter_w, vis_slices=11)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=False, normalize=True)\
    .flatten()\
    .reduction(sum_invert_results, sum_invert_results)

dirty, sumwt=dirty_bag.compute()
print(qa_image(dirty, context='Dirty'))
show_image(dirty, title='Dirty')
plt.show()