Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
[compat]
Adapt = "0.4, 1.0, 2.0, 3.0"
Atomix = "0.1"
EnzymeCore = "0.5, 0.6"
EnzymeCore = "0.6.4"
MacroTools = "0.5"
PrecompileTools = "1"
Requires = "1.3"
Expand Down
78 changes: 71 additions & 7 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,25 @@ module EnzymeExt
fwd_kernel(f, args...; ndrange, workgroupsize)
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
if !any(ActiveTys)
return f
end
function inact(ctx, args2::Vararg{Any, N}) where N
args3 = ntuple(Val(N)) do i
Base.@_inline_meta
if ActiveTys[i]
args2[i][]
else
args2[i]
end
end
f(ctx, args3...)
end
return inact
end

function EnzymeRules.augmented_primal(config::Config, func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
kernel = func.val
f = kernel.f

Expand All @@ -49,34 +67,80 @@ module EnzymeExt
# TODO autodiff_deferred on the func.val
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)
FT = Const{Core.Typeof(f)}

arg_refs = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Ref(EnzymeCore.make_zero(args[i].val))
else
nothing
end
end
args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

# TODO in KA backends like CUDAKernels, etc have a version with a parent job type
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args)...)
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), FT, Const, Const{ctxTy}, map(Core.Typeof, args2)...)


subtape = Array{TapeType}(undef, size(blocks(iterspace)))

aug_kernel = similar(kernel, aug_fwd)

aug_kernel(f, ModifiedBetween, subtape, args...; ndrange, workgroupsize)
aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)

# TODO the fact that ctxTy is type unstable means this is all type unstable.
# Since custom rules require a fixed return type, explicitly cast to Any, rather
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.

res = AugmentedReturn{Nothing, Nothing, Array}(nothing, nothing, subtape)
res = AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs)}}(nothing, nothing, (subtape, arg_refs))
return res
end

function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, subtape, args...; ndrange=nothing, workgroupsize=nothing)
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
subtape, arg_refs = tape

args2 = ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
Duplicated(Ref(args[i].val), arg_refs[i])
else
args[i]
end
end

kernel = func.val
f = kernel.f

tup = Val(ntuple(Val(N)) do i
Base.@_inline_meta
args[i] isa Active
end)
f = make_active_byref(f, tup)

ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))

rev_kernel = similar(func.val, rev)
rev_kernel(f, ModifiedBetween, subtape, args...; ndrange, workgroupsize)
return ((nothing for a in args)...,)
rev_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
return ntuple(Val(N)) do i
Base.@_inline_meta
if args[i] isa Active
arg_refs[i][]
else
nothing
end
end
end
end
33 changes: 28 additions & 5 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,51 @@ using KernelAbstractions
@inbounds A[I] *= A[I]
end

function caller(A, backend)
function square_caller(A, backend)
kernel = square!(backend)
kernel(A, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end


@kernel function mul!(A, B)
I = @index(Global, Linear)
@inbounds A[I] *= B
end

function mul_caller(A, B, backend)
kernel = mul!(backend)
kernel(A, B, ndrange=size(A))
KernelAbstractions.synchronize(backend)
end

function enzyme_testsuite(backend, ArrayT, supports_reverse=true)
@testset "kernels" begin
A = ArrayT{Float64}(undef, 64)
A .= (1:1:64)
dA = ArrayT{Float64}(undef, 64)
dA .= 1

if supports_reverse
Enzyme.autodiff(Reverse, caller, Duplicated(A, dA), Const(backend()))

A .= (1:1:64)
dA .= 1

Enzyme.autodiff(Reverse, square_caller, Duplicated(A, dA), Const(backend()))
@test all(dA .≈ (2:2:128))


A .= (1:1:64)
dA .= 1

_, dB, _ = Enzyme.autodiff(Reverse, mul_caller, Duplicated(A, dA), Active(1.2), Const(backend()))[1]

@test all(dA .≈ 1.2)
@test dB ≈ sum(1:1:64)
end

A .= (1:1:64)
dA .= 1

Enzyme.autodiff(Forward, caller, Duplicated(A, dA), Const(backend()))
Enzyme.autodiff(Forward, square_caller, Duplicated(A, dA), Const(backend()))
@test all(dA .≈ 2:2:128)

end
Expand Down