# Dask imaging demonstration with progress bars

This notebook explores the use of dask for parallelisation. We work through the steps of imaging, generating graphs, ending up with a major/minor cycle algorithm using dask. When run interactively, progress bars are shown.

In [None]:
%matplotlib inline

import os
import sys

from distributed import Client, progress
from dask import delayed
import dask

sys.path.append(os.path.join('..', '..'))

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.data.polarisation import PolarisationFrame
from arl.data.parameters import get_parameter
from arl.visibility.operations import create_blockvisibility, create_visibility_from_rows, \
    copy_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.deconvolution import deconvolve_cube, restore_cube
from arl.image.operations import show_image, export_image_to_fits, qa_image
from arl.image.iterators import raster_iter
from arl.visibility.iterators import vis_timeslice_iter, vis_timeslice_iter
from arl.util.testing_support import create_named_configuration
from arl.fourier_transforms.ftprocessor import predict_2d, invert_2d, \
    create_image_from_visibility, predict_skycomponent_visibility, residual_image, \
    invert_timeslice_single, invert_timeslice_single, \
    predict_timeslice_single, predict_timeslice_single, advise_wide_field

from arl.fourier_transforms.ftprocessor import predict_2d, invert_2d, \
    create_image_from_visibility, \
    predict_skycomponent_visibility, residual_image, invert_timeslice_single, \
    predict_timeslice_single, advise_wide_field

from arl.pipelines.dask_graphs import *
from arl.pipelines.generic_dask_graphs import *

import logging

log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))

We create the visibility. 

In [None]:
lowcore = create_named_configuration('LOWBD2-CORE')
times = numpy.linspace(-3,+3,13) * (numpy.pi / 12.0)
frequency = numpy.array([1e8])
channel_bandwidth = numpy.array([1e7])


reffrequency = numpy.max(frequency)
phasecentre = SkyCoord(ra=+15.0 * u.deg, dec=-45.0 * u.deg, frame='icrs', equinox=2000.0)
vt = create_blockvisibility(lowcore, times, frequency, channel_bandwidth=channel_bandwidth,
                       weight=1.0, phasecentre=phasecentre, 
                       polarisation_frame=PolarisationFrame("stokesI"))

Create a grid of components and predict each in turn, using the full phase term including w.

In [None]:
npixel = 512
cellsize=0.001
flux = numpy.array([[100.0]])
facets = 4

model = create_image_from_visibility(vt, npixel=npixel, cellsize=cellsize, npol=1,
                                    polarisation_frame=PolarisationFrame("stokesI"))
spacing_pixels = npixel // facets
log.info('Spacing in pixels = %s' % spacing_pixels)
spacing = 180.0 * cellsize * spacing_pixels / numpy.pi
centers = -1.5, -0.5, +0.5, +1.5
comps = list()
for iy in centers:
    for ix in centers:
        pra =  int(round(npixel // 2 + ix * spacing_pixels - 1))
        pdec = int(round(npixel // 2 + iy * spacing_pixels - 1))
        sc = pixel_to_skycoord(pra, pdec, model.wcs)
        log.info("Component at (%f, %f) %s" % (pra, pdec, str(sc)))
        comps.append(create_skycomponent(flux=flux, frequency=frequency, direction=sc, 
                                         polarisation_frame=PolarisationFrame("stokesI")))
vt = predict_skycomponent_blockvisibility(vt, comps)

In [None]:
model_graph = delayed(create_image_from_visibility)(vt, npixel=512, cellsize=0.001, npol=1)
psf_graph = create_invert_graph(vt, model_graph, dopsf=True, 
                                invert_single=invert_timeslice_single, 
                                iterator=vis_timeslice_iter, 
                                normalize=False, timeslice=10.0)
psf_graph.visualize()
psf, sumwt = psf_graph.compute()

In [None]:
show_image(psf);plt.show()

In [None]:
residual_timeslice_graph = create_residual_graph(vt, model_graph, 
                                                 predict_residual=predict_timeslice_single, 
                                                 invert_residual=invert_timeslice_single, 
                                                 iterator=vis_timeslice_iter)

solution_graph = create_solve_image_graph(vt, model_graph=model_graph, 
                                          invert_residual=invert_timeslice_single, 
                                          predict_residual=predict_timeslice_single, 
                                          iterator=vis_timeslice_iter, algorithm='hogbom',
                                          niter=1000, fractional_threshold=0.1,
                                          threshold=1.0, nmajor=3, gain=0.1)
residual_timeslice_graph = create_residual_graph(vt, solution_graph, 
                                                 predict_residual=predict_timeslice_single, 
                                                 invert_residual=invert_timeslice_single, 
                                                 iterator=vis_timeslice_iter)


restore_graph = create_restore_graph(solution_graph, psf_graph, residual_timeslice_graph)
restore_graph.visualize()

Create a Client to talk to a scheduler (automatically created) and a number of workers.

In [None]:
c=Client()
c.scheduler_info()

Now we can use the client to execute the graph.

In [None]:
future=c.compute(restore_graph);progress(future)

future.result will wait until the result is ready

In [None]:
restored_image=future.result()

In [None]:
show_image(restored_image)
plt.show()