[AutoDiff] Enable same-file retroactive derivative registration.#22871
Merged
rxwei merged 1 commit intoswiftlang:tensorflowfrom Feb 25, 2019
Merged
[AutoDiff] Enable same-file retroactive derivative registration.#22871rxwei merged 1 commit intoswiftlang:tensorflowfrom
rxwei merged 1 commit intoswiftlang:tensorflowfrom
Conversation
Contributor
Author
|
@swift-ci please test tensorflow |
Retroactive derivative registration, a.k.a. the `@differentiating` attribute, is the ultimate direction we'd like to go for. There are two critical goals: 1. Eliminating the JVP and VJP concepts from the user land. Today's situation is that the `jvp:` and `vjp:` attribute arguments are still needed for registering a derivative via a `@differentiable` attribute on the original function definition. The user is not supposed to know what a JVP/VJP is in order to define a derivative. When `@differentiating` is added, we would like to eliminate `vjp:` and `vjp:` attribute arguments from the `@differentiable` attribute. 2. Make differentiation as extensible as the rest of Swift (retroactive conformances, etc). Users should be able to make a function from an imported module be differentiable, and export the derivative to other files and modules that import the module containing the retroactive derivative. Implementation-wise, cross-file/cross-module retroactive derivative registration requires a top-level differentiability witness construct in SIL that behaves like SIL protocol witnesses. For more implementation information, see [this forum discussion](https://forums.swift.org/t/help-needed-with-retroactive-differentiability/19927/6?u=rxwei). This PR is **not** doing that. However, same-file retroactive derivative registration is fairly easy to do -- all it takes is to make each `@differentiating` imply a `@differentiable` attribute on the original function. In other words, the `@differentiable` attribute (or more precisely the `[differentiable]` SIL function attribute) is acting like a differentiability witness in a SIL module already, so we can just make use of that. This PR adds support for same-file retroactive derivative registration. So long as the original function exists in the same file as the derivative with a `@differentiating` attribute, derivative registration works. The detailed implementation steps are as follows: * When type-checking a `@differentiating` attribute, after resolving the original function declaration, check to see if the original function is in the same file as the retroactive derivative. If not, emit a diagnostic. * Find an existing `@differentiable` attribute on the original function that is w.r.t. all parameters. If it exists, treat that attribute as the witness. If not, create a new `@differentiable` attribute that covers all parameters and treat that as the witness. * If the `@differentiable` attribute already has an associated function (JVP/VJP) that the retroactive derivative represents, emit a diagnostic. * Set the corresponding associated function in the `@differentiable` attribute to be the retroactive derivative. Now, basic things work as demonstrated in the following example. A `@differentiating` attribute makes it work as if the original function had a `@differentiable` attribute. ```swift func functionWithRetroDeriv(x: Float) -> Float { return x } @differentiating(functionWithRetroDeriv) func retroDeriv(x: Float) -> (value: Float, pullback: (Float) -> Float) { return (value: x, pullback: { _ in 100 }) } gradient(at: 3, in: functionWithRetroDeriv) // => 100 ``` * When the derivative function has a different canonical generic signature as the original function and when the original function does not have an all-parameter `@differentiable` attribute that specifies the derivative function's generic constraints w.r.t. the original function, the behavior is undefined. This will be fixed in a follow-up PR. * We do not support `wrt:` in a `@differentiating` attribute yet, so we can't use this on functions with non-differentiable parameters. * The `@differentiating` attribute is not registered in `gyb_syntax_support` yet, so syntax verification would fail if we use `@differentiating` in the standard library. But we can live with this for now. Resolves [TF-278](https://bugs.swift.org/browse/TF-278).
Contributor
Author
|
@swift-ci please test tensorflow |
dan-zheng
approved these changes
Feb 25, 2019
|
|
||
| // Reject different-file retroactive derivatives. | ||
| // TODO(TF-136): Full support for cross-file/cross-module retroactive | ||
| // differentiability will require SIL differnetiability witnesses and lots of |
Contributor
There was a problem hiding this comment.
Suggested change
| // differentiability will require SIL differnetiability witnesses and lots of | |
| // differentiability will require SIL differentiability witnesses and lots of |
rxwei
added a commit
to rxwei/swift
that referenced
this pull request
May 11, 2019
…ftlang#22871) Retroactive derivative registration, a.k.a. the `@differentiating` attribute, is the ultimate direction we'd like to go for. There are two critical goals: 1. Eliminating the JVP and VJP concepts from the user land. Today's situation is that the `jvp:` and `vjp:` attribute arguments are still needed for registering a derivative via a `@differentiable` attribute on the original function definition. The user is not supposed to know what a JVP/VJP is in order to define a derivative. When `@differentiating` is added, we would like to eliminate `vjp:` and `vjp:` attribute arguments from the `@differentiable` attribute. 2. Make differentiation as extensible as the rest of Swift (retroactive conformances, etc). Users should be able to make a function from an imported module be differentiable, and export the derivative to other files and modules that import the module containing the retroactive derivative. Implementation-wise, cross-file/cross-module retroactive derivative registration requires a top-level differentiability witness construct in SIL that behaves like SIL protocol witnesses. For more implementation information, see [this forum discussion](https://forums.swift.org/t/help-needed-with-retroactive-differentiability/19927/6?u=rxwei). This PR is **not** doing that. However, same-file retroactive derivative registration is fairly easy to do -- all it takes is to make each `@differentiating` imply a `@differentiable` attribute on the original function. In other words, the `@differentiable` attribute (or more precisely the `[differentiable]` SIL function attribute) is acting like a differentiability witness in a SIL module already, so we can just make use of that. This PR adds support for same-file retroactive derivative registration. So long as the original function exists in the same file as the derivative with a `@differentiating` attribute, derivative registration works. The detailed implementation steps are as follows: * When type-checking a `@differentiating` attribute, after resolving the original function declaration, check to see if the original function is in the same file as the retroactive derivative. If not, emit a diagnostic. * Find an existing `@differentiable` attribute on the original function that is w.r.t. all parameters. If it exists, treat that attribute as the witness. If not, create a new `@differentiable` attribute that covers all parameters and treat that as the witness. * If the `@differentiable` attribute already has an associated function (JVP/VJP) that the retroactive derivative represents, emit a diagnostic. * Set the corresponding associated function in the `@differentiable` attribute to be the retroactive derivative. Now, basic things work as demonstrated in the following example. A `@differentiating` attribute makes it work as if the original function had a `@differentiable` attribute. ```swift func functionWithRetroDeriv(x: Float) -> Float { return x } @differentiating(functionWithRetroDeriv) func retroDeriv(x: Float) -> (value: Float, pullback: (Float) -> Float) { return (value: x, pullback: { _ in 100 }) } gradient(at: 3, in: functionWithRetroDeriv) // => 100 ``` * When the derivative function has a different canonical generic signature as the original function and when the original function does not have an all-parameter `@differentiable` attribute that specifies the derivative function's generic constraints w.r.t. the original function, the behavior is undefined. This will be fixed in a follow-up PR. * We do not support `wrt:` in a `@differentiating` attribute yet, so we can't use this on functions with non-differentiable parameters. * The `@differentiating` attribute is not registered in `gyb_syntax_support` yet, so syntax verification would fail if we use `@differentiating` in the standard library. But we can live with this for now. Resolves [TF-278](https://bugs.swift.org/browse/TF-278).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Overview
Retroactive derivative registration, a.k.a. the
@differentiatingattribute, is the ultimate direction we'd like to go for. There are two critical goals:jvp:andvjp:attribute arguments are still needed for registering a derivative via a@differentiableattribute on the original function definition. The user is not supposed to know what a JVP/VJP is in order to define a derivative. When@differentiatingis added, we would like to eliminatevjp:andvjp:attribute arguments from the@differentiableattribute.Implementation-wise, cross-file/cross-module retroactive derivative registration requires a top-level differentiability witness construct in SIL that behaves like SIL protocol witnesses. For more implementation information, see this forum discussion. This PR is not doing that. However, same-file retroactive derivative registration is fairly easy to do -- all it takes is to make each
@differentiatingimply a@differentiableattribute on the original function. In other words, the@differentiableattribute (or more precisely the[differentiable]SIL function attribute) is acting like a differentiability witness in a SIL module already, so we can just make use of that.Changes
This PR adds support for same-file retroactive derivative registration. So long as the original function exists in the same file as the derivative with a
@differentiatingattribute, derivative registration works. The detailed implementation steps are as follows:@differentiatingattribute, after resolving the original function declaration, check to see if the original function is in the same file as the retroactive derivative. If not, emit a diagnostic.@differentiableattribute on the original function that is w.r.t. all parameters. If it exists, treat that attribute as the witness. If not, create a new@differentiableattribute that covers all parameters and treat that as the witness.@differentiableattribute already has an associated function (JVP/VJP) that the retroactive derivative represents, emit a diagnostic.@differentiableattribute to be the retroactive derivative.Now, basic things work as demonstrated in the following example. A
@differentiatingattribute makes it work as if the original function had a@differentiableattribute.Known unhandled
When the derivative function has a different canonical generic signature as the original function and when the original function does not have an all-parameter
@differentiableattribute that specifies the derivative function's generic constraints w.r.t. the original function, the behavior is undefined. This will be fixed in a follow-up PR.We do not support
wrt:in a@differentiatingattribute yet, so we can't use this on functions with non-differentiable parameters.The
@differentiatingattribute is not registered ingyb_syntax_supportyet, so syntax verification would fail if we use@differentiatingin the standard library. But we can live with this for now.Resolves TF-278.