# Dask imaging demonstration

This notebook explores the use of dask for parallelisation. We show the graphs for various types of predict and invert. 

See imaging-pipelines notebook for pipeline processing with dask.

In [None]:
%matplotlib inline

import os
import sys

from dask import delayed

sys.path.append(os.path.join('..', '..'))

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.operations import apply_gaintable
from arl.data.polarisation import PolarisationFrame
from arl.visibility.base import create_blockvisibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.operations import show_image, qa_image
from arl.util.testing_support import create_named_configuration
from arl.imaging import create_image_from_visibility, predict_skycomponent_blockvisibility, \
    advise_wide_field
from arl.imaging.weighting import weight_visibility
from arl.graphs.dask_init import get_dask_Client
from arl.graphs.graphs import create_deconvolve_facet_graph, create_invert_facet_graph, \
    create_invert_wstack_graph, create_predict_facet_graph, compute_list, \
    create_predict_wstack_graph, create_invert_facet_wstack_graph
from arl.pipelines.graphs import create_continuum_imaging_pipeline_graph


import logging

log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

We create the visibility and fill in values with the transform of a number of point sources. 

In [None]:
def ingest_visibility(freq=1e8, chan_width=1e6, time=0.0, reffrequency=[1e8]):
    lowcore = create_named_configuration('LOWBD2-CORE')
    times = [time]
    frequency = numpy.array([freq])
    channel_bandwidth = numpy.array([chan_width])

    phasecentre = SkyCoord(ra=+15.0 * u.deg, dec=-26.7 * u.deg, frame='icrs', equinox='J2000')
    vt = create_blockvisibility(lowcore, times, frequency, channel_bandwidth=channel_bandwidth,
                                weight=1.0, phasecentre=phasecentre, 
                                polarisation_frame=PolarisationFrame("stokesI"))
    npixel = 256
    cellsize=0.001
    model = create_image_from_visibility(vt, npixel=npixel, cellsize=cellsize, npol=1, frequency=reffrequency,
                                        polarisation_frame=PolarisationFrame("stokesI"))
    flux = numpy.array([[100.0]])
    facets = 4

    spacing_pixels = npixel // facets
    spacing = 180.0 * cellsize * spacing_pixels / numpy.pi
    centers = -1.5, -0.5, +0.5, +1.5
    comps = list()
    for iy in centers:
        for ix in centers:
            pra =  int(round(npixel // 2 + ix * spacing_pixels - 1))
            pdec = int(round(npixel // 2 + iy * spacing_pixels - 1))
            sc = pixel_to_skycoord(pra, pdec, model.wcs)
            comps.append(create_skycomponent(flux=flux, frequency=vt.frequency, direction=sc, 
                                             polarisation_frame=PolarisationFrame("stokesI")))
    predict_skycomponent_blockvisibility(vt, comps)

    return vt

Get the Local Sky Model. It is empty.

In [None]:
def get_LSM(vt, npixel = 256, cellsize=0.001, reffrequency=[1e8]):
    model = create_image_from_visibility(vt, npixel=npixel, cellsize=cellsize, npol=1, frequency=reffrequency,
                                        polarisation_frame=PolarisationFrame("stokesI"))
    return model

In [None]:
c=get_dask_Client()

In [None]:
nfreqwin=7
vis_graph_list=list()
for freq in numpy.linspace(0.8e8,1.2e8,nfreqwin):
    vis_graph_list.append(delayed(ingest_visibility)(freq, time=0.0))
nvis=len(vis_graph_list)
vis_graph_list = compute_list(c, vis_graph_list)

npixel=256
facets=4
model_graph = delayed(get_LSM)(vis_graph_list[nvis//2], npixel=npixel)

Calculate optimum parameters for wide field imaging

In [None]:
advice=advise_wide_field(vis_graph_list[0], guard_band_image=4.0)

Make and display a graph to predict using facets

In [None]:
from arl.graphs.vis import simple_vis
predict_graph = create_predict_facet_graph(vis_graph_list, model_graph, facets=2)
simple_vis(predict_graph[0])

Make and display a graph for predict using w stacking

In [None]:
predict_graph = create_predict_wstack_graph(vis_graph_list, model_graph, vis_slices=11)
simple_vis(predict_graph[0])

In [None]:
predict_graph = create_predict_wstack_graph(vis_graph_list, model_graph, vis_slices=11)
predicted_vis_graph_list = compute_list(c, predict_graph)

Do the same for invert. note the difference in structure of the graphs.

In [None]:
dirty_graph = create_invert_wstack_graph(vis_graph_list, model_graph, vis_slices=11, 
                                         wstep=8.0, kernel='wprojection', dopsf=False)
simple_vis(dirty_graph)

In [None]:
dirty_graph = create_invert_facet_graph(vis_graph_list, model_graph, dopsf=False, 
                                        facets=4)
simple_vis(dirty_graph)

In [None]:
dirty_graph = create_invert_facet_wstack_graph(vis_graph_list, model_graph, vis_slices=11, 
                                         wstep=8.0, facets=4, dopsf=False)
simple_vis(dirty_graph)

Now compute the dirty image. At this scale, the FFTs are so cheap that we can make the graph with many more w-slices than we need. There is little immediate overhead for using too many slices but do not try to make a diagram of it!

In [None]:
dirty_graph = create_invert_wstack_graph(vis_graph_list, model_graph, vis_slices=1000, dopsf=False)

In [None]:
future=c.compute(dirty_graph)
dirty, sumwt=future.result()
print(qa_image(dirty, context='Dirty image'))
show_image(dirty, title='Dirty')
plt.show()

In [None]:
c.shutdown()