In [None]:
import dataclasses
import gc
import itertools
import json
import os
import time

import astropy.time as at
import jax
import numpy as np
from astropy import units as au, coordinates as ac
from astropy.coordinates import offset_by
from tomographic_kernel.frames import ENU
from tqdm import tqdm

from dsa2000_cal.assets.content_registry import fill_registries
from dsa2000_cal.assets.registries import array_registry
from dsa2000_cal.common.fits_utils import ImageModel, save_image_to_fits
from dsa2000_cal.common.mixed_precision_utils import mp_policy
from dsa2000_cal.common.quantity_utils import time_to_jnp, quantity_to_jnp, quantity_to_np
from dsa2000_cal.common.types import VisibilityCoords
from dsa2000_cal.common.wgridder import vis_to_image_np
from dsa2000_cal.delay_models.base_far_field_delay_engine import build_far_field_delay_engine, BaseFarFieldDelayEngine
from dsa2000_cal.delay_models.base_near_field_delay_engine import BaseNearFieldDelayEngine, \
    build_near_field_delay_engine
from dsa2000_cal.gain_models.beam_gain_model import build_beam_gain_model
from dsa2000_cal.gain_models.gain_model import GainModel
from dsa2000_cal.geodesics.base_geodesic_model import BaseGeodesicModel, build_geodesic_model
from dsa2000_cal.imaging.base_imagor import fit_beam
from dsa2000_cal.imaging.utils import get_array_image_parameters
from dsa2000_cal.systematics.dish_aperture_effects import build_dish_aperture_effects
from dsa2000_cal.visibility_model.source_models.celestial.base_point_source_model import BasePointSourceModel, \
    build_point_source_model


@dataclasses.dataclass
class TimerLog:
    msg: str

    def __post_init__(self):
        self.t0 = time.time()

    def __enter__(self):
        print(f"{self.msg}")

    def __exit__(self, exc_type, exc_val, exc_tb):
        print(f"... took {time.time() - self.t0:.3f} seconds")
        return False


def build_mock_obs_setup(angular_offset: au.Quantity, dec: au.Quantity, calibration_error_stddev: au.Quantity,
                         pointing_offset_stddev: au.Quantity, array_name: str):
    fill_registries()
    array = array_registry.get_instance(array_registry.get_match(array_name))

    array_location = array.get_array_location()

    ref_time = at.Time('2021-01-01T00:00:00', scale='utc')
    num_times = 1
    obstimes = ref_time + np.arange(num_times) * array.get_integration_time()

    phase_center = ENU(0, 0, 1, location=array_location, obstime=ref_time).transform_to(ac.ICRS())
    phase_center = ac.ICRS(phase_center.ra, dec)

    freqs = array.get_channels()

    antennas = array.get_antennas()

    far_field_delay_engine = build_far_field_delay_engine(
        antennas=antennas,
        phase_center=phase_center,
        start_time=obstimes.min(),
        end_time=obstimes.max(),
        ref_time=ref_time
    )

    near_field_delay_engine = build_near_field_delay_engine(
        antennas=antennas,
        start_time=obstimes.min(),
        end_time=obstimes.max(),
        ref_time=ref_time
    )

    geodesic_model = build_geodesic_model(
        antennas=antennas,
        array_location=array_location,
        phase_center=phase_center,
        obstimes=obstimes,
        ref_time=ref_time,
        pointings=None
    )

    chan_width = array.get_channel_width()

    # build sky model pointings
    pos_ang = np.random.uniform(0, 2 * np.pi) * au.rad
    source_ra, source_dec = offset_by(phase_center.ra, phase_center.dec, posang=pos_ang, distance=angular_offset)
    source = ac.ICRS(source_ra, source_dec)

    lmn0 = geodesic_model.compute_far_field_lmn(source_ra.to('rad').value, source_dec.to('rad').value, 0.)

    beam_model = build_beam_gain_model(array_name=array_name, full_stokes=False, times=obstimes, ref_time=ref_time)

    source_model = build_point_source_model(
        model_freqs=freqs[[0, -1]],
        ra=source.ra[None],
        dec=source.dec[None],
        A=np.ones((1, 2)) * au.Jy
    )

    dish_aperture_effects = build_dish_aperture_effects(
        dish_diameter=array.get_antenna_diameter(),
        focal_length=array.get_focal_length(),
        elevation_pointing_error_stddev=pointing_offset_stddev,
        cross_elevation_pointing_error_stddev=pointing_offset_stddev,
        # axial_focus_error_stddev=3 * au.mm,
        # elevation_feed_offset_stddev=3 * au.mm,
        # cross_elevation_feed_offset_stddev=3 * au.mm,
        # horizon_peak_astigmatism_stddev=5 * au.mm,
        # surface_error_mean=0 * au.mm,
        # surface_error_stddev=1 * au.mm
    )

    calibration_phase_error_stddev = calibration_error_stddev.to('rad').value

    return lmn0, ref_time, obstimes, freqs, chan_width, phase_center, antennas, near_field_delay_engine, far_field_delay_engine, calibration_phase_error_stddev, dish_aperture_effects, beam_model, source_model, geodesic_model


@jax.jit
def compute_corrupt_visibilties(
        source_model: BasePointSourceModel,
        beam_model: GainModel,
        visibility_coords: VisibilityCoords,
        near_field_delay_engine: BaseNearFieldDelayEngine,
        far_field_delay_engine: BaseFarFieldDelayEngine,
        geodesic_model: BaseGeodesicModel
):
    """
    Computes the corrupted visibilities for a single solution interval, and stores in the visibilities buffer.

    Args:
        source_model: The source model to use.
        beam_model: The beam model to use.
        visibility_coords: The visibility coordinates.
        near_field_delay_engine: The near field delay engine.
        far_field_delay_engine: The far field delay engine.
        geodesic_model: The geodesic model.

    Returns:
        [T*B, C] array of corrupted visibilities (no noise).
    """
    visibilities = source_model.predict(
        visibility_coords=visibility_coords,
        gain_model=beam_model,
        near_field_delay_engine=near_field_delay_engine,
        far_field_delay_engine=far_field_delay_engine,
        geodesic_model=geodesic_model
    )
    T, B, C = np.shape(visibilities)
    return jax.lax.reshape(visibilities, (T * B, C))


@jax.jit
def compute_visibility_coords(far_field_delay_engine, times, freqs) -> VisibilityCoords:
    return far_field_delay_engine.compute_visibility_coords(
        freqs=freqs,
        times=times,
        with_autocorr=False
    )


fit_beam_jit = jax.jit(fit_beam, static_argnames=['max_central_size'])


def main(plot_folder: str, result_file: str, array_name: str, oversample_factor: float,
         num_patches_per_side: int):
    os.makedirs(plot_folder, exist_ok=True)

    results = []
    num_pixel, dl, dm, _, _ = get_array_image_parameters(array_name, oversample_factor, threshold=0.5)

    # num_pixel = 2**15
    # dl = dm = 0.8 * au.arcsec
    # l0 = m0 = 0 * au.arcsec

    psf_buffer = np.zeros((num_pixel, num_pixel), dtype=mp_policy.image_dtype, order='F')

    pbar = tqdm(itertools.count())
    # Create array setup
    for angular_offset in [1.5, 3, 4.5, 6, 7.5, 9, 10.5, 12, 13.5, 15] * au.deg:
        for pointing_offset_stddev in [0, 0.5, 1, 1.5, 2, 2.5, 3] * au.arcmin:
            for calibration_error_stddev in [1, 2] * au.deg:
                gc.collect()
                transit_dec = np.random.uniform(-31, 90) * au.deg
                pbar.set_description(
                    f"Angular offset: {angular_offset}, Pointing offset: {pointing_offset_stddev}, Calibration error: {calibration_error_stddev}, Transit dec: {transit_dec}"
                )
                object_name = f"{array_name}_O{angular_offset.to('deg').value:.1f}deg_PE{pointing_offset_stddev.to('arcmin').value:.1f}arcmin_CE{calibration_error_stddev.to('deg').value:.1f}deg"

                (lmn0, ref_time, obstimes, obsfreqs, chan_width, phase_center, antennas, near_field_delay_engine,
                 far_field_delay_engine, calibration_phase_error_stddev, dish_aperture_effects, beam_model,
                 source_model, geodesic_model) = build_mock_obs_setup(
                    angular_offset=angular_offset,
                    dec=transit_dec,
                    calibration_error_stddev=calibration_error_stddev,
                    pointing_offset_stddev=pointing_offset_stddev,
                    array_name=array_name)
                times = time_to_jnp(obstimes, ref_time)
                freqs = quantity_to_jnp(obsfreqs, 'Hz')
                with TimerLog(f"Computing visibilty coordinates"):
                    visibility_coords = jax.block_until_ready(
                        compute_visibility_coords(
                            freqs=freqs,
                            times=times,
                            far_field_delay_engine=far_field_delay_engine
                        )
                    )
                with TimerLog(f"Computing corrupt visibilities"):
                    visibilities = jax.block_until_ready(compute_corrupt_visibilties(
                        source_model=source_model,
                        beam_model=beam_model,
                        visibility_coords=visibility_coords,
                        near_field_delay_engine=near_field_delay_engine,
                        far_field_delay_engine=far_field_delay_engine,
                        geodesic_model=geodesic_model
                    ))  # [T*B, C]

                print(f"Size of vis: {visibilities.nbytes / 2 ** 30} GB")
                uvw = np.array(visibility_coords.uvw.reshape((-1, 3)), order='C')  # [T*B, 3]
                freqs = np.array(visibility_coords.freqs),
                num_rows, C = np.shape(visibilities)
                visibilities = np.asarray(visibilities, order='F')  # [T*B, C]

                with TimerLog(
                        f"Computing PSF for {angular_offset}, {pointing_offset_stddev}, {calibration_error_stddev}, {transit_dec}"):
                    psf_buffer *= 0.
                    vis_to_image_np(
                        uvw=uvw,
                        freqs=freqs,
                        vis=visibilities,
                        pixsize_m=quantity_to_np(dm, 'rad'),
                        pixsize_l=quantity_to_np(dl, 'rad'),
                        center_l=quantity_to_np(lmn0[0], 'rad'),
                        center_m=quantity_to_np(lmn0[1], 'rad'),
                        npix_l=num_pixel,
                        npix_m=num_pixel,
                        wgt=None,
                        mask=None,
                        epsilon=1e-6,
                        double_precision_accumulation=False,
                        scale_by_n=True,
                        normalise=True,
                        output_buffer=psf_buffer,
                        num_threads=72
                    )
                with TimerLog("Fitting beam"):
                    major, minor, posang = fit_beam_jit(
                        psf=psf_buffer,
                        dl=quantity_to_jnp(dl, 'rad'),
                        dm=quantity_to_jnp(dm, 'rad')
                    )

                    rad2arcsec = 3600 * 180 / np.pi
                    print(
                        f"Beam major: {major * rad2arcsec:.2f}arcsec, "
                        f"minor: {minor * rad2arcsec:.2f}arcsec, "
                        f"posang: {posang * 180 * np.pi:.2f}deg"
                    )

                with TimerLog("Saving PSF"):
                    image_model = ImageModel(
                        phase_center=phase_center,
                        obs_time=ref_time,
                        dl=dl,
                        dm=dm,
                        freqs=np.mean(obsfreqs)[None],
                        bandwidth=len(obsfreqs) * chan_width,
                        coherencies=('I',),
                        beam_major=np.asarray(major) * au.rad,
                        beam_minor=np.asarray(minor) * au.rad,
                        beam_pa=np.asarray(posang) * au.rad,
                        unit='JY/PIXEL',
                        object_name=object_name.upper(),
                        image=psf_buffer[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
                    )
                    save_image_to_fits(os.path.join(plot_folder, f"{object_name}_psf.fits"), image_model=image_model,
                                       overwrite=True)

                with TimerLog(
                        f"Computing FoV for {angular_offset}, {pointing_offset_stddev}, {calibration_error_stddev}, {transit_dec}"):
                    psf_buffer *= 0.
                    vis_to_image_np(
                        uvw=uvw,
                        freqs=freqs,
                        vis=visibilities,
                        pixsize_m=quantity_to_np(dm, 'rad'),
                        pixsize_l=quantity_to_np(dl, 'rad'),
                        center_l=0.,  #quantity_to_np(l0, 'rad'),
                        center_m=0.,  #quantity_to_np(m0, 'rad'),
                        npix_l=num_pixel,
                        npix_m=num_pixel,
                        wgt=None,
                        mask=None,
                        epsilon=1e-6,
                        double_precision_accumulation=False,
                        scale_by_n=True,
                        normalise=True,
                        output_buffer=psf_buffer,
                        num_threads=72
                    )
                with TimerLog("Computing patch RMS's"):
                    patch_rmss = []
                    for patch_row in range(num_patches_per_side):
                        for patch_col in range(num_patches_per_side):
                            patch_size = num_pixel // num_patches_per_side

                            patch = psf_buffer[
                                    patch_row * patch_size:(patch_row + 1) * patch_size,
                                    patch_col * patch_size:(patch_col + 1) * patch_size
                                    ]
                            patch_rmss.append(float(np.sqrt(np.mean(patch ** 2))))
                results.append(
                    dict(
                        angular_offset_deg=angular_offset.to('deg').value,
                        pointing_offset_stddev_arcmin=pointing_offset_stddev.to('arcmin').value,
                        calibration_error_stddev_deg=calibration_error_stddev.to('deg').value,
                        transit_dec_deg=transit_dec.to('deg').value,
                        patch_rmss_Jy=patch_rmss
                    )
                )
                with TimerLog("Saving FoV"):
                    image_model = ImageModel(
                        phase_center=phase_center,
                        obs_time=ref_time,
                        dl=dl,
                        dm=dm,
                        freqs=np.mean(obsfreqs)[None],
                        bandwidth=len(obsfreqs) * chan_width,
                        coherencies=('I',),
                        beam_major=np.asarray(major) * au.rad,
                        beam_minor=np.asarray(minor) * au.rad,
                        beam_pa=np.asarray(posang) * au.rad,
                        unit='JY/PIXEL',
                        object_name=object_name.upper(),
                        image=psf_buffer[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
                    )
                    save_image_to_fits(os.path.join(plot_folder, f"{object_name}_fov.fits"),
                                       image_model=image_model, overwrite=True)
                pbar.update(1)
                with open(result_file, 'w') as f:
                    json.dump(results, f, indent=2)


if __name__ == '__main__':
    main(
        plot_folder='plots',
        result_file='results.json',
        array_name='dsa2000_optimal_v1',
        oversample_factor=3.3,
        num_patches_per_side=3
    )