Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add support for differentiating scalars with the backprop op #225

Merged
merged 19 commits into from
Aug 1, 2023

Conversation

pengmai
Copy link
Contributor

@pengmai pengmai commented Jul 31, 2023

Description of the Change: Modify the verifier and lowering of the backprop op to support differentiation of scalars. This also adds some documentation that addresses comments on #158.

Benefits: Greater flexibility and generality of the backprop op.

[sc-42844]

@pengmai pengmai requested review from dime10 and rmoyard July 31, 2023 20:27
@pengmai pengmai marked this pull request as ready for review July 31, 2023 20:27
Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job 💯

Does this PR re-add all tests that were removed in the base? Does it add any new tests? (just so I can check them)

mlir/lib/Gradient/Utils/GradientShape.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Utils/GradientShape.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/ConversionPatterns.cpp Outdated Show resolved Hide resolved
mlir/lib/Gradient/Transforms/ConversionPatterns.cpp Outdated Show resolved Hide resolved
}
hybridGradients.push_back(result);
}
}
else {
// Co-tangent is a rank 1 tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the cotagent still be rank 1 when the result is rank 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, specifically in the context of our current hybrid gradient architecture, the quantum co-tangent always has shape [<param_count>, ...<num_results>]. This translates to a tensor<?xf64> when the result is rank 0, which is the case of a single measurement (regardless of if the measurement is a tensor<f64> or f64).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the quantum co-tangent always has shape [<param_count>, ...<num_results>]

I thought the co-tangents are a tuple of size <num_results> where each element has shape (<param_count>,)?

But you are right this means the co-tangent is always rank 1 then! But that's because the result of the function we differentiate is always of rank 1 (the argmap function is what we differentiate).

I think it's this comment further down that confuses me:

If the rank of the cotangent is 1, this implies the primal function returns a rank-0 value (either a scalar or a tensor).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the co-tangents are a tuple of size <num_results> where each element has shape (<param_count>,)?

Right, sorry, this is correct. I meant that each member of the co-tangents tuple has shape [<param_count>, ...<cotangents[i].shape>]. The name num_results was poorly chosen 😅

I think it's this comment further down that confuses me

So a primal function that returns (f64, tensor<2xf64>, tensor<f64>) will have cotangents (tensor<?xf64>, tensor<?x2xf64>, tensor<?xf64>). If the cotangent at index i will always be of rank 1 greater than its corresponding primal result. I can edit the comment to try to make that more clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a primal function that returns (f64, tensor<2xf64>, tensor) will have cotangents (tensor<?xf64>, tensor<?x2xf64>, tensor<?xf64>). If the cotangent at index i will always be of rank 1 greater than its corresponding primal result.

This sounds strange to me. By primal do you mean the whole QNode or the argmap function?
Because no matter what return types the QNode has, the argmap function always returns tensor<?xf64>. Consequently, the cotangent vector is also always of type tensor<?xf64>. Why would the rank be one higher than the corresponding result? Maybe I'm missing something 😅

If you are talking about a generic function (independent of the QNode/argmap scenario), then a function that returns (f64, tensor<2xf64>, tensor<f64>) should have cotangents like (_, tensor<2xf64>, tensor<f64>) (I think the first one is omitted because Enzyme doesn't accept a cotangent for a scalar result right?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm specifically talking about the QNode/argmap situation, where "primal result" means one or more measurements. In order to compute the overall hybrid Jacobian, we need to run a backpropagation computation for each row of the Jacobian and there are as many rows as scalar element outputs of the hybrid circuit.

For example, if we have a hybrid function that returns qml.expval(...), qml.probs() with 1 wire, the QNode return would be (tensor<f64>, tensor<2xf64>). There is thus one row of the hybrid Jacobian matrix (meaning a cotangent/quantum gradient vector of tensor<?xf64>) from the expval, and two rows of the hybrid Jacobian: one from each element of the probs.

As each row of the Jacobian matrix has a different quantum gradient vector of size tensor<?xf64>, the expval results in a quantum gradient of tensor<?xf64> while the probs results in a (transposed in the current implementation) quantum Jacobian of tensor<?x2xf64>. For both, the rank of the quantum gradient/Jacobian is 1 greater than the corresponding measurement.

I realize the terminology is confusing (which is a common problem in AD literature 😢). It might help to exclusively specify "quantum gradient" instead of "cotangent" because it's a cotangent with respect to the argmap, but not the circuit. This also only applies to the current hybrid gradient architecture.

(For completeness, the cotangents with respect to the circuit in this case would be (1, [0, 0]), (0, [1, 0]), (0, [0, 1]). It looks like the qgrad function uses these implicitly).

Copy link
Contributor

@dime10 dime10 Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the terminology is confusing (which is a common problem in AD literature 😢). It might help to exclusively specify "quantum gradient" instead of "cotangent" because it's a cotangent with respect to the argmap, but not the circuit. This also only applies to the current hybrid gradient architecture.

Ah got it, yes I think calling the quantum Jacobian the cotangent is the confusing part here, imo they are two separate entities. We do use individual rows of the quantum Jacobian (rank measurement + 1) as contangents, but the contangent itself only exists with respect to the argmap function (always fixed rank 1)😅

Thanks for the clarification :)

@pengmai
Copy link
Contributor Author

pengmai commented Jul 31, 2023

Thanks!

Does this PR re-add all tests that were removed in the base? Does it add any new tests? (just so I can check them)

I'm pretty sure it does but I can double check.

pengmai and others added 3 commits July 31, 2023 17:38
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, that looks great 💯 Just a couple of questions

// Conceptually a map from scalar result indices (w.r.t. other scalars) to the position in
// the overall list of returned gradients.
// For instance, a backprop op that returns (tensor, f64, tensor, f64, f64) will have
// scalarIndices = {1, 3, 4}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

rewriter.create<BackpropOp>(loc, TypeRange{}, op.getCalleeAttr(), adaptor.getArgs(),
argShadows, calleeResults, resShadows, diffArgIndicesAttr);
auto bufferizedBackpropOp = rewriter.create<BackpropOp>(
loc, scalarReturnTypes, op.getCalleeAttr(), adaptor.getArgs(), argShadows,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure why we init the backprop op with results here. Is it because we only do the bufferization for the tensors and we still need the float as results?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.


// We need to generate a new __enzyme_autodiff name per different function signature. One
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that covers the case where there is a single function and switch between scalar and tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This covers the case where we have one function to be differentiated with type (f64, memref<f64>) -> f64 and another function in the same module with type (f64, f64, memref<f64>, f64) -> f64. This is because the first will generate an __enzyme_autodiff that returns f64, while the second will generate an __enzyme_autodiff that returns struct { f64, f64, f64 }.

// The results of backprop are in data in
rewriter.create<LLVM::CallOp>(loc, backpropFnDecl, callArgs);
rewriter.eraseOp(op);
// The results of backprop are in data in, except scalar derivatives which are in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: data in is not in backprop anymore, I think it is diffargshadows

// 1, this implies the primal function returns a rank-0 value (either a
// scalar or a tensor<scalar>). The Jacobian of a scalar -> scalar should be
// a scalar, but as a special case, the Jacobian of a scalar ->
// tensor<scalar> should be tensor<scalar>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -14,6 +14,90 @@

// RUN: quantum-opt %s --lower-gradients=only=adj --split-input-file --verify-diagnostics | FileCheck %s

// Check scalar to scalar function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the tests!

@pengmai pengmai merged commit 650c849 into enzyme_integration Aug 1, 2023
16 checks passed
@pengmai pengmai deleted the jmp/enzyme-scalars branch August 1, 2023 16:26
rmoyard added a commit that referenced this pull request Aug 2, 2023
* Draft integration

* Shape

* optional

* Add Enzyme

* Draft

* more

* Checkout 0.0.69

* Update

* Right versions

* Compile works

* Wheels and tests

* Typo

* Black

* Change hash

* Add dep version

* Typo

* Remove cache llvm

* Update lit

* Remove bin

* Remove

* Add ls

* CHnage key

* ENzyme path

* test

* Change

* Add tests

* Black

* Typo

* Compiling

* Revive param count

* Revive param count

* Update

* Remove comment

* Compute backprop res type

* Readd tests

* Update buff

* Change tests

* Changes

* Working insert slice

* Add diff arg indices

* Add copyright

* Update

* Fix tests buff and hybrid

* Test param shift

* Fix classical and adjoint

* Fix parameter shift tests

* Fix tests

* Pylint disable

* Add globals

* Typos

* Update tests openqasm

* Comment from review

* More

* Correct enzyme path

* Add free

* Typos

* Explicitly call preserve-nvvm with Enzyme

* Add loose types

* Remove tests and revive pcount in argmap

* Update to 0.0.74

* Update Enzyme

* Update format

* Fix wheels

* Attempt to fix failing upload to codecov by bumping Enzyme to working build

* [MLIR] Integrated BackpropOp lowering with Enzyme (#193)

* Rework Enzyme calling convention to unpack pointers to fix incorrect zero gradients

* Exclude Enzyme from autoformatting

* Update MemSetOp builder API

* Switch metadata approach from enzyme_function_like to enzyme_allocation_like to avoid double free error with generated derivative code

* Remove Enzyme wrapper function, switch argmap to destination-passing style

* Remove parameter count from argmap parameters

* Fix bug with Enzyme calling convention around inactive memrefs

* Update backprop lowering test

* Add index dialect to dependent dialects of gradient_to_llvm

* Update pinned version of Enzyme to one that builds with LLVM 17, only use 'auto' when clear from context

* Get generic alloc/free names from MLIR utilities, only register them when use-generic-functions option is set

* Update changelog

* Remove enzyme-loose-types flag

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* Add tests for PreEnzymeOpt pipeline, add explicit enzyme_const annotation for floats

* Remove index -> llvm and memref -> llvm patterns from gradient -> llvm, run those passes separately afterwards

* Autoformatting

---------

Co-authored-by: Romain Moyard <rmoyard@gmail.com>

* Enzyme integration + Deallocation (#201)

* Re-enable buffer deallocation passes.

* Register _mlir_memref_to_llvm_free as "like" free.

Using the allocation registration mechanism for Enzyme, fails to
automatically register the free-like function as free. Likely due to a
possibility on different arity.

* Explicitly name post enzyme file as ".postenzyme.ll"

* Comments.

* Readd tests

* Fix adjoint

* Fix parameter shift

* Update classical jac

* format

* Smaller fixes to the backprop operation & lowering (#224)

* Fix bug with grad on classical qjit function

If the grad target is a qjit object, the current implementation assumes
that the wrapped function is already a "JAX function", that is a
`QNode` or `Function` object. This is not the case however when a bare
Python function is wrapped in qjit.

* Refactor `compDiffArgIndices` and add...

... `computeDiffArgs` as a convenience method to filter the callee
argument list.

* Tighten backprop type declarations & add verifier

For example, the op didn't check whether the declared number of gradient
results matches the number of differentiable arguments, a bug in one of
the bufferization tests.

The `quantum_jacobian` operand name was switched out to `cotangents`
since it more generally describes the purpose of that argument, and in
the way it is used now only receives a slice of the quantum jacobian
anyways.

* Allocate primal out buffers during bufferization

Reorganizes backprop op arguments into:
- primal `args`
- `diffArgShadows` or gradient result buffers
- `calleeResults` or primal output buffers
- `cotangents` or callee result shadows

* Update mlir/lib/Gradient/Transforms/BufferizationPatterns.cpp

Co-authored-by: Jacob Mai Peng <jacobmpeng@gmail.com>

---------

Co-authored-by: Jacob Mai Peng <jacobmpeng@gmail.com>

* Remove too many free

* Format

* Right indentation

* Change function like

* Update from review

* [MLIR] Add support for differentiating scalars with the backprop op (#225)

* Add docstring explaining computeMemRefSizeInBytes

* Update changelog

* Support scalars in BackpropOp

* WIP: Adding back removed tests

* Add float type returns to backprop op

* Bug fixes with scalar differentiation support

* Add scalar support to verifier, revert changes to test

* Add back more tests

* Try to clarify role of scalarIndices

* Add verifier for shaped differentiable types, fix bug from merge commit

* Update comment

Co-authored-by: David Ittah <dime10@users.noreply.github.com>

* Add back more deleted tests, comment explaining scalar unpacking of the enzyme_autodiff call

* Clarify terminology around quantum gradients/cotangents

* Update shadow vs data in terminology in comments

---------

Co-authored-by: David Ittah <dime10@users.noreply.github.com>

* Update doc/changelog.md

Co-authored-by: David Ittah <dime10@users.noreply.github.com>

* More tests

* Add tests

* Changelog.

---------

Co-authored-by: Jacob Mai Peng <jacobmpeng@gmail.com>
Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: Erick <erick.ochoalopez@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants