Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultinomialRV JAX implementation #1360

Closed
wants to merge 2 commits into from

Conversation

GStechschulte
Copy link

This draft PR is a work in progress and contains a JAX implementation of MultinomialRV for issue #1326. The implementation builds off the Multinomial Distribution implementation in NumPyro. Likewise, the output is similar to that of the numpy implementation. Below, you will find a brief outline of the functions used to construct the MultinomialRV.

def _categorical(key, p, shape)

  • returns the outcomes $k$ with probability $p$ for each trial / experiment $n$.

def _scatter_add_ones(operand, indices, update)

  • returns the outcome counts by utilising the jax.lax.scatter_add() function
  • operand is a zero filled array.
  • indices is the outcomes array with an added dimension and specifies the indices to which the update should be applied to.
  • update is an array filled with ones and can be thought of as a cnt += 1 for each $K = k$ occurrence.
  • In summary, the operand array is updated +1 using the update array according to the outcomes in the indices array.

I still need to add a test for this. Thanks!

@brandonwillard brandonwillard added JAX Involves JAX transpilation random variables Involves random variables and/or sampling labels Dec 11, 2022
@rlouf
Copy link
Member

rlouf commented Dec 12, 2022

Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0

@rlouf rlouf changed the title Add JAX implementation of MultinomialRV Add JAX implementation MultinomialRV Dec 12, 2022
@rlouf rlouf changed the title Add JAX implementation MultinomialRV Add MultinomialRV JAX implementation Dec 12, 2022
@GStechschulte
Copy link
Author

GStechschulte commented Dec 12, 2022

Thanks! We also need to figure out if the licenses are compatible and how to do proper attribution if you took inspiration from someone else's implementation. It looks like Numpyro is licensed under Apache 2.0

Based on the NumPyro Apache License 2.0 section 4, we may reproduce and distribute copies of the Work or Derivative Works (the JAX implementation of MultinomialRV) provided we:

  1. give any other recipients of the Work or Derivative Works a copy of this License; and
  2. the modified file must contain carry a notice stating the file was changed
  3. in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work
  4. if the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation

Since inspiration was drawn from a few functions and not an entire file, I suggest we include, in addition to (1), (2), (3), in the documentation for this RV, something along the lines

MultinomialRV uses source code from the file xyz.py from of the NumPyro project, copyright YYYY, licensed under the Apache 2.0 license>

@GStechschulte GStechschulte marked this pull request as ready for review December 12, 2022 21:00
Comment on lines 298 to 470
def _categorical(key, p, shape):
shape = shape or p.shape[:-1]
s = jax.numpy.cumsum(p, axis=-1)
r = jax.random.uniform(key, shape=shape + (1,))

return jax.numpy.sum(s < r, axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised because JAX does have an implementation for the categorical distribution here that uses their implementation for the Gumbel distribution. Is this justified in the codebase, or is it just because it was implemented at a time where jax.random.categorical was not available (which you should be able to determine with git blame)?

Comment on lines 314 to 320
samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))(
jax.numpy.zeros((outcomes_2d.shape[0], p.shape[-1]), dtype=outcomes.dtype),
jax.numpy.expand_dims(outcomes_2d, axis=-1),
jax.numpy.ones(outcomes_2d.shape, dtype=outcomes.dtype)
)

sample = jax.numpy.reshape(samples_2d, size + p.shape[-1:])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we use jax.nn.one_hot on the output of the categorical and then reduce the resulting tensor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, your proposal seems to be a more elegant solution; thanks. Just committed.

@rlouf
Copy link
Member

rlouf commented Dec 14, 2022

I rebased your branch on main to use the new key splitting scheme in the JAX backend. You'll have to pull the changes!

Copy link
Author

@GStechschulte GStechschulte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I performed the git pull --merge, it kept a copy of the previous code (without the new key splitting). This was likely a mistake on my end. Therefore, I deleted it.

Comment on lines 314 to 320
samples_2d = jax.vmap(_scatter_add_one, (0, 0, 0))(
jax.numpy.zeros((outcomes_2d.shape[0], p.shape[-1]), dtype=outcomes.dtype),
jax.numpy.expand_dims(outcomes_2d, axis=-1),
jax.numpy.ones(outcomes_2d.shape, dtype=outcomes.dtype)
)

sample = jax.numpy.reshape(samples_2d, size + p.shape[-1:])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, your proposal seems to be a more elegant solution; thanks. Just committed.

@rlouf
Copy link
Member

rlouf commented Dec 14, 2022

Yes you need to git pull --rebase in such cases. Here's a good explanation of how rebasing works. And of course the documentation for git pull

Comment on lines 335 to 470
def _categorical(key, p, shape):
shape = shape or p.shape[:-1]
s = jax.numpy.cumsum(p, axis=-1)
r = jax.random.uniform(key, shape=shape + (1,))

return jax.numpy.sum(s < r, axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that different from jax.random.categorical?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the jax.random.categorical uses their implementation of the Gumbel distribution. However, after using git blame, it seems the NumPyro team used this implementation before the implementation of the jax.random.categorical.

If we would like to update this code to use the jax.random.categorical, it would be the following:

def sample_fn(rng, size, dtype, *parameters):

        rng_key = rng["jax_state"]
        rng_key, sampling_key = jax.random.split(rng_key, 2)

        n, p = parameters
        n_max = jax.numpy.max(n)
        size = size or p.shape[:-1]
    
        logits = jax.scipy.special.logit(p)
        indices = jax.random.categorical(jax_key, logits, shape=(n_max,) + size)
        one_hot = jax.nn.one_hot(indices, p.shape[0])
        sample = jax.numpy.sum(one_hot, axis=0, dtype=dtype, keepdims=False)

        rng["jax_state"] = rng_key

        return (rng, sample)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks much simpler, great!

Since the multinomial distribution is slightly more complex than the other ones when it comes to shapes we should make sure that the output shape of the samples that are generated in the JAX backend is identical to those of the samples generated with the other backends.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to use jax.random.choice for this kind of thing. jax.random.categorical (via Gumbel) has a quadratic complexity.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You just need to check the Gumbel implementation, they form a N by N array.

.gitignore Outdated
@@ -55,3 +55,4 @@ aesara-venv/
testing-report.html
coverage.xml
.coverage.*
jax_multinomial_test.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a job for a "local" Git ignore (see here).

@rlouf
Copy link
Member

rlouf commented Jan 16, 2023

@GStechschulte I think we should follow @AdrienCorenflos's suggestion here. I'll take another look at it this week.

@rlouf
Copy link
Member

rlouf commented Feb 21, 2023

@GStechschulte I rebased your branch on main. Do you plan on implementing @AdrienCorenflos's suggestion above?

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've squared, rebased, and added a few fixes. The JAX steps are not properly accounting for the shapes of the distribution parameters, so that needs to be finished. I refactored the tests so that they cover one of the most basic cases and another that should confirm that the sizes/shapes are handled correctly (when they are).

@rlouf, I had to add a case for Constants in assert_size_argument_jax_compatible. You'll need to confirm that this is valid more generally.

FYI: the tests aren't being run with shape inference, since the jax_mode used by the tests doesn't include "ShapeOpt" (or whatever its tag is). Standard JAX mode should, since it's included with the "fast_run" tag, but, if we want to test for non-trivial shape scenarios (e.g. the shape value isn't explicitly constant, but can be "inferred" as a constant value), we'll need to add it.

@rlouf
Copy link
Member

rlouf commented Mar 10, 2023

@rlouf, I had to add a case for Constants in assert_size_argument_jax_compatible. You'll need to confirm that this is valid more generally.

That's valid. I was so focused on the complex case that I forgot the simplest one.

@GStechschulte GStechschulte closed this by deleting the head repository Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Involves JAX transpilation random variables Involves random variables and/or sampling
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add JAX implementation for MultinomialRV
4 participants