New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization interface #12
Comments
My idea is to do something like below, and I successfully tested autograd.grad to check if it can do this. This seems to solve part of the parameter issue by using weights and data as positional parameters and only deriving for weights. Or is there a reason we cannot do this with quantum nodes? class MyGradbasedOptimizer:
def _init_(self, cost, initial_weights ...):
self.grad = autograd.grad(cost, 0) # Computes derivative for first argument, the weights
self.past_gradient = 0 # Keeps track of gradients at past points, i.e. for momentum optimizer
self.global_step = 0 # Keeps track of steps performed with this optimizer
self.weights = initial_weights
...
def step(self, current_weights, current_data):
self.weights = self.grad(current_weights, current_data)
....
def cost(weights, data):
''' weights and data can be any kind of (nested) np.array it seems...'''
# Depends on quantum nodes and classical nodes...
return scalar_cost
o = MyGradbasedOptimizer(cost, initial_weights, ...)
o.step(data)
new_weights = o.weights() ```
|
Looks like it does all the basic things we want to do, while keeping things as simple as possible 👍 . I've asked @smite to take a crack at writing a couple slimmed-down optimizers based on the examples in autograd (which acts largely like these in terms of what they do) |
I separated the optimize_SGD method from the Optimizer class into a standalone function in 60c5756. |
@smite I have a couple of questions regarding
|
1b. The current function can also be forced to run just one iteration by using optimize_SGD(max_steps=1).
|
Yep, the optimizers need to be classes, since they need to keep track of internal state. We want something only slightly more complex than they have in autograd (https://github.com/HIPS/autograd/blob/master/autograd/misc/optimizers.py). Those ones force you to either predict the full number of iterations you want from the start, or risk losing state between multilple calls. No ML library that I have used has a SGD optimizer. They all have GD optimizers. It is up to the user to decide what to pass this. In future versions, we can build more complex classes, but for the initial release, our goal is to have the basic mathematical functions built-in with minimal further assumptions about what the user might want to do |
* Revert * Remove return_type argument * Small Fix * Small Fix * Fix classical shadow * Small fix * Small fix * Use if self.wires * Small fix * Fix wire_map * Fix wire_map * Fix wire_map * Revert * Revert * Add classical shadow logic to the MP class * Small fix * Fix tests * Fix tests * Fix tests * Fix tests * Fix tests * Fix tests * Fix tests * Add test * Add wire_order to process_samples * Docstring * Add wire_order to process_samples * Small fix * Small fix * Add wire_order to process_samples * Add wire_order to process_samples * Change tests * Add wire_order to process_samples * Add breaking change message * Change tests * Change tests * Add wire_order to process_samples * Change tests * Fix tests * Fix tests * Add wire_order to process_samples * Change tests * Coverage * Coverage * Fix tests * Add tests * Revert * Add tests * Add tests * Remove duplicated test * Add tests * Small changes * Fix tests * black * Change name * Add CustomMeasurement * Small fix * Small fix * Small fix * test (ClassicalShadow): 🧪 Test ClassicalShadow.process method * Add changelog entry * test (ClassicalShadow): 🧪 Fix mocker.spy * revert (ClassicalShadow): ⏪ Remove process_state. * Revert * feat: ✨ Add _ShadowExpval class. * Fix tests * Fix test * Fix * Fix docs * Fix docs * Fix types * Fix types * Fix docstrings * fix test * Add changelog entry * Remove docs * Remove docs * Fix docstrings * Small change * Remove return_type * Fix * Fix * Fix * fix * fix * fix * Fix tests * Remove mocker * Add tests * Fix tests * Coverage * Update doc/releases/changelog-dev.md * Revert merge * Revert tests * Revert * Copy hamiltonian * Use seed * Change docstrings * Change instance check. * Fix * Fix * Remove dead code * Fix tests * Add deprecation warning * Fix docstring * Add changelog entry * Fix docstring * Add type * Remove type * Deprecate observables * Add changelog entry * Change warning type * Change warning type * Revert * Add changelog entry * Refactor MP * Refactor MP * Coverage * Add support for custom measurements * Add changelog entry * Add measurement's method_name attribute * Fix * Fix * Small fix * Make return_type None by default * Add tests * Add tests * Coverage * Add support for new return types * Update tests/measurements/test_counts.py * Remove unused imports * Add import * Move import * Update tests/measurements/test_measurements.py * Move ABC inheritannce * Update tests/measurements/test_vn_entropy.py Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com> * Call _shape_new * refactor (State): Raise error later in shape method. * refactor (State): Raise error later in shape method. * test: 🧪 Change skip message * move ABC * Move import * Add TODO comment. * refactor (QubitDevice): ♻️ Move code into private function. * Improve deprecation cycle * Change 'tape' to 'qscript'. Add tests for overriden measurement process * Add docstring * Change 'CustomMeasurement' to 'MeasurementTransform' * Improve tests * Fix * Coverage * Improve docstring * Improve docstring * Add changelog entry * Update pennylane/measurements/measurements.py Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca> * Fix docstring * Fix docstring * Fix docstring * Fix naming * Move code inside _measure() Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com> Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
… gradient transforms and device gradients backward mode (#3235) * Structure * Struct jax interface * First draft * Single measurement is working * First derivative * tests * Add tests * x64 Jax testing * Cleanup * more cleanup * More tests * More tests * QNode testing * More tests pass * Typos in tests * Test JVP structure. * More tests * More tests * More tests * Typoo * Coverage * test * Jax import test * Typo * Trigger CI * Update param shift * Docstrings * very first try on a simple func * is_abstract use * reenable JAX JIT tests * wire in jax jit * intermed changes * drafting * draft * try jax.experimental.host_callback.call again (not working though) * Revert "try jax.experimental.host_callback.call again (not working though)" This reverts commit 1c8f85d. * getting some tests to pass with shortcuts (need more work to polish) * more * reset Hermitian file * update device method expected output * clean * post-processing draft * Device backward mode works * skip the FWD mode test; fix the shape for device diff_method bwd * no prints * move dedicated JIT tests into separate file; allow JIT tests by parametrization * Remove fwd test skippings * Revert "Remove fwd test skippings" This reverts commit 5bc7220. * Move new jitting interface into its own file * Move new jitting interface into its own file * update imported func name * multi-param single scalar out works * comment * jacobian shape extracted * getting the shape right for test_gradient and for first couple of test_jax_new tests * Refining the shape definitions further; test_gradient still okay * tests/returntypes/test_jax_new.py passes * Add in JAX JIT QNode integration tests (no hessians or fwd mode just now) * Add in TODOs for fwd mode; qml.counts is not implemented for JAX-JIT (TODO to consider because no gradient and callback requires shape and dtype, but qml.counts returns a dict) * Skip more Hessian tests; skip a fwd mode test case * formatting * formatting and linting * changelog * more cleaning * one jac function suffices; more renaming * parametrize over jax jacobian functions * revert change in jvp.py; add docstring; revert the jitting kwarg; keep only minimal change in execution.py * no need to update pennylane/interfaces/jax_jit.py (old return types file) * remove unused code * linting * more testing and validation * switch to an example with multiple measurements * docstring * no squeezing required for post-processing * comment on qml.counts * move around funcs * docstring * better name for fn * copy existing unit test file to a new test_jax_jit_new.py file * JIT-specific tests * Trim Python specific tests * module docstrings * jit the whole fnc * auxiliary function for a single shape * linting * linting improvement suggested * no need to skip fwd mode test cases * matrix parameter * lint * linting * re-add unused-variable because of CI * changelog * port over more tests * no jax-jit test cases in test_jax_qnode_new.py * format * changelog * no TODO * trigger CI * trigger CI Co-authored-by: Romain Moyard <rmoyard@gmail.com>
… device gradients forward mode (#3445) * Structure * Struct jax interface * First draft * Single measurement is working * First derivative * tests * Add tests * x64 Jax testing * Cleanup * more cleanup * More tests * More tests * QNode testing * More tests pass * Typos in tests * Test JVP structure. * More tests * More tests * More tests * Typoo * Coverage * test * Jax import test * Typo * Trigger CI * Update param shift * Docstrings * very first try on a simple func * is_abstract use * reenable JAX JIT tests * wire in jax jit * intermed changes * drafting * draft * try jax.experimental.host_callback.call again (not working though) * Revert "try jax.experimental.host_callback.call again (not working though)" This reverts commit 1c8f85d. * getting some tests to pass with shortcuts (need more work to polish) * more * reset Hermitian file * update device method expected output * clean * post-processing draft * Device backward mode works * skip the FWD mode test; fix the shape for device diff_method bwd * no prints * move dedicated JIT tests into separate file; allow JIT tests by parametrization * Remove fwd test skippings * Revert "Remove fwd test skippings" This reverts commit 5bc7220. * Move new jitting interface into its own file * Move new jitting interface into its own file * update imported func name * multi-param single scalar out works * comment * jacobian shape extracted * getting the shape right for test_gradient and for first couple of test_jax_new tests * Refining the shape definitions further; test_gradient still okay * tests/returntypes/test_jax_new.py passes * Add in JAX JIT QNode integration tests (no hessians or fwd mode just now) * Add in TODOs for fwd mode; qml.counts is not implemented for JAX-JIT (TODO to consider because no gradient and callback requires shape and dtype, but qml.counts returns a dict) * Skip more Hessian tests; skip a fwd mode test case * formatting * formatting and linting * changelog * more cleaning * one jac function suffices; more renaming * parametrize over jax jacobian functions * revert change in jvp.py; add docstring; revert the jitting kwarg; keep only minimal change in execution.py * no need to update pennylane/interfaces/jax_jit.py (old return types file) * remove unused code * linting * more testing and validation * switch to an example with multiple measurements * Add fwd mode implementation, get basic test_jax_new.py file to pass after allowing the forward mode test cases * allow more test cases * docstring * no squeezing required for post-processing * comment on qml.counts * move around funcs * docstring * better name for fn * copy existing unit test file to a new test_jax_jit_new.py file * JIT-specific tests * Trim Python specific tests * module docstrings * jit the whole fnc * auxiliary function for a single shape * linting * linting improvement suggested * no need to skip fwd mode test cases * matrix parameter * lint * linting * re-add unused-variable because of CI * changelog * port over more tests * no jax-jit test cases in test_jax_qnode_new.py * format * changelog * changelog; unlock tests * have no jax-jit usage in test_jax_qnode_new.py * no TODO * changelog * Update pennylane/interfaces/jax_jit_tuple.py Co-authored-by: Romain Moyard <rmoyard@gmail.com> * switch to using any with generator for perf Co-authored-by: Romain Moyard <rmoyard@gmail.com>
* Structure * Struct jax interface * First draft * Single measurement is working * First derivative * tests * Add tests * x64 Jax testing * Cleanup * more cleanup * More tests * More tests * QNode testing * More tests pass * Typos in tests * Test JVP structure. * More tests * More tests * More tests * Typoo * Coverage * test * Jax import test * Typo * Trigger CI * Update param shift * Docstrings * very first try on a simple func * is_abstract use * reenable JAX JIT tests * wire in jax jit * intermed changes * drafting * draft * try jax.experimental.host_callback.call again (not working though) * Revert "try jax.experimental.host_callback.call again (not working though)" This reverts commit 1c8f85d. * getting some tests to pass with shortcuts (need more work to polish) * more * reset Hermitian file * update device method expected output * clean * post-processing draft * Device backward mode works * skip the FWD mode test; fix the shape for device diff_method bwd * no prints * move dedicated JIT tests into separate file; allow JIT tests by parametrization * Remove fwd test skippings * Revert "Remove fwd test skippings" This reverts commit 5bc7220. * Move new jitting interface into its own file * Move new jitting interface into its own file * update imported func name * multi-param single scalar out works * comment * jacobian shape extracted * getting the shape right for test_gradient and for first couple of test_jax_new tests * Refining the shape definitions further; test_gradient still okay * tests/returntypes/test_jax_new.py passes * Add in JAX JIT QNode integration tests (no hessians or fwd mode just now) * Add in TODOs for fwd mode; qml.counts is not implemented for JAX-JIT (TODO to consider because no gradient and callback requires shape and dtype, but qml.counts returns a dict) * Skip more Hessian tests; skip a fwd mode test case * formatting * formatting and linting * changelog * more cleaning * one jac function suffices; more renaming * parametrize over jax jacobian functions * revert change in jvp.py; add docstring; revert the jitting kwarg; keep only minimal change in execution.py * no need to update pennylane/interfaces/jax_jit.py (old return types file) * remove unused code * linting * more testing and validation * switch to an example with multiple measurements * Add fwd mode implementation, get basic test_jax_new.py file to pass after allowing the forward mode test cases * allow more test cases * docstring * no squeezing required for post-processing * comment on qml.counts * move around funcs * docstring * better name for fn * copy existing unit test file to a new test_jax_jit_new.py file * JIT-specific tests * Trim Python specific tests * module docstrings * jit the whole fnc * auxiliary function for a single shape * linting * linting improvement suggested * no need to skip fwd mode test cases * matrix parameter * lint * linting * re-add unused-variable because of CI * changelog * port over more tests * no jax-jit test cases in test_jax_qnode_new.py * format * changelog * changelog; unlock tests * have no jax-jit usage in test_jax_qnode_new.py * first try * more logic and hessian tests * draft * remove old tests/returntypes/test_jax_jit_new.py * draft * fix imports from the remnants of the merge; update hessian test case to match the JAX outputs; remove outdated test file * more * working on more tests * update more test cases * remove max_diff_error test case * extend condition to select the jacobians from the list * TODO done * no need to get first element of all_jacs * Apply suggestions from code review Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com> * Add shots test Co-authored-by: Romain Moyard <rmoyard@gmail.com> Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com>
This issue reflects an interface change, away from a single
Optimizer
class which the user invokes with a keyword argument, towards something more akin to TensorFlow:The
Optimizer
class would then wrap these individual methods/classes as a convenience to the user.From @mariaschuld:
The text was updated successfully, but these errors were encountered: