In [None]:
import os

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"

from dsa2000_common.common.enu_frame import ENU

import itertools
import astropy.coordinates as ac
import astropy.time as at
import astropy.units as au
import jax
import jax.numpy as jnp

from dsa2000_assets.content_registry import fill_registries
from dsa2000_assets.registries import array_registry
from dsa2000_common.common.astropy_utils import mean_itrs
from dsa2000_common.common.logging import dsa_logger
from dsa2000_common.common.quantity_utils import quantity_to_np, time_to_jnp
from dsa2000_common.delay_models.uvw_utils import geometric_uvw_from_gcrs, perley_lmn_from_icrs
from dsa2000_fm.systematics.ionosphere import evolve_gcrs

import numpy as np
from scipy.special import j0
from scipy.integrate import simpson
import pylab as plt


@jax.jit
def compute_uvw(antennas_gcrs, time, ra0, dec0):
    antennas_uvw = geometric_uvw_from_gcrs(evolve_gcrs(antennas_gcrs, time), ra0, dec0)
    i_idxs, j_idxs = jnp.asarray(list(itertools.combinations(range(antennas_uvw.shape[0]), 2))).T
    # i_idxs, j_idxs = np.triu_indices(antennas_uvw.shape[0], k=1)
    uvw = antennas_uvw[i_idxs] - antennas_uvw[j_idxs]
    return uvw


def main(duration, freq_block_size,
         with_earth_rotation: bool, with_freq_synthesis: bool,
         num_reduced_obsfreqs: int | None,
         num_reduced_obstimes: int | None):
    fill_registries()
    array = array_registry.get_instance(array_registry.get_match('dsa1650_P305'))
    antennas = array.get_antennas()

    channel_width = array.get_channel_width()
    system_equivalent_flux_density = 3360 * au.Jy
    integration_time = array.get_integration_time()
    array_location = mean_itrs(antennas.get_itrs()).earth_location

    ref_time = ref_time = at.Time("2025-06-10T00:00:00", format='isot', scale='utc')
    pointing = ENU(0, 0, 1, obstime=ref_time, location=array_location).transform_to(ac.ICRS())

    obstimes = ref_time + np.arange(int(duration / integration_time)) * integration_time
    if num_reduced_obstimes is not None:
        times_before = len(obstimes)
        obstimes = obstimes[::len(obstimes) // num_reduced_obstimes]
        times_after = len(obstimes)
        integration_time *= times_before / times_after
        dsa_logger.info(f"Adjusted integration time: {integration_time}")

    obsfreqs = array.get_channels()
    if num_reduced_obsfreqs is not None:
        chans_before = len(obsfreqs)
        obsfreqs = obsfreqs[::len(obsfreqs) // num_reduced_obsfreqs]
        chans_after = len(obsfreqs)
        channel_width *= chans_before / chans_after
        dsa_logger.info(f"Adjusted channel width: {channel_width}")

    if not with_earth_rotation:
        obstimes = ref_time + np.arange(1) * integration_time
        dsa_logger.info(f"Simulated earth rotation noise reduction: {simulated_noise_reduction:.2f}")

    bandwidth = channel_width * len(obsfreqs)

    if not with_freq_synthesis:
        obsfreqs = np.mean(obsfreqs)[None]
        dsa_logger.info(f"Simulated freq synth noise reduction: {simulated_noise_reduction:.2f}")

    dsa_logger.info(f"Observing {pointing} at {ref_time} (Transit).")

    antennas_gcrs = quantity_to_np(antennas.get_gcrs(ref_time).cartesian.xyz.T)

    # We grid(degrid(f) + noise)

    N = np.shape(antennas_gcrs)[0]

    ra0 = pointing.ra.rad
    dec0 = pointing.dec.rad
    freqs = quantity_to_np(obsfreqs, 'Hz')
    times = time_to_jnp(obstimes, ref_time)

    ra_far = ra0 + 5 * np.pi / 180
    dec_far = dec0
    far_pointing = ac.ICRS(ra_far * au.rad, dec_far * au.rad)
    l0_far, m0_far, _ = perley_lmn_from_icrs(
        ra_far,
        dec_far,
        ra0,
        dec0
    )

    # 3) choose k sample points
    # wavelength = 0.222
    Nr = 4096
    rb = np.linspace(0, 16e3, Nr)

    rad2arcsec = 3600 * 180 / np.pi
    k = np.linspace(0, 1000 / rad2arcsec, Nr)

    Fbar = 0
    count = 0
    for t_idx in range(len(times)):
        uvw = np.array(compute_uvw(antennas_gcrs, times[t_idx], ra0, dec0), order='F')
        for freq_idx in range(len(freqs)):
            wavelength = 299792458 / freqs[freq_idx]
            uvec = np.linspace(uvw[:, 0].min(), uvw[:, 0].max(), Nr * 2) / wavelength
            vvec = np.linspace(uvw[:, 1].min(), uvw[:, 1].max(), Nr * 2) / wavelength
            U, V = np.meshgrid(uvec, vvec, indexing='ij')
            R = np.sqrt(U ** 2 + V ** 2)
            f, _ = np.histogram(R.flatten(), rb / wavelength)
            _Fbar = az_avg(k, rb / wavelength, f)
            Fbar = Fbar + _Fbar
            count += 1
    radial_psf = Fbar / count
    radial_psf /= radial_psf[0]

    rad2arcsec = 3600 * 180 / np.pi
    plt.plot(k * rad2arcsec, radial_psf)
    plt.ylim(1e-6, 1)
    plt.xscale('log')
    plt.yscale('log')
    plt.axvline(3.05 / 2)
    plt.xlabel('l (arcsec)')
    plt.grid()
    plt.show()


def az_avg(k, rb, f):
    rc = 0.5 * (rb[:-1] + rb[1:])  # bin centers
    # 4) build the J0(kr) kernel and integrate
    #    note the factor “r” in the integrand
    kernel = j0(2*np.pi * np.outer(k, rc)) * rc[np.newaxis, :]
    Fbar = 2 * np.pi * simpson(f[np.newaxis, :] * kernel, rc, axis=1)

    # Fbar[j] is the azimuthal‐average PSF at radius k[j]

    return Fbar


main(
    duration=7 * au.min,
    freq_block_size=10,
    with_earth_rotation=True,
    with_freq_synthesis=True,
    num_reduced_obsfreqs=10,
    num_reduced_obstimes=1
)