Skip to content

Commit

Permalink
ENH respond to more CR
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Dec 16, 2023
1 parent 4d366a6 commit 7f34324
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions jax_galsim/photon_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,15 +694,43 @@ def convolve(self, rhs, rng=None):
self._sort_by_nokeep(sinds=self_sinds)
rhs._sort_by_nokeep(sinds=rhs_sinds)

# When two photon arrays are convolved, you basically perturb the positions of one
# by adding the positions of the other. For example, if you have a delta function
# and want to convolve with a Gaussian, then the photon arrays are an array of zeros
# for the delta function and an array of Gaussian draws for the Gaussian. The convolution
# is then implemented by adding the positions of the two arrays.

# The edge case here is if the photons in anb array are correlated. for example, if
# you draw photons from a sum of two profiles, you could have the photons from one
# of the components only at the start of the array and the photons from the other
# component only at the end of the array like this
#
# [A, A, A, ..., A, B, B, B. ..., B]
#
# where A and B represent which component the photon came from. If you convolve two
# photon arrays where both arrays have intenral correlations in the ordering of the
# photons, then you need to randomly sort one of the arrays before the convolution.
# Otherwise you won't properly be adding a random draew from one profile to the other.

# the indexing and PRNG code snippets below handle this case of convolving two internally
# correlated photon arrays.

# these are indicies that randomly sort the RHS's photons.
rng = BaseDeviate(rng)
rsinds = jrng.choice(
rng._state.split_one(),
self._Ntot,
shape=(self.size(),),
replace=False,
)
# these indices do not randomly sort the RHS's photons
nrsinds = jnp.arange(self.size())

# now we randomly sort if both arrays are internally correlated
# however there is a catch. The RHS may not be keeping all of its photons
# (i.e., rhs._nokeep is True for some photons). In this case, we additionally
# sort the random indices by the value of rhs._nokeep so that the photons to be
# kept are still at the front of the array but are in a new random order.
sinds = jax.lax.cond(
self._is_corr & rhs._is_corr,
lambda nrsinds, rsinds: rsinds.at[
Expand Down

0 comments on commit 7f34324

Please sign in to comment.