# Wide-field imaging demonstration

### This script makes a fake data set, fills it with a number of point components, and then images it. 

In [None]:
%matplotlib inline

import os
import sys

sys.path.append(os.path.join('..', '..'))

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (8.0, 8.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy import constants as const
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.visibility.operations import create_visibility
from arl.skymodel.operations import create_skycomponent
from arl.image.operations import show_image, export_image_to_fits
from arl.image.iterators import raster_iter
from arl.visibility.iterators import vis_timeslice_iter
from arl.util.testing_support import create_named_configuration
from arl.fourier_transforms.ftprocessor import invert_2d, create_image_from_visibility, \
    weight_visibility, predict_skycomponent_visibility, create_w_term_image, invert_by_image_partitions, \
    invert_timeslice, invert_wprojection

import logging

log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

doplot = True

### Set up imaging parameters

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated'}

### Construct the SKA1-LOW core configuration

In [None]:
lowcore = create_named_configuration('LOWBD2-CORE')

### We create the visibility. This just makes the uvw, time, antenna1, antenna2, weight columns in a table

In [None]:
times = numpy.arange(-numpy.pi, +numpy.pi, numpy.pi / 4)
frequency = numpy.array([1e8])

reffrequency = numpy.max(frequency)
phasecentre = SkyCoord(ra=+15.0 * u.deg, dec=-60.0 * u.deg, frame='icrs', equinox=2000.0)
vt = create_visibility(lowcore, times, frequency, weight=1.0, phasecentre=phasecentre, params=params)

### Plot the synthesized UV coverage.

In [None]:
if doplot:
    plt.clf()
    for f in frequency:
        x = f / const.c
        plt.plot(x * vt.data['uvw'][:, 0], x * vt.data['uvw'][:, 1], '.', color='b')
        plt.plot(-x * vt.data['uvw'][:, 0], -x * vt.data['uvw'][:, 1], '.', color='r')
        plt.xlabel('U (wavelengths)')
        plt.ylabel('V (wavelengths)')

### Look at the phase term due to w. Evaluate this for the median absolute w. 

In [None]:
if doplot:
    wterm = create_w_term_image(vt, params=params)
    show_image(wterm)
    plt.show()

### Create a grid of components and predict each in turn, using the full phase term including w.

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated',
          'image_partitions':4}

flux = numpy.array([[100.0]])
vt.data['vis'] *= 0.0

model = create_image_from_visibility(vt, params=params)
spacing_pixels = params['npixel'] // params['image_partitions']
log.info('Spacing in pixels = %s' % spacing_pixels)
spacing = 180.0 * params['cellsize'] * spacing_pixels / numpy.pi
centers = -1.5, -0.5, +0.5, +1.5
for iy in centers:
    for ix in centers:
        pra =  int(round(params['npixel'] // 2 + ix * spacing_pixels - 1))
        pdec = int(round(params['npixel'] // 2 + iy * spacing_pixels - 1))
        sc = pixel_to_skycoord(pra, pdec, model.wcs)
        log.info("Component at (%f, %f) %s" % (pra, pdec, str(sc)))
        comp = create_skycomponent(flux=flux, frequency=frequency, direction=sc)
        predict_skycomponent_visibility(vt, comp)

### Make the dirty image and point spread function. Note that the shape of the sources vary with position in the image. This space-variant property of the PSF arises from the w-term.

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated'}

dirty = create_image_from_visibility(vt, params=params)
psf = create_image_from_visibility(vt, params=params)
vt = weight_visibility(vt, dirty)
dirty = invert_2d(vt, dirty, params=params)
psf = invert_2d(vt, dirty, dopsf=True, params=params)
psfmax = psf.data.max()
dirty.data = dirty.data / psfmax
psf.data = psf.data / psfmax

if doplot:
    show_image(dirty)

print("Max, min in dirty image = %.6f, %.6f, psfmax = %f" % (dirty.data.max(), dirty.data.min(), psfmax))
print("Max, min in PSF = %.6f, %.6f, psfmax = %f" % (psf.data.max(), psf.data.min(), psfmax))

export_image_to_fits(dirty, 'imaging-wterm_dirty.fits')
export_image_to_fits(psf, 'imaging-wterm_psf.fits')

### Use image plane partitioning (faceting) to make the image

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated',
          'image_partitions': 4}


dirtyFacet = create_image_from_visibility(vt, params=params)
dirtyFacet = invert_by_image_partitions(vt, dirtyFacet, image_iterator=raster_iter, params=params)
dirtyFacet.data = dirtyFacet.data / (psfmax * params['image_partitions'] * params['image_partitions'])

if doplot:
    show_image(dirtyFacet)

print("Max, min in dirty image = %.6f, %.6f, psfmax = %f" % (dirtyFacet.data.max(), dirtyFacet.data.min(), psfmax))
export_image_to_fits(dirtyFacet, 'imaging-wterm_dirtyFacet.fits')

### That was the best case. This time, we will not arrange for the points to be at the center of the partitions.

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated',
          'image_partitions':2}

dirtyFacet2 = create_image_from_visibility(vt, params=params)
dirtyFacet2 = invert_by_image_partitions(vt, dirtyFacet2, image_iterator=raster_iter, params=params)
dirtyFacet2.data = dirtyFacet2.data / (psfmax * params['image_partitions'] * params['image_partitions'])

if doplot:
    show_image(dirtyFacet2)

print("Max, min in dirty image = %.6f, %.6f, psfmax = %f" % (dirtyFacet2.data.max(), dirtyFacet2.data.min(), psfmax))
export_image_to_fits(dirtyFacet2, 'imaging-wterm_dirtyFacet2.fits')

### Look at images as a function of time. Show difference from the best facet image

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'kernel':'calculated',
          'timeslice': 3600.0}

for visslice in vis_timeslice_iter(vt, params):
    dirtySnapshot = create_image_from_visibility(visslice, params=params)
    dirtySnapshot = invert_2d(visslice, dirtySnapshot, params=params)

    psfSnapshot = create_image_from_visibility(visslice, params=params)
    psfSnapshot = invert_2d(visslice, dirtySnapshot, dopsf=True, params=params)
    psfmax = psfSnapshot.data.max()

    dirtySnapshot.data /= psfmax

    print("Max, min in dirty image = %.6f, %.6f, psfmax = %f" % 
          (dirtySnapshot.data.max(), dirtySnapshot.data.min(), psfmax))
    if doplot:
        dirtySnapshot.data -= dirtyFacet.data
        show_image(dirtySnapshot)



### Image by correcting each time slice and summing

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'timeslice': 1.0,
          'nprocessor':1}

dirtyTimeslice = create_image_from_visibility(vt, params=params)
targetimage = invert_timeslice(vt, dirtyTimeslice, params=params)
psfTimeslice = create_image_from_visibility(vt, params=params)
psfTimeslice  = invert_timeslice(vt, psfTimeslice, dopsf=True, params=params)
psfmax = psfTimeslice.data.max()

dirtyTimeslice.data /= psfmax
psfTimeslice.data /= psfmax

show_image(dirtyTimeslice)
plt.show()

export_image_to_fits(dirtyTimeslice, 'imaging-wterm_dirty_Timeslice.fits')

### Now try w projection. WProjection is implemented via invert_2d with a specialised kernel.

In [None]:
params = {'npixel': 256,
          'npol': 1,
          'cellsize': 0.001,
          'spectral_mode': 'channel',
          'channelwidth': 5e7,
          'reffrequency': 1e8,
          'wloss':0.05}

dirtyWProjection = create_image_from_visibility(vt, params=params)
dirtyWProjection = invert_wprojection(vt, dirtyWProjection, params=params)
psfWProjection = create_image_from_visibility(vt, params=params)
psfWProjection = invert_wprojection(vt, psfWProjection, dopsf=True, params=params)
psfmax = psfWProjection.data.max()
dirtyWProjection.data = dirtyWProjection.data / psfmax
psfWProjection.data = psfWProjection.data / psfmax

if doplot:
    show_image(dirtyWProjection)

print("Max, min in dirty image = %.6f, %.6f, psfmax = %f" % (dirtyWProjection.data.max(), 
                                                             dirtyWProjection.data.min(), psfmax))
export_image_to_fits(dirtyWProjection, 'imaging-wterm_dirty_WProjection.fits')
export_image_to_fits(psfWProjection, 'imaging-wterm_psf_WProjection.fits')