-
Notifications
You must be signed in to change notification settings - Fork 571
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dynamic_one_shot
uses tapes with shot-vectors and jitting takes adv…
…antage of it (#5617) ### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [ ] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [x] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** `dynamic_one_shot` creates n-shots tapes which is wasteful. **Description of the Change:** Create a single tape with a shot-vector which indicates to the device how many times to repeat the tape execution. **Benefits:** For a tape like ``` dev = qml.device("default.qubit", shots=2000, seed=jax.random.PRNGKey(123)) @qml.qnode(dev, diff_method=None) def func(x, y): qml.RX(x, wires=0) m0 = qml.measure(0, reset=False, postselect=1) qml.cond(m0, qml.RY)(y, wires=1) return qml.expval(qml.PauliZ(0)) params = np.pi / 4 * np.ones(2) ``` The execution times are as follows (Latitude laptop): - 12.7 s : vanilla Python - 8.2 s : jax.vmap - 11.1 s : jax.jit + jax.vmap + compilation - 6.89 ms : jax.jit + jax.vmap **Possible Drawbacks:** **Related GitHub Issues:** [sc-62097] --------- Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai> Co-authored-by: Christina Lee <christina@xanadu.ai>
- Loading branch information
1 parent
1d34de9
commit 9c9b6ba
Showing
6 changed files
with
68 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters