-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement BaseNoise
and GaussianNoise
#19
Comments
Well, I have played with JAX random generators for personal usage so I would be pleased to enter in the game, but may be I should be introduced to what exactly is needed, especially if one wants to perform some multi-GPUs simulations where the spread on RNG should be done with care, I guess. So feel free to contact me for discussion. For instance, do we have to rebuild a sort of "Random Distributions" set of classes or we rely on the Numpyro Distributions lib.. This at least would avoid error-trials process to validate this part of Jax-Galsim, no? Hum after ready some part of the code, I understand that this JAX-GalSim implementation should be as close as possible to the original GalSim way of doing things. Q: "Galsim random.py" implements if I understand the random distribution classes, and "Noise" classes rely on that to generate image noise. So, I guess that the first thing to do would be the make a JAX version of "Galsim random.py", no? |
Hello, _PYRO_STACK = []
def default_process_message(msg):
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
def apply_stack(msg):
"""
Execute the effect stack at a single site according to the following scheme:
1. For each ``Messenger`` in the stack from bottom to top,
execute ``Messenger.process_message`` with the message;
if the message field "stop" is True, stop;
otherwise, continue
2. Apply default behavior (``default_process_message``) to finish remaining
site execution
3. For each ``Messenger`` in the stack from top to bottom,
execute ``Messenger.postprocess_message`` to update the message
and internal messenger state with the site results
"""
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
default_process_message(msg)
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer - 1 :]:
handler.postprocess_message(msg)
return msg
class Messenger(object):
def __init__(self, fn=None):
if fn is not None and not callable(fn):
raise ValueError(
"Expected `fn` to be a Python callable object; "
"instead found type(fn) = {}.".format(type(fn))
)
self.fn = fn
def __enter__(self):
_PYRO_STACK.append(self)
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
else:
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
if self.fn is None:
# Assume self is being used as a decorator.
assert len(args) == 1 and not kwargs
self.fn = args[0]
return self
with self:
return self.fn(*args, **kwargs)
class seed(Messenger):
"""
"""
def __init__(self, fn=None, rng_seed=None):
if isinstance(rng_seed, int) or (
isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed)
):
rng_seed = jax.random.PRNGKey(rng_seed)
if not (
isinstance(rng_seed, (np.ndarray, jnp.ndarray))
and rng_seed.dtype == jnp.uint32
and rng_seed.shape == (2,)
):
raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed)))
self.rng_key = rng_seed
super(seed, self).__init__(fn)
def process_message(self, msg):
if (
msg["type"] == "sample" and msg["kwargs"]["rng_key"] is None
):
if msg["value"] is not None:
# no need to create a new key when value is available
return
self.rng_key, rng_key_sample = jax.random.split(self.rng_key)
msg["kwargs"]["rng_key"] = rng_key_sample def sample(
name, fn, rng_key=None, sample_shape=()
):
"""
"""
assert isinstance(
sample_shape, tuple
), "sample_shape needs to be a tuple of integers"
if not isinstance(fn, Normal):
type_error = TypeError(
"only Normal implemeneted"
)
# if no active Messengers, draw a sample or return obs as expected:
if not _PYRO_STACK:
return fn(rng_key=rng_key, sample_shape=sample_shape)
# Otherwise, we initialize a message...
initial_msg = {
"type": "sample",
"name": name,
"fn": fn,
"value": None,
"args": (),
"kwargs": {"rng_key": rng_key, "sample_shape": sample_shape},
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg["value"] def _reshape(x, shape):
if isinstance(x, (int, float, np.ndarray, np.generic)):
return np.reshape(x, shape)
else:
return jnp.reshape(x, shape)
def promote_shapes(*args, shape=()):
# adapted from lax.lax_numpy
if len(args) < 2 and not shape:
return args
else:
shapes = [jnp.shape(arg) for arg in args]
num_dims = len(jax.lax.broadcast_shapes(shape, *shapes))
return [
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
for arg, s in zip(args, shapes)
]
def is_prng_key(key):
try:
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False
class Normal():
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
self.batch_shape = jax.lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
def sample(self, key, sample_shape=()):
assert is_prng_key(key), 'not valid key'
eps = jax.random.normal(
key, shape=sample_shape + self.batch_shape
)
return self.loc + eps * self.scale
def __call__(self, *args, **kwargs):
key = kwargs.pop("rng_key")
return self.sample(key, *args, **kwargs) A user model with internal generation of random normal sample def model(sig=1, data2={'mean':-1}):
n1 = sample('n1',Normal(loc=0,scale=sig))
n2 = sample("n2",Normal(loc=data2['mean']))
return n1,n2,(sig,data2) Than one can use for instance seed(model, rng_seed=1)(sig=10) leading to
or an other example seed(model, rng_seed=jax.random.PRNGKey(10))(data2={'mean':10})
of course the return of the arguments is only to check that they are well updated... Now, the Numpyro use -case if more advanced but it may be overkilled for what we want, although it may serve as a guide. The What do u think? |
Here is a snippet from a demo that you can run from Colab # Define a galsim galaxy as the sum of two objects
obj1 = galsim.Gaussian(half_light_radius=1.)
obj2 = galsim.Exponential(half_light_radius=0.5)
# Rescale the flux of one object
obj2 = obj2.withFlux(0.4)
# Sum the two components of my galaxy
gal = obj1 + obj2
im = gal.drawImage(nx=128, ny=128, scale=0.02, method='no_pixel')
imshow(im.array); colorbar() then def model(sig=1.0):
n1 = galsim.noise.sample('n1',galsim.noise.GaussianNoise(loc=0.,scale=sig), sample_shape=(128,128))
return n1
noise = galsim.helpers.seed(model, rng_seed=1)(sig=5e-5)
noise_img = galsim.Image(noise)
imshow((im+noise_img).array); colorbar() Now of course this is very "handy" (à la main) and I guess many parameters should be done automatic but I guess we can discuss of the schema before going ahead or change for another strategy. |
I merged noise fields in #79. |
This is a good issue for people with existing knowledge of JAX and its random number generation particularities.
We have a section of the design document here to document how to handle RNG in JAX-GalSim but it's currently empty.
We need to discuss a few ideas for the API, see if there is a way to provide the GalSim API and still comply with the JAX random number generation ideas.
If you are interested in taking on this issue, please just mention your interest in this thread, have a look at the CONTRIBUTING.md document, and feel free to discuss any questions you may have in this issue.
A good starting point would be try to demo how the random number generation needed for demo1 would look like with JAX primitives under the hood.
The text was updated successfully, but these errors were encountered: