# ICAL pipeline processing using Dask workflows.

In [None]:
% matplotlib inline

import os
import sys

sys.path.append(os.path.join('..', '..'))

from data_models.parameters import arl_path

results_dir = arl_path('test_results')

from matplotlib import pylab

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from data_models.polarisation import PolarisationFrame

from wrappers.serial.calibration.calibration import solve_gaintable
from wrappers.serial.calibration.calibration_control import create_calibration_controls
from wrappers.serial.image.operations import show_image, export_image_to_fits, qa_image
from wrappers.serial.simulation.testing_support import create_low_test_skymodel_from_gleam
from wrappers.serial.imaging.base import create_image_from_visibility, advise_wide_field, \
    predict_skycomponent_visibility
from wrappers.arlexecute.image.gather_scatter import image_gather_channels
from wrappers.arlexecute.visibility.coalesce import convert_blockvisibility_to_visibility, \
    convert_visibility_to_blockvisibility

from workflows.arlexecute.imaging.imaging_arlexecute import invert_list_arlexecute_workflow, \
    weight_list_arlexecute_workflow, \
    taper_list_arlexecute_workflow, remove_sumwt

from workflows.arlexecute.simulation.simulation_arlexecute import simulate_list_arlexecute_workflow, \
    corrupt_list_arlexecute_workflow
from workflows.arlexecute.pipelines.pipeline_arlexecute import continuum_imaging_list_arlexecute_workflow, \
    ical_list_arlexecute_workflow
from workflows.arlexecute.skymodel.skymodel_arlexecute import predict_skymodel_list_arlexecute_workflow

from wrappers.arlexecute.execution_support.arlexecute import arlexecute
from wrappers.arlexecute.execution_support.dask_init import get_dask_Client


In [None]:
pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'Greys'

### We will use dask to distribute processing

In [None]:
client = get_dask_Client(memory_limit=4 * 1024 * 1024 * 1024, n_workers=7)
arlexecute.set_client(client)

### All Dask workers log to a common file

In [None]:
import logging

def init_logging():
    log = logging.getLogger()
    logging.basicConfig(filename='%s/imaging-pipeline.log' % results_dir,
                        filemode='a',
                        format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                        datefmt='%H:%M:%S',
                        level=logging.INFO)
log = logging.getLogger()
logging.info("Starting ARL demo")
arlexecute.run(init_logging)

### We create a graph to make the visibility. The parameter rmax determines the distance of the furthest antenna/stations used. All other parameters are determined from this number.

In [None]:
nfreqwin=7
ntimes=5
rmax=750.0
frequency=numpy.linspace(0.9e8,1.1e8,nfreqwin)
channel_bandwidth=numpy.array(nfreqwin*[frequency[1]-frequency[0]])
times = numpy.linspace(-numpy.pi/3.0, numpy.pi/3.0, ntimes)
phasecentre=SkyCoord(ra=+30.0 * u.deg, dec=-60.0 * u.deg, frame='icrs', equinox='J2000')

blockvis_list=simulate_list_arlexecute_workflow('LOWBD2',
                                         frequency=frequency, 
                                         channel_bandwidth=channel_bandwidth,
                                         times=times,
                                         phasecentre=phasecentre,
                                         order='frequency',
                                                format='blockvis',
                                        rmax=rmax)
blockvis_list = arlexecute.compute(blockvis_list, sync=True)
print('%d rows in block vis_list' % len(blockvis_list))

Create row oriented visibility

In [None]:
blockvis_list = arlexecute.scatter(blockvis_list)
vis_list = [arlexecute.execute(convert_blockvisibility_to_visibility)(bv) for bv in blockvis_list]
vis_list = arlexecute.compute(vis_list, sync=True)
print('%d rows in vis_list' % len(vis_list))

In [None]:
import pprint
pp = pprint.PrettyPrinter()

advice=advise_wide_field(vis_list[-1], guard_band_image=6.0, delA=0.1, 
                         oversampling_synthesised_beam=4)
pp.pprint(advice)
context='wstack'
vis_slices = advice['vis_slices']
npixel=advice['npixels2']
cellsize=advice['cellsize']

### Now make and compute a graph to fill a skymodel drawn from GLEAM 

In [None]:
gleam_skymodel = [arlexecute.execute(create_low_test_skymodel_from_gleam)
                     (npixel=npixel, cellsize=cellsize, frequency=[frequency[f]],
                      phasecentre=phasecentre,
                      polarisation_frame=PolarisationFrame("stokesI"),
                      flux_limit=0.1,
                      flux_threshold=1.0,
                      flux_max=5.0) for f, freq in enumerate(frequency)]
gleam_skymodel = arlexecute.compute(gleam_skymodel, sync=True)
future_gleam_skymodel = arlexecute.scatter(gleam_skymodel)


### Now predict the visibility from this model

In [None]:
vis_list = arlexecute.scatter(vis_list)
predicted_vislist = predict_skymodel_list_arlexecute_workflow(vis_list, gleam_skymodel, 
                                                              context=context, 
                                                              vis_slices=vis_slices,
                                                              facets=1)
predicted_vislist = arlexecute.compute(predicted_vislist, sync=True)


### Get the LSM. This is currently blank.

In [None]:
model_list = [arlexecute.execute(create_image_from_visibility)
              (vis_list[f],
               npixel=npixel,
               frequency=[frequency[f]],
               channel_bandwidth=[channel_bandwidth[f]],
               cellsize=cellsize, phasecentre=phasecentre,
               polarisation_frame=PolarisationFrame("stokesI"))
               for f, freq in enumerate(frequency)]
model_list = arlexecute.compute(model_list, sync=True)
model_list = arlexecute.scatter(model_list)

### Weight the data

In [None]:
future_predicted_vislist = arlexecute.scatter(predicted_vislist)
predicted_vislist = weight_list_arlexecute_workflow(future_predicted_vislist, model_list)
predicted_vislist = taper_list_arlexecute_workflow(predicted_vislist, 0.003)
predicted_vislist = arlexecute.compute(predicted_vislist, sync=True)

### Add phase errors

In [None]:
corrupted_blockvislist = [arlexecute.execute(convert_visibility_to_blockvisibility, nout=1)(v) 
                      for v in predicted_vislist]
corrupted_blockvislist = corrupt_list_arlexecute_workflow(corrupted_blockvislist, phase_error=1.0, seed=180555)
corrupted_blockvislist = arlexecute.compute(corrupted_blockvislist, sync=True)

In [None]:
corrupted_vislist = [arlexecute.execute(convert_blockvisibility_to_visibility, nout=1)(v) 
                      for v in corrupted_blockvislist]
corrupted_vislist = arlexecute.compute(corrupted_vislist, sync=True)

### Create and execute graph to make the dirty image

In [None]:
future_corrupted_vislist = arlexecute.scatter(corrupted_vislist)
dirty_list = invert_list_arlexecute_workflow(future_corrupted_vislist, model_list,
                                                 context=context,
                                                 vis_slices=vis_slices, dopsf=False)

dirty_list = arlexecute.compute(dirty_list, sync=True)
dirty = dirty_list[0][0]
show_image(dirty, cm='Greys')
print(qa_image(dirty))
plt.show()

### Now we set up the ICAL pipeline calibration controls

In [None]:
controls = create_calibration_controls()
        
controls['T']['first_selfcal'] = 1
controls['G']['first_selfcal'] = 3

controls['T']['timescale'] = 'auto'
controls['G']['timescale'] = 'auto'

pp.pprint(controls)

### ICAL with:
- timeslicing
- msmfs distributed clean (8 by 8 subimages overlapped by 16 pixels)
- selfcal for T (first iteration), G (third iteration)
 
First make the graph

In [None]:
ical_list = ical_list_arlexecute_workflow(corrupted_vislist, 
                                          model_imagelist=model_list,  
                                          context=context, 
                                          vis_slices=vis_slices,
                                          scales=[0, 3, 10], algorithm='mmclean', 
                                          nmoment=3, niter=1000, 
                                          fractional_threshold=0.1,
                                          threshold=0.01, nmajor=5, gain=0.25,
                                          psf_support=128,
                                          deconvolve_facets=8,
                                          deconvolve_overlap=16,
                                          deconvolve_taper='tukey',
                                          timeslice='auto',
                                          global_solution=False, 
                                          do_selfcal=True,
                                          calibration_context = 'TG', 
                                          controls=controls,
                                          tol=1e-6)

### Now run the graph

In [None]:
log.info('About to run ical')
result=arlexecute.compute(ical_list, sync=True)
centre=nfreqwin//2
residual = result[1][centre]
restored = result[2][centre]

f=show_image(restored, title='Restored clean image', cm='Greys', vmax=1.0, vmin=-0.1)
print(qa_image(restored, context='Restored clean image'))
plt.show()

f=show_image(restored, title='Restored clean image (10x deeper)', cm='Greys', vmax=0.1, vmin=-0.01)
plt.show()
restored_cube = image_gather_channels(result[2])
export_image_to_fits(restored_cube, '%s/arl_demo_ical_%s_restored.fits' 
                     %(results_dir, context))

f=show_image(residual[0], title='Residual clean image (10x deeper)', cm='Greys', 
             vmax=0.1, vmin=-0.01)
print(qa_image(residual[0], context='Residual clean image'))
plt.show()
export_image_to_fits(residual[0], '%s/arl_demo_ical_%s_residual.fits' 
                     %(results_dir, context))