In [None]:
import numpy as np
import os
import datetime

import astropy.units as u
import astropy.constants as const
import fastimgproto.imager as imager
import fastimgproto.visibility as visibility
import fastimgproto.gridder.conv_funcs as kfuncs

from astropy.coordinates import Angle, SkyCoord, AltAz, EarthLocation
from astropy.time import Time

from fastimgproto.skymodel.helpers import SkyRegion, SkySource
from fastimgproto.sourcefind.image import SourceFindImage
from fastimgproto.telescope.readymade import Meerkat

from fastimgproto.gridder.conv_funcs import Pillbox
from fastimgproto.gridder.conv_funcs import Triangle
from fastimgproto.gridder.conv_funcs import Sinc
from fastimgproto.gridder.conv_funcs import Gaussian
from fastimgproto.gridder.conv_funcs import GaussianSinc
from fastimgproto.gridder.conv_funcs import PSWF

from fastimgproto.bindings import cpp_image_visibilities, CppKernelFuncs


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
# Plot image pixels in cartesian ordering (i.e. y-positive == upwards):
plt.rcParams['image.origin'] = 'lower'
# Make plots bigger
plt.rcParams['figure.figsize'] = 10, 10

In [None]:
telescope = Meerkat()
print("Telescope with {} antennae == {} baselines".format(
    len(telescope.ant_local_xyz), len(telescope.baseline_local_xyz)))
print("Centre: {!r}, {!r}".format(telescope.lon, telescope.lat))

In [None]:
pointing_centre = SkyCoord(0 * u.deg, -30 * u.deg)
obs_central_frequency = 30. * u.MHz
wavelength = const.c / obs_central_frequency
transit_time = telescope.next_transit(pointing_centre.ra,
                                      start_time=Time('2010-01-01'))

In [None]:
altaz = pointing_centre.transform_to(
    AltAz(obstime=transit_time,
         location=telescope.centre))
altaz.alt.deg

In [None]:
nstep=10
obs_times = transit_time + np.linspace(-4, 4, nstep) * u.hr
print("Generating UVW-baselines for {} timesteps".format(nstep))
uvw_m = telescope.uvw_tracking_skycoord(pointing_centre, obs_times)
# From here on we use UVW as multiples of wavelength, lambda:
uvw_lambda = (uvw_m / wavelength).to(u.dimensionless_unscaled).value


# Additional source to North-East of pointing centre
extra_src_position = SkyCoord(ra=pointing_centre.ra + 0.01 * u.deg,
                              dec=pointing_centre.dec + 0.01 * u.deg, )

steady_sources = [
    SkySource(pointing_centre, flux=1 * u.Jy),
    SkySource(extra_src_position, flux=0.4 * u.Jy),
]

# Transient sources
transient_posn = SkyCoord(
    ra=pointing_centre.ra - 0.0 * u.deg,
    dec=pointing_centre.dec - 0.0 * u.deg)
transient_posn2 = SkyCoord(
    ra=pointing_centre.ra - 4.0 * u.deg,
    dec=pointing_centre.dec - 4.0 * u.deg)
transient_posn3 = SkyCoord(
    ra=pointing_centre.ra - 7.0 * u.deg,
    dec=pointing_centre.dec - 7.0 * u.deg)
transient_posn4 = SkyCoord(
    ra=pointing_centre.ra - 10.0 * u.deg,
    dec=pointing_centre.dec - 10.0 * u.deg)
transient_posn5 = SkyCoord(
    ra=pointing_centre.ra - 13.0 * u.deg,
    dec=pointing_centre.dec - 13.0 * u.deg)

transient_sources = [
    SkySource(position=transient_posn, flux=1 * u.Jy),
    SkySource(position=transient_posn2, flux=1 * u.Jy),
    SkySource(position=transient_posn3, flux=1 * u.Jy),
    SkySource(position=transient_posn4, flux=1 * u.Jy),
    SkySource(position=transient_posn5, flux=1 * u.Jy),
]

# All sources
all_sources = transient_sources

# Simulate incoming data; includes transient sources, noise:
print("Simulating visibilities")
data_vis = visibility.visibilities_for_source_list(
    pointing_centre,
    source_list = all_sources, 
    uvw = uvw_lambda)

#vis_noise_level = 0.1 * u.Jy
#data_vis = visibility.add_gaussian_noise(vis_noise_level, data_vis)

vis_weights=np.ones(len(data_vis), dtype=np.float_)

## Set parameters

In [None]:
import stp_python

image_size=1024 * u.pixel
cell_size=100 * u.arcsecond
kernel_support = 3
trunc = kernel_support
kernel_func = kfuncs.PSWF(trunc=trunc)
kernel_func_name=CppKernelFuncs.pswf
kernel_trunc_radius = kernel_support
kernel_support=kernel_support
kernel_exact=False
kernel_oversampling=8
generate_beam=True
fft_routine = stp_python.FFTRoutine.FFTW_ESTIMATE_FFT
fft_wisdom_filename=""
num_wplanes=80
wplanes_median=True
max_wpconv_support=50
analytic_gcf=False
hankel_opt=True
undersampling_opt=1
kernel_trunc_perc=1.0
interp_type="linear"  # "linear" , "cubic"

## Run Python implementation 

In [None]:
start = datetime.datetime.now()

image, beam = imager.image_visibilities(
    data_vis,
    vis_weights,
    uvw_lambda,
    image_size=image_size,
    cell_size=cell_size,
    kernel_func=kernel_func,
    kernel_support=kernel_support,
    kernel_exact=kernel_exact,
    kernel_oversampling=kernel_oversampling,
    num_wplanes=num_wplanes,
    wplanes_median=wplanes_median,
    max_wpconv_support=max_wpconv_support,
    analytic_gcf=analytic_gcf,
    hankel_opt=hankel_opt,
    interp_type=interp_type,
    undersampling_opt=undersampling_opt
    )

stop = datetime.datetime.now()

In [None]:
duration = (stop - start).total_seconds()
duration

## Run CPP implementation

In [None]:
start = datetime.datetime.now()

interp_type_cpp=stp_python.InterpType.LINEAR
if interp_type == "linear":
    interp_type_cpp=stp_python.InterpType.LINEAR
if interp_type == "cubic":
    interp_type_cpp=stp_python.InterpType.CUBIC

bind_image, bind_beam = cpp_image_visibilities(
    data_vis,
    vis_weights,
    uvw_lambda,
    image_size=image_size,
    cell_size=cell_size,
    kernel_func_name=kernel_func_name,
    kernel_trunc_radius = kernel_support,
    kernel_support=kernel_support,
    kernel_exact=kernel_exact,
    kernel_oversampling=kernel_oversampling,
    generate_beam=generate_beam,
    fft_routine=fft_routine,
    fft_wisdom_filename=fft_wisdom_filename,
    num_wplanes=num_wplanes,
    wplanes_median=wplanes_median,
    max_wpconv_support=max_wpconv_support,
    analytic_gcf=analytic_gcf,
    hankel_opt=hankel_opt,
    undersampling_opt=undersampling_opt,
    kernel_trunc_perc=kernel_trunc_perc,
    interp_type=interp_type_cpp
    )

stop = datetime.datetime.now()

In [None]:
duration = (stop - start).total_seconds()
duration

In [None]:
bind_diffmax = np.max(np.abs(bind_image-image))
print(bind_diffmax)
bind_re_diffmax = np.max(np.abs(bind_beam-beam))
print(bind_re_diffmax)

# assert cpp_diffmax < 2e-15

## Plot results

In [None]:
# %matplotlib notebook
%matplotlib inline
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(12,12))
clim = (-0.1, 0.7)
# xlim = (250,750)
xlim = (450,550)
# xlim = (550,800)
ylim = xlim

im_plot = axes[0, 0].imshow(image, clim=clim)
axes[0, 1].imshow(beam, clim=clim)
axes[1, 0].imshow(bind_image, clim=clim)
axes[1, 1].imshow(bind_beam, clim=clim)

#img_ax.set_xlim(*xlim)
#img_ax.set_ylim(*ylim)
axes[0, 0].set_title('image python')
axes[1, 0].set_title('image cpp')

x_range = xlim[1]-xlim[0]
y_range = ylim[1]-ylim[0]
beam_xlim = ( beam.shape[1]/2 - x_range/2, beam.shape[1]/2 + x_range/2)
beam_ylim = ( beam.shape[0]/2 - y_range/2, beam.shape[0]/2 + y_range/2)
#bm_ax.set_xlim(beam_xlim)
#bm_ax.set_ylim(beam_ylim)
axes[0, 1].set_title('beam python')
axes[1, 1].set_title('beam cpp')

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im_plot, cax=cbar_ax)
#beam_xlim, beam_ylim

## Sourcefinding

In [None]:
from fastimgproto.sourcefind.image import SourceFindImage

detection_n_sigma=15
analysis_n_sigma=15
sfimage = SourceFindImage(data=np.real(image),
                          detection_n_sigma=detection_n_sigma,
                          analysis_n_sigma=analysis_n_sigma,
                          )
detection_n_sigma=15
analysis_n_sigma=15
bind_sfimage = SourceFindImage(data=np.real(bind_image),
                          detection_n_sigma=detection_n_sigma,
                          analysis_n_sigma=analysis_n_sigma,
                          )

In [None]:
for isl in sfimage.islands:
    print(isl.params.extremum)

In [None]:
for isl in bind_sfimage.islands:
    print(isl.params.extremum)

In [None]:
# Differences
assert(len(sfimage.islands) == len(bind_sfimage.islands))

for idx in range(0, len(sfimage.islands)):
    print(abs(sfimage.islands[idx].params.extremum.value - bind_sfimage.islands[idx].params.extremum.value))