Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nice union{} error #1479

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
end

if A <: Active
if !allocatedinline(rt) || rt isa Union
if (!allocatedinline(rt) || rt isa Union) && rt != Union{}
forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI)
res = forward(f, args...)
tape = res[1]
Expand All @@ -244,7 +244,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
throw(ErrorException("Duplicated Returns not yet handled"))
end

if A <: Active && rt <: Complex
if (A <: Active && rt <: Complex) && rt != Union{}
if Holomorphic
seen = IdDict()
seen2 = IdDict()
Expand Down
79 changes: 55 additions & 24 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT
adjoint::PT
end

struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World} <: AbstractThunk{FA, RT, TT, Width}
adjoint::PT
end

@inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT
@inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT

Expand Down Expand Up @@ -5277,7 +5281,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
cf = LLVM.called_operand(tmp)
if isa(cf, LLVM.Function)
nm = LLVM.name(cf)
if nm == "gpu_signal_exception" || nm == "gpu_report_exception"
if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw"
shouldemit = false
break
end
Expand Down Expand Up @@ -5433,6 +5437,9 @@ struct CompileResult{AT, PT}
TapeType::Type
end

@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} =
enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...)

@inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} =
enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...)

Expand Down Expand Up @@ -5536,7 +5543,9 @@ end
end

@inline function default_adjoint(T)
if T <: AbstractFloat
if T == Union{}
return nothing
elseif T <: AbstractFloat
return one(T)
elseif T <: Complex
error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff")
Expand All @@ -5559,7 +5568,7 @@ end

JuliaContext() do ctx
F = eltype(FA)
is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk
is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk
is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk
is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk
needs_tape = CC <: AdjointThunk
Expand All @@ -5569,32 +5578,44 @@ end
argtypes = DataType[argtt.parameters...]
argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N]

if !RawCall
if false && CC <: PrimalErrorThunk
primargs = [quote
convert($(eltype(T)), $(argexprs[i]).val)
end for (i, T) in enumerate(argtypes)]
return quote
fn.val($(primargs...))
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end
end

if !RawCall && !(CC <: PrimalErrorThunk)
if rettype <: Active
if length(argtypes) + is_adjoint + needs_tape != length(argexprs)
return quote
throw(MethodError($CC($fptr), $args))
throw(MethodError($CC(fptr), $args))
end
end
elseif rettype <: Const
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC($fptr), $args))
throw(MethodError($CC(fptr), $args))
end
end
else
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC($fptr), $args))
throw(MethodError($CC(fptr), $args))
end
end
end
end

types = DataType[]

if eltype(rettype) === Union{}
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
if eltype(rettype) === Union{} && false
return quote
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end
end
if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType)
rrt = eltype(rettype)
Expand Down Expand Up @@ -5665,7 +5686,9 @@ end
end
continue
end

if CC <: PrimalErrorThunk
continue
end
if T <: Active
if is_adjoint
if width == 1
Expand Down Expand Up @@ -5752,8 +5775,10 @@ end
end
push!(sret_types, NT)
end

@assert i == length(argexprs)+1

if !(CC <: PrimalErrorThunk)
@assert i == length(argexprs)+1
end

# Tape
if CC <: AugmentedForwardThunk
Expand Down Expand Up @@ -5785,7 +5810,7 @@ end

T_void = convert(LLVMType, Nothing)

combinedReturn = Tuple{sret_types...}
combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...}
if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types)
combinedReturn = AnonymousStruct(combinedReturn)
end
Expand Down Expand Up @@ -6003,29 +6028,30 @@ end
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

sig = Tuple{eltype(FA), map(eltype, TT.parameters)...}

interp = GPUCompiler.get_interpreter(tmp_job)

# TODO check compile return here, early
# rrt = Core.Compiler.return_type(f, primal.tt) # nothing
rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any)
rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype

run_enzyme = true

if rrt == Union{}
estr = "Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up"
return quote
error($estr)
end
run_enzyme = false
A = Const
end

if !(A <: Const) && guaranteed_const_nongen(rrt, World)
if run_enzyme && !(A <: Const) && guaranteed_const_nongen(rrt, World)
estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant"
return quote
error($estr)
end
end

rt2 = if A isa UnionAll
rt2 = if !run_enzyme
Const{rrt}
elseif A isa UnionAll
A{rrt}
else
@assert A isa DataType
Expand All @@ -6034,7 +6060,7 @@ end
A
end

params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

# We need to use primal as the key, to lookup the right method
Expand All @@ -6045,7 +6071,13 @@ end


compile_result = cached_compilation(job)
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
if !run_enzyme
ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World}
return quote
Base.@_inline_meta
$ErrT($(compile_result.adjoint))
end
elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
TapeType = compile_result.TapeType
AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType}
AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType}
Expand Down Expand Up @@ -6086,7 +6118,6 @@ import GPUCompiler: deferred_codegen_jobs
params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI)
tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

sig = Tuple{eltype(FA), map(eltype, TT.parameters)...}
interp = GPUCompiler.get_interpreter(tmp_job)

rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any)
Expand Down
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,15 @@ end
@test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1]
end


function assured_err(x)
throw(AssertionError("foo"))
end

@testset "UnionAll" begin
@test_throws AssertionError Enzyme.autodiff(Reverse, assured_err, Active, Active(2.0))
end

struct MyFlux
end

Expand Down
Loading