In [None]:
from dsa2000_common.common.ray_utils import TimerLog
import dataclasses
import gc
import itertools
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple

import astropy.time as at
import jax
import jax.numpy as jnp
import numpy as np
from astropy import units as au, coordinates as ac
from astropy.coordinates import offset_by
from dsa2000_common.common.enu_frame import ENU
from tqdm import tqdm

from dsa2000_assets.content_registry import fill_registries
from dsa2000_assets.registries import array_registry
from dsa2000_common.common.fits_utils import ImageModel, save_image_to_fits
from dsa2000_common.common.wgridder import vis_to_image_np
from dsa2000_common.common.array_types import FloatArray
from dsa2000_common.common.mixed_precision_utils import mp_policy
from dsa2000_common.common.quantity_utils import time_to_jnp, quantity_to_jnp, quantity_to_np
from dsa2000_common.common.types import VisibilityCoords
from dsa2000_common.delay_models.base_far_field_delay_engine import build_far_field_delay_engine, \
    BaseFarFieldDelayEngine
from dsa2000_common.delay_models.base_near_field_delay_engine import BaseNearFieldDelayEngine, \
    build_near_field_delay_engine
from dsa2000_common.gain_models.base_spherical_interpolator import BaseSphericalInterpolatorGainModel
from dsa2000_common.gain_models.beam_gain_model import build_beam_gain_model
from dsa2000_common.gain_models.gain_model import GainModel
from dsa2000_common.geodesics.base_geodesic_model import BaseGeodesicModel, build_geodesic_model
from dsa2000_common.visibility_model.source_models.celestial.base_point_source_model import BasePointSourceModel, \
    build_point_source_model
from dsa2000_fm.imaging.base_imagor import fit_beam
from dsa2000_fm.imaging.utils import get_array_image_parameters
from dsa2000_fm.systematics.dish_aperture_effects import build_dish_aperture_effects, DishApertureEffects



def create_effects(array_name: str, calibration_error_stddev: au.Quantity, pointing_offset_stddev: au.Quantity):
    fill_registries()
    array = array_registry.get_instance(array_registry.get_match(array_name))

    dish_aperture_effects = build_dish_aperture_effects(
        dish_diameter=array.get_antenna_diameter(),
        focal_length=array.get_focal_length(),
        elevation_pointing_error_stddev=pointing_offset_stddev,
        cross_elevation_pointing_error_stddev=pointing_offset_stddev,
        # axial_focus_error_stddev=3 * au.mm,
        # elevation_feed_offset_stddev=3 * au.mm,
        # cross_elevation_feed_offset_stddev=3 * au.mm,
        # horizon_peak_astigmatism_stddev=5 * au.mm,
        # surface_error_mean=0 * au.mm,
        # surface_error_stddev=1 * au.mm
    )
    calibration_phase_error_stddev = calibration_error_stddev.to('rad').value
    return dish_aperture_effects, calibration_phase_error_stddev


def build_mock_obs_setup(angular_offset: au.Quantity, dec: au.Quantity, array_name: str):
    fill_registries()
    array = array_registry.get_instance(array_registry.get_match(array_name))

    array_location = array.get_array_location()

    ref_time = at.Time('2021-01-01T00:00:00', scale='utc')
    num_times = 1
    obstimes = ref_time + np.arange(num_times) * array.get_integration_time()

    phase_center = ENU(0, 0, 1, location=array_location, obstime=ref_time).transform_to(ac.ICRS())
    phase_center = ac.ICRS(phase_center.ra, dec)

    freqs = array.get_channels()

    antennas = array.get_antennas()

    far_field_delay_engine = build_far_field_delay_engine(
        antennas=antennas,
        phase_center=phase_center,
        start_time=obstimes.min(),
        end_time=obstimes.max(),
        ref_time=ref_time
    )

    near_field_delay_engine = build_near_field_delay_engine(
        antennas=antennas,
        start_time=obstimes.min(),
        end_time=obstimes.max(),
        ref_time=ref_time
    )

    geodesic_model = build_geodesic_model(
        antennas=antennas,
        array_location=array_location,
        phase_center=phase_center,
        obstimes=obstimes,
        ref_time=ref_time,
        pointings=None
    )

    chan_width = array.get_channel_width()

    # build sky model pointings
    pos_ang = np.random.uniform(0, 2 * np.pi) * au.rad
    source_ra, source_dec = offset_by(phase_center.ra, phase_center.dec, posang=pos_ang, distance=angular_offset)
    source = ac.ICRS(source_ra, source_dec)

    lmn0 = geodesic_model.compute_far_field_lmn(source_ra.to('rad').value, source_dec.to('rad').value, 0.)

    beam_model = build_beam_gain_model(array_name=array_name, full_stokes=False, times=obstimes[[0, -1]],
                                       ref_time=ref_time, freqs=freqs[[0, -1]])

    source_model = build_point_source_model(
        model_freqs=freqs[[0, -1]],
        ra=source.ra[None],
        dec=source.dec[None],
        A=np.ones((1, 2)) * au.Jy
    )

    return lmn0, ref_time, obstimes, freqs, chan_width, phase_center, antennas, near_field_delay_engine, far_field_delay_engine, beam_model, source_model, geodesic_model


@jax.jit
def compute_corrupt_visibilties(
        key,
        calibration_phase_error_stddev: FloatArray,
        source_model: BasePointSourceModel,
        beam_model: GainModel,
        visibility_coords: VisibilityCoords,
        near_field_delay_engine: BaseNearFieldDelayEngine,
        far_field_delay_engine: BaseFarFieldDelayEngine,
        geodesic_model: BaseGeodesicModel
):
    """
    Computes the corrupted visibilities for a single solution interval, and stores in the visibilities buffer.

    Args:
        source_model: The source model to use.
        beam_model: The beam model to use.
        visibility_coords: The visibility coordinates.
        near_field_delay_engine: The near field delay engine.
        far_field_delay_engine: The far field delay engine.
        geodesic_model: The geodesic model.

    Returns:
        [T*B, C] array of corrupted visibilities (no noise).
    """
    visibilities = source_model.predict(
        visibility_coords=visibility_coords,
        gain_model=beam_model,
        near_field_delay_engine=near_field_delay_engine,
        far_field_delay_engine=far_field_delay_engine,
        geodesic_model=geodesic_model
    )
    T, B, C = np.shape(visibilities)
    A = geodesic_model.num_antennas
    calibration_phase_errors = jax.random.normal(key, (T, A, C),
                                                 dtype=jax.numpy.float32) * calibration_phase_error_stddev
    relative_phase = calibration_phase_errors[:, visibility_coords.antenna1] - calibration_phase_errors[:,
                                                                               visibility_coords.antenna2]  # [T, B, C]
    visibilities *= jax.lax.complex(jnp.cos(relative_phase), jnp.sin(relative_phase))
    return jax.lax.reshape(visibilities, (T * B, C))


@jax.jit
def compute_visibility_coords_and_beam_model(key, far_field_delay_engine, times, freqs, beam_model,
                                             geodesic_model: BaseGeodesicModel,
                                             dish_aperture_effects: DishApertureEffects) -> Tuple[
    VisibilityCoords, BaseSphericalInterpolatorGainModel]:
    beam_model = dish_aperture_effects.apply_dish_aperture_effects(key, beam_model, geodesic_model)
    return far_field_delay_engine.compute_visibility_coords(
        freqs=freqs,
        times=times,
        with_autocorr=False
    ), beam_model


fit_beam_jit = jax.jit(fit_beam, static_argnames=['max_central_size'])


def main(plot_folder: str, result_file: str, array_name: str, oversample_factor: float,
         num_patches_per_side: int, save_psf: bool):
    os.makedirs(plot_folder, exist_ok=True)
    np.random.seed(0)
    key = jax.random.PRNGKey(0)
    results = []
    if os.path.exists(result_file):
        with open(result_file, 'r') as f:
            results = json.load(f)
    num_pixel, dl, dm, _, _ = get_array_image_parameters(
        array_name=array_name,
        oversample_factor=oversample_factor,
        threshold=0.5
    )

    fov = num_pixel * dl

    # num_pixel = 2**15
    # dl = dm = 0.8 * au.arcsec
    # l0 = m0 = 0 * au.arcsec

    psf_buffer = np.zeros((num_pixel, num_pixel), dtype=mp_policy.image_dtype, order='C')
    visibilities_buffer = None
    pbar = tqdm(itertools.count())

    calibration_errors_on_source = [1, 2, 3] * au.deg
    calibration_errors_off_source = (np.array([10, 50, 100]) * 8.4479745 / 1350) * au.rad
    # calibration_errors_array = np.concatenate(
    #     [calibration_errors_on_source.to('deg').value, calibration_errors_off_source.to('deg').value]
    # ) * au.deg
    calibration_errors_array = calibration_errors_on_source
    pointing_offset_stddev_array = [0, 1, 2, 3] * au.arcmin

    rad2arcsec = 3600 * 180 / np.pi

    # Create array setup
    while True:
        transit_dec = np.random.uniform(-30, 90) * au.deg

        angular_offset = np.sqrt(np.random.uniform((np.sqrt(2) * 0.5 * fov.to('deg').value) ** 2, 90 ** 2)) * au.deg
        pointing_offset_stddev = np.random.uniform(0, 4) * au.arcmin
        calibration_error_stddev = np.random.uniform(0, 5) * au.deg

        # Choose the location on the sky of source first
        (lmn0, ref_time, obstimes, obsfreqs, chan_width, phase_center, antennas, near_field_delay_engine,
         far_field_delay_engine, beam_model,
         source_model, geodesic_model) = build_mock_obs_setup(
            angular_offset=angular_offset,
            dec=transit_dec,
            array_name=array_name
        )

        dish_aperture_effects, calibration_phase_error_stddev = create_effects(
            array_name=array_name,
            calibration_error_stddev=calibration_error_stddev,
            pointing_offset_stddev=pointing_offset_stddev
        )
        gc.collect()
        key, subkey = jax.random.split(key)

        pbar.set_description(
            f"Angular offset: {angular_offset}, Pointing offset: {pointing_offset_stddev}, \n"
            f"Calibration error: {calibration_error_stddev}, Transit dec: {transit_dec}"
        )
        pbar.display()
        object_name = f"{array_name}_O{angular_offset.to('deg').value:.1f}deg_PE{pointing_offset_stddev.to('arcmin').value:.1f}arcmin_CE{calibration_error_stddev.to('deg').value:.1f}deg"

        times = time_to_jnp(obstimes, ref_time)
        freqs = quantity_to_jnp(obsfreqs, 'Hz')
        with TimerLog(f"Computing visibilty coordinates and applying dish effects"):
            visibility_coords, corrupt_beam_model = jax.block_until_ready(
                compute_visibility_coords_and_beam_model(
                    key=subkey,
                    freqs=freqs,
                    times=times,
                    far_field_delay_engine=far_field_delay_engine,
                    beam_model=beam_model,
                    geodesic_model=geodesic_model,
                    dish_aperture_effects=dish_aperture_effects
                )
            )
        with TimerLog(f"Computing corrupt visibilities"):
            T, B, _ = np.shape(visibility_coords.uvw)
            C = np.shape(visibility_coords.freqs)[0]
            if visibilities_buffer is None:
                visibilities_buffer = np.zeros((T * B, C), dtype=mp_policy.vis_dtype, order='F')
            else:
                visibilities_buffer *= 0.

            def sliced_kernel(freq_idx, key):
                _visibility_coords = VisibilityCoords(
                    uvw=visibility_coords.uvw,
                    freqs=visibility_coords.freqs[freq_idx:freq_idx + 1],
                    times=visibility_coords.times,
                    antenna1=visibility_coords.antenna1,
                    antenna2=visibility_coords.antenna2
                )
                _visibilities = jax.block_until_ready(compute_corrupt_visibilties(
                    key=key,
                    source_model=source_model,
                    beam_model=corrupt_beam_model,
                    visibility_coords=_visibility_coords,
                    near_field_delay_engine=near_field_delay_engine,
                    far_field_delay_engine=far_field_delay_engine,
                    geodesic_model=geodesic_model,
                    calibration_phase_error_stddev=calibration_phase_error_stddev
                ))  # [T*B, 1]
                visibilities_buffer[:, freq_idx:freq_idx + 1] = _visibilities

            with ThreadPoolExecutor(max_workers=10) as executor:
                key, subkey = jax.random.split(key)
                keys = jax.random.split(subkey, len(obsfreqs))
                result_map = executor.map(sliced_kernel, range(len(obsfreqs)), keys)

            _ = list(result_map)
            gc.collect()

        print(f"Size of vis: {visibilities_buffer.nbytes / 2 ** 30} GB")
        uvw = np.array(visibility_coords.uvw.reshape((-1, 3)), order='C')  # [T*B, 3]
        freqs = np.array(visibility_coords.freqs)
        visibilities_buffer = np.asarray(visibilities_buffer, order='F')  # [T*B, C]

        if save_psf:
            with TimerLog(
                    f"Computing PSF for lmn={lmn0}"):
                psf_buffer *= 0.
                vis_to_image_np(
                    uvw=uvw,
                    freqs=freqs,
                    vis=visibilities_buffer,
                    pixsize_m=quantity_to_np(dm, 'rad'),
                    pixsize_l=quantity_to_np(dl, 'rad'),
                    center_l=np.array(lmn0[0]),
                    center_m=np.array(lmn0[1]),
                    npix_l=num_pixel,
                    npix_m=num_pixel,
                    wgt=None,
                    mask=None,
                    epsilon=1e-6,
                    double_precision_accumulation=False,
                    scale_by_n=True,
                    normalise=True,
                    output_buffer=psf_buffer,
                    num_threads=72
                )
                norm = np.max(psf_buffer)
            with TimerLog("Fitting beam"):
                psf_buffer /= norm  # Normalise
                major, minor, posang = fit_beam_jit(
                    psf=psf_buffer,
                    dl=quantity_to_jnp(dl, 'rad'),
                    dm=quantity_to_jnp(dm, 'rad')
                )
                psf_buffer *= norm

                print(
                    f"Beam major: {major * rad2arcsec:.2f}arcsec, "
                    f"minor: {minor * rad2arcsec:.2f}arcsec, "
                    f"posang: {posang * 180 * np.pi:.2f}deg"
                )

            with TimerLog("Saving PSF"):
                image_model = ImageModel(
                    phase_center=phase_center,
                    obs_time=ref_time,
                    dl=dl,
                    dm=dm,
                    freqs=np.mean(obsfreqs)[None],
                    bandwidth=len(obsfreqs) * chan_width,
                    coherencies=('I',),
                    beam_major=np.asarray(major) * au.rad,
                    beam_minor=np.asarray(minor) * au.rad,
                    beam_pa=np.asarray(posang) * au.rad,
                    unit='JY/PIXEL',
                    object_name=object_name.upper(),
                    image=psf_buffer[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
                )
                save_image_to_fits(os.path.join(plot_folder, f"{object_name}_psf.fits"),
                                   image_model=image_model,
                                   overwrite=True)
        else:
            print("Skipping PSF computation, so beam parameters are not computed.")
            major, minor, posang = (dl * oversample_factor) / rad2arcsec, (dl * oversample_factor) / rad2arcsec, 0

        with TimerLog(f"Computing FoV sidelobes"):
            psf_buffer *= 0.
            vis_to_image_np(
                uvw=uvw,
                freqs=freqs,
                vis=visibilities_buffer,
                pixsize_m=quantity_to_np(dm, 'rad'),
                pixsize_l=quantity_to_np(dl, 'rad'),
                center_l=0.,  #quantity_to_np(l0, 'rad'),
                center_m=0.,  #quantity_to_np(m0, 'rad'),
                npix_l=num_pixel,
                npix_m=num_pixel,
                wgt=None,
                mask=None,
                epsilon=1e-6,
                double_precision_accumulation=False,
                scale_by_n=True,
                normalise=True,
                output_buffer=psf_buffer,
                num_threads=72
            )
        with TimerLog("Computing patch RMS's"):
            patch_rmss = []
            for patch_row in range(num_patches_per_side):
                for patch_col in range(num_patches_per_side):
                    patch_size = num_pixel // num_patches_per_side

                    patch = psf_buffer[
                            patch_row * patch_size:(patch_row + 1) * patch_size,
                            patch_col * patch_size:(patch_col + 1) * patch_size
                            ]
                    patch_rmss.append(float(np.sqrt(np.mean(patch ** 2))))
        results.append(
            dict(
                angular_offset_deg=angular_offset.to('deg').value,
                pointing_offset_stddev_arcmin=pointing_offset_stddev.to('arcmin').value,
                calibration_error_stddev_deg=calibration_error_stddev.to('deg').value,
                transit_dec_deg=transit_dec.to('deg').value,
                patch_rmss_Jy=patch_rmss
            )
        )
        if save_psf:
            with TimerLog("Saving FoV"):
                image_model = ImageModel(
                    phase_center=phase_center,
                    obs_time=ref_time,
                    dl=dl,
                    dm=dm,
                    freqs=np.mean(obsfreqs)[None],
                    bandwidth=len(obsfreqs) * chan_width,
                    coherencies=('I',),
                    beam_major=np.asarray(major) * au.rad,
                    beam_minor=np.asarray(minor) * au.rad,
                    beam_pa=np.asarray(posang) * au.rad,
                    unit='JY/PIXEL',
                    object_name=object_name.upper(),
                    image=psf_buffer[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
                )
                save_image_to_fits(os.path.join(plot_folder, f"{object_name}_fov.fits"),
                                   image_model=image_model, overwrite=True)
        pbar.update(1)
        with open(result_file, 'w') as f:
            json.dump(results, f, indent=2)


if __name__ == '__main__':
    main(
        plot_folder='assess_bright_sources_plots_random',
        result_file='assess_bright_sources_plots_random/results.json',
        array_name='dsa2000_optimal_v1',
        oversample_factor=3.3,
        num_patches_per_side=3,
        save_psf=False
    )

In [None]:
import pylab as plt

with open('assess_bright_sources_plots_random/results.json', 'r') as f:
    results = json.load(f)
    # Make a histogram of RMS, I want the 10% percentile, 50% percentile, 90% percentile vs calibration error and pointing offset
    calibration_error = np.array([r['calibration_error_stddev_deg'] for r in results])
    pointing_offset = np.array([r['pointing_offset_stddev_arcmin'] for r in results])
    patch_rmss = np.array([np.mean(r['patch_rmss_Jy']) for r in results])
    angular_offset_deg = np.array([r['angular_offset_deg'] for r in results])

    sc = plt.scatter(calibration_error, pointing_offset, c=patch_rmss, cmap='viridis')
    plt.colorbar(sc)
    plt.xlabel('Calibration error (deg)')
    plt.ylabel('Pointing offset (arcmin)')
    plt.title('Mean patch RMS (Jy)')
    plt.show()

    cal_error_array = np.arange(6)
    point_offset_array = np.arange(5)
    #do a max bincount
    cal_error_bins = np.digitize(calibration_error, cal_error_array)
    point_offset_bins = np.digitize(pointing_offset, point_offset_array)
    patch_accumulate = -np.inf * np.ones((len(cal_error_array), len(point_offset_array)))
    for cal_error_bin, point_offset_bins, patch_rms in zip(cal_error_bins, point_offset_bins, patch_rmss):
        if cal_error_bin >= len(cal_error_array):
            continue
        if point_offset_bins >= len(point_offset_array):
            continue
        if patch_rms > patch_accumulate[cal_error_bin, point_offset_bins]:
            patch_accumulate[cal_error_bin, point_offset_bins] = patch_rms

    plt.imshow(
        patch_accumulate.T,
        extent=(0, cal_error_array.max(), 0, point_offset_array.max()),
        origin='lower',
        aspect='auto'
    )
    plt.colorbar()
    plt.xlabel('Calibration error (deg)')
    plt.ylabel('Pointing offset (arcmin)')
    plt.title('Max patch RMS (Jy)')
    plt.show()

    bins = [
        ((0., 3), (0, 2)),
        ((0., 3), (2, 4)),
    ]
    for (cal_lower, cal_upper), (point_lower, point_upper) in bins:
        mask = (calibration_error >= cal_lower) & (calibration_error < cal_upper) & (pointing_offset >= point_lower) & (
                pointing_offset < point_upper)
        plt.scatter(angular_offset_deg[mask], patch_rmss[mask])
    plt.show()

    thermal_limit = 2.6e-6
    patch_rmss /= thermal_limit

    plt.hist(patch_rmss, bins=np.linspace(0., 0.001, 10), density=True)
    plt.xlabel('Patch RMS / thermal limit')
    plt.ylabel('Density')

    plt.show()
