In [None]:

import os

os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"

import queue
from dsa2000_common.common.logging import dsa_logger
from dsa2000_fm.bright_souces.evaluate_rms import simulate_rms
import astropy.coordinates as ac
import astropy.units as au
import jax.random


def run(result_idx, cpu_idx, gpu_idx, pointing_offset_stddev, axial_focus_error_stddev,
        horizon_peak_astigmatism_stddev, with_smearing):
    cpus = jax.devices("cpu")
    gpus = jax.devices("cuda")
    cpu = cpus[cpu_idx]
    gpu = gpus[gpu_idx]
    simulate_rms(
        cpu=cpu,
        gpu=gpu,
        result_num=result_idx,
        seed=0,
        save_folder='sky_loss_19mar2025_varying_systematics',
        array_name='dsa2000_optimal_v1',
        pointing=ac.ICRS(0 * au.deg, 0 * au.deg),
        num_measure_points=256,
        image_batch_size=128,
        angular_radius=1.75 * au.deg,
        prior_psf_sidelobe_peak=1e-3,
        bright_source_id='nvss_calibrators',
        pointing_offset_stddev=pointing_offset_stddev,
        axial_focus_error_stddev=axial_focus_error_stddev,
        horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,
        turbulent=True,
        dawn=True,
        high_sun_spot=True,
        with_ionosphere=True,
        with_dish_effects=True,
        with_smearing=with_smearing
    )


def main(node_idx: int, num_nodes: int):
    cpus = jax.devices("cpu")
    gpus = jax.devices("cuda")
    dsa_logger.info(f"Launching over {len(gpus)} gpus")
    queues = [queue.Queue() for _ in gpus]

    # fill queues with input args
    result_idx = 0
    node_id = 0
    for pointing_offset_stddev in [0, 1, 2, 4] * au.arcmin:
        for axial_focus_error_stddev in [0, 3, 5] * au.mm:
            for horizon_peak_astigmatism_stddev in [0, 1, 2, 4] * au.mm:
                for with_smearing in [True, False]:
                    if (node_id % num_nodes) == node_idx:
                        q = queues[result_idx % len(gpus)]
                        gpu_idx = result_idx % len(gpus)
                        cpu_idx = result_idx % len(cpus)
                        q.put((run, result_idx, cpu_idx, gpu_idx, pointing_offset_stddev,
                               axial_focus_error_stddev, horizon_peak_astigmatism_stddev, with_smearing))
                    result_idx += 1
                    node_id += 1

    # now run the jobs in thread pool
    def worker(q):
        while True:
            args = q.get()
            if args is None:
                break
            f = args[0]
            args = args[1:]
            f(*args)

    # now run the jobs in thread pool, each job processes a queue
    import concurrent.futures

    with concurrent.futures.ThreadPoolExecutor() as executor:
        for q in queues:
            executor.submit(worker, q)


main(0, 2)
