-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Implement the closure optimization that is specialized towards the linear map tuples / enums produced by autodiff. #68944
Comments
Copied testcase and example from #68901: import _Differentiation
import Darwin
@differentiable(reverse)
func f(_ x: Float) -> Float {
if (x > 0) {
return sin(x) * cos(x)
} else {
return sin(x) + cos(x)
}
} So, for the case above we'd turn: // foo(_:)
sil hidden [noinline] @$s6sincos3fooyS2fF : $@convention(thin) (Float) -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 "x" // users: %26, %24, %16, %12, %3, %2, %1
bb0(%0 : $Float):
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
debug_value %0 : $Float, let, name "x", argno 1 // id: %2
%3 = struct_extract %0 : $Float, #Float._value // users: %13, %9, %5
%4 = float_literal $Builtin.FPIEEE32, 0x0 // 0 // user: %5
%5 = builtin "fcmp_olt_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %7
%6 = tuple () // users: %22, %8
cond_br %5, bb1, bb2 // id: %7
bb1: // Preds: bb0
%8 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %19
%9 = builtin "int_sin_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %10
%10 = struct $Float (%9 : $Builtin.FPIEEE32) // user: %18
// function_ref closure #1 in _vjpSin(_:)
%11 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %12
%12 = partial_apply [callee_guaranteed] %11(%0) : $@convention(thin) (Float, Float) -> Float // user: %19
%13 = builtin "int_cos_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %14
%14 = struct $Float (%13 : $Builtin.FPIEEE32) // user: %18
// function_ref closure #1 in _vjpCos(_:)
%15 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %16
%16 = partial_apply [callee_guaranteed] %15(%0) : $@convention(thin) (Float, Float) -> Float // user: %19
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%17 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %18
%18 = partial_apply [callee_guaranteed] %17(%14, %10) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %19
%19 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%8, %12, %16, %18) // user: %20
%20 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %19 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %21
br bb3(%20 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %21
bb2: // Preds: bb0
%22 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %29
// function_ref closure #1 in _vjpSin(_:)
%23 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %24
%24 = partial_apply [callee_guaranteed] %23(%0) : $@convention(thin) (Float, Float) -> Float // user: %29
// function_ref closure #1 in _vjpCos(_:)
%25 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %26
%26 = partial_apply [callee_guaranteed] %25(%0) : $@convention(thin) (Float, Float) -> Float // user: %29
// function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
%27 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) // user: %28
%28 = thin_to_thick_function %27 : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) // user: %29
%29 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%22, %24, %26, %28) // user: %30
%30 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %29 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %31
br bb3(%30 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %31
// %32 // user: %37
bb3(%32 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
// function_ref pullback of f(_:)
%33 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
%34 = integer_literal $Builtin.Int64, 1 // user: %35
%35 = builtin "sitofp_Int64_FPIEEE32"(%34 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %36
%36 = struct $Float (%35 : $Builtin.FPIEEE32) // user: %37
%37 = apply %33(%36, %32) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %38
return %37 : $Float // id: %38
} // end sil function '$s6sincos3fooyS2fF' into: enum _AD__$s6sincos1fyS2fF_bb0__Pred__src_0_wrt_0 {
}
enum _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0 {
case bb0(())
}
enum _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0 {
case bb0(())
}
enum _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0 {
case bb2((predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float, (Float, Float)))
case bb1((predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, Float))
}
// foo(_:)
sil hidden [noinline] @$s6sincos3fooyS2fF : $@convention(thin) (Float) -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 "x" // users: %26, %24, %16, %12, %3, %2, %1
bb0(%0 : $Float):
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
debug_value %0 : $Float, let, name "x", argno 1 // id: %2
%3 = struct_extract %0 : $Float, #Float._value // users: %13, %9, %5
%4 = float_literal $Builtin.FPIEEE32, 0x0 // 0 // user: %5
%5 = builtin "fcmp_olt_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %7
%6 = tuple () // users: %22, %8
cond_br %5, bb1, bb2 // id: %7
bb1: // Preds: bb0
%8 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %19
%9 = builtin "int_sin_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %10
%10 = struct $Float (%9 : $Builtin.FPIEEE32) // user: %18
%13 = builtin "int_cos_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %14
%14 = struct $Float (%13 : $Builtin.FPIEEE32)
%newt = tuple $(Float, Float) (%14, %10)
%19 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, Float, (Float, Float)) (%8, %0, %0, %newt) // user: %20
%20 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %19 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, (Float, Float)) // user: %21
br bb3(%20 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %21
bb2: // Preds: bb0
%22 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %29
%29 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float) (%22, %0, %0) // user: %30
%30 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %29 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float) // user: %31
br bb3(%30 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %31
// %32 // user: %37
bb3(%32 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
// function_ref pullback of f(_:)
%33 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
%34 = integer_literal $Builtin.Int64, 1 // user: %35
%35 = builtin "sitofp_Int64_FPIEEE32"(%34 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %36
%36 = struct $Float (%35 : $Builtin.FPIEEE32) // user: %37
%37 = apply %33(%36, %32) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %38
return %37 : $Float // id: %38
} // end sil function '$s6sincos3fooyS2fF'
|
[Design] Autodiff Pullback Closure Specialization OptimizationWhat is it?This optimization can help alleviate heap allocation costs associated to closures by eliminating usage of closures altogether, in certain call-sites. Given a function call-site, if the callee takes a closure as an input argument and then calls the closure in its body, we can eliminate the overhead of the closure context's heap allocation by -
Using this optimization, SIL code like this:
Will look like this:
See here for a high-level description of the general Swift closure specialization optimization. Limitations of the general closure specialization optimizationThe general closure specialization optimization operates under numerous restrictions. The ones most directly affecting AD are:
DesignThis optimization is going to be implemented as a SILFunctionTransform. Pre-optimization steps
Optimization stepsNote - The below steps are a rough estimation and not set in stone.
Post-optimization steps
Location in optimization pipelineRight after the current position of the general closure specialization optimization pipeline seems like a sane default to start out with. Open questions for discussionQ. What happens to old branch tracing enums? Q. What happens to old pullback? |
@asl Could you take a look at the brief design write up for our AD specific closure specialization optimization? |
Tagging @BradLarson as well Overall, I think some steps / part would likely need some clarification.
Note that in AD case the closure it not passed as input argument directly. Instead it's buried deep inside a linear map tuple.
What does it mean "branch trace enum is non-trivial"? Separate enums are created for each BB. So, if the last BB of the function does not have any calls / predecessors, the corresponding enum will be quite trivial (e.g. no tuple payload).
What does this mean? What is the "node" here? What is the "top node"?
What do you mean as "parent enum"? Note that you need to operate on the function as a whole as case / payloads for branch trace enum for a basic block are predecessor basic blocks and corresponding linear map tuples. Also, you need to build everything at once at RPOT-manner otherwise we'd end with the same type lowering issues as we already faced previously.
Can you please expand "can be derived" case? It seems the most important thing here.
What are "arguments" here? Arguments of original function? I am confused. Do you have an example where this would be necessary? In general, it would help if each of the steps will be illustrated by some SIL example, so we can review the meaning of each step.
What "existing code" you're referring to? Certainly we can safely delete the unused enums.
Again, we can safely remove unused code. Pullback function does not exist as a separate entity, it's an internal implementation detail. And if unused, could be dropped entirely. |
We discussed the intended optimization and its scope with @jkshtj. He will prepare a refined proposal. |
@asl @BradLarson I've revised the proposal. I've rewritten it a gist here. Could you please take a look? |
Here are my comments:
You need also ensure that pullback is a private function. While this is always currently, but it's an implementation detail as of now.
How would you determine it?
Again, you need to define what
Isn't this case covered by the generic closure specialization transformation? Why would we need to reimplement it?
How would you handle various ownership-related things?
There is some contradiction here. This is all possible if all nested VJPs would be inlined into a top-level VJP. Otherwise you won't see these nested pullbacks (partial apply's).
You definitely need to restart pipeline after this transformation. |
Subtask of #68901
Implement the closure optimization that is specialized towards the linear map tuples / enums produced by autodiff. In particular, as in example above, if we see the particular closure (partial_apply), then instead of storing the closure in the tuple, store the closed value. And move partial_apply down to the apply site (no need to fold, there are existing passes to do this). And we know the place of use due to the way how the linear map tuples and branch tracing enums are generated.
The text was updated successfully, but these errors were encountered: