Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BaseNoise and GaussianNoise #19

Closed
Tracked by #3
EiffL opened this issue Jun 14, 2022 · 4 comments
Closed
Tracked by #3

Implement BaseNoise and GaussianNoise #19

EiffL opened this issue Jun 14, 2022 · 4 comments
Assignees
Labels
help wanted Extra attention is needed JAX An issue that involves a pure JAX question

Comments

@EiffL
Copy link
Member

EiffL commented Jun 14, 2022

This is a good issue for people with existing knowledge of JAX and its random number generation particularities.

We have a section of the design document here to document how to handle RNG in JAX-GalSim but it's currently empty.

We need to discuss a few ideas for the API, see if there is a way to provide the GalSim API and still comply with the JAX random number generation ideas.

If you are interested in taking on this issue, please just mention your interest in this thread, have a look at the CONTRIBUTING.md document, and feel free to discuss any questions you may have in this issue.

A good starting point would be try to demo how the random number generation needed for demo1 would look like with JAX primitives under the hood.

@EiffL EiffL mentioned this issue Jun 14, 2022
16 tasks
@EiffL EiffL changed the title BaseNoise Implement BaseNoise and GaussianNoise Jun 14, 2022
@EiffL EiffL added help wanted Extra attention is needed JAX An issue that involves a pure JAX question labels Jun 14, 2022
@jecampagne
Copy link
Collaborator

jecampagne commented Jun 16, 2022

Well, I have played with JAX random generators for personal usage so I would be pleased to enter in the game, but may be I should be introduced to what exactly is needed, especially if one wants to perform some multi-GPUs simulations where the spread on RNG should be done with care, I guess. So feel free to contact me for discussion.

For instance, do we have to rebuild a sort of "Random Distributions" set of classes or we rely on the Numpyro Distributions lib.. This at least would avoid error-trials process to validate this part of Jax-Galsim, no? Hum after ready some part of the code, I understand that this JAX-GalSim implementation should be as close as possible to the original GalSim way of doing things.

Q: "Galsim random.py" implements if I understand the random distribution classes, and "Noise" classes rely on that to generate image noise. So, I guess that the first thing to do would be the make a JAX version of "Galsim random.py", no?
Q: from which branch one should start? "image" for instance?

@jecampagne
Copy link
Collaborator

jecampagne commented Nov 27, 2022

Hello,
Here is some functions/class extracted from Numpyro to glue the "random seeds":

_PYRO_STACK = []

def default_process_message(msg):
       msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])

def apply_stack(msg):
    """
    Execute the effect stack at a single site according to the following scheme:
        1. For each ``Messenger`` in the stack from bottom to top,
           execute ``Messenger.process_message`` with the message;
           if the message field "stop" is True, stop;
           otherwise, continue
        2. Apply default behavior (``default_process_message``) to finish remaining
           site execution
        3. For each ``Messenger`` in the stack from top to bottom,
           execute ``Messenger.postprocess_message`` to update the message
           and internal messenger state with the site results
    """
    pointer = 0
    for pointer, handler in enumerate(reversed(_PYRO_STACK)):
        handler.process_message(msg)
        # When a Messenger sets the "stop" field of a message,
        # it prevents any Messengers above it on the stack from being applied.
        if msg.get("stop"):
            break

    default_process_message(msg)

    # A Messenger that sets msg["stop"] == True also prevents application
    # of postprocess_message by Messengers above it on the stack
    # via the pointer variable from the process_message loop
    for handler in _PYRO_STACK[-pointer - 1 :]:
        handler.postprocess_message(msg)
    return msg

class Messenger(object):
    def __init__(self, fn=None):
        if fn is not None and not callable(fn):
            raise ValueError(
                "Expected `fn` to be a Python callable object; "
                "instead found type(fn) = {}.".format(type(fn))
            )
        self.fn = fn

    def __enter__(self):
        _PYRO_STACK.append(self)

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            assert _PYRO_STACK[-1] is self
            _PYRO_STACK.pop()
        else:
            if self in _PYRO_STACK:
                loc = _PYRO_STACK.index(self)
                for i in range(loc, len(_PYRO_STACK)):
                    _PYRO_STACK.pop()

    def process_message(self, msg):
        pass

    def postprocess_message(self, msg):
        pass

    def __call__(self, *args, **kwargs):
        if self.fn is None:
            # Assume self is being used as a decorator.
            assert len(args) == 1 and not kwargs
            self.fn = args[0]
            return self
        with self:
            return self.fn(*args, **kwargs)

class seed(Messenger):
    """
    """

    def __init__(self, fn=None, rng_seed=None):
        if isinstance(rng_seed, int) or (
            isinstance(rng_seed, (np.ndarray, jnp.ndarray)) and not jnp.shape(rng_seed)
        ):
            rng_seed = jax.random.PRNGKey(rng_seed)
        if not (
            isinstance(rng_seed, (np.ndarray, jnp.ndarray))
            and rng_seed.dtype == jnp.uint32
            and rng_seed.shape == (2,)
        ):
            raise TypeError("Incorrect type for rng_seed: {}".format(type(rng_seed)))
        self.rng_key = rng_seed
        super(seed, self).__init__(fn)

    def process_message(self, msg):
        if (
            msg["type"] == "sample" and msg["kwargs"]["rng_key"] is None
        ):
            if msg["value"] is not None:
                # no need to create a new key when value is available
                return
            self.rng_key, rng_key_sample = jax.random.split(self.rng_key)
            msg["kwargs"]["rng_key"] = rng_key_sample
def sample(
    name, fn,  rng_key=None, sample_shape=()
):
    """
    """
    assert isinstance(
        sample_shape, tuple
    ), "sample_shape needs to be a tuple of integers"
    if not isinstance(fn, Normal):
        type_error = TypeError(
            "only Normal implemeneted"
        )

    # if no active Messengers, draw a sample or return obs as expected:
    if not _PYRO_STACK:
        return fn(rng_key=rng_key, sample_shape=sample_shape)


    # Otherwise, we initialize a message...
    initial_msg = {
        "type": "sample",
        "name": name,
        "fn": fn,
        "value": None,
        "args": (),
        "kwargs": {"rng_key": rng_key, "sample_shape": sample_shape},
    }

    # ...and use apply_stack to send it to the Messengers
    msg = apply_stack(initial_msg)
    return msg["value"]
def _reshape(x, shape):
    if isinstance(x, (int, float, np.ndarray, np.generic)):
        return np.reshape(x, shape)
    else:
        return jnp.reshape(x, shape)

def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [jnp.shape(arg) for arg in args]
        num_dims = len(jax.lax.broadcast_shapes(shape, *shapes))
        return [
            _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
            for arg, s in zip(args, shapes)
        ]

def is_prng_key(key):
    try:
        return key.shape == (2,) and key.dtype == np.uint32
    except AttributeError:
        return False

class Normal():

    def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
        self.loc, self.scale = promote_shapes(loc, scale)
        self.batch_shape = jax.lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key), 'not valid key'
        eps = jax.random.normal(
            key, shape=sample_shape + self.batch_shape
        )
        return self.loc + eps * self.scale

    def __call__(self, *args, **kwargs):
        key = kwargs.pop("rng_key")
        return self.sample(key, *args, **kwargs)

A user model with internal generation of random normal sample

def model(sig=1, data2={'mean':-1}):
  n1 = sample('n1',Normal(loc=0,scale=sig))
  n2 = sample("n2",Normal(loc=data2['mean']))
  return n1,n2,(sig,data2)

Than one can use for instance

seed(model, rng_seed=1)(sig=10)

leading to

(DeviceArray(-11.470195, dtype=float32),
 DeviceArray(-2.092164, dtype=float32),
 (10, {'mean': -1}))

or an other example

seed(model, rng_seed=jax.random.PRNGKey(10))(data2={'mean':10})
(DeviceArray(0.7978776, dtype=float32),
 DeviceArray(10.231143, dtype=float32),
 (1, {'mean': 10}))

of course the return of the arguments is only to check that they are well updated...

Now, the Numpyro use -case if more advanced but it may be overkilled for what we want, although it may serve as a guide.

The Normal class above is an archetype of what can be used, we can certainly modify the JaxGalSim: Gaussian class to include the samplemethod as well.

What do u think?

@jecampagne
Copy link
Collaborator

jecampagne commented Nov 29, 2022

Here is a snippet from a demo that you can run from Colab

# Define a galsim galaxy as the sum of two objects
obj1 = galsim.Gaussian(half_light_radius=1.)
obj2 = galsim.Exponential(half_light_radius=0.5)

# Rescale the flux of one object
obj2 = obj2.withFlux(0.4)

# Sum the two components of my galaxy
gal = obj1 + obj2

im = gal.drawImage(nx=128, ny=128, scale=0.02, method='no_pixel')
imshow(im.array); colorbar()

image

then

def model(sig=1.0):
  n1 = galsim.noise.sample('n1',galsim.noise.GaussianNoise(loc=0.,scale=sig), sample_shape=(128,128))
  return n1

noise = galsim.helpers.seed(model, rng_seed=1)(sig=5e-5)
noise_img = galsim.Image(noise)
imshow((im+noise_img).array); colorbar()

image

Now of course this is very "handy" (à la main) and I guess many parameters should be done automatic but I guess we can discuss of the schema before going ahead or change for another strategy.

@ismael-mendoza ismael-mendoza self-assigned this Oct 23, 2023
@beckermr
Copy link
Collaborator

I merged noise fields in #79.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed JAX An issue that involves a pure JAX question
Projects
None yet
Development

No branches or pull requests

4 participants