Skip to content
Closed
2 changes: 2 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,6 @@ struct ForwardMode <: Mode
end
const Forward = ForwardMode()

include("rules.jl")

end # module EnzymeCore
113 changes: 113 additions & 0 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module EnzymeRules

import EnzymeCore: Annotation
export Config, ConfigWidth
export needs_primal, needs_shadow, width, overwritten

"""
forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...)

Calculate the forward derivative. The first argument `func` is the callable
for which the rule applies to. Either wrapped in a [`Const`](@ref)), or
a [`Duplicated`](@ref) if it is a closure.
The second argument is the return type annotation, and all other arguments are
the annotated function arguments.
"""
function forward end

struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end
const ConfigWidth{Width} = Config{<:Any,<:Any, Width}

needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal
needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow
width(::Config{<:Any, <:Any, Width}) where Width = Width
overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten

"""
augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...)

Must return a tuple of length 2.
The first-value is primal value and the second is the tape. If no tape is
required return `(val, nothing)`.
"""
function augmented_primal end

"""
reverse(::Config, func::Annotation{typeof(f)}, dret::Annotation, tape, args::Annotation...)

Takes gradient of derivative, activity annotation, and tape
"""
function reverse end

_annotate(T::DataType) = TypeVar(gensym(), Annotation{T})
_annotate(::Type{T}) where T = TypeVar(gensym(), Annotation{T})
function _annotate(VA::Core.TypeofVararg)
T = _annotate(VA.T)
if isdefined(VA, :N)
return Vararg{T, VA.N}
else
return Vararg{T}
end
end

function has_frule_from_sig(@nospecialize(TT); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
ft = TT.parameters[1]
tt = map(_annotate, TT.parameters[2:end])
TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...}
isapplicable(forward, TT; world)
end

function has_rrule_from_sig(@nospecialize(TT); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
ft = TT.parameters[1]
tt = map(_annotate, TT.parameters[2:end])
TT = Tuple{<:Config, <:Annotation{ft}, <:Annotation, <:Any, tt...}
isapplicable(reverse, TT; world)
end

function has_frule(@nospecialize(f); world=Base.get_world_counter())
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, Vararg{<:Annotation}}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, TT.parameters...}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT::Type); world=Base.get_world_counter())
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, Vararg{<:Annotation}}
isapplicable(forward, TT; world)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT::Type), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter())
TT = Base.unwrap_unionall(TT)
TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, TT.parameters...}
isapplicable(forward, TT; world)
end

# Base.hasmethod is a precise match we want the broader query.
function isapplicable(@nospecialize(f), @nospecialize(TT); world=Base.get_world_counter())
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
return !isempty(Base._methods_by_ftype(sig, -1, world)) # TODO cheaper way of querying?
end

function has_rrule(@nospecialize(TT), world=Base.get_world_counter())
return false
end

function issupported()
@static if VERSION < v"1.7.0"
return false
else
return true
end
end

end # EnzymeRules
3 changes: 2 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export markType, batch_size, onehot, chunkedonehot
using LinearAlgebra
import EnzymeCore: ReverseMode, ForwardMode, Annotation, Mode

import EnzymeCore: EnzymeRules

# Independent code, must be loaded before "compiler.jl"
include("pmap.jl")

Expand Down Expand Up @@ -61,7 +63,6 @@ end
end
end


include("logic.jl")
include("typeanalysis.jl")
include("typetree.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan

EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T)

EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils)
EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils)
EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val)
EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig)
Expand All @@ -211,7 +212,11 @@ EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocati
EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils)

EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val)

EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size)

EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign)

EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP)

EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme),
Expand Down
Loading