# Calibration using graphs

In [None]:
%matplotlib inline

import os
import sys

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from functools import partial
from dask import delayed
from distributed import progress
import dask.bag as bag

sys.path.append(os.path.join('..', '..'))

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.solvers import solve_gaintable
from arl.calibration.operations import apply_gaintable
from arl.data.data_models import Image
from arl.data.polarisation import PolarisationFrame
from arl.data.parameters import get_parameter
from arl.visibility.base import create_blockvisibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.deconvolution import deconvolve_cube
from arl.image.operations import show_image
from arl.image.iterators import  image_raster_iter
from arl.visibility.iterators import vis_timeslice_iter
from arl.util.testing_support import create_named_configuration
from arl.imaging import predict_2d, advise_wide_field
    
from arl.graphs.dask_init import get_dask_Client
from arl.graphs.graphs import create_invert_wstack_graph, create_predict_wstack_graph, \
    create_selfcal_graph_list
from arl.graphs.generic_graphs import create_generic_image_graph
from arl.util.graph_support import create_simulate_vis_graph, \
    create_predict_gleam_model_graph, create_corrupt_vis_graph, \
    create_gleam_model_graph
from arl.pipelines.graphs import create_continuum_imaging_pipeline_graph
from arl.graphs.vis import simple_vis

import logging

log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

In [None]:
c=get_dask_Client()

We create a graph to make the visibility 

In [None]:
nfreqwin=3
ntimes=5
frequency=numpy.linspace(0.8e8,1.2e8,nfreqwin)
if nfreqwin > 1:
    channel_bandwidth=numpy.array(nfreqwin*[frequency[1]-frequency[0]])
else:
    channel_bandwidth=numpy.array([1e7])
times = numpy.linspace(-numpy.pi/3.0, numpy.pi/3.0, ntimes)
phasecentre=SkyCoord(ra=+30.0 * u.deg, dec=-60.0 * u.deg, frame='icrs', equinox='J2000')

vis_graph_list=create_simulate_vis_graph('LOWBD2-CORE',
                                         frequency=frequency, 
                                         channel_bandwidth=channel_bandwidth,
                                         times=times,
                                         phasecentre=phasecentre)

Find the optimum values for wide field imaging

In [None]:
wprojection_planes=1
advice=advise_wide_field(vis_graph_list[0].compute(), guard_band_image=4.0, delA=0.02,
                         wprojection_planes=wprojection_planes)
vis_slices = advice['vis_slices']

Now make a graph to fill with a model drawn from GLEAM. We then add phase errors of off 1 radian rms to each station. We will compute this graph in order not to confuse it's processing with the imaging. 

In [None]:
corrupted_vis_graph_list = create_predict_gleam_model_graph(vis_graph_list,
                                                            frequency=[frequency[len(frequency)//2]],
                                                            channel_bandwidth=[channel_bandwidth[len(frequency)//2]],
                                                            c_predict_graph=create_predict_wstack_graph,
                                                                vis_slices=vis_slices)
corrupted_vis_graph_list = create_corrupt_vis_graph(corrupted_vis_graph_list, 
                                                    phase_error=1.0)
corrupted_vis_graph_list=c.compute(corrupted_vis_graph_list)

Now make a graph to construct the LSM. The LSM is drawn from GLEAM as well but only includes sources brighter than 1 Jy

In [None]:
LSM_graph=create_gleam_model_graph(vis_graph_list[len(vis_graph_list)//2],
                                   frequency=[frequency[len(frequency)//2]],
                                   channel_bandwidth=[channel_bandwidth[len(frequency)//2]],
                                   flux_limit=1.0).compute()

Now make a dirty image to see the effect of the phase errors introduced

In [None]:
dirty_graph = create_invert_wstack_graph(corrupted_vis_graph_list, LSM_graph,
                                         vis_slices=vis_slices, dopsf=False)
future=c.compute(dirty_graph)
dirty=future.result()[0]
show_image(dirty, title='No selfcal')
plt.show()

First make a selfcal graph in which the different Visibility's are selfcal'ed independently. We will look at the graph for just one Visibility.

In [None]:
selfcal_vis_graph_list = create_selfcal_graph_list(corrupted_vis_graph_list, LSM_graph,
                                                   c_predict_graph=create_predict_wstack_graph,
                                                   vis_slices=vis_slices,
                                                   global_solution=False)
simple_vis(selfcal_vis_graph_list[0])

Now make a global solution. Note that all Visibilities are now coupled.

In [None]:
selfcal_vis_graph_list = create_selfcal_graph_list(corrupted_vis_graph_list, LSM_graph,
                                                   c_predict_graph=create_predict_wstack_graph,
                                                   vis_slices=vis_slices,
                                                   global_solution=True)
simple_vis(selfcal_vis_graph_list[0])

The graph for making the dirty image now shows a global synchronisation point. We alleviate this by only sending averaged visibilities to the gather step before averaging over the solution interval, the model visibility is divided out. Only the gaintable is sent back for application.

In [None]:
dirty_graph = create_invert_wstack_graph(selfcal_vis_graph_list, LSM_graph, facets=2,
                                         vis_slices=vis_slices, dopsf=False)
simple_vis(dirty_graph)

In [None]:
future=c.compute(dirty_graph)
dirty=future.result()[0]
show_image(dirty, title='With selfcal')
plt.show()

In [None]:
c.shutdown()