New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Remove differentiation order from AD-related instructions. #27579
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding order to the codebase by removing order
from the codebase
87b472e
to
e92f534
Compare
@swift-ci please test tensorflow |
Thank you @marcrasi for investigating and fixing the mysterious segfault (16fea91576536f73fe0495c4bbe05af4daa3fcf1)! |
The differentiation order field in `differentiable_function` and `differentiable_function_extract` instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handle `order`, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch. Changes include: * Remove `differentiationOrder` from `DifferentiableFunctionInst` and `DifferentiableFunctionExtractInst`. * Make `DifferentiableFunctionInst::DifferentiableFunctionInst` take an optional pair of JVP and VJP instead of a variable-size array. * Rename "associated functions" to "derivative functions" in `DifferentiableFunctionInst` to align better with [the design](https://forums.swift.org/t/differentiable-programming-mega-proposal/28547). Filed task [TF-882](https://bugs.swift.org/browse/TF-882) to track the renaming of all other occurrences of "associated functions". Resolves [TF-880](https://bugs.swift.org/browse/TF-880).
@@ -896,16 +896,23 @@ namespace { | |||
DifferentiableFunctionExtractee::Original, | |||
TC.getTypeLowering(origFnTy, getResilienceExpansion()) | |||
}); | |||
for (auto kind : {AutoDiffAssociatedFunctionKind::JVP, | |||
AutoDiffAssociatedFunctionKind::VJP}) { | |||
for (AutoDiffAssociatedFunctionKind kind : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explanation of how I found this.
The simple reproducer from @rxwei was
func main() {
let grad = gradient { (x0: Float, x1: Float) -> Float in .zero }
grad(4, 5)
}
main()
Compiling that with swiftc -sanitize=address
and running the binary segfaults.
@rxwei had already looked at the generated code from that and it was the same between the working and broken versions of the compiler.
I decided to look at the SIL for the gradient<A, B>(of:)
function in the stdlib. To do this, I found the command for build lib/swift/linux/x86_64/Swift.swiftmodule
in swift-linux-x86_64/build.ninja
. I ran that, adding -emit-sil
, and replacing -o xyz.swiftmodule
with -o stdlib.sil
.
The result for the good compiler was:
// gradient<A, B>(of:)
sil @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF : $@convention(thin) <T, R where T : Differentiable, R : Differentiable, R : FloatingPoint, R == R.TangentVector> (@guaranteed @differentiable @callee_guaranteed (@in_guaranteed T) -> @out R) -> @owned @callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector {
// %0 // users: %5, %4, %3, %6, %1
bb0(%0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R):
debug_value %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R, let, name "f", argno 1 // id: %1
// function_ref closure #1 in gradient<A, B>(of:)
%2 = function_ref @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lFADxcfU_ : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %6
%3 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %7
%4 = differentiable_function_extract [jvp] [order 1] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %8
%5 = differentiable_function_extract [vjp] [order 1] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %9
%6 = partial_apply [callee_guaranteed] %2<T, R>(%0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %10
strong_retain %3 : $@callee_guaranteed (@in_guaranteed T) -> @out R // id: %7
strong_retain %4 : $@callee_guaranteed (@in_guaranteed T) -> (@out R, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> @out R) // id: %8
strong_retain %5 : $@callee_guaranteed (@in_guaranteed T) -> (@out R, @owned @callee_guaranteed (@in_guaranteed R) -> @out T.TangentVector) // id: %9
return %6 : $@callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector // id: %10
} // end sil function '$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF'
The result for the bad compiler was:
// gradient<A, B>(of:)
sil @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF : $@convention(thin) <T, R where T : Differentiable, R : Differentiable, R : FloatingPoint, R == R.TangentVector> (@guaranteed @differentiable @callee_guaranteed (@in_guaranteed T) -> @out R) -> @owned @callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector {
// %0 // users: %5, %4, %3, %6, %1
bb0(%0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R):
debug_value %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R, let, name "f", argno 1 // id: %1
// function_ref closure #1 in gradient<A, B>(of:)
%2 = function_ref @$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lFADxcfU_ : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %6
%3 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %7
%4 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %8
%5 = differentiable_function_extract [jvp] %0 : $@differentiable @callee_guaranteed (@in_guaranteed T) -> @out R // user: %9
%6 = partial_apply [callee_guaranteed] %2<T, R>(%0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 == τ_0_1.TangentVector> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @out τ_0_0.TangentVector // user: %10
strong_retain %3 : $@callee_guaranteed (@in_guaranteed T) -> @out R // id: %7
strong_retain %4 : $@callee_guaranteed (@in_guaranteed T) -> @out R // id: %8
strong_retain %5 : $@callee_guaranteed (@in_guaranteed T) -> (@out R, @owned @callee_guaranteed (@in_guaranteed T.TangentVector) -> @out R) // id: %9
return %6 : $@callee_guaranteed (@in_guaranteed T) -> @out T.TangentVector // id: %10
} // end sil function '$ss8gradient2of13TangentVectorQzxcq_xXG_ts14DifferentiableRzsAER_SFR_ACsAEPQy_Rs_r0_lF'
Very suspicious difference that the bad compiler extracts [original]
, [original]
, and [jvp]
.
I guessed that this piece of code was responsible for generating those, I added some llvm::dbgs()
statements to see what kind
and DifferentiableFunctionExtractee(kind)
were, and I saw that DifferentiableFunctionExtractee(kind)
was wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation!
@swift-ci please test tensorflow |
The differentiation order field in
differentiable_function
anddifferentiable_function_extract
instructions is unsupported and will not be used by the current design. Quite a lot of dead code exists to try to handleorder
, but it is mostly incomplete and untested. This PR removes the differentiation order from the code base to simplify what we upstream to the 'master' branch.Changes include:
differentiationOrder
fromDifferentiableFunctionInst
andDifferentiableFunctionExtractInst
.DifferentiableFunctionInst::DifferentiableFunctionInst
take an optional pair of JVP and VJP instead of a variable-size array.DifferentiableFunctionInst
to align better with the design. Filed task TF-882 to track the renaming of all other occurrences of "associated functions".Resolves TF-880.