# Investigation of compression for LOW data

Baseline dependent averaging is a form of data compression. In this script, we use a gridded approach to BDA. We create a critically sampled snapshot of a LOW data set, and then compress and decompress it to see what errors result.

In [None]:
%matplotlib inline

import os
import sys

from time import clock

sys.path.append(os.path.join('..', '..'))

from matplotlib import pylab

pylab.rcParams['agg.path.chunksize'] = 10000
pylab.rcParams['figure.figsize'] = (10.0, 10.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy import constants as const
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt
from matplotlib.pyplot import cm 

from arl.visibility.operations import create_visibility
from arl.skymodel.operations import create_skycomponent, insert_skycomponent
from arl.image.operations import show_image, export_image_to_fits, qa_image, create_image_from_array, reproject_image
from arl.fourier_transforms.fft_support import extract_mid
from arl.visibility.compress import compress_visibility, decompress_visibility
from arl.image.iterators import raster_iter
from arl.visibility.iterators import vis_timeslice_iter
from arl.util.testing_support import create_named_configuration
from arl.fourier_transforms.ftprocessor import *

import logging
log = logging.getLogger()
log.setLevel(logging.DEBUG)
log.addHandler(logging.StreamHandler(sys.stdout))

Construct the SKA1-LOW configuration

In [None]:
low = create_named_configuration('LOWBD2')

In [None]:
oversampling = 2
sampling_time = 35.0 / (oversampling * 8e4)
log.info("Critical sampling time = %.5f (radians) %.2f (seconds)" % 
         (sampling_time, sampling_time * 43200.0 / numpy.pi))
sampling_frequency = 1e8 * 35.0 / (oversampling * 8e4) 
log.info("Critical sampling frequency = %.5f (Hz) " % (sampling_frequency))
times = numpy.arange(- sampling_time, + 1.01 * sampling_time, sampling_time)
frequency = numpy.linspace(1e8 - sampling_frequency, 1e8 + sampling_frequency, 3)
print("Observing frequencies %s Hz" % (frequency))

We create the visibility holding the vis, uvw, time, antenna1, antenna2, weight columns in a table. The actual visibility values are zero.

In [None]:
reffrequency = numpy.max(frequency)
phasecentre = SkyCoord(ra=+180.0 * u.deg, dec=-60.0 * u.deg, frame='icrs', equinox=2000.0)
vt = create_visibility(low, times, frequency, weight=1.0, phasecentre=phasecentre, npol=1)

In [None]:
plt.clf()
for chan in range(len(frequency)):
    plt.plot(+vt.uvw_lambda(chan)[:,0], +vt.uvw_lambda(chan)[:,1], '.', color='r')
    plt.plot(-vt.uvw_lambda(chan)[:,0], -vt.uvw_lambda(chan)[:,1], '.', color='b')
plt.title('Original uv coverage')
plt.xlabel('U (m)')
plt.ylabel('V (m)')
plt.show()

Fill the LOW field of view with sources. We create an image just to do the WCS book keeping

In [None]:
vt.data['vis'] *= 0.0
npixel=16384

model = create_image_from_visibility(vt, npixel=npixel, cellsize=0.00001, npol=1, nchan=len(frequency))
centre = model.wcs.wcs.crpix-1
spacing_pixels = npixel // 16
log.info('Spacing in pixels = %s' % spacing_pixels)
spacing = model.wcs.wcs.cdelt * spacing_pixels
locations = [-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]

# We calculate the source positions in pixels and then calculate the
# world coordinates to put in the skycomponent description
for iy in locations:
    for ix in locations:
        if ix >= iy:
            p = int(round(centre[0] + ix * spacing_pixels * numpy.sign(model.wcs.wcs.cdelt[0]))), \
                int(round(centre[1] + iy * spacing_pixels * numpy.sign(model.wcs.wcs.cdelt[1])))
            sc = pixel_to_skycoord(p[0], p[1], model.wcs)
            log.info("Component at (%f, %f) [0-rel] %s" % (p[0], p[1], str(sc)))
            flux = numpy.array([len(frequency)*[100.0 + 2.0 * ix + iy * 20.0]]).transpose()
            comp = create_skycomponent(flux=flux, frequency=frequency, direction=sc)
            predict_skycomponent_visibility(vt, comp)

Select the shortest uvw points and form a new visibility set. It is this that we will compress using a grid in u,v space.

In [None]:
boundary = 0.01
visr = numpy.sqrt(vt.u**2+vt.v**2)
uvmax = numpy.max(visr)
inner_rows = (numpy.abs(visr)  < boundary * uvmax)
vts = create_visibility_from_rows(vt, inner_rows)
plt.clf()
plt.plot(vts.u, vts.v, '.')
plt.title('Original inner uv coverage')
plt.xlabel('U (m)')
plt.ylabel('V (m)')

plt.show()

Create a model to serve as the image specification for the compression. We set the cellsize correspondingly larger so that the sampling of the grid in uv is much finer.

In [None]:
errors = []
signals = []
t_compresses = []
t_decompresses = []

npixels = [128, 256, 512, 1024, 2048, 4096, 8192]

for npixel in npixels:
    model = create_image_from_visibility(vts, npixel=npixel, npol=1, image_nchan=1, cellsize=0.00001/boundary)

    ts = clock()
    cvts = compress_visibility(vts, model)
    t_compress = clock() - ts
    t_compresses.append(t_compress)
    log.debug("Compression using npixel = %d took %.1f seconds" % (npixel, t_compress))


    plt.clf()
    plt.plot(cvts.u, cvts.v, '.')
    plt.title('Compressed uv coverage')
    plt.xlabel('U (m)')
    plt.ylabel('V (m)')
    plt.show()

    template_vis = copy.deepcopy(vts)
    template_vis.data['vis']*=0.0
    ts = clock()
    dcvts = decompress_visibility(cvts, template_vis, model)
    t_decompress = clock() - ts
    t_decompresses.append(t_decompress)
    log.debug("Decompression using npixel = %d took %.1f seconds" % (npixel, t_decompress))


    plt.clf()
    signal = numpy.sqrt(numpy.average(numpy.abs(vts.vis[...,0,0])**2))
    error  = numpy.sqrt(numpy.average(numpy.abs(vts.vis[...,0,0]-dcvts.vis[...,0,0])))
    signals.append(signal)
    errors.append(error)
    plt.plot(vts.vis[...,0,0].real, vts.vis[...,0,0].imag, '.', color='g', label='Visibility')
    plt.plot(vts.vis[...,0,0].real - dcvts.vis[...,0,0].real,  
         vts.vis[...,0,0].imag - dcvts.vis[...,0,0].imag, '.', color='r', label='Error in recovery')
    plt.title("Error in visibility recovery: npixel %d rms signal, error = %.2f, %.2f" % (npixel, signal, error))
    plt.xlabel('Real part of error')
    plt.ylabel('Imaginary part of error')
    plt.show()



In [None]:
plt.clf()
plt.semilogy(npixels, errors, color='r', label='RMS Error')
plt.semilogy(npixels, signals, color='b', label='RMS Visibility')
plt.title('RMS compression/decompression')
plt.xlabel('Number of pixels')
plt.ylabel('RMS')
plt.legend()
plt.show()
plt.semilogy(npixels, t_compresses, color='r', label='compression')
plt.semilogy(npixels, t_decompresses, color='b', label='decompression')
plt.legend()
plt.title('Time for compression and decompression')
plt.xlabel('Number of pixels')
plt.ylabel('Time (s)')

plt.show()

