# Dask imaging demonstration

This notebook explores the use of dask for parallelisation.

In [None]:
%matplotlib inline

import os
import sys

from dask import delayed

sys.path.append(os.path.join('..', '..'))

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.data.polarisation import PolarisationFrame
from arl.data.parameters import get_parameter
from arl.visibility.operations import create_visibility, create_visibility_from_rows, copy_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.deconvolution import deconvolve_cube
from arl.image.operations import show_image, export_image_to_fits
from arl.image.iterators import raster_iter
from arl.visibility.iterators import vis_timeslice_iter, vis_wslice_iter
from arl.util.testing_support import create_named_configuration
from arl.fourier_transforms.ftprocessor import predict_2d, invert_2d, create_image_from_visibility, \
    predict_skycomponent_visibility, residual_image, invert_timeslice_single, invert_wslice_single, \
    predict_timeslice_single, predict_wslice_single,residual_image, advise_wide_field

import logging

log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

Construct the SKA1-LOW core configuration

In [None]:
lowcore = create_named_configuration('LOWBD2-CORE')

We create the visibility. 

This just makes the uvw, time, antenna1, antenna2, weight columns in a table

In [None]:
times = numpy.linspace(-3,+3,13) * (numpy.pi / 12.0)
frequency = numpy.array([1e8])
channel_bandwidth = numpy.array([1e7])


reffrequency = numpy.max(frequency)
phasecentre = SkyCoord(ra=+15.0 * u.deg, dec=-45.0 * u.deg, frame='icrs', equinox=2000.0)
vt = create_visibility(lowcore, times, frequency, channel_bandwidth=channel_bandwidth,
                       weight=1.0, phasecentre=phasecentre, polarisation_frame=PolarisationFrame("stokesI"))

Advise on wide field parameters. This returns a dictionary with all the input and calculated variables.

In [None]:
advice = advise_wide_field(vt)

Create a grid of components and predict each in turn, using the full phase term including w.

In [None]:
params = {'npixel': 512,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channel_bandwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated',
          'facets':4}

npixel = 512
cellsize=0.001
facets = 4
flux = numpy.array([[100.0]])

model = create_image_from_visibility(vt, npixel=512, cellsize=0.001, npol=1)
spacing_pixels = npixel // facets
log.info('Spacing in pixels = %s' % spacing_pixels)
spacing = 180.0 * cellsize * spacing_pixels / numpy.pi
centers = -1.5, -0.5, +0.5, +1.5
for iy in centers:
    for ix in centers:
        pra =  int(round(npixel // 2 + ix * spacing_pixels - 1))
        pdec = int(round(npixel // 2 + iy * spacing_pixels - 1))
        sc = pixel_to_skycoord(pra, pdec, model.wcs)
        log.info("Component at (%f, %f) %s" % (pra, pdec, str(sc)))
        comp = create_skycomponent(flux=flux, frequency=frequency, direction=sc, 
                                   polarisation_frame=PolarisationFrame("stokesI"))
        predict_skycomponent_visibility(vt, comp)

Define a Dask enabled invert looking like invert_2d but with additional arguments for the invert for a single chunk, and the iterator. The iterator is used to split the visibility up into pieces before calling the
single chunk.

In [None]:
def invert_dask(vt, model, dopsf=False, normalize=True, invert_single=invert_2d, iterator=vis_timeslice_iter, 
                **kwargs):

    def accumulate_results(results, normalize=normalize):
        acc = []
        sumwt = 0.0
        nresults = len(results)
        for i, result in enumerate(results):
            if i>0:
                acc.data += result[0].data
                sumwt += result[1]
            else:
                acc = result[0]
                sumwt = result[1]
        
        if normalize:
            acc.data /= float(sumwt)
            
        return acc, sumwt 

    results = list()

    for rows in iterator(vt, **kwargs):
        v = copy_visibility(create_visibility_from_rows(vt, rows))
        result = delayed(invert_single, pure=True)(v, model, dopsf=dopsf, normalize=False, **kwargs)
        results.append(result)

    return delayed(accumulate_results, pure=True)(results, normalize)

In [None]:
dirty_2d_dask = invert_dask(vt, model, False, invert_single=invert_2d, iterator=vis_timeslice_iter, normalize=False)


Now we can execute the graph:

In [None]:
dirty_2d, sumwt_2d = dirty_2d_dask.compute()
show_image(dirty_2d)

print("Max, min in dirty image = %.6f, %.6f, sumwt = %s" % (dirty_2d.data.max(), dirty_2d.data.min(),
     sumwt_2d))

export_image_to_fits(dirty_2d, '%s/imaging-dask_2d.fits' % (results_dir))

Now we do the same thing but with improved invert and predict

In [None]:
dirty_wslice_dask = invert_dask(vt, model, False, invert_single=invert_wslice_single, iterator=vis_wslice_iter,
                           normalize=False, wslice=10.0)

In [None]:
dirty_wslice, sumwt_wslice = dirty_wslice_dask.compute()
show_image(dirty_wslice)

print("Max, min in dirty image = %.6f, %.6f, sumwt = %s" % (dirty_wslice.data.max(), dirty_wslice.data.min(),
     sumwt_wslice))

export_image_to_fits(dirty_wslice, '%s/imaging-dask_wslice.fits' % (results_dir))

Now do timeslicing

In [None]:
dirty_timeslice_dask = invert_dask(vt, model, False, invert_single=invert_timeslice_single, 
                                   iterator=vis_timeslice_iter, normalize=False)

In [None]:
dirty_timeslice, sumwt_timeslice = dirty_timeslice_dask.compute()
show_image(dirty_timeslice)

print("Max, min in dirty image = %.6f, %.6f, sumwt = %s" % (dirty_timeslice.data.max(), dirty_timeslice.data.min(),
     sumwt_timeslice))

export_image_to_fits(dirty_timeslice, '%s/imaging-dask_timeslice.fits' % (results_dir))

Now do the same for the predict function

In [None]:
def predict_dask(vt, model, predict_single=predict_timeslice_single, iterator=vis_timeslice_iter, **kwargs):

    
    def accumulate_results(results):
        i=0
        for rows in iterator(vt, **kwargs):
            visslice = create_visibility_from_rows(vt, rows)
            vt.data['vis'][rows] += results[i].data['vis']
            i+=1
            
        return vt 

    results = list()

    for rows in iterator(vt, **kwargs):
        visslice = copy_visibility(create_visibility_from_rows(vt, rows))
        result = delayed(predict_single, pure=True)(visslice, model, **kwargs)
        results.append(result)

    return delayed(accumulate_results, pure=True)(results)

In [None]:
vtpred = copy_visibility(vt)
predict_timeslice_dask = predict_dask(vtpred, model, predict_single=predict_timeslice_single, 
                                      iterator=vis_timeslice_iter)

Execute the graph

In [None]:
vtpred = predict_timeslice_dask.compute()

Now we will turn to major/minor cycle cleaning. In this case, it turns out that we would benefit from the residual_image function since it does predict/invert on a single chunk rather than doing predict on all chunks and then invert on all chunks.

First we need the corresponding dask function:

In [None]:
def residual_dask(vis, visres, model, iterator=vis_timeslice_iter, **kwargs):

    def accumulate_results(results):
        i=0
        for rows in iterator(visres, **kwargs):
            visslice = create_visibility_from_rows(visres, rows)
            visres.data['vis'][rows] = results[i][0].data['vis']
            i+=1           

        acc = []
        sumwt = 0.0
        nresults = len(results)
        for i, result in enumerate(results):
            if i>0:
                acc.data += result[1].data
                sumwt += result[2]
            else:
                acc = result[1]
                sumwt = result[2]
        
        acc.data /= float(sumwt)
            
        return visres, acc, sumwt

    results = list()

    for rows in iterator(vis, **kwargs):
        visslice = copy_visibility(create_visibility_from_rows(vis, rows))
        result = delayed(residual_image, pure=True)(visslice, model, normalize=False, **kwargs)
        results.append(result)

    return delayed(accumulate_results, pure=True)(results)

We will use the timeslice functions.

In [None]:
residual_timeslice_dask = residual_dask(vt, vtpred, model, 
                                        predict_residual=predict_timeslice_single, 
                                        invert_residual=invert_timeslice_single,
                                        iterator=vis_timeslice_iter)
residual_timeslice_dask.visualize()

Finally we make a version of solve_image adapted to this approach

In [None]:
def solve_image_dask(vis, model, components=None, residual=residual_dask, invert=invert_dask, **kwargs):
    """Solve for image using deconvolve_cube and specified predict, invert

    This is the same as a majorcycle/minorcycle algorithm. The components are removed prior to deconvolution.

    See also arguments for predict, invert, deconvolve_cube functions.2d

    :param vis:
    :param model: Model image
    :param predict: Predict function e.g. predict_2d, predict_wslice
    :param invert: Invert function e.g. invert_2d, invert_wslice
    :returns: Visibility, model
    """
    nmajor = get_parameter(kwargs, 'nmajor', 5)
    log.info("solve_image: Performing %d major cycles" % nmajor)
    
    # The model is added to each major cycle and then the visibilities are
    # calculated from the full model
    visres = copy_visibility(vis)
    visres.data['vis'][...] = 0.0

    dask_residual=residual(vis, visres, model, **kwargs)
    visres, dirty, sumwt = dask_residual.compute()
    
    if components is not None:
        vispred = predict_skycomponent_visibility(vispred, components)
    
    dask_psf = invert(visres, model, dopsf=True, **kwargs)
    psf, sumwt = dask_psf.compute()
    
    thresh = get_parameter(kwargs, "threshold", 0.0)
    
    for i in range(nmajor):
        log.info("solve_image: Start of major cycle %d" % i)
        cc, res = deconvolve_cube(dirty, psf, **kwargs)
        model.data += cc.data
        dask_residual=residual(vis, visres, model, **kwargs)
        visres, dirty, sumwt = dask_residual.compute()
        if numpy.abs(dirty.data).max() < 1.1 * thresh:
            log.info("Reached stopping threshold %.6f Jy" % thresh)
            break
        log.info("solve_image: End of major cycle")
    
    log.info("solve_image: End of major cycles")
    return visres, model, dirty

Now we can solve for the image

In [None]:
model.data*=0.0
visres, model, residual = solve_image_dask(vt, model=model, invert=invert_dask, 
                                           invert_residual=invert_timeslice_single, 
                                           predict_residual=predict_timeslice_single, 
                                           iterator=vis_timeslice_iter, algorithm='hogbom',
                                           niter=1000, fractional_threshold=0.1,
                                           threshold=1.0, nmajor=3, gain=0.1)

In [None]:
show_image(residual)