# AW-Projection

A tweakable end-to-end runthrough to test AW-projection

In [None]:
import numpy as np
import os
import datetime

import astropy.units as u
import astropy.constants as const
import fastimgproto.imager as imager
import fastimgproto.visibility as visibility
import fastimgproto.gridder.conv_funcs as kfuncs

from astropy.coordinates import Angle, SkyCoord, AltAz, EarthLocation
from astropy.time import Time

from fastimgproto.skymodel.helpers import SkyRegion, SkySource
from fastimgproto.sourcefind.image import SourceFindImage
from fastimgproto.telescope.readymade import Meerkat
from fastimgproto.telescope.readymade import VLA_C

from fastimgproto.gridder.conv_funcs import Pillbox
from fastimgproto.gridder.conv_funcs import Triangle
from fastimgproto.gridder.conv_funcs import Sinc
from fastimgproto.gridder.conv_funcs import Gaussian
from fastimgproto.gridder.conv_funcs import GaussianSinc
from fastimgproto.gridder.conv_funcs import PSWF

from fastimgproto.gridder import akernel_generation

from fastimgproto.bindings import cpp_image_visibilities, CppKernelFuncs, CppFFTRoutines, CppInterpolation
from fastimgproto.bindings import PYTHON_KERNELS

from astropy.io import fits
from astropy.table import Table
from astropy.wcs import WCS


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# Plot image pixels in cartesian ordering (i.e. y-positive == upwards):
plt.rcParams['image.origin'] = 'lower'
# Make plots bigger
plt.rcParams['figure.figsize'] = 10, 10

## Notebook settings

In [None]:
# Set notebook options

IMAGER_PYTHON=True
IMAGER_CPP=True
WENSS_DATA=False
PRIMARY_BEAM=False


In [None]:
pbeam_coefs = np.array([0, 0, 0, 0, 1, 0, 1, 0, 0])

In [None]:
telescope = VLA_C()
    
print("Telescope with {} antennae == {} baselines".format(
    len(telescope.ant_local_xyz), len(telescope.baseline_local_xyz)))
print("Centre: {!r}, {!r}".format(telescope.lon, telescope.lat))

In [None]:
pointing_centre = SkyCoord(194.24 * u.deg, 47.339 * u.deg)
obs_central_frequency = 74 * u.MHz
wavelength = const.c / obs_central_frequency
transit_time = telescope.next_transit(pointing_centre.ra,
                                      start_time=Time('2017-01-01'))

In [None]:
altaz = pointing_centre.transform_to(
    AltAz(obstime=transit_time,
         location=telescope.centre))
altaz.alt.deg

In [None]:
nstep=10
obs_times = transit_time + np.linspace(-6, 6, nstep) * u.hr
print("Generating UVW-baselines for {} timesteps".format(nstep))
uvw_m, lha = telescope.uvw_tracking_skycoord(pointing_centre, obs_times)
# From here on we use UVW as multiples of wavelength, lambda:
uvw_lambda = (uvw_m / wavelength).to(u.dimensionless_unscaled).value

if WENSS_DATA is True:
    sources_array=np.matrix([[ 1, 11, 47, 52.21, 48, 18, 49.2, 2033], [ 2, 11, 51, 9.26, 47, 28, 55.7, 2889], 
                             [ 3, 11, 53, 24.60, 49, 31, 9.6, 4041], [ 4, 11, 53, 46.43, 51, 17, 4.6, 2328], 
                             [ 5, 11, 54, 20.56, 45, 23, 29.7, 3012], [ 6, 11, 56, 9.31, 44, 50, 20.0, 2223], 
                             [ 7, 11, 59, 13.55, 53, 53, 7.8, 4673], [ 8, 11, 59, 13.83, 53, 53, 7.1, 4595], 
                             [ 9, 12, 0, 31.11, 45, 48, 42.3, 3170], [10, 12, 7, 14.14, 54, 7, 54.0, 2677], 
                             [11, 12, 9, 13.62, 43, 39, 18.5, 7220], [12, 12, 13, 29.92, 48, 23, 12.9, 2021], 
                             [13, 12, 15, 29.88, 53, 35, 50.2, 7827], [14, 12, 22, 51.88, 50, 26, 55.0, 2570], 
                             [15, 12, 24, 28.60, 42, 6, 34.5, 5127], [16, 12, 24, 29.19, 42, 6, 49.7, 3565], 
                             [17, 12, 30, 34.50, 41, 38, 58.1, 2081], [18, 12, 34, 29.72, 41, 9, 38.6, 2239], 
                             [19, 12, 36, 50.00, 36, 55, 18.0, 3277], [20, 12, 41, 56.97, 57, 30, 44.6, 4347], 
                             [21, 12, 44, 49.15, 40, 48, 6.5, 2139], [22, 12, 44, 49.46, 36, 9, 23.0, 2803], 
                             [23, 12, 46, 38.99, 56, 49, 21.1, 2472], [24, 12, 46, 46.05, 38, 41, 40.0, 2653], 
                             [25, 12, 47, 7.40, 49, 0, 19.5, 2430], [26, 12, 49, 23.07, 44, 44, 46.5, 2240], 
                             [27, 12, 51, 42.70, 50, 34, 22.8, 4672], [28, 12, 51, 46.60, 50, 34, 34.1, 2960], 
                             [29, 12, 52, 8.60, 52, 45, 30.5, 4610], [30, 12, 52, 16.79, 47, 15, 36.4, 4378], 
                             [31, 12, 52, 26.33, 56, 34, 20.6, 6792], [32, 12, 56, 17.84, 37, 13, 42.8, 2917], 
                             [33, 12, 56, 57.42, 47, 20, 20.8, 16888], [34, 12, 57, 23.56, 36, 44, 26.4, 2438], 
                             [35, 12, 57, 23.88, 36, 44, 19.1, 2312], [36, 12, 58, 1.93, 44, 35, 21.7, 3703], 
                             [37, 13, 0, 13.75, 38, 4, 32.9, 2654], [38, 13, 0, 32.93, 40, 9, 7.9, 5688], 
                             [39, 13, 3, 43.99, 37, 56, 9.9, 2046], [40, 13, 4, 28.85, 53, 50, 0.9, 3118], 
                             [41, 13, 4, 57.84, 38, 32, 29.8, 2462], [42, 13, 16, 12.83, 45, 4, 38.7, 2245], 
                             [43, 13, 19, 46.03, 51, 48, 11.7, 3071], [44, 13, 21, 13.87, 42, 34, 42.2, 2858], 
                             [45, 13, 21, 18.07, 42, 35, 0.6, 6188], [46, 13, 21, 21.65, 42, 35, 16.3, 3329], 
                             [47, 13, 23, 24.08, 41, 15, 12.9, 2293], [48, 13, 23, 24.29, 41, 15, 15.1, 2218], 
                             [49, 13, 26, 2.25, 36, 47, 59.3, 2596], [50, 13, 26, 13.98, 49, 34, 34.4, 2231], 
                             [51, 13, 27, 37.17, 55, 4, 6.2, 2926], [52, 13, 29, 54.91, 47, 12, 26.8, 3789], 
                             [53, 13, 31, 37.37, 50, 7, 57.4, 3505], [54, 13, 31, 37.38, 50, 7, 57.6, 3478], 
                             [55, 13, 35, 19.45, 41, 0, 2.4, 3836], [56, 13, 37, 38.58, 47, 41, 47.7, 2036], 
                             [57, 13, 37, 38.65, 47, 41, 49.0, 2070], [58, 13, 38, 49.36, 38, 51, 14.6, 12507], 
                             [59, 13, 38, 49.37, 38, 51, 14.5, 12494], [60, 13, 41, 33.60, 53, 44, 50.1, 3607], 
                             [61, 13, 41, 44.87, 46, 57, 17.2, 2677], [62, 13, 45, 26.34, 49, 46, 32.7, 8796], 
                             [63, 13, 45, 33.08, 42, 50, 14.7, 3242], [64, 13, 46, 23.66, 48, 13, 10.0, 2034], 
                             [65, 13, 49, 34.67, 53, 41, 17.4, 2186], [66, 14, 0, 19.00, 53, 36, 59.8, 2900]])

    #sources_array=np.matrix([[33, 12, 56, 57.42, 47, 20, 20.8, 16888], [13, 12, 15, 29.88, 53, 35, 50.2, 7827]])
#                             [11, 12, 9, 13.62, 43, 39, 18.5, 7220]])

    transient_sources = []
    for idx in range(0, sources_array.shape[0]):
        transient_posn = SkyCoord(ra=(sources_array[idx,1]+sources_array[idx,2]/60+sources_array[idx,3]/3600)*15*u.deg, 
                                  dec=((sources_array[idx,4]+sources_array[idx,5]/60+sources_array[idx,6]/3600)*u.deg))
        source=SkySource(position=transient_posn, flux=sources_array[idx,7] * 2.83 * u.mJy)
        transient_sources.append(source)

    steady_sources = []
    
else:    
    # Additional source to North-East of pointing centre
    #extra_src_position = SkyCoord(ra=pointing_centre.ra - 2.0 * u.deg,
    #                              dec=pointing_centre.dec - 2.0 * u.deg, )
    #steady_sources = [
    #    SkySource(extra_src_position, flux=1.0 * u.Jy),
    #]

    # Transient sources
    transient_posn = SkyCoord(
        ra=pointing_centre.ra - 0.0 * u.deg,
        dec=pointing_centre.dec - 0.0 * u.deg)
    transient_posn2 = SkyCoord(
        ra=pointing_centre.ra - 3.0 * u.deg,
        dec=pointing_centre.dec - 3.0 * u.deg)
    transient_posn3 = SkyCoord(
        ra=pointing_centre.ra - 6.0 * u.deg,
        dec=pointing_centre.dec - 6.0 * u.deg)
    transient_posn4 = SkyCoord(
        ra=pointing_centre.ra - 9.0 * u.deg,
        dec=pointing_centre.dec - 9.0 * u.deg)
    transient_posn5 = SkyCoord(
        ra=pointing_centre.ra - 12.0 * u.deg,
        dec=pointing_centre.dec - 12.0 * u.deg)
    
    transient_sources = [
        SkySource(position=transient_posn, flux=1 * u.Jy),
        SkySource(position=transient_posn2, flux=1 * u.Jy),
        SkySource(position=transient_posn3, flux=1 * u.Jy),
        SkySource(position=transient_posn4, flux=1 * u.Jy),
        SkySource(position=transient_posn5, flux=1 * u.Jy),
    ]

# All sources
all_sources = transient_sources

# Change range of local hour angle to vary between negative and positive values 
lha = np.array(lha) - np.where(lha > 12.0 * u.hourangle, 24.0, 0.0)
    
# Simulate incoming data; includes transient sources, noise:
print("Simulating visibilities")
if PRIMARY_BEAM is True:
    data_vis = visibility.visibilities_for_source_list_and_pbeam(
        pointing_centre,
        source_list = all_sources, 
        uvw = uvw_lambda,
        lha = lha,
        pbeam_coefs = pbeam_coefs)
else:
    data_vis = visibility.visibilities_for_source_list(
        pointing_centre,
        source_list = all_sources, 
        uvw = uvw_lambda)   

#vis_noise_level = 0.1 * u.Jy
#data_vis = visibility.add_gaussian_noise(vis_noise_level, data_vis)

vis_weights=np.ones(len(data_vis), dtype=np.float_)

In [None]:
uvw_parangles = np.empty_like(lha)
for idx, ha in enumerate(lha):
    uvw_parangles[idx] = visibility.parallatic_angle(ha, np.deg2rad(pointing_centre.dec.value),
                                              np.deg2rad(pointing_centre.ra.value))
    #print([uvw_parangles[idx], ha])    

In [None]:
# Invert v signal
uvw_lambda[:,1]*=-1
    
if WENSS_DATA is True:
    image_size=2048 * u.pixel
    cell_size=60 * u.arcsecond
    #image_size=16384 * u.pixel
    #cell_size=5 * u.arcsecond
    
else:
    image_size=2048 * u.pixel
    cell_size=60 * u.arcsecond
    
expected_sources = np.empty(shape=(len(all_sources), 3))
for idx in np.arange(len(all_sources)):
    l, m = visibility.calculate_direction_cosines(pointing_centre, all_sources[idx])
    flux = all_sources[idx].flux.value
    expected_sources[idx] = [-l/np.sin(cell_size.to(u.rad).value)+image_size.value/2,
                             m/np.sin(cell_size.to(u.rad).value)+image_size.value/2,
                            flux]



## Read casapy fits file 

In [None]:
# READ SIMULATED IMAGE FROM CASAPY
if WENSS_DATA is False:
    fits_image_filename='STP_simulated_vlac_5src_corrected_wplanes.image.fits'
    hdul = fits.open(fits_image_filename)
    casa_header = hdul[0].header
    casa_image = hdul[0].data[0][0]
else:
    fits_image_filename='STP_simulated_wenss_corrected_wplanes.image.fits'
    hdul = fits.open(fits_image_filename)
    casa_header = hdul[0].header
    casa_image = hdul[0].data[0][0]
    
# Setup WCS
wcs = WCS(casa_header)

## Set parameters

In [None]:
#############################################################################################
#############################################################################################
#############################################################################################
#############################################################################################

import stp_python

padding_factor=1.0
kernel_support=3
kernel_func_name='gaussian'
kernel_exact=False
kernel_oversampling=8
generate_beam=False
gridding_correction=True
analytic_gcf=True
fft_routine='estimate'
fft_wisdom_filename="WisdomFile_STP.fftw"
num_wplanes=128
wplanes_median=False
max_wpconv_support=127
kernel_trunc_perc=1.0
hankel_opt=True
hankel_proj_slice=False
interp_type='cubic'  # 'linear' , 'cubic'
undersampling_opt=1
aproj_numtimesteps=0
obs_dec=pointing_centre.dec.value
obs_ra=pointing_centre.ra.value
aproj_opt=False
aproj_mask_perc=0.0

#############################################################################################
#############################################################################################
#############################################################################################
#############################################################################################

## Run Python implementation 

In [None]:
if IMAGER_PYTHON is True:
    
    start = datetime.datetime.now()

    trunc=kernel_support
    kernel_func = PYTHON_KERNELS[kernel_func_name](trunc=trunc)
    
    image, beam = imager.image_visibilities(
        data_vis,
        vis_weights,
        uvw_lambda,
        image_size=image_size,
        cell_size=cell_size,
        kernel_func=kernel_func,
        kernel_support=kernel_support,
        kernel_exact=kernel_exact,
        kernel_oversampling=kernel_oversampling,
        gridding_correction=gridding_correction,
        analytic_gcf=analytic_gcf,
        num_wplanes=num_wplanes,
        wplanes_median=wplanes_median,
        max_wpconv_support=max_wpconv_support,
        hankel_opt=hankel_opt,
        interp_type=interp_type,
        undersampling_opt=undersampling_opt,
        kernel_trunc_perc=kernel_trunc_perc,
        aproj_numtimesteps=aproj_numtimesteps,
        obs_dec=obs_dec,
        obs_ra=obs_ra,
        lha=lha,
        pbeam_coefs=pbeam_coefs,
        aproj_opt=aproj_opt,
        aproj_mask_perc=aproj_mask_perc,
        )

    stop = datetime.datetime.now()
    duration = (stop - start).total_seconds()
    print(duration)

## Run CPP implementation

In [None]:
if IMAGER_CPP is True:
    
    start = datetime.datetime.now()

    fft_image, fft_beam = cpp_image_visibilities(
        data_vis,
        vis_weights,
        uvw_lambda,
        image_size=image_size,
        cell_size=cell_size,
        padding_factor=padding_factor,
        kernel_func_name=kernel_func_name,
        kernel_support=kernel_support,
        kernel_exact=kernel_exact,
        kernel_oversampling=kernel_oversampling,
        generate_beam=generate_beam,
        gridding_correction=gridding_correction,
        analytic_gcf=analytic_gcf,
        fft_routine=fft_routine,
        fft_wisdom_filename=fft_wisdom_filename,
        num_wplanes=num_wplanes,
        wplanes_median=wplanes_median,
        max_wpconv_support=max_wpconv_support,
        hankel_opt=hankel_opt,
        hankel_proj_slice=hankel_proj_slice,
        undersampling_opt=undersampling_opt,
        kernel_trunc_perc=kernel_trunc_perc,
        interp_type=interp_type,
        aproj_numtimesteps=aproj_numtimesteps,
        obs_dec=obs_dec,
        obs_ra=obs_ra,
        aproj_opt=aproj_opt,
        aproj_mask_perc=aproj_mask_perc,
        lha=lha,
        pbeam_coefs=pbeam_coefs,    
        )

    stop = datetime.datetime.now()

    duration = (stop - start).total_seconds()
    print(duration)
    fft_image = np.fft.fftshift(fft_image)
    #fft_beam = np.fft.fftshift(fft_beam)

## Plot results

In [None]:
%matplotlib inline

if IMAGER_CPP is True:
    fig = plt.figure(figsize=(8,8))
    ax = plt.subplot(projection=wcs, slices=('x', 'y', 0, 0))
    ax.set_xlabel("J2000 Right Ascension")
    ax.set_ylabel("J2000 Declination")
    #ax.set_title('STP AW-projection Dirty Image')

    lon = ax.coords[0]
    lat = ax.coords[1]
    lon.set_major_formatter('hh:mm')
    lat.set_ticks(number=7)
    lon.set_ticks(number=9)

    if WENSS_DATA is True:
        output_file = 'dirty_image_awproj_wenss.pdf'
        clim = (-0.1, 4)
    else:
        output_file = 'dirty_image_awproj_5diag_pbeam_corrected.pdf'
        clim = (-0.1, 0.3) 

    im_plot = ax.imshow(fft_image, cmap='Greys', clim=clim, interpolation=None, origin='lower')

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.17, 0.05, 0.66])
    fig.colorbar(im_plot, cax=cbar_ax)

    plt.savefig(output_file)


if IMAGER_PYTHON is True:
    fig = plt.figure(figsize=(8,8))
    ax = plt.subplot(projection=wcs, slices=('x', 'y', 0, 0))
    ax.set_xlabel("J2000 Right Ascension")
    ax.set_ylabel("J2000 Declination")
    ax.set_title('STP-Python AW-projection Dirty Image')

    lon = ax.coords[0]
    lat = ax.coords[1]
    lon.set_major_formatter('hh:mm')
    lat.set_ticks(number=7)
    lon.set_ticks(number=9)

    if WENSS_DATA is True:
        clim = (-0.1, 4)
    else:
        clim = (-0.1, 0.3) 

    im_plot = ax.imshow(image, cmap='Greys', clim=clim, interpolation=None, origin='lower')

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.17, 0.05, 0.66])
    fig.colorbar(im_plot, cax=cbar_ax)





In [None]:
kernel_size = int(image_size.value) * int(padding_factor)
akernel = akernel_generation.generate_akernel(pbeam_coefs, image_size.value * cell_size.to(u.rad).value, kernel_size)

In [None]:
%matplotlib inline

fig = plt.figure(figsize=(8,8))
ax = plt.subplot(projection=wcs, slices=('x', 'y', 0, 0))
ax.set_xlabel("J2000 Right Ascension")
ax.set_ylabel("J2000 Declination")
ax.set_title('Normalised Primary Beam')

lon = ax.coords[0]
lat = ax.coords[1]
lon.set_major_formatter('hh:mm')
lat.set_ticks(number=7)
lon.set_ticks(number=9)

clim = (0, 1)

im_plot = ax.imshow(1/akernel/np.max(1/akernel), cmap='gray', clim=clim, interpolation=None, origin='lower')
    
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.17, 0.05, 0.66])
fig.colorbar(im_plot, cax=cbar_ax)

output_file = 'primary_beam.pdf'
plt.savefig(output_file)


## Sourcefinding

In [None]:
from fastimgproto.sourcefind.image import SourceFindImage

if IMAGER_PYTHON is True:
    detection_n_sigma=20
    analysis_n_sigma=20
    sfimage = SourceFindImage(data=np.real(image).astype(np.float64),
                              detection_n_sigma=detection_n_sigma,
                              analysis_n_sigma=analysis_n_sigma,
                              )

if IMAGER_CPP is True:
    detection_n_sigma=20
    analysis_n_sigma=20
    bind_sfimage = SourceFindImage(data=np.real(fft_image).astype(np.float64),
                              detection_n_sigma=detection_n_sigma,
                              analysis_n_sigma=analysis_n_sigma,
                              )

In [None]:
if IMAGER_PYTHON is True:
    for isl in sfimage.islands:
        print(isl.params.extremum)

In [None]:
if IMAGER_CPP is True:
    for isl in bind_sfimage.islands:
        print(isl.params.extremum)

In [None]:
# Put found source information into numpy array
if IMAGER_CPP is True:
    found_sources = np.empty(shape=(len(bind_sfimage.islands), 3))
    for idx in np.arange(len(bind_sfimage.islands)):
        isl = bind_sfimage.islands[idx]
        found_sources[idx] = [isl.params.moments_fit.x_centre, isl.params.moments_fit.y_centre, isl.extremum.value]
        
print("Expected:", len(expected_sources))
print("Found:", len(found_sources))

In [None]:
# Compute difference between found and expected sources
if IMAGER_CPP is True:
    diff_sources = np.empty(shape=(len(transient_sources), 3))

    for i in np.arange(len(expected_sources)):
        expsrc = expected_sources[i]
        min_diff = 2#float("inf")
        best_idx = -1
        for j in np.arange(len(found_sources)):
            src_diff = found_sources[j] - expsrc
            src_diff2 = np.power(src_diff, 2)
            diff = src_diff2[0] + src_diff2[1]
            if (diff < min_diff):
                min_diff = diff
                best_idx = j

        if best_idx > -1:
            diff_sources[i] = expsrc-found_sources[best_idx]
            print(i, ": ", np.round(expsrc, 3), np.round(found_sources[best_idx], 3), np.round(diff_sources[i], 5))
        else:
            print(i, ": ", expsrc, "Not found")
    

In [None]:
# Compare Python and C++
if IMAGER_PYTHON and IMAGER_CPP is True:
    assert(len(sfimage.islands) == len(bind_sfimage.islands))

    for idx in range(0, len(sfimage.islands)):
        print(abs(sfimage.islands[idx].params.extremum.value - bind_sfimage.islands[idx].params.extremum.value))