Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autodiff] valueWithPullback operator does not take functions with inout parameters #72982

Open
JaapWijnen opened this issue Apr 11, 2024 · 2 comments

Comments

@JaapWijnen
Copy link
Contributor

Description

Since #66873 was merged the compiler is now able to differentiate through functions with multiple results (such as functions with a differentiable inout parameter that also return a result).
Unfortunately we cannot directly ask for the pullback of these functions however due to missing implementations of valueWithPullback with inout parameters.

A potential function signature would be (for arity1):

@inlinable
public func valueWithPullback<T, R>(
  at x: inout T, of f: @differentiable(reverse) (inout T) -> R
) -> (value: R, pullback: (R.TangentVector, inout T.TangentVector) -> Void) {
  return Builtin.applyDerivative_vjp(f, x) // Currently missing Builtin
}

Currently we can get around this missing feature by making a copy of the parameter of a non inout function:

@differentiable(reverse)
func square(x: inout Double) { // we can't directly call valueWithPullback on this function
    x * x
}

@differentiable(reverse)
func nonInoutSquare(x: Double) -> Double {
    var x = x
    square(x: x)
    return x
}

let result = valueWithPullback(at: 5.0, of: nonInoutSquare)

This kind of defeats the point of course in terms of expressivity and performance since we have to make additional copies here that would be avoided when directly using inout parameters.

Potential issue:
There are currently three valueWithPullback implementations from arity 1 to 3. Due to the underlying Builtins we unfortunately can't simplify these using parameter packs (as far as I can tell). Adding potential functions with inout parameters here will greatly increase the amount of overloads for all the unique combinations of parameters being "normal" or "inout" and functions having differentiable results or not. inout parameters also don't lend themselves to parameter packs at this time unfortunately (afaik).

Do people see any other potential roadblocks for this feature?

Additional information

No response

@JaapWijnen JaapWijnen added task triage needed This issue needs more specific labels labels Apr 11, 2024
@JaapWijnen
Copy link
Contributor Author

@asl @jkshtj

@asl
Copy link
Collaborator

asl commented Apr 12, 2024

Tagging @rxwei @dan-zheng for some historical decisions & rationale.

@hborla hborla added AutoDiff and removed triage needed This issue needs more specific labels labels Apr 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants