Skip to content

[AutoDiff] Enable same-file retroactive derivative registration.#22871

Merged
rxwei merged 1 commit intoswiftlang:tensorflowfrom
rxwei:retrodiff
Feb 25, 2019
Merged

[AutoDiff] Enable same-file retroactive derivative registration.#22871
rxwei merged 1 commit intoswiftlang:tensorflowfrom
rxwei:retrodiff

Conversation

@rxwei
Copy link
Copy Markdown
Contributor

@rxwei rxwei commented Feb 25, 2019

Overview

Retroactive derivative registration, a.k.a. the @differentiating attribute, is the ultimate direction we'd like to go for. There are two critical goals:

  1. Eliminating the JVP and VJP concepts from the user land. Today's situation is that the jvp: and vjp: attribute arguments are still needed for registering a derivative via a @differentiable attribute on the original function definition. The user is not supposed to know what a JVP/VJP is in order to define a derivative. When @differentiating is added, we would like to eliminate vjp: and vjp: attribute arguments from the @differentiable attribute.
  2. Make differentiation as extensible as the rest of Swift (retroactive conformances, etc). Users should be able to make a function from an imported module be differentiable, and export the derivative to other files and modules that import the module containing the retroactive derivative.

Implementation-wise, cross-file/cross-module retroactive derivative registration requires a top-level differentiability witness construct in SIL that behaves like SIL protocol witnesses. For more implementation information, see this forum discussion. This PR is not doing that. However, same-file retroactive derivative registration is fairly easy to do -- all it takes is to make each @differentiating imply a @differentiable attribute on the original function. In other words, the @differentiable attribute (or more precisely the [differentiable] SIL function attribute) is acting like a differentiability witness in a SIL module already, so we can just make use of that.

Changes

This PR adds support for same-file retroactive derivative registration. So long as the original function exists in the same file as the derivative with a @differentiating attribute, derivative registration works. The detailed implementation steps are as follows:

  • When type-checking a @differentiating attribute, after resolving the original function declaration, check to see if the original function is in the same file as the retroactive derivative. If not, emit a diagnostic.
  • Find an existing @differentiable attribute on the original function that is w.r.t. all parameters. If it exists, treat that attribute as the witness. If not, create a new @differentiable attribute that covers all parameters and treat that as the witness.
  • If the @differentiable attribute already has an associated function (JVP/VJP) that the retroactive derivative represents, emit a diagnostic.
  • Set the corresponding associated function in the @differentiable attribute to be the retroactive derivative.

Now, basic things work as demonstrated in the following example. A @differentiating attribute makes it work as if the original function had a @differentiable attribute.

func functionWithRetroDeriv(x: Float) -> Float {
  return x
}
@differentiating(functionWithRetroDeriv)
func retroDeriv(x: Float) -> (value: Float, pullback: (Float) -> Float) {
  return (value: x, pullback: { _ in 100 })
}
gradient(at: 3, in: functionWithRetroDeriv) // => 100

Known unhandled

  • When the derivative function has a different canonical generic signature as the original function and when the original function does not have an all-parameter @differentiable attribute that specifies the derivative function's generic constraints w.r.t. the original function, the behavior is undefined. This will be fixed in a follow-up PR.

  • We do not support wrt: in a @differentiating attribute yet, so we can't use this on functions with non-differentiable parameters.

  • The @differentiating attribute is not registered in gyb_syntax_support yet, so syntax verification would fail if we use @differentiating in the standard library. But we can live with this for now.

Resolves TF-278.

@rxwei rxwei requested review from dan-zheng and marcrasi February 25, 2019 01:11
@rxwei rxwei added the tensorflow This is for "tensorflow" branch PRs. label Feb 25, 2019
@rxwei
Copy link
Copy Markdown
Contributor Author

rxwei commented Feb 25, 2019

@swift-ci please test tensorflow

Retroactive derivative registration, a.k.a. the `@differentiating` attribute, is the ultimate direction we'd like to go for. There are two critical goals:
1. Eliminating the JVP and VJP concepts from the user land. Today's situation is that the `jvp:` and `vjp:` attribute arguments are still needed for registering a derivative via a `@differentiable` attribute on the original function definition. The user is not supposed to know what a JVP/VJP is in order to define a derivative. When `@differentiating` is added, we would like to eliminate `vjp:` and `vjp:` attribute arguments from the `@differentiable` attribute.
2. Make differentiation as extensible as the rest of Swift (retroactive conformances, etc). Users should be able to make a function from an imported module be differentiable, and export the derivative to other files and modules that import the module containing the retroactive derivative.

Implementation-wise, cross-file/cross-module retroactive derivative registration requires a top-level differentiability witness construct in SIL that behaves like SIL protocol witnesses. For more implementation information, see [this forum discussion](https://forums.swift.org/t/help-needed-with-retroactive-differentiability/19927/6?u=rxwei). This PR is **not** doing that. However, same-file retroactive derivative registration is fairly easy to do -- all it takes is to make each `@differentiating` imply a `@differentiable` attribute on the original function. In other words, the `@differentiable` attribute (or more precisely the `[differentiable]` SIL function attribute) is acting like a differentiability witness in a SIL module already, so we can just make use of that.

This PR adds support for same-file retroactive derivative registration. So long as the original function exists in the same file as the derivative with a `@differentiating` attribute, derivative registration works. The detailed implementation steps are as follows:
* When type-checking a `@differentiating` attribute, after resolving the original function declaration, check to see if the original function is in the same file as the retroactive derivative. If not, emit a diagnostic.
* Find an existing `@differentiable` attribute on the original function that is w.r.t. all parameters. If it exists, treat that attribute as the witness. If not, create a new `@differentiable` attribute that covers all parameters and treat that as the witness.
* If the `@differentiable` attribute already has an associated function (JVP/VJP) that the retroactive derivative represents, emit a diagnostic.
* Set the corresponding associated function in the `@differentiable` attribute to be the retroactive derivative.

Now, basic things work as demonstrated in the following example. A `@differentiating` attribute makes it work as if the original function had a `@differentiable` attribute.

```swift
func functionWithRetroDeriv(x: Float) -> Float {
  return x
}
@differentiating(functionWithRetroDeriv)
func retroDeriv(x: Float) -> (value: Float, pullback: (Float) -> Float) {
  return (value: x, pullback: { _ in 100 })
}
gradient(at: 3, in: functionWithRetroDeriv) // => 100
```

* When the derivative function has a different canonical generic signature as the original function and when the original function does not have an all-parameter `@differentiable` attribute that specifies the derivative function's generic constraints w.r.t. the original function, the behavior is undefined. This will be fixed in a follow-up PR.

* We do not support `wrt:` in a `@differentiating` attribute yet, so we can't use this on functions with non-differentiable parameters.

* The `@differentiating` attribute is not registered in `gyb_syntax_support` yet, so syntax verification would fail if we use `@differentiating` in the standard library. But we can live with this for now.

Resolves [TF-278](https://bugs.swift.org/browse/TF-278).
@rxwei
Copy link
Copy Markdown
Contributor Author

rxwei commented Feb 25, 2019

@swift-ci please test tensorflow


// Reject different-file retroactive derivatives.
// TODO(TF-136): Full support for cross-file/cross-module retroactive
// differentiability will require SIL differnetiability witnesses and lots of
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// differentiability will require SIL differnetiability witnesses and lots of
// differentiability will require SIL differentiability witnesses and lots of

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix this later.

@rxwei rxwei merged commit 318ab08 into swiftlang:tensorflow Feb 25, 2019
@rxwei rxwei deleted the retrodiff branch February 25, 2019 19:54
rxwei added a commit to rxwei/swift that referenced this pull request May 11, 2019
…ftlang#22871)

Retroactive derivative registration, a.k.a. the `@differentiating` attribute, is the ultimate direction we'd like to go for. There are two critical goals:
1. Eliminating the JVP and VJP concepts from the user land. Today's situation is that the `jvp:` and `vjp:` attribute arguments are still needed for registering a derivative via a `@differentiable` attribute on the original function definition. The user is not supposed to know what a JVP/VJP is in order to define a derivative. When `@differentiating` is added, we would like to eliminate `vjp:` and `vjp:` attribute arguments from the `@differentiable` attribute.
2. Make differentiation as extensible as the rest of Swift (retroactive conformances, etc). Users should be able to make a function from an imported module be differentiable, and export the derivative to other files and modules that import the module containing the retroactive derivative.

Implementation-wise, cross-file/cross-module retroactive derivative registration requires a top-level differentiability witness construct in SIL that behaves like SIL protocol witnesses. For more implementation information, see [this forum discussion](https://forums.swift.org/t/help-needed-with-retroactive-differentiability/19927/6?u=rxwei). This PR is **not** doing that. However, same-file retroactive derivative registration is fairly easy to do -- all it takes is to make each `@differentiating` imply a `@differentiable` attribute on the original function. In other words, the `@differentiable` attribute (or more precisely the `[differentiable]` SIL function attribute) is acting like a differentiability witness in a SIL module already, so we can just make use of that.

This PR adds support for same-file retroactive derivative registration. So long as the original function exists in the same file as the derivative with a `@differentiating` attribute, derivative registration works. The detailed implementation steps are as follows:
* When type-checking a `@differentiating` attribute, after resolving the original function declaration, check to see if the original function is in the same file as the retroactive derivative. If not, emit a diagnostic.
* Find an existing `@differentiable` attribute on the original function that is w.r.t. all parameters. If it exists, treat that attribute as the witness. If not, create a new `@differentiable` attribute that covers all parameters and treat that as the witness.
* If the `@differentiable` attribute already has an associated function (JVP/VJP) that the retroactive derivative represents, emit a diagnostic.
* Set the corresponding associated function in the `@differentiable` attribute to be the retroactive derivative.

Now, basic things work as demonstrated in the following example. A `@differentiating` attribute makes it work as if the original function had a `@differentiable` attribute.

```swift
func functionWithRetroDeriv(x: Float) -> Float {
  return x
}
@differentiating(functionWithRetroDeriv)
func retroDeriv(x: Float) -> (value: Float, pullback: (Float) -> Float) {
  return (value: x, pullback: { _ in 100 })
}
gradient(at: 3, in: functionWithRetroDeriv) // => 100
```

* When the derivative function has a different canonical generic signature as the original function and when the original function does not have an all-parameter `@differentiable` attribute that specifies the derivative function's generic constraints w.r.t. the original function, the behavior is undefined. This will be fixed in a follow-up PR.

* We do not support `wrt:` in a `@differentiating` attribute yet, so we can't use this on functions with non-differentiable parameters.

* The `@differentiating` attribute is not registered in `gyb_syntax_support` yet, so syntax verification would fail if we use `@differentiating` in the standard library. But we can live with this for now.

Resolves [TF-278](https://bugs.swift.org/browse/TF-278).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

tensorflow This is for "tensorflow" branch PRs.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants