# AW-Projection

A tweakable end-to-end runthrough to test AW-projection

In [None]:
import numpy as np
import os
import datetime

import astropy.units as u
import astropy.constants as const
import fastimgproto.imager as imager
import fastimgproto.visibility as visibility
import fastimgproto.gridder.conv_funcs as kfuncs

from astropy.coordinates import Angle, SkyCoord, AltAz, EarthLocation
from astropy.time import Time

from fastimgproto.skymodel.helpers import SkyRegion, SkySource
from fastimgproto.sourcefind.image import SourceFindImage
from fastimgproto.telescope.readymade import Meerkat
from fastimgproto.telescope.readymade import VLA_C

from fastimgproto.gridder.conv_funcs import Pillbox
from fastimgproto.gridder.conv_funcs import Triangle
from fastimgproto.gridder.conv_funcs import Sinc
from fastimgproto.gridder.conv_funcs import Gaussian
from fastimgproto.gridder.conv_funcs import GaussianSinc
from fastimgproto.gridder.conv_funcs import PSWF

from fastimgproto.bindings import cpp_image_visibilities, CppKernelFuncs


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# Plot image pixels in cartesian ordering (i.e. y-positive == upwards):
plt.rcParams['image.origin'] = 'lower'
# Make plots bigger
plt.rcParams['figure.figsize'] = 10, 10

In [None]:

IMAGER_PYTHON=True
IMAGER_CPP=True
WENSS_DATA=True


In [None]:
if WENSS_DATA is True:
    telescope = VLA_C()
else:
    telescope = Meerkat()
    
print("Telescope with {} antennae == {} baselines".format(
    len(telescope.ant_local_xyz), len(telescope.baseline_local_xyz)))
print("Centre: {!r}, {!r}".format(telescope.lon, telescope.lat))

In [None]:
if WENSS_DATA is True:
    pointing_centre = SkyCoord(194.24 * u.deg, 47.339 * u.deg)
    obs_central_frequency = 74 * u.MHz
    wavelength = const.c / obs_central_frequency
    transit_time = telescope.next_transit(pointing_centre.ra,
                                          start_time=Time('2017-01-01'))
else:
    pointing_centre = SkyCoord(0 * u.deg, -30 * u.deg)
    obs_central_frequency = 30. * u.MHz
    wavelength = const.c / obs_central_frequency
    transit_time = telescope.next_transit(pointing_centre.ra,
                                          start_time=Time('2010-01-01'))

In [None]:
altaz = pointing_centre.transform_to(
    AltAz(obstime=transit_time,
         location=telescope.centre))
altaz.alt.deg

In [None]:
if WENSS_DATA is True:
    nstep=1440
    obs_times = transit_time + np.linspace(-6, 6, nstep) * u.hr
    print("Generating UVW-baselines for {} timesteps".format(nstep))
    uvw_m, lha = telescope.uvw_tracking_skycoord(pointing_centre, obs_times)
    # From here on we use UVW as multiples of wavelength, lambda:
    uvw_lambda = (uvw_m / wavelength).to(u.dimensionless_unscaled).value
    
    sources_array=np.matrix([[ 1, 11, 47, 52.21, 48, 18, 49.2, 2033], [ 2, 11, 51, 9.26, 47, 28, 55.7, 2889], [ 3, 11, 53, 24.60, 49, 31, 9.6, 4041], [ 4, 11, 53, 46.43, 51, 17, 4.6, 2328], [ 5, 11, 54, 20.56, 45, 23, 29.7, 3012], [ 6, 11, 56, 9.31, 44, 50, 20.0, 2223], [ 7, 11, 59, 13.55, 53, 53, 7.8, 4673], [ 8, 11, 59, 13.83, 53, 53, 7.1, 4595], [ 9, 12, 0, 31.11, 45, 48, 42.3, 3170], [10, 12, 7, 14.14, 54, 7, 54.0, 2677], [11, 12, 9, 13.62, 43, 39, 18.5, 7220], [12, 12, 13, 29.92, 48, 23, 12.9, 2021], [13, 12, 15, 29.88, 53, 35, 50.2, 7827], [14, 12, 22, 51.88, 50, 26, 55.0, 2570], [15, 12, 24, 28.60, 42, 6, 34.5, 5127], [16, 12, 24, 29.19, 42, 6, 49.7, 3565], [17, 12, 30, 34.50, 41, 38, 58.1, 2081], [18, 12, 34, 29.72, 41, 9, 38.6, 2239], [19, 12, 36, 50.00, 36, 55, 18.0, 3277], [20, 12, 41, 56.97, 57, 30, 44.6, 4347], [21, 12, 44, 49.15, 40, 48, 6.5, 2139], [22, 12, 44, 49.46, 36, 9, 23.0, 2803], [23, 12, 46, 38.99, 56, 49, 21.1, 2472], [24, 12, 46, 46.05, 38, 41, 40.0, 2653], [25, 12, 47, 7.40, 49, 0, 19.5, 2430], [26, 12, 49, 23.07, 44, 44, 46.5, 2240], [27, 12, 51, 42.70, 50, 34, 22.8, 4672], [28, 12, 51, 46.60, 50, 34, 34.1, 2960], [29, 12, 52, 8.60, 52, 45, 30.5, 4610], [30, 12, 52, 16.79, 47, 15, 36.4, 4378], [31, 12, 52, 26.33, 56, 34, 20.6, 6792], [32, 12, 56, 17.84, 37, 13, 42.8, 2917], [33, 12, 56, 57.42, 47, 20, 20.8, 16888], [34, 12, 57, 23.56, 36, 44, 26.4, 2438], [35, 12, 57, 23.88, 36, 44, 19.1, 2312], [36, 12, 58, 1.93, 44, 35, 21.7, 3703], [37, 13, 0, 13.75, 38, 4, 32.9, 2654], [38, 13, 0, 32.93, 40, 9, 7.9, 5688], [39, 13, 3, 43.99, 37, 56, 9.9, 2046], [40, 13, 4, 28.85, 53, 50, 0.9, 3118], [41, 13, 4, 57.84, 38, 32, 29.8, 2462], [42, 13, 16, 12.83, 45, 4, 38.7, 2245], [43, 13, 19, 46.03, 51, 48, 11.7, 3071], [44, 13, 21, 13.87, 42, 34, 42.2, 2858], [45, 13, 21, 18.07, 42, 35, 0.6, 6188], [46, 13, 21, 21.65, 42, 35, 16.3, 3329], [47, 13, 23, 24.08, 41, 15, 12.9, 2293], [48, 13, 23, 24.29, 41, 15, 15.1, 2218], [49, 13, 26, 2.25, 36, 47, 59.3, 2596], [50, 13, 26, 13.98, 49, 34, 34.4, 2231], [51, 13, 27, 37.17, 55, 4, 6.2, 2926], [52, 13, 29, 54.91, 47, 12, 26.8, 3789], [53, 13, 31, 37.37, 50, 7, 57.4, 3505], [54, 13, 31, 37.38, 50, 7, 57.6, 3478], [55, 13, 35, 19.45, 41, 0, 2.4, 3836], [56, 13, 37, 38.58, 47, 41, 47.7, 2036], [57, 13, 37, 38.65, 47, 41, 49.0, 2070], [58, 13, 38, 49.36, 38, 51, 14.6, 12507], [59, 13, 38, 49.37, 38, 51, 14.5, 12494], [60, 13, 41, 33.60, 53, 44, 50.1, 3607], [61, 13, 41, 44.87, 46, 57, 17.2, 2677], [62, 13, 45, 26.34, 49, 46, 32.7, 8796], [63, 13, 45, 33.08, 42, 50, 14.7, 3242], [64, 13, 46, 23.66, 48, 13, 10.0, 2034], [65, 13, 49, 34.67, 53, 41, 17.4, 2186], [66, 14, 0, 19.00, 53, 36, 59.8, 2900]])
    #sources_array=np.matrix([[29, 12, 52, 8.60, 52, 45, 30.5, 4610], [30, 12, 52, 16.79, 47, 15, 36.4, 4378],[33, 12, 56, 57.42, 47, 20, 20.8, 16888]])
    steady_sources = []
    for idx in range(0, sources_array.shape[0]):
        steady_posn = SkyCoord(ra=(sources_array[idx,1]+sources_array[idx,2]/60+sources_array[idx,3]/3600)*15*u.deg, dec=((sources_array[idx,4]+sources_array[idx,5]/60+sources_array[idx,6]/3600)*u.deg))
        source=SkySource(position=steady_posn, flux=sources_array[idx,7] * 2.83 * u.mJy)
        steady_sources.append(source)
    transient_sources = steady_sources
    
else:
    nstep=10
    obs_times = transit_time + np.linspace(-4, 4, nstep) * u.hr
    print("Generating UVW-baselines for {} timesteps".format(nstep))
    uvw_m, lha = telescope.uvw_tracking_skycoord(pointing_centre, obs_times)
    # From here on we use UVW as multiples of wavelength, lambda:
    uvw_lambda = (uvw_m / wavelength).to(u.dimensionless_unscaled).value
    # Change range of local hour angle to vary between negative and positive values 
    lha = np.array(lha) - np.where(lha > 12.0 * u.hourangle, 24.0, 0.0)

    # Additional source to North-East of pointing centre
    extra_src_position = SkyCoord(ra=pointing_centre.ra + 0.01 * u.deg,
                                  dec=pointing_centre.dec + 0.01 * u.deg, )

    steady_sources = [
        SkySource(pointing_centre, flux=1 * u.Jy),
        SkySource(extra_src_position, flux=0.4 * u.Jy),
    ]

    # Transient sources
    transient_posn = SkyCoord(
        ra=pointing_centre.ra - 0.0 * u.deg,
        dec=pointing_centre.dec - 0.0 * u.deg)
    transient_posn2 = SkyCoord(
        ra=pointing_centre.ra - 4.0 * u.deg,
        dec=pointing_centre.dec - 4.0 * u.deg)
    transient_posn3 = SkyCoord(
        ra=pointing_centre.ra - 7.0 * u.deg,
        dec=pointing_centre.dec - 7.0 * u.deg)
    transient_posn4 = SkyCoord(
        ra=pointing_centre.ra - 10.0 * u.deg,
        dec=pointing_centre.dec - 10.0 * u.deg)
    transient_posn5 = SkyCoord(
        ra=pointing_centre.ra - 13.0 * u.deg,
        dec=pointing_centre.dec - 13.0 * u.deg)

    transient_sources = [
        SkySource(position=transient_posn, flux=1 * u.Jy),
        SkySource(position=transient_posn2, flux=1 * u.Jy),
        SkySource(position=transient_posn3, flux=1 * u.Jy),
        SkySource(position=transient_posn4, flux=1 * u.Jy),
        SkySource(position=transient_posn5, flux=1 * u.Jy),
    ]

# All sources
all_sources = transient_sources

# Simulate incoming data; includes transient sources, noise:
print("Simulating visibilities")
data_vis = visibility.visibilities_for_source_list(
    pointing_centre,
    source_list = all_sources, 
    uvw = uvw_lambda)

#vis_noise_level = 0.1 * u.Jy
#data_vis = visibility.add_gaussian_noise(vis_noise_level, data_vis)

vis_weights=np.ones(len(data_vis), dtype=np.float_)

In [None]:
if WENSS_DATA is True:
    # Invert v signal
    uvw_lambda[:,1]*=-1
    image_size=1024 * u.pixel
    cell_size=60 * u.arcsecond
else:
    image_size=1024 * u.pixel
    cell_size=100 * u.arcsecond


## Set parameters

In [None]:
import stp_python

kernel_support = 3
trunc = kernel_support
kernel_func = kfuncs.PSWF(trunc=trunc)
kernel_func_name=CppKernelFuncs.pswf
kernel_trunc_radius = kernel_support
kernel_support=kernel_support
kernel_exact=False
kernel_oversampling=8
generate_beam=True
fft_routine = stp_python.FFTRoutine.FFTW_ESTIMATE_FFT
fft_wisdom_filename=""
num_wplanes=20
wplanes_median=False
max_wpconv_support=50
analytic_gcf=True
kernel_trunc_perc=0.1
hankel_opt=False
interp_type="linear"  # "linear" , "cubic"
undersampling_opt=1
aproj_tinc=0 #(np.max(lha)-np.min(lha))/9
obs_dec=pointing_centre.dec.value
obs_lat=pointing_centre.ra.value

## Run Python implementation 

In [None]:
if IMAGER_PYTHON is True:
    
    start = datetime.datetime.now()

    image, beam = imager.image_visibilities(
        data_vis,
        vis_weights,
        uvw_lambda,
        lha=lha,
        image_size=image_size,
        cell_size=cell_size,
        kernel_func=kernel_func,
        kernel_support=kernel_support,
        kernel_exact=kernel_exact,
        kernel_oversampling=kernel_oversampling,
        num_wplanes=num_wplanes,
        wplanes_median=wplanes_median,
        max_wpconv_support=max_wpconv_support,
        analytic_gcf=analytic_gcf,
        hankel_opt=hankel_opt,
        interp_type=interp_type,
        undersampling_opt=undersampling_opt,
        kernel_trunc_perc=kernel_trunc_perc,
        aproj_tinc=aproj_tinc,
        obs_dec=obs_dec,
        obs_lat=obs_lat
        )

    stop = datetime.datetime.now()
    duration = (stop - start).total_seconds()
    print(duration)

## Run CPP implementation

In [None]:
if IMAGER_CPP is True:
    
    start = datetime.datetime.now()

    interp_type_cpp=stp_python.InterpType.LINEAR
    if interp_type == "linear":
        interp_type_cpp=stp_python.InterpType.LINEAR
    if interp_type == "cubic":
        interp_type_cpp=stp_python.InterpType.CUBIC

    bind_image, bind_beam = cpp_image_visibilities(
        data_vis,
        vis_weights,
        uvw_lambda,
        image_size=image_size,
        cell_size=cell_size,
        kernel_func_name=kernel_func_name,
        kernel_trunc_radius = kernel_support,
        kernel_support=kernel_support,
        kernel_exact=kernel_exact,
        kernel_oversampling=kernel_oversampling,
        generate_beam=generate_beam,
        fft_routine=fft_routine,
        fft_wisdom_filename=fft_wisdom_filename,
        num_wplanes=num_wplanes,
        wplanes_median=wplanes_median,
        max_wpconv_support=max_wpconv_support,
        analytic_gcf=analytic_gcf,
        hankel_opt=hankel_opt,
        undersampling_opt=undersampling_opt,
        kernel_trunc_perc=kernel_trunc_perc,
        interp_type=interp_type_cpp
        )

    stop = datetime.datetime.now()
    
    duration = (stop - start).total_seconds()
    print(duration)

In [None]:
if IMAGER_PYTHON and IMAGER_CPP is True:
    bind_diffmax = np.max(np.abs(bind_image-image))
    print(bind_diffmax)
    bind_re_diffmax = np.max(np.abs(bind_beam-beam))
    print(bind_re_diffmax)

    # assert cpp_diffmax < 2e-15

## Plot results

In [None]:
if IMAGER_PYTHON and IMAGER_CPP is True:
    # %matplotlib notebook
    %matplotlib inline
    fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(12,12))
    if WENSS_DATA is True:
        clim = (-0.1, 5)
    else:
        clim = (-0.1, 0.7)    
    # xlim = (250,750)
    xlim = (450,550)
    # xlim = (550,800)
    ylim = xlim

    im_plot = axes[0, 0].imshow(image, clim=clim)
    axes[0, 1].imshow(beam, clim=clim)
    axes[1, 0].imshow(bind_image, clim=clim)
    axes[1, 1].imshow(bind_beam, clim=clim)

    #img_ax.set_xlim(*xlim)
    #img_ax.set_ylim(*ylim)
    axes[0, 0].set_title('image python')
    axes[1, 0].set_title('image cpp')

    x_range = xlim[1]-xlim[0]
    y_range = ylim[1]-ylim[0]
    beam_xlim = ( beam.shape[1]/2 - x_range/2, beam.shape[1]/2 + x_range/2)
    beam_ylim = ( beam.shape[0]/2 - y_range/2, beam.shape[0]/2 + y_range/2)
    #bm_ax.set_xlim(beam_xlim)
    #bm_ax.set_ylim(beam_ylim)
    axes[0, 1].set_title('beam python')
    axes[1, 1].set_title('beam cpp')

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(im_plot, cax=cbar_ax)
    #beam_xlim, beam_ylim

## Sourcefinding

In [None]:
from fastimgproto.sourcefind.image import SourceFindImage

if IMAGER_PYTHON is True:
    detection_n_sigma=50
    analysis_n_sigma=50
    sfimage = SourceFindImage(data=np.real(image),
                              detection_n_sigma=detection_n_sigma,
                              analysis_n_sigma=analysis_n_sigma,
                              )

if IMAGER_CPP is True:
    detection_n_sigma=50
    analysis_n_sigma=50
    bind_sfimage = SourceFindImage(data=np.real(bind_image),
                              detection_n_sigma=detection_n_sigma,
                              analysis_n_sigma=analysis_n_sigma,
                              )

In [None]:
if IMAGER_PYTHON and IMAGER_CPP is True:
    for isl in sfimage.islands:
        print(isl.params.extremum)

In [None]:
if IMAGER_PYTHON and IMAGER_CPP is True:
    for isl in bind_sfimage.islands:
        print(isl.params.extremum)

In [None]:
# Differences
if IMAGER_PYTHON and IMAGER_CPP is True:
    assert(len(sfimage.islands) == len(bind_sfimage.islands))

    for idx in range(0, len(sfimage.islands)):
        print(abs(sfimage.islands[idx].params.extremum.value - bind_sfimage.islands[idx].params.extremum.value))