# ICAL pipeline processing using Dask workflows.

In [None]:
%matplotlib inline

import os
import sys

sys.path.append(os.path.join('..', '..'))

from data_models.parameters import arl_path

results_dir = arl_path('test_results')

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from data_models.polarisation import PolarisationFrame

from wrappers.serial.calibration.calibration import solve_gaintable
from wrappers.serial.calibration.operations import apply_gaintable
from wrappers.serial.calibration.calibration_control import create_calibration_controls
from wrappers.serial.visibility.base import create_blockvisibility
from wrappers.serial.skycomponent.operations import create_skycomponent, filter_skycomponents_by_flux
from wrappers.serial.image.deconvolution import deconvolve_cube
from wrappers.serial.image.operations import show_image, export_image_to_fits, qa_image
from wrappers.serial.visibility.iterators import vis_timeslice_iter
from wrappers.serial.simulation.testing_support import create_named_configuration, create_low_test_image_from_gleam
from wrappers.serial.imaging.base import predict_2d, create_image_from_visibility, advise_wide_field

from workflows.arlexecute.imaging.imaging_arlexecute import invert_list_arlexecute_workflow, \
    predict_list_arlexecute_workflow, deconvolve_list_arlexecute_workflow, weight_list_arlexecute_workflow
from workflows.arlexecute.simulation.simulation_arlexecute import simulate_list_arlexecute_workflow, \
    corrupt_list_arlexecute_workflow
from workflows.arlexecute.pipelines.pipeline_arlexecute import continuum_imaging_list_arlexecute_workflow, \
    ical_list_arlexecute_workflow

from wrappers.arlexecute.execution_support.arlexecute import arlexecute
from wrappers.arlexecute.execution_support.dask_init import get_dask_Client



In [None]:
pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'Greys'

### We will use dask to distribute processing

In [None]:
client = get_dask_Client(memory_limit=4* 1024 * 1024 * 1024, n_workers=7)
arlexecute.set_client(client)


### All Dask workers log to a common file

In [None]:
import logging

def init_logging():
    log = logging.getLogger()
    logging.basicConfig(filename='%s/imaging-pipeline.log' % results_dir,
                        filemode='a',
                        format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                        datefmt='%H:%M:%S',
                        level=logging.INFO)
log = logging.getLogger()
logging.info("Starting ARL demo")
arlexecute.run(init_logging)

### We create a graph to make the visibility. The parameter rmax determines the distance of the furthest antenna/stations used. All other parameters are determined from this number.

In [None]:
nfreqwin=7
ntimes=5
rmax=750.0
frequency=numpy.linspace(0.9e8,1.1e8,nfreqwin)
channel_bandwidth=numpy.array(nfreqwin*[frequency[1]-frequency[0]])
times = numpy.linspace(-numpy.pi/3.0, numpy.pi/3.0, ntimes)
phasecentre=SkyCoord(ra=+30.0 * u.deg, dec=-60.0 * u.deg, frame='icrs', equinox='J2000')

vis_list=simulate_list_arlexecute_workflow('LOWBD2',
                                         frequency=frequency, 
                                         channel_bandwidth=channel_bandwidth,
                                         times=times,
                                         phasecentre=phasecentre,
                                         order='frequency',
                                        rmax=rmax)
print('%d elements in vis_list' % len(vis_list))
vis_list = arlexecute.compute(vis_list, sync=True)
print('Total number of visibilities = %d' % (len(vis_list)*numpy.product(vis_list[0].vis.shape)))

In [None]:
import pprint
pp = pprint.PrettyPrinter()

advice=advise_wide_field(vis_list[-1], guard_band_image=6.0, delA=0.05, oversampling_synthesised_beam=4)
pp.pprint(advice)

context='timeslice'
vis_slices = ntimes
npixel=advice['npixels2']
cellsize=advice['cellsize']

### Now make and compute a graph to fill with a model drawn from GLEAM 

In [None]:
gleam_model = [arlexecute.execute(create_low_test_image_from_gleam)
               (npixel=npixel, frequency=[frequency[f]], channel_bandwidth=[channel_bandwidth[f]],
                cellsize=cellsize, phasecentre=phasecentre, polarisation_frame=PolarisationFrame("stokesI"),
                flux_limit=0.3, applybeam=True, kind='cubic', flux_max=10.0)
                     for f, freq in enumerate(frequency)]
log.info('About to make GLEAM model')
gleam_model = arlexecute.compute(gleam_model, sync=True)
future_gleam_model = arlexecute.scatter(gleam_model)

### Now predict the visibility from this model, and then apply some phase errors

In [None]:
future_vis_graph = arlexecute.scatter(vis_list)
predicted_vislist = predict_list_arlexecute_workflow(future_vis_graph, future_gleam_model,  
                                                context=context, vis_slices=vis_slices)
predicted_vislist = arlexecute.compute(predicted_vislist, sync=True)


In [None]:
corrupted_vislist = corrupt_list_arlexecute_workflow(predicted_vislist, phase_error=1.0, seed=180555)
corrupted_vislist =  arlexecute.compute(corrupted_vislist, sync=True)
future_corrupted_vislist = arlexecute.scatter(corrupted_vislist)

### Get the LSM. This is currently blank. Style of ARL is to create e.g. template images and pass those

In [None]:
model_list = [arlexecute.execute(create_image_from_visibility)(vis_list[f],
                                                     npixel=npixel,
                                                     frequency=[frequency[f]],
                                                     channel_bandwidth=[channel_bandwidth[f]],
                                                     cellsize=cellsize,
                                                     phasecentre=phasecentre,
                                                     polarisation_frame=PolarisationFrame("stokesI"))
               for f, freq in enumerate(frequency)]
future_model_graph = arlexecute.persist(model_list)

### Weight the data

In [None]:
future_corrupted_vislist = weight_list_arlexecute_workflow(future_corrupted_vislist, future_model_graph)

### Create and execute graph to make the dirty image

In [None]:
future_corrupted_vislist = arlexecute.scatter(corrupted_vislist)
dirty_list = invert_list_arlexecute_workflow(future_corrupted_vislist, future_model_graph,
                                                 context=context,
                                                 vis_slices=vis_slices, dopsf=False)

dirty_list = arlexecute.compute(dirty_list, sync=True)
dirty = dirty_list[0][0]
show_image(dirty, cm='Greys', vmax=1.0, vmin=-0.1)
print(qa_image(dirty))
plt.show()

### Now we set up the ICAL pipeline calibration controls

In [None]:
controls = create_calibration_controls()
        
controls['T']['first_selfcal'] = 1
controls['G']['first_selfcal'] = 3
controls['B']['first_selfcal'] = 4

controls['T']['timescale'] = 'auto'
controls['G']['timescale'] = 'auto'
controls['B']['timescale'] = 1e5

pp.pprint(controls)

### ICAL with:
- wstacking
- msmfs distributed clean (8 by 8 subimages overlapped by 16 pixels)
- selfcal for T (first iteration), G (third iteration), B (fourth iteration)
 
First make the graph

In [None]:
future_corrupted_vislist = arlexecute.scatter(corrupted_vislist)
ical_list = ical_list_arlexecute_workflow(future_corrupted_vislist, 
                                          model_imagelist=model_list,  
                                          context=context, 
                                          vis_slices=vis_slices,
                                          scales=[0, 3, 10], algorithm='mmclean', 
                                          nmoment=3, niter=1000, 
                                          fractional_threshold=0.1,
                                          threshold=0.01, nmajor=5, gain=0.25,
                                          psf_support=128,
                                          deconvolve_facets = 8, 
                                          deconvolve_overlap=16,
                                          deconvolve_taper='tukey',
                                          timeslice='auto',
                                          global_solution=False, 
                                          do_selfcal=True,
                                          calibration_context = 'TG', 
                                          controls=controls)

### Now run the graph

In [None]:
log.info('About to run ical')
result=arlexecute.compute(ical_list, sync=True)
centre=nfreqwin//2
residual = result[1][centre]
restored = result[2][centre]

f=show_image(restored, title='Restored clean image', cm='Greys', vmax=1.0, 
             vmin=-0.1)
print(qa_image(restored, context='Restored clean image'))
plt.show()
f=show_image(restored, title='Restored clean image (10x deeper)', cm='Greys', vmax=0.1, 
             vmin=-0.01)
plt.show()
export_image_to_fits(restored, '%s/imaging-dask_ical_restored.fits' 
                     %(results_dir))

f=show_image(residual[0], title='Residual clean image (10x deeper)', cm='Greys', 
             vmax=0.1, vmin=-0.01)
print(qa_image(residual[0], context='Residual clean image'))
plt.show()
export_image_to_fits(residual[0], '%s/imaging-dask_ical_residual.fits' 
                     %(results_dir))