Skip to content

Commit

Permalink
Add differential use handler
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Apr 1, 2024
1 parent 6de7670 commit 239a88c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ const CustomReversePass = Ptr{Cvoid}
EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle)
EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle)

const CustomDiffUse = Ptr{Cvoid}
EnzymeRegisterDiffUseCallHandler(name, handle) = ccall((:EnzymeRegisterDiffUseCallHandler, libEnzyme), Cvoid, (Cstring, CustomDiffUse), name, handle)

Check warning on line 234 in src/api.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/api.jl:234:-EnzymeRegisterDiffUseCallHandler(name, handle) = ccall((:EnzymeRegisterDiffUseCallHandler, libEnzyme), Cvoid, (Cstring, CustomDiffUse), name, handle) src/api.jl:235:-EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove)) src/api.jl:236:-EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args)) src/api.jl:237:-EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) src/api.jl:238:- src/api.jl:239:-EnzymeGradientUtilsReplaceAWithB(gutils, a, b) = ccall((:EnzymeGradientUtilsReplaceAWithB, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef), gutils, a, b) src/api.jl:240:-EnzymeGradientUtilsErase(gutils, a) = ccall((:EnzymeGradientUtilsErase, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, a) src/api.jl:241:-EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall((:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, LLVMValueRef, UInt8), gutils, a, orig, erase) src/api.jl:242:-EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) src/api.jl:243:-EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) src/api.jl:244:-EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) src/api.jl:245:-EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) src/api.jl:246:-EnzymeGradientUtilsLookup(gutils, val, B) = ccall((:EnzymeGradientUtilsLookup, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) src/api.jl:247:-EnzymeGradientUtilsInvertPointer(gutils, val, B) = ccall((:EnzymeGradientUtilsInvertPointer, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) src/api.jl:248:-EnzymeGradientUtilsDiffe(gutils, val, B) = ccall((:EnzymeGradientUtilsDiffe, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, B) src/api.jl:249:-EnzymeGradientUtilsAddToDiffe(gutils, val, diffe, B, T) = ccall((:EnzymeGradientUtilsAddToDiffe, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, LLVMTypeRef), gutils, val, diffe, B, T) src/api.jl:250:-function EnzymeGradientUtilsAddToInvertedPointerDiffeTT(gutils, orig, origVal, vd, size, origptr, prediff, B, align, premask) src/api.jl:251:- ccall((:EnzymeGradientUtilsAddToInvertedPointerDiffeTT, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, CTypeTreeRef, Cuint, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef, Cuint, LLVMValueRef), gutils, orig, origVal, vd, size, origptr, prediff, B, align, premask) src/api.jl:252:-end src/api.jl:253:- src/api.jl:254:-EnzymeGradientUtilsSetDiffe(gutils, val, diffe, B) = ccall((:EnzymeGradientUtilsSetDiffe, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef, LLVM.API.LLVMBuilderRef), gutils, val, diffe, B) src/api.jl:255:-EnzymeGradientUtilsIsConstantValue(gutils, val) = ccall((:EnzymeGradientUtilsIsConstantValue, libEnzyme), UInt8, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) src/api.jl:256:-EnzymeGradientUtilsIsConstantInstruction(gutils, val) = ccall((:EnzymeGradientUtilsIsConstantInstruction, libEnzyme), UInt8, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, va
EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove))
EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args))
EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T)
Expand Down
11 changes: 10 additions & 1 deletion src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
end

C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}

alloctx = LLVM.IRBuilder()
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))

Expand Down Expand Up @@ -949,3 +949,12 @@ function enzyme_custom_rev(B, orig, gutils, tape)
enzyme_custom_common_rev(#=forward=#false, B, orig, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape)
return nothing
end

function enzyme_custom_diffuse(orig, gutils, val, isshadow, mode)
# use default
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) && !has_aug_fwd_rule(orig, gutils)
return (false, true)
end
# don't use default and always require the arg
return (true, false)
end
11 changes: 11 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,18 @@ macro fwdfunc(f)
))
end


macro diffusefunc(f)
:(@cfunction((OrigCI, gutils, val, shadow, mode, useDefault) -> begin
res = $f(LLVM.CallInst(OrigCI), GradientUtils(gutils), LLVM.Value(val), shadow != 0, mode)::Tuple{Bool, Bool}
unsafe_store!(useDefault, UInt8(res[2]))
UInt8(res[1])
end, UInt8, (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt8, API.CDerivativeMode, Ptr{UInt8})
))
end

@noinline function register_llvm_rules()
API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse))

Check warning on line 1169 in src/rules/llvmrules.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rules/llvmrules.jl:1169:- API.EnzymeRegisterDiffUseCallHandler("enzyme_custom", @diffusefunc(enzyme_custom_diffuse)) src/rules/llvmrules.jl:1170:- register_handler!( src/rules/llvmrules.jl:1171:- ("julia.call",), src/rules/llvmrules.jl:1172:- @augfunc(jlcall_augfwd), src/rules/llvmrules.jl:1173:- @revfunc(jlcall_rev), src/rules/llvmrules.jl:1174:- @fwdfunc(jlcall_fwd), src/rules/llvmrules.jl:1175:- ) src/rules/llvmrules.jl:1176:- register_handler!( src/rules/llvmrules.jl:1177:- ("julia.call2",), src/rules/llvmrules.jl:1178:- @augfunc(jlcall2_augfwd), src/rules/llvmrules.jl:1179:- @revfunc(jlcall2_rev), src/rules/llvmrules.jl:1180:- @fwdfunc(jlcall2_fwd), src/rules/llvmrules.jl:1181:- ) src/rules/llvmrules.jl:1182:- register_handler!( src/rules/llvmrules.jl:1183:- ("jl_apply_generic", "ijl_apply_generic"), src/rules/llvmrules.jl:1184:- @augfunc(generic_augfwd), src/rules/llvmrules.jl:1185:- @revfunc(generic_rev), src/rules/llvmrules.jl:1186:- @fwdfunc(generic_fwd), src/rules/llvmrules.jl:1187:- ) src/rules/llvmrules.jl:1188:- register_handler!( src/rules/llvmrules.jl:1189:- ("jl_invoke", "ijl_invoke", "jl_f_invoke"), src/rules/llvmrules.jl:1190:- @augfunc(invoke_augfwd), src/rules/llvmrules.jl:1191:- @revfunc(invoke_rev), src/rules/llvmrules.jl:1192:- @fwdfunc(invoke_fwd), src/rules/llvmrules.jl:1193:- ) src/rules/llvmrules.jl:1194:- register_handler!( src/rules/llvmrules.jl:1195:- ("jl_f__apply_latest", "jl_f__call_latest"), src/rules/llvmrules.jl:1196:- @augfunc(apply_latest_augfwd), src/rules/llvmrules.jl:1197:- @revfunc(apply_latest_rev), src/rules/llvmrules.jl:1198:- @fwdfunc(apply_latest_fwd), src/rules/llvmrules.jl:1199:- ) src/rules/llvmrules.jl:1200:- register_handler!( src/rules/llvmrules.jl:1201:- ("jl_threadsfor",), src/rules/llvmrules.jl:1202:- @augfunc(threadsfor_augfwd), src/rules/llvmrules.jl:1203:- @revfunc(threadsfor_rev), src/rules/llvmrules.jl:1204:- @fwdfunc(threadsfor_fwd), src/rules/llvmrules.jl:1205:- ) src/rules/llvmrules.jl:1206:- register_handler!( src/rules/llvmrules.jl:1207:- ("jl_pmap",), src/rules/llvmrules.jl:1208:- @augfunc(pmap_augfwd), src/rules/llvmrules.jl:1209:- @revfunc(pmap_rev), src/rules/llvmrules.jl:1210:- @fwdfunc(pmap_fwd), src/rules/llvmrules.jl:1211:- ) src/rules/llvmrules.jl:1212:- register_handler!( src/rules/llvmrules.jl:1213:- ("jl_new_task", "ijl_new_task"), src/rules/llvmrules.jl:1214:- @augfunc(newtask_augfwd), src/rules/llvmrules.jl:1215:- @revfunc(newtask_rev), src/rules/llvmrules.jl:1216:- @fwdfunc(newtask_fwd), src/rules/llvmrules.jl:1217:- ) src/rules/llvmrules.jl:1218:- register_handler!( src/rules/llvmrules.jl:1219:- ("jl_set_task_threadpoolid", "ijl_set_task_threadpoolid"), src/rules/llvmrules.jl:1220:- @augfunc(set_task_tid_augfwd), src/rules/llvmrules.jl:1221:- @revfunc(set_task_tid_rev), src/rules/llvmrules.jl:1222:- @fwdfunc(set_task_tid_fwd), src/rules/llvmrules.jl:1223:- ) src/rules/llvmrules.jl:1224:- register_handler!( src/rules/llvmrules.jl:1225:- ("jl_enq_work",), src/rules/llvmrules.jl:1226:- @augfunc(enq_work_augfwd), src/rules/llvmrules.jl:1227:- @revfunc(enq_work_rev), src/rules/llvmrules.jl:1228:- @fwdfunc(enq_work_fwd) src/rules/llvmrules.jl:1229:- ) src/rules/llvmrules.jl:1230:- register_handler!( src/rules/llvmrules.jl:1231:- ("enzyme_custom",), src/rules/llvmrules.jl:1232:- @augfunc(enzyme_custom_augfwd), src/rules/llvmrules.jl:1233:- @revfunc(enzyme_custom_rev), src/rules/llvmrules.jl:1234:- @fwdfunc(enzyme_custom_fwd) src/rules/llvmrules.jl:1235:- ) src/rules/llvmrules.jl:1236:- register_handler!( src/rules/llvmrules.jl:1237:- ("jl_wait",), src/rules/llvmrules.jl:1238:- @augfunc(wait_augfwd), sr
register_handler!(
("julia.call",),
@augfunc(jlcall_augfwd),
Expand Down

0 comments on commit 239a88c

Please sign in to comment.