-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prng_key kwarg to new device API #4596
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #4596 +/- ##
=======================================
Coverage 99.64% 99.64%
=======================================
Files 375 375
Lines 33370 33398 +28
=======================================
+ Hits 33252 33280 +28
Misses 118 118
☔ View full report in Codecov by Sentry. |
[sc-44393] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a small comment on docs! also, don't forget to add a changelog entry.
I had a bigger request, but we can push it to another PR since it'll make tests more awkward
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
**Context:** The `default.qubit.jax` device can take a `prng_key`, and if it does, it returns the same results from sampling any given state across repeated measurements, rather than randomizing the sampling of the state. **Description of the Change:** Allow the `seed` on `DefaultQubit2` to be a `jax.random.PRNGKey`; if it is, have the new device replicate the `DefaultQubitJax` sampling behaviour. **Benefits:** You can get the same behaviour on the new device if you want a set PRNG key when using `jax`. **Possible Drawbacks:** Previously, the jax device would use `jax.choice` to generate samples regardless of whether or not a PRNG key was provided; the difference was whether it used a new random PRNG key each time, or the same one over and over. For the new DefaultQubit, if you provide a PRNG key as the seed, the behaviour will be the same as the `default.qubit.jax` device with a set `prng_key`. If you leave the seed as None or an integer, even if the interface is `jax`, it will use the `numpy.random.default_rng` to generate the samples, and not `jax.choice`. If you want a random PRNGKey and `jax.choice`, you can fix it by reinitializing the device each time you run, with a randomly generated PRNG key as the seed. --------- Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Context:
The
default.qubit.jax
device can take aprng_key
, and if it does, it returns the same results from sampling any given state across repeated measurements, rather than randomizing the sampling of the state.Description of the Change:
Allow the
seed
onDefaultQubit2
to be ajax.random.PRNGKey
; if it is, have the new device replicate theDefaultQubitJax
sampling behaviour.Benefits:
You can get the same behaviour on the new device if you want a set PRNG key when using
jax
.Possible Drawbacks:
Previously, the jax device would use
jax.choice
to generate samples regardless of whether or not a PRNG key was provided; the difference was whether it used a new random PRNG key each time, or the same one over and over.For the new DefaultQubit, if you provide a PRNG key as the seed, the behaviour will be the same as the
default.qubit.jax
device with a setprng_key
. If you leave the seed as None or an integer, even if the interface isjax
, it will use thenumpy.random.default_rng
to generate the samples, and notjax.choice
.If you want a random PRNGKey and
jax.choice
, you can fix it by reinitializing the device each time you run, with a randomly generated PRNG key as the seed.