In [None]:
import os

os.environ['JAX_PLATFORMS'] = 'cuda'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'

import astropy.coordinates as ac
import astropy.time as at
import astropy.units as au
import jax
import jax.numpy as jnp
import numpy as np

from dsa2000_assets.content_registry import fill_registries
from dsa2000_assets.registries import array_registry
from dsa2000_common.common.logging import dsa_logger
from dsa2000_common.common.enu_frame import ENU
from dsa2000_common.common.fits_utils import ImageModel, save_image_to_fits
from dsa2000_common.common.quantity_utils import quantity_to_jnp
from dsa2000_common.common.ray_utils import TimerLog
from dsa2000_fm.array_layout.fast_psf_evaluation_evolving import compute_psf_from_gcrs
from dsa2000_fm.imaging.base_imagor import fit_beam

compute_psf_from_gcrs_jit = jax.jit(compute_psf_from_gcrs, static_argnames=['with_autocorr', 'accumulate_dtype'])


def main(save_folder: str, save_name: str, array_name: str, fov: au.Quantity, pixel_size: au.Quantity,
         transit_dec: au.Quantity, duration: au.Quantity,
         dt: au.Quantity,
         num_freqs: int):
    os.makedirs(save_folder, exist_ok=True)
    fill_registries()
    array = array_registry.get_instance(array_registry.get_match(array_name))
    antennas = array.get_antennas()
    array_location = array.get_array_location()

    obstime = at.Time('2025-06-10T16:00:00', scale='utc')

    antennas_gcrs = quantity_to_jnp(antennas.get_gcrs(obstime=obstime).cartesian.xyz.T)
    n = int(fov / pixel_size)
    if n % 2 == 1:
        n += 1

    if fov > 4 * au.arcmin:
        dsa_logger.warning(
            f"FOV {fov} is larger than 4 arcmin, and this will be slow. Better use gridding."
        )
    lvec = mvec = (-n / 2 + np.arange(n)) * pixel_size.to(au.rad).value
    L, M = np.meshgrid(lvec, lvec, indexing='ij')
    N = np.sqrt(1 - L ** 2 - M ** 2)
    lmn = jnp.stack([L.flatten(), M.flatten(), N.flatten()], axis=-1)
    freqs = np.linspace(700e6, 2000e6, num_freqs) * au.Hz
    freqs_jax = quantity_to_jnp(freqs, 'Hz')

    zenith = ENU(0, 0, 1, obstime=obstime, location=array_location).transform_to(ac.ICRS())

    num_times = int(duration / dt)

    times = jnp.arange(num_times) * quantity_to_jnp(dt, 's')

    with TimerLog("PSF evaluation"):
        # Compute the PSF
        psf = np.asarray(
            jax.block_until_ready(
                compute_psf_from_gcrs_jit(
                    antennas_gcrs=antennas_gcrs,
                    ra=quantity_to_jnp(zenith.ra, 'rad'),
                    dec=quantity_to_jnp(transit_dec, 'rad'),
                    lmn=lmn,
                    times=times,
                    freqs=freqs_jax,
                    with_autocorr=False,
                    accumulate_dtype=jnp.float32
                ).reshape(L.shape)
            )
        )

    with TimerLog("Fitting beam and saving"):
        major, minor, posang = fit_beam(
            psf=psf,
            dl=quantity_to_jnp(pixel_size, 'rad'),
            dm=quantity_to_jnp(pixel_size, 'rad')
        )
        dsa_logger.info(
            f"Beam fit: {major * 3600 * 180 / np.pi:.2f}arcsec, {minor * 3600 * 180 / np.pi:.2f}arcsec, {posang * 180 / np.pi:.2f}dec")

        image_model = ImageModel(
            phase_center=ac.ICRS(zenith.ra, transit_dec),
            obs_time=obstime,
            dl=pixel_size,
            dm=pixel_size,
            freqs=np.mean(freqs)[None],
            bandwidth=(freqs[-1] - freqs[0]),
            coherencies=('I',),
            beam_major=np.asarray(major) * au.rad,
            beam_minor=np.asarray(minor) * au.rad,
            beam_pa=np.asarray(posang) * au.rad,
            unit='JY/BEAM',
            object_name=f'{save_name}_PSF',
            image=psf[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
        )
        save_image_to_fits(
            file_path=os.path.join(save_folder, f'{save_name}_psf.fits'),
            image_model=image_model,
            overwrite=True
        )


if __name__ == '__main__':
    array_name = f"dsa1650_P305_v2.4.6"
    for transit_dec in [0, -30, 30, 60, 90] * au.deg:
        for duration in [7 * au.min, 28 * au.min]:
            save_name = f"{array_name}_{duration.to('min').value}min_dec{transit_dec.to('deg').value}"
            with TimerLog(f"Working on {save_name}"):
                main(
                    save_folder='cadenced_psfs',
                    save_name=save_name,
                    array_name=array_name,
                    pixel_size=0.8 * au.arcsec,
                    fov=3 * au.arcmin,
                    transit_dec=transit_dec,
                    duration=duration,
                    dt=1.5 * au.s,
                    num_freqs=10000
                )
