In [None]:

import dataclasses
import gc
import os
import time

import astropy.time as at
import jax
import numpy as np
from astropy import units as au, coordinates as ac
from tomographic_kernel.frames import ENU

from dsa2000_cal.assets.content_registry import fill_registries
from dsa2000_cal.assets.registries import array_registry
from dsa2000_cal.common.fits_utils import ImageModel, save_image_to_fits
from dsa2000_cal.common.mixed_precision_utils import mp_policy
from dsa2000_cal.common.quantity_utils import time_to_jnp, quantity_to_jnp, quantity_to_np
from dsa2000_cal.common.types import VisibilityCoords
from dsa2000_cal.common.wgridder import vis_to_image_np
from dsa2000_common.delay_models import build_far_field_delay_engine
from dsa2000_cal.imaging.base_imagor import fit_beam
from dsa2000_cal.imaging.utils import get_array_image_parameters


@dataclasses.dataclass
class TimerLog:
    msg: str

    def __post_init__(self):
        self.t0 = time.time()

    def __enter__(self):
        print(f"{self.msg}")

    def __exit__(self, exc_type, exc_val, exc_tb):
        print(f"... took {time.time() - self.t0:.3f} seconds")
        return False


def build_mock_obs_setup(array_name: str, num_sol_ints_time: int):
    fill_registries()
    array = array_registry.get_instance(array_registry.get_match(array_name))

    array_location = array.get_array_location()

    ref_time = at.Time('2021-01-01T00:00:00', scale='utc')
    num_times = 4 * num_sol_ints_time
    obstimes = ref_time + np.arange(num_times) * array.get_integration_time()

    phase_center = ENU(0, 0, 1, location=array_location, obstime=ref_time).transform_to(ac.ICRS())
    phase_center = ac.ICRS(phase_center.ra, 0 * au.deg)

    freqs = array.get_channels()[:40]

    antennas = array.get_antennas()

    far_field_delay_engine = build_far_field_delay_engine(
        antennas=antennas,
        phase_center=phase_center,
        start_time=obstimes.min(),
        end_time=obstimes.max(),
        ref_time=ref_time
    )

    chan_width = array.get_channel_width()

    return ref_time, obstimes, freqs, chan_width, phase_center, antennas, far_field_delay_engine


def grid_psf(visibility_coords: VisibilityCoords, num_l, num_m, dl, dm, l0, m0):
    """
    Grids the visibilities for a single solution interval.

    Args:
        visibilities: [Ts, B, F[,2,2]] the visibilities
        weights: [Ts, B, F[,2,2]] the weights
        flags: [Ts, B, F[,2,2]] the flags, True means flagged, don't grid.

    Returns:
        the gridded image and psf
    """

    freqs = np.asarray(visibility_coords.freqs)  # [C]
    C = np.shape(freqs)[0]
    uvw = np.array(visibility_coords.uvw)  # [T, B, 3]
    T, B, _ = np.shape(uvw)
    num_rows = T * B
    # Add extra axes
    visibilities = np.ones((num_rows, C), dtype=mp_policy.vis_dtype, order='F')
    psf_buffer = np.zeros((num_l, num_m), dtype=mp_policy.image_dtype, order='F')

    uvw = np.asarray(uvw.reshape((num_rows, 3)), order='C')  # Want

    vis_to_image_np(
        uvw=uvw,
        freqs=freqs,
        vis=visibilities,
        pixsize_m=quantity_to_np(dm, 'rad'),
        pixsize_l=quantity_to_np(dl, 'rad'),
        center_l=quantity_to_np(l0, 'rad'),
        center_m=quantity_to_np(m0, 'rad'),
        npix_l=num_l,
        npix_m=num_m,
        wgt=None,
        mask=None,
        epsilon=1e-6,
        double_precision_accumulation=False,
        scale_by_n=True,
        normalise=True,
        output_buffer=psf_buffer,
        num_threads=72
    )

    if np.all(psf_buffer == 0) or not np.all(np.isfinite(psf_buffer)):
        print(f"PSF buffer is all zeros or contains NaNs/Infs")
    return psf_buffer


@jax.jit
def compute_visibility_coords(far_field_delay_engine, times, freqs) -> VisibilityCoords:
    return far_field_delay_engine.compute_visibility_coords(
        freqs=freqs,
        times=times,
        with_autocorr=False
    )


fit_beam_jit = jax.jit(fit_beam, static_argnames=['max_central_size'])


def main(plot_folder: str, image_name: str, array_name: str, num_sol_ints_time: int, fov: au.Quantity,
         oversample_factor: float = 3.8):
    os.makedirs(plot_folder, exist_ok=True)

    # Create array setup
    (ref_time, obstimes, obsfreqs, chan_width, phase_center, antennas, far_field_delay_engine) = build_mock_obs_setup(
        array_name, num_sol_ints_time)

    num_pixel, dl, dm, l0, m0 = get_array_image_parameters(array_name, fov, oversample_factor)

    # num_pixel = 2**15
    # dl = dm = 0.8 * au.arcsec
    # l0 = m0 = 0 * au.arcsec

    print(f"Image size: [{num_pixel}, {num_pixel}], pixel size: {dl.to('arcsec').value:.2f}arcsec")

    visibilities = None
    psf_buffer = None
    psf_accumulate = None

    freqs = quantity_to_jnp(obsfreqs, 'Hz')

    for time_idx in range(len(obstimes)):

        times = time_to_jnp(obstimes[time_idx:(time_idx + 1)], ref_time)

        with TimerLog(f"Computing visibilty coordinates for time_idx {time_idx}"):
            visibility_coords = jax.block_until_ready(
                compute_visibility_coords(
                    freqs=freqs,
                    times=times,
                    far_field_delay_engine=far_field_delay_engine
                )
            )

        with TimerLog(f"Computing PSF for {time_idx}"):
            C = np.shape(visibility_coords.freqs)[0]
            uvw = np.array(visibility_coords.uvw)  # [T, B, 3]
            T, B, _ = np.shape(uvw)
            num_rows = T * B
            # Add extra axes
            if visibilities is None:
                visibilities = np.ones((num_rows, C), dtype=mp_policy.vis_dtype, order='F')
            if psf_buffer is None:
                psf_buffer = np.zeros((num_pixel, num_pixel), dtype=mp_policy.image_dtype, order='F')

            uvw = np.asarray(uvw.reshape((num_rows, 3)), order='C')  # Want

            vis_to_image_np(
                uvw=uvw,
                freqs=np.asarray(visibility_coords.freqs),
                vis=visibilities,
                pixsize_m=quantity_to_np(dm, 'rad'),
                pixsize_l=quantity_to_np(dl, 'rad'),
                center_l=quantity_to_np(l0, 'rad'),
                center_m=quantity_to_np(m0, 'rad'),
                npix_l=num_pixel,
                npix_m=num_pixel,
                wgt=None,
                mask=None,
                epsilon=1e-6,
                double_precision_accumulation=False,
                scale_by_n=True,
                normalise=True,
                output_buffer=psf_buffer,
                num_threads=72
            )
            # psf = grid_psf(visibility_coords, num_pixel, num_pixel, dl, dm, l0, m0)

        with TimerLog("Plotting image"):
            if psf_accumulate is None:
                psf_accumulate = psf_buffer.copy()
            else:
                psf_accumulate += psf_buffer

            normalisation = np.max(psf_accumulate)
            psf_normed = psf_accumulate / normalisation
            major, minor, posang = fit_beam_jit(
                psf=psf_normed,
                dl=quantity_to_jnp(dl, 'rad'),
                dm=quantity_to_jnp(dm, 'rad')
            )
            rad2arcsec = 3600 * 180 / np.pi
            print(
                f"Beam major: {major * rad2arcsec:.2f}arcsec, "
                f"minor: {minor * rad2arcsec:.2f}arcsec, "
                f"posang: {posang * 180 * np.pi:.2f}deg"
            )

            image_model = ImageModel(
                phase_center=phase_center,
                obs_time=ref_time,
                dl=dl,
                dm=dm,
                freqs=np.mean(obsfreqs)[None],
                bandwidth=len(obsfreqs) * chan_width,
                coherencies=('I',),
                beam_major=np.asarray(major) * au.rad,
                beam_minor=np.asarray(minor) * au.rad,
                beam_pa=np.asarray(posang) * au.rad,
                unit='JY/PIXEL',
                object_name=f'DSA2000_PSF_{time_idx:03d}',
                image=psf_normed[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
            )
            save_image_to_fits(os.path.join(plot_folder, f"{image_name}_psf.fits"), image_model=image_model,
                               overwrite=True)

            del psf_normed
            gc.collect()


if __name__ == '__main__':
    main(
        plot_folder='plots',
        image_name='dsa2000_optimal_v1_5MHz',
        array_name='dsa2000_optimal_v1',
        num_sol_ints_time=103,
        fov=7 * au.deg,
        oversample_factor=3
    )
