# Dask imaging demonstration

This notebook explores the use of dask for parallelisation. We work through the steps of imaging using dask, ending up with a major/minor cycle algorithm without and with selfcalibration.

The functions used are in pipelines/dask_graphs.py

In [None]:
%matplotlib inline

import os
import sys

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from functools import partial
from dask import delayed
from distributed import progress
import dask.bag as bag

sys.path.append(os.path.join('..', '..'))

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.solvers import solve_gaintable
from arl.calibration.operations import apply_gaintable, create_gaintable_from_blockvisibility
from arl.data.data_models import Image, BlockVisibility, Visibility
from arl.data.polarisation import PolarisationFrame
from arl.data.parameters import get_parameter
from arl.visibility.operations import create_blockvisibility, create_visibility_from_rows, \
    copy_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.deconvolution import deconvolve_cube, restore_cube
from arl.image.operations import show_image, export_image_to_fits, qa_image, copy_image, create_empty_image_like
from arl.image.gather_scatter import image_gather, image_scatter
from arl.image.iterators import raster_iter
from arl.visibility.iterators import vis_timeslice_iter
from arl.util.testing_support import create_named_configuration, simulate_gaintable, \
    create_low_test_image_from_gleam, create_low_test_beam
from arl.fourier_transforms.ftprocessor import predict_2d, invert_2d, predict_timeslice, invert_timeslice, \
    normalize_sumwt, create_image_from_visibility, \
    predict_skycomponent_blockvisibility, residual_image, invert_timeslice_single, \
    predict_timeslice_single, predict_timeslice_single, advise_wide_field
from arl.graphs.dask_init import get_dask_Client, kill_dask_Client
from arl.graphs.dask_graphs import create_invert_wstack_graph, create_deconvolve_facet_graph

from arl.graphs.generic_dask_graphs import create_generic_image_graph
from arl.util.dask_graph_support import create_simulate_vis_graph, \
    create_predict_gleam_model_graph, create_corrupt_vis_graph, \
    create_dump_vis_graph, create_load_vis_graph
from arl.pipelines.pipeline_dask_graphs import create_continuum_imaging_pipeline_graph, \
    create_ical_pipeline_graph    
from arl.graphs.vis import simple_vis

import logging

log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

In [None]:
c=get_dask_Client()
print(c)

We create a graph to make the visibility 

In [None]:
nfreqwin=7
ntimes=11
frequency=numpy.linspace(0.8e8,1.2e8,nfreqwin)
channel_bandwidth=numpy.array(nfreqwin*[frequency[1]-frequency[0]])
times = numpy.linspace(-numpy.pi/3.0, numpy.pi/3.0, ntimes)
phasecentre=SkyCoord(ra=+30.0 * u.deg, dec=-60.0 * u.deg, frame='icrs', equinox=2000.0)

vis_graph_list=create_simulate_vis_graph(frequency=frequency, 
                                         channel_bandwidth=channel_bandwidth,
                                         times=times,
                                         phasecentre=phasecentre)

Now make a graph to fill with a model drawn from GLEAM 

In [None]:
advice=advise_wide_field(vis_graph_list[0].compute(), guard_band_image=4.0, delA=0.2,
                         wprojection_planes=1)
wstep=advice['w_sampling_primary_beam']
vis_slices = advice['vis_slices']
wprojection_planes = advice['wprojection_planes']
kernel=advice['kernel']

In [None]:
predicted_vis_graph_list = create_predict_gleam_model_graph(vis_graph_list, 
                                                            vis_slices=vis_slices, 
                                                            wstep=wstep, 
                                                            kernel=kernel)

corrupted_vis_graph_list = create_predict_gleam_model_graph(vis_graph_list, 
                                                            vis_slices=vis_slices, 
                                                            wstep=wstep, 
                                                            kernel=kernel,
                                                            phase_error=1.0,)



simple_vis(predicted_vis_graph_list[0])

Get the LSM. This is currently blank.

In [None]:
def get_LSM(vt, npixel = 512, cellsize=0.001, reffrequency=[1e8]):
    model = create_image_from_visibility(vt, npixel=npixel, cellsize=cellsize, 
                                         npol=1, frequency=reffrequency,
                                         polarisation_frame=PolarisationFrame("stokesI"))
    return model

model_graph=delayed(get_LSM)(vis_graph_list[len(vis_graph_list)//2])

Create a graph to make the dirty image 

In [None]:
dirty_graph = create_invert_wstack_graph(predicted_vis_graph_list, 
                                         model_graph, 
                                         vis_slices=vis_slices,
                                         wstep=wstep, 
                                         kernel=kernel, 
                                         dopsf=False)
simple_vis(dirty_graph)

In [None]:
dirty=dirty_graph.compute()
show_image(dirty[0])
plt.show()

In [None]:
continuum_imaging_graph = \
    create_continuum_imaging_pipeline_graph(predicted_vis_graph_list, 
                                            model_graph=model_graph, 
                                            c_deconvolve_graph=create_deconvolve_facet_graph,
                                            facets=1,
                                            c_invert_graph=create_invert_wstack_graph,
                                            vis_slices=vis_slices, wstep=wstep,
                                            kernel=kernel,
                                            algorithm='hogbom', niter=1000, 
                                            fractional_threshold=0.1,
                                            threshold=0.1, nmajor=5, gain=0.1)

In [None]:
future=c.compute(continuum_imaging_graph)

In [None]:
deconvolved = future.result()[0]
residual = future.result()[1]
restored = future.result()[2]

f=show_image(deconvolved, title='Clean image - no selfcal')
print(qa_image(deconvolved, context='Clean image - no selfcal'))

plt.show()

f=show_image(restored, title='Restored clean image - no selfcal')
print(qa_image(restored, context='Restored clean image - no selfcal'))
plt.show()
export_image_to_fits(restored, '%s/imaging-dask_continuum_imaging_restored.fits' 
                     %(results_dir))

f=show_image(residual[0], title='Residual clean image - no selfcal')
print(qa_image(residual[0], context='Residual clean image - no selfcal'))
plt.show()
export_image_to_fits(residual[0], '%s/imaging-dask_continuum_imaging_residual.fits' 
                     %(results_dir))

In [None]:
corrupted_vis_graph_list = create_predict_gleam_model_graph(vis_graph_list, 
                                                            vis_slices=vis_slices, 
                                                            wstep=wstep, kernel=kernel,
                                                            phase_error=1.0)

ical_graph = create_ical_pipeline_graph(corrupted_vis_graph_list, 
                                        model_graph=model_graph,  
                                        c_deconvolve_graph=create_deconvolve_facet_graph,
                                        c_invert_graph=create_invert_wstack_graph,
                                        vis_slices=vis_slices, wstep=wstep,
                                        kernel=kernel,
                                        algorithm='hogbom', niter=1000, 
                                        fractional_threshold=0.1,
                                        threshold=0.1, nmajor=5, 
                                        gain=0.1, first_selfcal=1)

In [None]:
future=c.compute(ical_graph)

In [None]:
deconvolved = future.result()[0]
residual = future.result()[1]
restored = future.result()[2]

f=show_image(deconvolved, title='Clean image')
print(qa_image(deconvolved, context='Clean image'))
plt.show()

f=show_image(restored, title='Restored clean image')
print(qa_image(restored, context='Restored clean image'))
plt.show()
export_image_to_fits(restored, '%s/imaging-dask_ical_restored.fits' 
                     %(results_dir))



f=show_image(residual[0], title='Residual clean image')
print(qa_image(residual[0], context='Residual clean image'))
plt.show()
export_image_to_fits(residual[0], '%s/imaging-dask_ical_residual.fits' 
                     %(results_dir))

In [None]:
c.shutdown()