Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Remove differentiation order from AD-related instructions. #27579

Merged
merged 4 commits into from Oct 10, 2019

Conversation

rxwei
Copy link
Member

@rxwei rxwei commented Oct 8, 2019

The differentiation order field in differentiable_function and differentiable_function_extract instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handle order, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch.

Changes include:

  • Remove differentiationOrder from DifferentiableFunctionInst and DifferentiableFunctionExtractInst.
  • Make DifferentiableFunctionInst::DifferentiableFunctionInst take an optional pair of JVP and VJP instead of a variable-size array.
  • Rename "associated functions" to "derivative functions" in DifferentiableFunctionInst to align better with the design. Filed task TF-882 to track the renaming of all other occurrences of "associated functions".

Resolves TF-880.

@rxwei rxwei added the tensorflow This is for "tensorflow" branch PRs. label Oct 8, 2019
@rxwei rxwei requested review from dan-zheng and bgogul October 8, 2019 20:37
Copy link
Collaborator

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding order to the codebase by removing order from the codebase

include/swift/SIL/SILInstruction.h Outdated Show resolved Hide resolved
lib/ParseSIL/ParseSIL.cpp Show resolved Hide resolved
@rxwei rxwei force-pushed the remove-order branch 2 times, most recently from 87b472e to e92f534 Compare October 8, 2019 22:43
@rxwei
Copy link
Member Author

rxwei commented Oct 8, 2019

@swift-ci please test tensorflow

@rxwei
Copy link
Member Author

rxwei commented Oct 9, 2019

Thank you @marcrasi for investigating and fixing the mysterious segfault (16fea91576536f73fe0495c4bbe05af4daa3fcf1)!

rxwei and others added 3 commits October 9, 2019 16:09
The differentiation order field in `differentiable_function` and `differentiable_function_extract` instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handle `order`, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch.

Changes include:
* Remove `differentiationOrder` from `DifferentiableFunctionInst` and `DifferentiableFunctionExtractInst`.
* Make `DifferentiableFunctionInst::DifferentiableFunctionInst` take an optional pair of JVP and VJP instead of a variable-size array.
* Rename "associated functions" to "derivative functions" in `DifferentiableFunctionInst` to align better with [the design](https://forums.swift.org/t/differentiable-programming-mega-proposal/28547). Filed task [TF-882](https://bugs.swift.org/browse/TF-882) to track the renaming of all other occurrences of "associated functions".

Resolves [TF-880](https://bugs.swift.org/browse/TF-880).
@@ -896,16 +896,23 @@ namespace {
DifferentiableFunctionExtractee::Original,
TC.getTypeLowering(origFnTy, getResilienceExpansion())
});
for (auto kind : {AutoDiffAssociatedFunctionKind::JVP,
AutoDiffAssociatedFunctionKind::VJP}) {
for (AutoDiffAssociatedFunctionKind kind :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation of how I found this.

The simple reproducer from @rxwei was

func main() {
  let grad = gradient { (x0: Float, x1: Float) -> Float in .zero }
  grad(4, 5)
}
main()

Compiling that with swiftc -sanitize=address and running the binary segfaults.

@rxwei had already looked at the generated code from that and it was the same between the working and broken versions of the compiler.

I decided to look at the SIL for the gradient<A, B>(of:) function in the stdlib. To do this, I found the command for build lib/swift/linux/x86_64/Swift.swiftmodule in swift-linux-x86_64/build.ninja. I ran that, adding -emit-sil, and replacing -o xyz.swiftmodule with -o stdlib.sil.

The result for the good compiler was:

// gradient<A, B>(of:)
sil @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF : $@convention(thin) <T, R where T : Differentiable, R : Differentiable, R : FloatingPoint, R == R.TangentVector> (@guaranteed @differentiable @callee_guaranteed (@in_guaranteed T) -> @out R) -> @owned @callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector {
// %0                                             // users: %5, %4, %3, %6, %1
bb0(%0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R):
  debug_value %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R, let, name "f", argno 1 // id: %1
  // function_ref closure #1 in gradient<A, B>(of:)
  %2 = function_ref @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lFADxcfU_ : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %6
  %3 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %7
  %4 = differentiable_function_extract [jvp] [order 1] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %8
  %5 = differentiable_function_extract [vjp] [order 1] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %9
  %6 = partial_apply [callee_guaranteed] %2<T, R>(%0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %10
  strong_retain %3 : $@callee_guaranteed (@in_guaranteed T) -> @out R // id: %7
  strong_retain %4 : $@callee_guaranteed (@in_guaranteed T) -> (@out R, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> @out R) // id: %8
  strong_retain %5 : $@callee_guaranteed (@in_guaranteed T) -> (@out R, @owned @callee_guaranteed (@in_guaranteed R) -> @out T.TangentVector) // id: %9
  return %6 : $@callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector // id: %10
} // end sil function '$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF'

The result for the bad compiler was:

// gradient<A, B>(of:)
sil @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF : $@convention(thin) <T, R where T : Differentiable, R : Differentiable, R : FloatingPoint, R == R.TangentVector> (@guaranteed @differentiable @callee_guaranteed (@in_guaranteed T) -> @out R) -> @owned @callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector {
// %0                                             // users: %5, %4, %3, %6, %1
bb0(%0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R):
  debug_value %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R, let, name "f", argno 1 // id: %1
  // function_ref closure #1 in gradient<A, B>(of:)
  %2 = function_ref @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lFADxcfU_ : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %6
  %3 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %7
  %4 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %8
  %5 = differentiable_function_extract [jvp] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %9
  %6 = partial_apply [callee_guaranteed] %2<T, R>(%0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %10
  strong_retain %3 : $@callee_guaranteed (@in_guaranteed T) -> @out R // id: %7
  strong_retain %4 : $@callee_guaranteed (@in_guaranteed T) -> @out R // id: %8
  strong_retain %5 : $@callee_guaranteed (@in_guaranteed T) -> (@out R, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> @out R) // id: %9
  return %6 : $@callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector // id: %10
} // end sil function '$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF'

Very suspicious difference that the bad compiler extracts [original], [original], and [jvp].

I guessed that this piece of code was responsible for generating those, I added some llvm::dbgs() statements to see what kind and DifferentiableFunctionExtractee(kind) were, and I saw that DifferentiableFunctionExtractee(kind) was wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!

@rxwei
Copy link
Member Author

rxwei commented Oct 9, 2019

@swift-ci please test tensorflow

@rxwei rxwei merged commit eeeeee2 into apple:tensorflow Oct 10, 2019
@rxwei rxwei deleted the remove-order branch October 10, 2019 00:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants