In [None]:
import os

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"

import itertools
import astropy.coordinates as ac
import astropy.time as at
import astropy.units as au
from astropy.coordinates import offset_by

import jax
import jax.numpy as jnp

from dsa2000_assets.content_registry import fill_registries
from dsa2000_assets.registries import array_registry
from dsa2000_common.common.astropy_utils import get_time_of_local_meridean
from dsa2000_common.common.fits_utils import ImageModel, save_image_to_fits
from dsa2000_common.common.quantity_utils import quantity_to_np, time_to_jnp
from dsa2000_common.common.wgridder import vis_to_image_np
from dsa2000_common.delay_models.uvw_utils import geometric_uvw_from_gcrs, perley_lmn_from_icrs
from dsa2000_fm.imaging.base_imagor import fit_beam
from dsa2000_fm.systematics.ionosphere import evolve_gcrs

import os

import numba
import numpy as np

from dsa2000_common.common.logging import dsa_logger


@jax.jit
def compute_uvw(antennas_gcrs, time, ra0, dec0):
    antennas_uvw = geometric_uvw_from_gcrs(evolve_gcrs(antennas_gcrs, time), ra0, dec0)
    i_idxs, j_idxs = jnp.asarray(list(itertools.combinations(range(antennas_uvw.shape[0]), 2))).T
    # i_idxs, j_idxs = np.triu_indices(antennas_uvw.shape[0], k=1)
    uvw = antennas_uvw[i_idxs] - antennas_uvw[j_idxs]
    return uvw


def main(array_name, plot_folder, dec: au.Quantity, pixel_size: au.Quantity, fov: au.Quantity,
         center_offset: au.Quantity, num_threads: int,
         duration: au.Quantity,
         freq_block_size: int,
         with_earth_rotation: bool,
         with_freq_synthesis: bool, num_reduced_obsfreqs: int | None,
         num_reduced_obstimes: int | None):
    image_name_base = f"{array_name}"
    plot_folder = os.path.join(plot_folder, image_name_base)
    os.makedirs(plot_folder, exist_ok=True)

    numba.set_num_threads(os.cpu_count())

    fill_registries()
    array = array_registry.get_instance(array_registry.get_match(array_name))
    antennas = array.get_antennas()
    array_location = array.get_array_location()

    channel_width = array.get_channel_width()
    system_equivalent_flux_density = 3360 * au.Jy
    integration_time = array.get_integration_time()

    pointing = ac.ICRS(ra=0 * au.rad, dec=dec)
    center_pointing = ac.ICRS(*offset_by(pointing.ra, pointing.dec, 90 * au.deg, center_offset))

    ref_time = get_time_of_local_meridean(coord=pointing, location=array_location,
                                          ref_time=at.Time("2025-06-10T00:00:00", format='isot', scale='utc'))

    obstimes = ref_time + np.arange(int(duration / integration_time)) * integration_time
    if num_reduced_obstimes is not None:
        times_before = len(obstimes)
        obstimes = obstimes[::len(obstimes) // num_reduced_obstimes]
        times_after = len(obstimes)
        dsa_logger.info(f"Reduced number of times from {times_before} to {times_after}")
        integration_time *= times_before / times_after
        dsa_logger.info(f"Adjusted integration time: {integration_time}")

    obsfreqs = array.get_channels()
    if num_reduced_obsfreqs is not None:
        chans_before = len(obsfreqs)
        obsfreqs = obsfreqs[::len(obsfreqs) // num_reduced_obsfreqs]
        chans_after = len(obsfreqs)
        dsa_logger.info(f"Reduced number of channels from {chans_before} to {chans_after}")
        channel_width *= chans_before / chans_after
        dsa_logger.info(f"Adjusted channel width: {channel_width}")

    if not with_earth_rotation:
        obstimes = ref_time + np.arange(1) * integration_time

    bandwidth = channel_width * len(obsfreqs)

    if not with_freq_synthesis:
        obsfreqs = np.mean(obsfreqs)[None]

    dsa_logger.info(f"Observing {pointing} at {ref_time} (Transit).")

    antennas_gcrs = quantity_to_np(antennas.get_gcrs(ref_time).cartesian.xyz.T)

    # We grid(degrid(f) + noise)

    N = np.shape(antennas_gcrs)[0]
    num_l = num_m = (int(fov / pixel_size) // 2) * 2
    dl = dm = np.array(pixel_size.to('rad').value)
    freqs = quantity_to_np(obsfreqs, 'Hz')
    times = time_to_jnp(obstimes, ref_time)
    ra0 = pointing.ra.rad
    dec0 = pointing.dec.rad

    l0, m0, _ = perley_lmn_from_icrs(
        center_pointing.ra.rad,
        center_pointing.dec.rad,
        ra0,
        dec0
    )

    output_psf_buffer = np.zeros((num_l, num_m), dtype=np.float32, order='F')
    output_psf_accum = np.zeros((num_l, num_m), dtype=np.float32, order='F')

    count = 0
    for t_idx in range(len(times)):
        uvw = np.array(compute_uvw(antennas_gcrs, times[t_idx], ra0, dec0), order='F')
        num_rows = uvw.shape[0]
        for nu_start_idx in range(0, len(freqs), freq_block_size):
            nu_end_idx = min(nu_start_idx + freq_block_size, len(freqs))

            vis_to_image_np(
                uvw=uvw,
                freqs=freqs[nu_start_idx:nu_end_idx],
                vis=np.ones((num_rows, nu_end_idx - nu_start_idx), dtype=np.complex64),
                pixsize_l=dl,
                pixsize_m=dm,
                center_l=l0,
                center_m=m0,
                npix_m=num_m,
                npix_l=num_l,
                wgt=None,
                mask=None,
                epsilon=1e-5,
                double_precision_accumulation=False,
                scale_by_n=True,
                normalise=True,
                output_buffer=output_psf_buffer,
                num_threads=num_threads
            )
            output_psf_accum += output_psf_buffer

            count += 1

    output_psf_accum /= count
    psf = output_psf_accum
    rad2arcsec = 3600 * 180 / np.pi
    major_beam, minor_beam, pa_beam = fit_beam(
        psf=psf,
        dl=dl * rad2arcsec,
        dm=dm * rad2arcsec
    )

    major_beam /= rad2arcsec
    minor_beam /= rad2arcsec

    dsa_logger.info(
        f"Beam major: {major_beam * rad2arcsec:.2f}arcsec, "
        f"minor: {minor_beam * rad2arcsec:.2f}arcsec, "
        f"posang: {pa_beam * 180 / np.pi:.2f}deg"
    )

    image_model = ImageModel(
        phase_center=pointing,
        obs_time=ref_time,
        dl=dl * au.rad,
        dm=dm * au.rad,
        freqs=np.mean(obsfreqs)[None],
        bandwidth=bandwidth,
        coherencies=('I',),
        beam_major=np.asarray(major_beam) * au.rad,
        beam_minor=np.asarray(minor_beam) * au.rad,
        beam_pa=np.asarray(pa_beam) * au.rad,
        unit='JY/BEAM',
        object_name=f'{image_name_base.upper()}_PSF',
        image=psf[:, :, None, None] * au.Jy  # [num_l, num_m, 1, 1]
    )
    psf_save_file = os.path.join(plot_folder, f"{image_name_base}_psf.fits")
    save_image_to_fits(psf_save_file, image_model=image_model,
                       overwrite=True, radian_angles=True, casa_compat_center_location=True)
    dsa_logger.info(f"PSF saved to {os.path.join(plot_folder, f'{image_name_base}_psf.fits')}")


if __name__ == '__main__':
    array_name = f"dsa1650_P305_v2.4.6"
    for weighting in ['natural']:
        for transit_dec in [0, -30, 30, 60, 90] * au.deg:
            for duration in [7 * au.min, 28 * au.min]:
                plot_folder = f"psf_dec{transit_dec.to('deg').value}_{duration.value}min_{weighting}"
                main(
                    array_name=array_name,
                    plot_folder=plot_folder,
                    dec=transit_dec,
                    pixel_size=0.6 * au.arcsec,
                    fov=5 * au.deg,
                    center_offset=0 * au.deg,
                    num_threads=os.cpu_count(),
                    freq_block_size=200,
                    duration=duration,
                    with_earth_rotation=True,
                    with_freq_synthesis=True,
                    num_reduced_obsfreqs=None,
                    num_reduced_obstimes=None
                )



2025-05-14 14:19:00,862 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:19:00,864 [INFO] (base_fits_source_model.py:668->build_fits_source_model_from_wsclean_components) dsa2000_cal: Found 1 fits files. 1 valid.
2025-05-14 14:19:00,865 [INFO] (base_fits_source_model.py:682->build_fits_source_model_from_wsclean_components) dsa2000_cal: Selecting frequencies: [1.35e+09] Hz
2025-05-14 14:19:00,872 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:19:00,890 [INFO] (base_fits_source_model.py:797->build_fits_source_model_from_wsclean_components) dsa2000_cal: freq=1350.0 MHz, dl=-3.878509448876249e-06 rad, dm=3.878509448876249e-06 rad
centre_ra=7.498798913309288e-33 rad, centre_dec=0.0 rad
centre_l_pix=256.0, centre_m_pix=256.0
num_l=512, num_m=512, num_stokes=1
2025-05-14 14:19:00,907 [INFO] (2196341534.py:239->main) dsa2000_cal: Reduced number of times from 280 to 2
2025-05-14 14:19:0

Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/arrays/dsa1650_9P/.sync_cert
Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/point_sources/.sync_cert


2025-05-14 14:20:59,616 [INFO] (2196341534.py:403->main) dsa2000_cal: Beam major: 3.07arcsec, minor: 3.04arcsec, posang: 90.37deg
2025-05-14 14:20:59,630 [INFO] (2196341534.py:435->main) dsa2000_cal: Image saved to plots_100freqs_2times_7.0min_natural_no_noise/point_sources_dsa1650_a_P305_v2.4.6/point_sources_dsa1650_a_P305_v2.4.6.fits
2025-05-14 14:20:59,644 [INFO] (2196341534.py:455->main) dsa2000_cal: PSF saved to plots_100freqs_2times_7.0min_natural_no_noise/point_sources_dsa1650_a_P305_v2.4.6/point_sources_dsa1650_a_P305_v2.4.6_psf.fits
2025-05-14 14:21:00,013 [INFO] (2196341534.py:110->deconvolve_image) dsa2000_cal: Cleaned 1 x 1 planes with 878 / 1000000 iterations
2025-05-14 14:21:00,325 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:21:00,327 [INFO] (base_fits_source_model.py:668->build_fits_source_model_from_wsclean_components) dsa2000_cal: Found 1 fits files. 1 valid.
2025-05-14 14:21:00,328 [INFO] (base_fits_source_mod

Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/arrays/dsa1650_9P/.sync_cert
Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/skamid_b1_1000h/.sync_cert


2025-05-14 14:22:56,918 [INFO] (2196341534.py:403->main) dsa2000_cal: Beam major: 3.08arcsec, minor: 2.35arcsec, posang: 90.82deg
2025-05-14 14:22:57,099 [INFO] (2196341534.py:435->main) dsa2000_cal: Image saved to plots_100freqs_2times_7.0min_natural_no_noise/skamid_b1_1000h_dsa1650_a_P305_v2.4.6/skamid_b1_1000h_dsa1650_a_P305_v2.4.6.fits
2025-05-14 14:22:57,278 [INFO] (2196341534.py:455->main) dsa2000_cal: PSF saved to plots_100freqs_2times_7.0min_natural_no_noise/skamid_b1_1000h_dsa1650_a_P305_v2.4.6/skamid_b1_1000h_dsa1650_a_P305_v2.4.6_psf.fits
2025-05-14 14:42:14,221 [INFO] (2196341534.py:110->deconvolve_image) dsa2000_cal: Cleaned 1 x 1 planes with 111006 / 1000000 iterations
2025-05-14 14:42:15,153 [INFO] (base_content.py:122->sync_content) dsa2000_cal: Syncing /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/ncg_5194 assets. This happens at most once per day, and may take a few minutes.


Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/arrays/dsa1650_9P/.sync_cert
Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/ncg_5194/.sync_cert
rsync -e ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o CheckHostIP=no -a --partial mario.caltech.edu:/safepool/fmcal_data/source_models/ncg_5194/NGC_5194_RO_MOM0_THINGS.FITS /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/ncg_5194/


2025-05-14 14:42:19,182 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:42:19,185 [INFO] (base_fits_source_model.py:668->build_fits_source_model_from_wsclean_components) dsa2000_cal: Found 1 fits files. 1 valid.
2025-05-14 14:42:19,187 [INFO] (base_fits_source_model.py:682->build_fits_source_model_from_wsclean_components) dsa2000_cal: Selecting frequencies: [1.41701197e+09] Hz
2025-05-14 14:42:19,273 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:42:19,325 [INFO] (base_fits_source_model.py:797->build_fits_source_model_from_wsclean_components) dsa2000_cal: freq=1417.01197043 MHz, dl=-7.272205393503071e-06 rad, dm=7.272205393503071e-06 rad
centre_ra=3.5337099588770036 rad, centre_dec=0.8237954069413236 rad
centre_l_pix=511.0, centre_m_pix=512.0
num_l=1024, num_m=1024, num_stokes=1
2025-05-14 14:42:19,383 [INFO] (2196341534.py:239->main) dsa2000_cal: Reduced number of times from 

Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/arrays/dsa1650_9P/.sync_cert
Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/galactic_center/.sync_cert
rsync -e ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o CheckHostIP=no -a --partial mario.caltech.edu:/safepool/fmcal_data/source_models/galactic_center/KATGC-model.fits /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/galactic_center/


2025-05-14 14:50:49,895 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:50:49,899 [INFO] (base_fits_source_model.py:668->build_fits_source_model_from_wsclean_components) dsa2000_cal: Found 1 fits files. 1 valid.
2025-05-14 14:50:49,902 [INFO] (base_fits_source_model.py:682->build_fits_source_model_from_wsclean_components) dsa2000_cal: Selecting frequencies: [1.28e+09] Hz
2025-05-14 14:50:49,937 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:50:50,304 [INFO] (base_fits_source_model.py:710->build_fits_source_model_from_wsclean_components) dsa2000_cal: Image has odd number of l pixels, removing last pixel
2025-05-14 14:50:50,305 [INFO] (base_fits_source_model.py:713->build_fits_source_model_from_wsclean_components) dsa2000_cal: Image has odd number of m pixels, removing last pixel
2025-05-14 14:50:50,316 [INFO] (base_fits_source_model.py:797->build_fits_source_model_from_wsclean_

Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/arrays/dsa1650_9P/.sync_cert
Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/point_sources/.sync_cert


2025-05-14 14:58:11,672 [INFO] (2196341534.py:403->main) dsa2000_cal: Beam major: 3.07arcsec, minor: 3.04arcsec, posang: 90.37deg
2025-05-14 14:58:11,686 [INFO] (2196341534.py:435->main) dsa2000_cal: Image saved to plots_100freqs_2times_7.0min_natural_with_noise/point_sources_dsa1650_a_P305_v2.4.6/point_sources_dsa1650_a_P305_v2.4.6.fits
2025-05-14 14:58:11,699 [INFO] (2196341534.py:455->main) dsa2000_cal: PSF saved to plots_100freqs_2times_7.0min_natural_with_noise/point_sources_dsa1650_a_P305_v2.4.6/point_sources_dsa1650_a_P305_v2.4.6_psf.fits
2025-05-14 14:58:12,026 [INFO] (2196341534.py:110->deconvolve_image) dsa2000_cal: Cleaned 1 x 1 planes with 869 / 1000000 iterations
2025-05-14 14:58:12,292 [INFO] (standardise_fits.py:145->standardize_fits) dsa2000_cal: Original BUNIT: JY/PIXEL
2025-05-14 14:58:12,294 [INFO] (base_fits_source_model.py:668->build_fits_source_model_from_wsclean_components) dsa2000_cal: Found 1 fits files. 1 valid.
2025-05-14 14:58:12,295 [INFO] (base_fits_source

Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/arrays/dsa1650_9P/.sync_cert
Searching for sync certificate: /home/albert/git/DSA2000-Cal/dsa2000_cal/src/dsa2000_assets/source_models/skamid_b1_1000h/.sync_cert


2025-05-14 15:00:01,221 [INFO] (2196341534.py:403->main) dsa2000_cal: Beam major: 3.08arcsec, minor: 2.35arcsec, posang: 90.82deg
2025-05-14 15:00:01,408 [INFO] (2196341534.py:435->main) dsa2000_cal: Image saved to plots_100freqs_2times_7.0min_natural_with_noise/skamid_b1_1000h_dsa1650_a_P305_v2.4.6/skamid_b1_1000h_dsa1650_a_P305_v2.4.6.fits
2025-05-14 15:00:01,587 [INFO] (2196341534.py:455->main) dsa2000_cal: PSF saved to plots_100freqs_2times_7.0min_natural_with_noise/skamid_b1_1000h_dsa1650_a_P305_v2.4.6/skamid_b1_1000h_dsa1650_a_P305_v2.4.6_psf.fits
