Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization interface #12

Closed
josh146 opened this issue Aug 29, 2018 · 6 comments
Closed

Optimization interface #12

josh146 opened this issue Aug 29, 2018 · 6 comments
Assignees
Labels
discussion 💬 Requiring extra discussion interface 🔌 Classical machine-learning interfaces
Milestone

Comments

@josh146
Copy link
Member

josh146 commented Aug 29, 2018

This issue reflects an interface change, away from a single Optimizer class which the user invokes with a keyword argument, towards something more akin to TensorFlow:

openqml.AdamOptimizer(args, kwargs, etc)

The Optimizer class would then wrap these individual methods/classes as a convenience to the user.

From @mariaschuld:

Nathan, Josh and I think that it would be best to define an optimizer as an object or method that defines to compute one step of updating the parameters, just like tensorflow does. We can then build the machine learning functionality around it in the next versions.

Issues:

  1. Is there any overhead created by autograd that we would like to be shared between steps? Can we somehow store it in the optimizer class?
  2. How can optimizers using past gradient information keep track of past steps?

Open questions:
Instead of relying on SciPy, shall we hand code the optimizers to reduce overheads? If yes, how do we deal with computational stability and such things?

@josh146 josh146 added discussion 💬 Requiring extra discussion interface 🔌 Classical machine-learning interfaces labels Aug 29, 2018
@mariaschuld
Copy link
Contributor

mariaschuld commented Aug 29, 2018

My idea is to do something like below, and I successfully tested autograd.grad to check if it can do this. This seems to solve part of the parameter issue by using weights and data as positional parameters and only deriving for weights. Or is there a reason we cannot do this with quantum nodes?

class MyGradbasedOptimizer:

    def _init_(self, cost, initial_weights ...):
        self.grad = autograd.grad(cost, 0) # Computes derivative for first argument, the weights
        self.past_gradient = 0 # Keeps track of gradients at past points, i.e. for momentum optimizer
        self.global_step = 0 # Keeps track of steps performed with this optimizer
        self.weights = initial_weights
        ...
    
    def step(self, current_weights, current_data):
        self.weights =  self.grad(current_weights, current_data)
        ....


def cost(weights, data):
    ''' weights and data can be any kind of (nested) np.array it seems...'''
    # Depends on quantum nodes and classical nodes...
    return scalar_cost
  

o = MyGradbasedOptimizer(cost, initial_weights,  ...)  
o.step(data)     
new_weights = o.weights() ```

@co9olguy
Copy link
Member

Looks like it does all the basic things we want to do, while keeping things as simple as possible 👍 . I've asked @smite to take a crack at writing a couple slimmed-down optimizers based on the examples in autograd (which acts largely like these in terms of what they do)

@co9olguy co9olguy modified the milestones: version 0.1, alpha Aug 31, 2018
@smite
Copy link
Contributor

smite commented Aug 31, 2018

I separated the optimize_SGD method from the Optimizer class into a standalone function in 60c5756.

@josh146
Copy link
Member Author

josh146 commented Sep 8, 2018

@smite I have a couple of questions regarding optimize_SGD.

  1. Is it possible to slim it down even further? For instance, have a very basic function openqml.SGDOptimizer, that simply performs a single optimization step given a cost function and initial weights. The user could then use this to define their own optimizer class, or use our provided higher level openqml.Optimizer.

  2. How come optimize_SGD now requires training data? Shouldn't this be able to be called without training data/classification data having to be provided?

@smite
Copy link
Contributor

smite commented Sep 8, 2018

  1. All nontrivial optimizers need to keep track of internal state variables that typically change at each iteration. In the simple SGD implementation we have now the state consists of the learning rate (which is a function of the current step number and some fixed parameters). More complex optimization algorithms like BFGS keep track of e.g. an approximate Hessian that is updated based on the gradient information over several iterations. The single step function would need to get all this state information as input arguments, and ideally also return every intermediate variable it computes (like the gradient) to the caller so the caller can update the state variables without having to recompute the intermediates. This doesn't seem very helpful, but maybe there's an use case I'm not thinking of?

1b. The current function can also be forced to run just one iteration by using optimize_SGD(max_steps=1).

  1. As far as I understand SGD always takes a data set as input, otherwise it would be just normal gradient descent. The stochastic part is selecting a random subsample of the data set for each iteration, and passing that to the cost function.

@co9olguy
Copy link
Member

co9olguy commented Sep 8, 2018

Yep, the optimizers need to be classes, since they need to keep track of internal state. We want something only slightly more complex than they have in autograd (https://github.com/HIPS/autograd/blob/master/autograd/misc/optimizers.py). Those ones force you to either predict the full number of iterations you want from the start, or risk losing state between multilple calls.

No ML library that I have used has a SGD optimizer. They all have GD optimizers. It is up to the user to decide what to pass this. In future versions, we can build more complex classes, but for the initial release, our goal is to have the basic mathematical functions built-in with minimal further assumptions about what the user might want to do

AlbertMitjans added a commit that referenced this issue Dec 5, 2022
* Revert

* Remove return_type argument

* Small Fix

* Small Fix

* Fix classical shadow

* Small fix

* Small fix

* Use if self.wires

* Small fix

* Fix wire_map

* Fix wire_map

* Fix wire_map

* Revert

* Revert

* Add classical shadow logic to the MP class

* Small fix

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Add test

* Add wire_order to process_samples

* Docstring

* Add wire_order to process_samples

* Small fix

* Small fix

* Add wire_order to process_samples

* Add wire_order to process_samples

* Change tests

* Add wire_order to process_samples

* Add breaking change message

* Change tests

* Change tests

* Add wire_order to process_samples

* Change tests

* Fix tests

* Fix tests

* Add wire_order to process_samples

* Change tests

* Coverage

* Coverage

* Fix tests

* Add tests

* Revert

* Add tests

* Add tests

* Remove duplicated test

* Add tests

* Small changes

* Fix tests

* black

* Change name

* Add CustomMeasurement

* Small fix

* Small fix

* Small fix

* test (ClassicalShadow): 🧪 Test ClassicalShadow.process method

* Add changelog entry

* test (ClassicalShadow): 🧪 Fix mocker.spy

* revert (ClassicalShadow): ⏪ Remove process_state.

* Revert

* feat: ✨ Add _ShadowExpval class.

* Fix tests

* Fix test

* Fix

* Fix docs

* Fix docs

* Fix types

* Fix types

* Fix docstrings

* fix test

* Add changelog entry

* Remove docs

* Remove docs

* Fix docstrings

* Small change

* Remove return_type

* Fix

* Fix

* Fix

* fix

* fix

* fix

* Fix tests

* Remove mocker

* Add tests

* Fix tests

* Coverage

* Update doc/releases/changelog-dev.md

* Revert merge

* Revert tests

* Revert

* Copy hamiltonian

* Use seed

* Change docstrings

* Change instance check.

* Fix

* Fix

* Remove dead code

* Fix tests

* Add deprecation warning

* Fix docstring

* Add changelog entry

* Fix docstring

* Add type

* Remove type

* Deprecate observables

* Add changelog entry

* Change warning type

* Change warning type

* Revert

* Add changelog entry

* Refactor MP

* Refactor MP

* Coverage

* Add support for custom measurements

* Add changelog entry

* Add measurement's method_name attribute

* Fix

* Fix

* Small fix

* Make return_type None by default

* Add tests

* Add tests

* Coverage

* Add support for new return types

* Update tests/measurements/test_counts.py

* Remove unused imports

* Add import

* Move import

* Update tests/measurements/test_measurements.py

* Move ABC inheritannce

* Update tests/measurements/test_vn_entropy.py

Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com>

* Call _shape_new

* refactor (State): Raise error later in shape method.

* refactor (State): Raise error later in shape method.

* test: 🧪 Change skip message

* move ABC

* Move import

* Add TODO comment.

* refactor (QubitDevice): ♻️ Move code into private function.

* Improve deprecation cycle

* Change 'tape' to 'qscript'. Add tests for overriden measurement process

* Add docstring

* Change 'CustomMeasurement' to 'MeasurementTransform'

* Improve tests

* Fix

* Coverage

* Improve docstring

* Improve docstring

* Add changelog entry

* Update pennylane/measurements/measurements.py

Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>

* Fix docstring

* Fix docstring

* Fix docstring

* Fix naming

* Move code inside _measure()

Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com>
Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
antalszava added a commit that referenced this issue Dec 6, 2022
… gradient transforms and device gradients backward mode (#3235)

* Structure

* Struct jax interface

* First draft

* Single measurement is working

* First derivative

* tests

* Add tests

* x64 Jax testing

* Cleanup

* more cleanup

* More tests

* More tests

* QNode testing

* More tests pass

* Typos in tests

* Test JVP structure.

* More tests

* More tests

* More tests

* Typoo

* Coverage

* test

* Jax import test

* Typo

* Trigger CI

* Update param shift

* Docstrings

* very first try on a simple func

* is_abstract use

* reenable JAX JIT tests

* wire in jax jit

* intermed changes

* drafting

* draft

* try jax.experimental.host_callback.call again (not working though)

* Revert "try jax.experimental.host_callback.call again (not working though)"

This reverts commit 1c8f85d.

* getting some tests to pass with shortcuts (need more work to polish)

* more

* reset Hermitian file

* update device method expected output

* clean

* post-processing draft

* Device backward mode works

* skip the FWD mode test; fix the shape for device diff_method bwd

* no prints

* move dedicated JIT tests into separate file; allow JIT tests by parametrization

* Remove fwd test skippings

* Revert "Remove fwd test skippings"

This reverts commit 5bc7220.

* Move new jitting interface into its own file

* Move new jitting interface into its own file

* update imported func name

* multi-param single scalar out works

* comment

* jacobian shape extracted

* getting the shape right for test_gradient and for first couple of test_jax_new tests

* Refining the shape definitions further; test_gradient still okay

* tests/returntypes/test_jax_new.py passes

* Add in JAX JIT QNode integration tests (no hessians or fwd mode just now)

* Add in TODOs for fwd mode; qml.counts is not implemented for JAX-JIT (TODO to consider because no gradient and callback requires shape and dtype, but qml.counts returns a dict)

* Skip more Hessian tests; skip a fwd mode test case

* formatting

* formatting and linting

* changelog

* more cleaning

* one jac function suffices; more renaming

* parametrize over jax jacobian functions

* revert change in jvp.py; add docstring; revert the jitting kwarg; keep only minimal change in execution.py

* no need to update pennylane/interfaces/jax_jit.py (old return types file)

* remove unused code

* linting

* more testing and validation

* switch to an example with multiple measurements

* docstring

* no squeezing required for post-processing

* comment on qml.counts

* move around funcs

* docstring

* better name for fn

* copy existing unit test file to a new test_jax_jit_new.py file

* JIT-specific tests

* Trim Python specific tests

* module docstrings

* jit the whole fnc

* auxiliary function for a single shape

* linting

* linting improvement suggested

* no need to skip fwd mode test cases

* matrix parameter

* lint

* linting

* re-add unused-variable because of CI

* changelog

* port over more tests

* no jax-jit test cases in test_jax_qnode_new.py

* format

* changelog

* no TODO

* trigger CI

* trigger CI

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
antalszava added a commit that referenced this issue Dec 7, 2022
… device gradients forward mode (#3445)

* Structure

* Struct jax interface

* First draft

* Single measurement is working

* First derivative

* tests

* Add tests

* x64 Jax testing

* Cleanup

* more cleanup

* More tests

* More tests

* QNode testing

* More tests pass

* Typos in tests

* Test JVP structure.

* More tests

* More tests

* More tests

* Typoo

* Coverage

* test

* Jax import test

* Typo

* Trigger CI

* Update param shift

* Docstrings

* very first try on a simple func

* is_abstract use

* reenable JAX JIT tests

* wire in jax jit

* intermed changes

* drafting

* draft

* try jax.experimental.host_callback.call again (not working though)

* Revert "try jax.experimental.host_callback.call again (not working though)"

This reverts commit 1c8f85d.

* getting some tests to pass with shortcuts (need more work to polish)

* more

* reset Hermitian file

* update device method expected output

* clean

* post-processing draft

* Device backward mode works

* skip the FWD mode test; fix the shape for device diff_method bwd

* no prints

* move dedicated JIT tests into separate file; allow JIT tests by parametrization

* Remove fwd test skippings

* Revert "Remove fwd test skippings"

This reverts commit 5bc7220.

* Move new jitting interface into its own file

* Move new jitting interface into its own file

* update imported func name

* multi-param single scalar out works

* comment

* jacobian shape extracted

* getting the shape right for test_gradient and for first couple of test_jax_new tests

* Refining the shape definitions further; test_gradient still okay

* tests/returntypes/test_jax_new.py passes

* Add in JAX JIT QNode integration tests (no hessians or fwd mode just now)

* Add in TODOs for fwd mode; qml.counts is not implemented for JAX-JIT (TODO to consider because no gradient and callback requires shape and dtype, but qml.counts returns a dict)

* Skip more Hessian tests; skip a fwd mode test case

* formatting

* formatting and linting

* changelog

* more cleaning

* one jac function suffices; more renaming

* parametrize over jax jacobian functions

* revert change in jvp.py; add docstring; revert the jitting kwarg; keep only minimal change in execution.py

* no need to update pennylane/interfaces/jax_jit.py (old return types file)

* remove unused code

* linting

* more testing and validation

* switch to an example with multiple measurements

* Add fwd mode implementation, get basic test_jax_new.py file to pass after allowing the forward mode test cases

* allow more test cases

* docstring

* no squeezing required for post-processing

* comment on qml.counts

* move around funcs

* docstring

* better name for fn

* copy existing unit test file to a new test_jax_jit_new.py file

* JIT-specific tests

* Trim Python specific tests

* module docstrings

* jit the whole fnc

* auxiliary function for a single shape

* linting

* linting improvement suggested

* no need to skip fwd mode test cases

* matrix parameter

* lint

* linting

* re-add unused-variable because of CI

* changelog

* port over more tests

* no jax-jit test cases in test_jax_qnode_new.py

* format

* changelog

* changelog; unlock tests

* have no jax-jit usage in test_jax_qnode_new.py

* no TODO

* changelog

* Update pennylane/interfaces/jax_jit_tuple.py

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* switch to using any with generator for perf

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
antalszava added a commit that referenced this issue Dec 14, 2022
* Structure

* Struct jax interface

* First draft

* Single measurement is working

* First derivative

* tests

* Add tests

* x64 Jax testing

* Cleanup

* more cleanup

* More tests

* More tests

* QNode testing

* More tests pass

* Typos in tests

* Test JVP structure.

* More tests

* More tests

* More tests

* Typoo

* Coverage

* test

* Jax import test

* Typo

* Trigger CI

* Update param shift

* Docstrings

* very first try on a simple func

* is_abstract use

* reenable JAX JIT tests

* wire in jax jit

* intermed changes

* drafting

* draft

* try jax.experimental.host_callback.call again (not working though)

* Revert "try jax.experimental.host_callback.call again (not working though)"

This reverts commit 1c8f85d.

* getting some tests to pass with shortcuts (need more work to polish)

* more

* reset Hermitian file

* update device method expected output

* clean

* post-processing draft

* Device backward mode works

* skip the FWD mode test; fix the shape for device diff_method bwd

* no prints

* move dedicated JIT tests into separate file; allow JIT tests by parametrization

* Remove fwd test skippings

* Revert "Remove fwd test skippings"

This reverts commit 5bc7220.

* Move new jitting interface into its own file

* Move new jitting interface into its own file

* update imported func name

* multi-param single scalar out works

* comment

* jacobian shape extracted

* getting the shape right for test_gradient and for first couple of test_jax_new tests

* Refining the shape definitions further; test_gradient still okay

* tests/returntypes/test_jax_new.py passes

* Add in JAX JIT QNode integration tests (no hessians or fwd mode just now)

* Add in TODOs for fwd mode; qml.counts is not implemented for JAX-JIT (TODO to consider because no gradient and callback requires shape and dtype, but qml.counts returns a dict)

* Skip more Hessian tests; skip a fwd mode test case

* formatting

* formatting and linting

* changelog

* more cleaning

* one jac function suffices; more renaming

* parametrize over jax jacobian functions

* revert change in jvp.py; add docstring; revert the jitting kwarg; keep only minimal change in execution.py

* no need to update pennylane/interfaces/jax_jit.py (old return types file)

* remove unused code

* linting

* more testing and validation

* switch to an example with multiple measurements

* Add fwd mode implementation, get basic test_jax_new.py file to pass after allowing the forward mode test cases

* allow more test cases

* docstring

* no squeezing required for post-processing

* comment on qml.counts

* move around funcs

* docstring

* better name for fn

* copy existing unit test file to a new test_jax_jit_new.py file

* JIT-specific tests

* Trim Python specific tests

* module docstrings

* jit the whole fnc

* auxiliary function for a single shape

* linting

* linting improvement suggested

* no need to skip fwd mode test cases

* matrix parameter

* lint

* linting

* re-add unused-variable because of CI

* changelog

* port over more tests

* no jax-jit test cases in test_jax_qnode_new.py

* format

* changelog

* changelog; unlock tests

* have no jax-jit usage in test_jax_qnode_new.py

* first try

* more logic and hessian tests

* draft

* remove old tests/returntypes/test_jax_jit_new.py

* draft

* fix imports from the remnants of the merge; update hessian test case to match the JAX outputs; remove outdated test file

* more

* working on more tests

* update more test cases

* remove max_diff_error test case

* extend condition to select the jacobians from the list

* TODO done

* no need to get first element of all_jacs

* Apply suggestions from code review

Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com>

* Add shots test

Co-authored-by: Romain Moyard <rmoyard@gmail.com>
Co-authored-by: Edward Jiang <34989448+eddddddy@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion 💬 Requiring extra discussion interface 🔌 Classical machine-learning interfaces
Projects
No open projects
API and user interface
  
Awaiting triage
Development

No branches or pull requests

4 participants