Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prng_key kwarg to new device API #4596

Merged
merged 24 commits into from
Sep 19, 2023
Merged

Add prng_key kwarg to new device API #4596

merged 24 commits into from
Sep 19, 2023

Conversation

lillian542
Copy link
Contributor

@lillian542 lillian542 commented Sep 14, 2023

Context:
The default.qubit.jax device can take a prng_key, and if it does, it returns the same results from sampling any given state across repeated measurements, rather than randomizing the sampling of the state.

Description of the Change:
Allow the seed on DefaultQubit2 to be a jax.random.PRNGKey; if it is, have the new device replicate the DefaultQubitJax sampling behaviour.

Benefits:
You can get the same behaviour on the new device if you want a set PRNG key when using jax.

Possible Drawbacks:
Previously, the jax device would use jax.choice to generate samples regardless of whether or not a PRNG key was provided; the difference was whether it used a new random PRNG key each time, or the same one over and over.

For the new DefaultQubit, if you provide a PRNG key as the seed, the behaviour will be the same as the default.qubit.jax device with a set prng_key. If you leave the seed as None or an integer, even if the interface is jax, it will use the numpy.random.default_rng to generate the samples, and not jax.choice.

If you want a random PRNGKey and jax.choice, you can fix it by reinitializing the device each time you run, with a randomly generated PRNG key as the seed.

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented Sep 14, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (3b082f7) 99.64% compared to head (b18178f) 99.64%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #4596   +/-   ##
=======================================
  Coverage   99.64%   99.64%           
=======================================
  Files         375      375           
  Lines       33370    33398   +28     
=======================================
+ Hits        33252    33280   +28     
  Misses        118      118           
Files Changed Coverage Δ
pennylane/devices/default_qubit.py 100.00% <100.00%> (ø)
pennylane/devices/qubit/sampling.py 100.00% <100.00%> (ø)
pennylane/devices/qubit/simulate.py 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@lillian542 lillian542 marked this pull request as ready for review September 15, 2023 15:02
@lillian542 lillian542 requested a review from a team September 15, 2023 17:30
@lillian542
Copy link
Contributor Author

[sc-44393]

Copy link
Contributor

@timmysilv timmysilv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a small comment on docs! also, don't forget to add a changelog entry.

I had a bigger request, but we can push it to another PR since it'll make tests more awkward

pennylane/devices/qubit/sampling.py Outdated Show resolved Hide resolved
pennylane/devices/qubit/sampling.py Outdated Show resolved Hide resolved
pennylane/devices/qubit/sampling.py Show resolved Hide resolved
lillian542 and others added 2 commits September 15, 2023 14:17
@timmysilv timmysilv requested a review from a team September 15, 2023 18:28
@lillian542 lillian542 requested review from a team and removed request for a team September 15, 2023 18:33
pennylane/devices/default_qubit.py Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
@timmysilv timmysilv enabled auto-merge (squash) September 19, 2023 18:46
@timmysilv timmysilv merged commit 78a5852 into master Sep 19, 2023
38 checks passed
@timmysilv timmysilv deleted the add_prng_key branch September 19, 2023 20:19
mudit2812 pushed a commit that referenced this pull request Sep 20, 2023
**Context:**
The `default.qubit.jax` device can take a `prng_key`, and if it does, it
returns the same results from sampling any given state across repeated
measurements, rather than randomizing the sampling of the state.

**Description of the Change:**
Allow the `seed` on `DefaultQubit2` to be a `jax.random.PRNGKey`; if it
is, have the new device replicate the `DefaultQubitJax` sampling
behaviour.

**Benefits:**
You can get the same behaviour on the new device if you want a set PRNG
key when using `jax`.

**Possible Drawbacks:**
Previously, the jax device would use `jax.choice` to generate samples
regardless of whether or not a PRNG key was provided; the difference was
whether it used a new random PRNG key each time, or the same one over
and over.

For the new DefaultQubit, if you provide a PRNG key as the seed, the
behaviour will be the same as the `default.qubit.jax` device with a set
`prng_key`. If you leave the seed as None or an integer, even if the
interface is `jax`, it will use the `numpy.random.default_rng` to
generate the samples, and not `jax.choice`.

If you want a random PRNGKey and `jax.choice`, you can fix it by
reinitializing the device each time you run, with a randomly generated
PRNG key as the seed.

---------

Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants