-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add differential use handler #1377
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -24,7 +24,9 @@ function get_shadow_type(gutils::GradientUtils, T::LLVM.Type) | |||||||||||||||||
end | ||||||||||||||||||
function get_uncacheable(gutils::GradientUtils, orig::LLVM.CallInst) | ||||||||||||||||||
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | ||||||||||||||||||
API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) | ||||||||||||||||||
if API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) != 1 | ||||||||||||||||||
Comment on lines
25
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
uncacheable .= 1 | ||||||||||||||||||
end | ||||||||||||||||||
return uncacheable | ||||||||||||||||||
end | ||||||||||||||||||
|
||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,5 @@ | ||||||||||||||||||||
|
||||||||||||||||||||
function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) | ||||||||||||||||||||
function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, mi, @nospecialize(RT), reverse::Bool, isKWCall::Bool) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
ops = collect(operands(orig)) | ||||||||||||||||||||
called = ops[end] | ||||||||||||||||||||
ops = ops[1:end-1] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
|
@@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) | |||||||||||||||||||
return args, activity, (overwritten...,), actives, kwtup | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) | ||||||||||||||||||||
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt)) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
width = get_width(gutils) | ||||||||||||||||||||
mode = get_mode(gutils) | ||||||||||||||||||||
|
||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
|
@@ -216,7 +216,23 @@ function enzyme_custom_setup_ret(gutils, orig, mi, RealRt) | |||||||||||||||||||
needsShadowP = Ref{UInt8}(0) | ||||||||||||||||||||
needsPrimalP = Ref{UInt8}(0) | ||||||||||||||||||||
|
||||||||||||||||||||
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) | ||||||||||||||||||||
# Conditionally use the get return. This is done because EnzymeGradientUtilsGetReturnDiffeType | ||||||||||||||||||||
# calls differential use analysis to determine needsprimal/shadow. However, since now this function | ||||||||||||||||||||
# is used as part of differential use analysis, we need to avoid an ininite recursion. Thus use | ||||||||||||||||||||
# the version without differential use if actual unreachable results are not available anyways. | ||||||||||||||||||||
uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(orig)))-1) | ||||||||||||||||||||
activep = if mode == API.DEM_ForwardMode || API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) == 1 | ||||||||||||||||||||
API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, mode) | ||||||||||||||||||||
Comment on lines
+223
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
else | ||||||||||||||||||||
actv = API.EnzymeGradientUtilsGetDiffeType(gutils, orig, false) | ||||||||||||||||||||
if !isghostty(RealRt) | ||||||||||||||||||||
needsPrimalP[] = 1 | ||||||||||||||||||||
if actv == API.DFT_DUP_ARG || actv == API.DFT_DUP_NONEED | ||||||||||||||||||||
needsShadowP[] = 1 | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
actv | ||||||||||||||||||||
end | ||||||||||||||||||||
needsPrimal = needsPrimalP[] != 0 | ||||||||||||||||||||
origNeedsPrimal = needsPrimal | ||||||||||||||||||||
_, sret, _ = get_return_info(RealRt) | ||||||||||||||||||||
|
@@ -479,7 +495,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR) | |||||||||||||||||||
return false | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
@inline function aug_fwd_mi(orig, gutils) | ||||||||||||||||||||
@inline function aug_fwd_mi(orig::LLVM.CallInst, gutils::GradientUtils) | ||||||||||||||||||||
width = get_width(gutils) | ||||||||||||||||||||
|
||||||||||||||||||||
# 1) extract out the MI from attributes | ||||||||||||||||||||
|
@@ -568,7 +584,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, | |||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
|
||||||||||||||||||||
alloctx = LLVM.IRBuilder() | ||||||||||||||||||||
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) | ||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -949,3 +965,12 @@ function enzyme_custom_rev(B, orig, gutils, tape) | |||||||||||||||||||
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
return nothing | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode) | ||||||||||||||||||||
# use default | ||||||||||||||||||||
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||
return (false, true) | ||||||||||||||||||||
end | ||||||||||||||||||||
# don't use default and always require the arg | ||||||||||||||||||||
return (true, false) | ||||||||||||||||||||
end | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1149,13 +1149,24 @@ | |||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
macro fwdfunc(f) | ||||||||||||||||||||||||||||||
:(@cfunction((B, OrigCI, gutils, normalR, shadowR) -> begin | ||||||||||||||||||||||||||||||
Check warning on line 1152 in src/rules/llvmrules.jl
|
||||||||||||||||||||||||||||||
UInt8($f(LLVM.IRBuilder(B), LLVM.CallInst(OrigCI), GradientUtils(gutils), normalR, shadowR)::Bool) | ||||||||||||||||||||||||||||||
end, UInt8, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}) | ||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
macro diffusefunc(f) | ||||||||||||||||||||||||||||||
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin | ||||||||||||||||||||||||||||||
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool} | ||||||||||||||||||||||||||||||
unsafe_store!(useDefault, UInt8(res[2])) | ||||||||||||||||||||||||||||||
UInt8(res[1]) | ||||||||||||||||||||||||||||||
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8}) | ||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||
Comment on lines
+1160
to
+1165
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
@noinline function register_llvm_rules() | ||||||||||||||||||||||||||||||
API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse)) | ||||||||||||||||||||||||||||||
Check warning on line 1169 in src/rules/llvmrules.jl
|
||||||||||||||||||||||||||||||
register_handler!( | ||||||||||||||||||||||||||||||
("julia.call",), | ||||||||||||||||||||||||||||||
@augfunc(jlcall_augfwd), | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶