diff --git a/autolens/lens/substructure_util.py b/autolens/lens/substructure_util.py index 564231639..9555b4d69 100644 --- a/autolens/lens/substructure_util.py +++ b/autolens/lens/substructure_util.py @@ -168,3 +168,75 @@ def simulate_substructure( image_2d = image_2d - background_sky_level return image_2d + + +def los_realizations_to_arrays( + realization_galaxies, + plane_redshifts, + max_n, + profile_cls, +): + import jax.numpy as jnp + + all_params = [] + all_masks = [] + all_kappas = [] + + for galaxies in realization_galaxies: + params, mask, kappas = galaxies_to_halo_arrays( + galaxies=galaxies, + plane_redshifts=plane_redshifts, + max_n=max_n, + profile_cls=profile_cls, + ) + all_params.append(params) + all_masks.append(mask) + all_kappas.append(kappas) + + return jnp.stack(all_params), jnp.stack(all_masks), jnp.stack(all_kappas) + + +def batched_simulate_substructure( + grid, + image_shape, + halo_params_batch, + halo_mask_batch, + scaling_matrix, + macro_deflections_fn, + macro_plane_mask, + sheet_kappas_batch, + source_image_fn, + psf_kernel, + exposure_time, + background_sky_level, + prng_keys, + halo_profile_cls, +): + import jax + import functools + + single_fn = functools.partial( + simulate_substructure, + grid=grid, + image_shape=image_shape, + scaling_matrix=scaling_matrix, + macro_deflections_fn=macro_deflections_fn, + macro_plane_mask=macro_plane_mask, + source_image_fn=source_image_fn, + psf_kernel=psf_kernel, + exposure_time=exposure_time, + background_sky_level=background_sky_level, + halo_profile_cls=halo_profile_cls, + ) + + def call(halo_params, halo_mask, sheet_kappas, prng_key): + return single_fn( + halo_params=halo_params, + halo_mask=halo_mask, + sheet_kappas=sheet_kappas, + prng_key=prng_key, + ) + + return jax.vmap(call)( + halo_params_batch, halo_mask_batch, sheet_kappas_batch, prng_keys, + )