diff --git a/src/Tracing.jl b/src/Tracing.jl index 9616b01306..a439aa93f7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -7,128 +7,7 @@ NoStopTracedTrack = 6 end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type) -) - if T === Any - return T - end - - if T === Union{} - return T - end - - if Enzyme.Compiler.isghostty(T) || Core.Compiler.isconstType(T) - return T - end - - if T == Type || T == DataType - return T - end - - # unknown number of fields - if T isa UnionAll - aT = Base.argument_datatype(T) - if isnothing(aT) - throw(TracedTypeError("Unhandled type $T")) - end - if isnothing(Base.datatype_fieldcount(aT)) - throw(TracedTypeError("Unhandled type $T")) - end - end - - if T isa Union - return Union{ - traced_type_inner(T.a, seen, mode, track_numbers), - traced_type_inner(T.b, seen, mode, track_numbers), - } - end - - # if abstract it must be by reference - if Base.isabstracttype(T) - if !(T isa UnionAll) && length(T.parameters) == 0 - return T - end - throw(TracedTypeError("Unhandled abstract type $T")) - end - - if !(Base.isconcretetype(T) || T isa UnionAll) - throw(AssertionError("Type $T is not concrete type or concrete tuple")) - end - - if haskey(seen, T) - return seen[T] - end - - seen2 = copy(seen) - seen2[T] = T - - changed = false - subTys = Type[] - for f in 1:fieldcount(T) - subT = fieldtype(T, f) - subTT = traced_type_inner(subT, seen2, mode, track_numbers) - changed |= subT != subTT - push!(subTys, subTT) - end - - if !changed - for (k, v) in seen2 - seen[k] = v - end - return T - end - - wrapped_carray = T <: AbstractArray && ancestor(T) <: ConcreteRArray - wrapped_tracedarray = T <: AbstractArray && ancestor(T) <: TracedRArray - - subParms = [] - for (i, SST) in enumerate(T.parameters) - if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive - TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) - push!(subParms, TrT) - elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber - TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers) - push!(subParms, TrT) - else - if SST isa Type - TrT = traced_type_inner(SST, seen, mode, track_numbers) - push!(subParms, TrT) - else - push!(subParms, SST) - end - end - end - - if !isempty(subParms) - TT2 = Core.apply_type(T.name.wrapper, subParms...) - else - TT2 = T - end - seen3 = copy(seen) - seen3[T] = TT2 - if fieldcount(T) == fieldcount(TT2) - legal = true - for f in 1:fieldcount(T) - subT = fieldtype(T, f) - subT2 = fieldtype(TT2, f) - subTT = traced_type_inner(subT, seen3, mode, track_numbers) - if subT2 != subTT - legal = false - break - end - end - if legal - for (k, v) in seen3 - seen[k] = v - end - return TT2 - end - end - - name = Symbol[] - throw(NoFieldMatchError(T, TT2)) -end +function traced_type_inner end Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{Union{}}), @@ -216,28 +95,56 @@ Base.@nospecializeinfer function traced_type_inner( return Core.apply_type(T.name.wrapper, traced_fieldtypes...) end -@inline is_concrete_tuple(x::T2) where {T2} = - (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) - -Base.@nospecializeinfer function traced_type_inner( +Base.@nospecializeinfer function traced_tuple_type_inner( @nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(track_numbers::Type) ) - if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll + if T === Tuple + return T + end + if T isa UnionAll + if T.var.lb === Union{} && T.var.ub === Any + return UnionAll(T.var, traced_type_inner(T.body, seen, mode, track_numbers)) + end throw(AssertionError("Type $T is not concrete type or concrete tuple")) - elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) - # Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...} - throw(AssertionError("Type tuple of vararg $T is not supported")) - end - TT = [ - traced_type_inner(T.parameters[i], seen, mode, track_numbers) for - i in 1:length(T.parameters) - ] + end + TT = Union{Type,Core.TypeofVararg}[] + for i in 1:length(T.parameters) + st = traced_type_inner(T.parameters[i], seen, mode, track_numbers) + push!(TT, st) + end return Tuple{TT...} end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Core.TypeofVararg), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) + return Vararg{traced_type_inner(T.T, seen, mode, track_numbers),T.N} +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::TypeVar), seen, mode::TraceMode, @nospecialize(track_numbers::Type) +) + if T.lb === Union{} && T.ub === Any + return T + end + throw(AssertionError("Unsupported Typevar $T lb=$(T.lb) ub=$(T.ub)")) +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Tuple}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) + return traced_tuple_type_inner(T, seen, mode, track_numbers) +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{<:NamedTuple}), seen, @@ -427,6 +334,132 @@ Base.@nospecializeinfer function traced_type_inner( throw("Val type $(Val{T}) cannot be traced") end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type) +) + if T === Any + return T + end + + if T === Union{} + return T + end + + if Enzyme.Compiler.isghostty(T) || Core.Compiler.isconstType(T) + return T + end + + if T == Type || T == DataType + return T + end + + # unknown number of fields + if Base.inferencebarrier(T) isa UnionAll + if T.var.lb === Union{} && T.var.ub === Any + return UnionAll(T.var, traced_type_inner(T.body, seen, mode, track_numbers)) + end + aT = Base.argument_datatype(T) + if isnothing(aT) + throw(TracedTypeError("Unhandled type $T")) + end + if isnothing(Base.datatype_fieldcount(aT)) + throw(TracedTypeError("Unhandled type $T")) + end + end + + if T isa Union + return Union{ + traced_type_inner(T.a, seen, mode, track_numbers), + traced_type_inner(T.b, seen, mode, track_numbers), + } + end + + # if abstract it must be by reference + if Base.isabstracttype(T) + if !(T isa UnionAll) && length(T.parameters) == 0 + return T + end + throw(TracedTypeError("Unhandled abstract type $T")) + end + + if T <: Tuple + return traced_tuple_type_inner(T, seen, mode, track_numbers) + end + + if haskey(seen, T) + return seen[T] + end + + seen2 = copy(seen) + seen2[T] = T + + changed = false + subTys = Union{Type,TypeVar}[] + for f in 1:fieldcount(T) + subT = fieldtype(T, f) + subTT = traced_type_inner(subT, seen2, mode, track_numbers) + changed |= subT != subTT + push!(subTys, subTT) + end + + if !changed + for (k, v) in seen2 + seen[k] = v + end + return T + end + + wrapped_carray = T <: AbstractArray && ancestor(T) <: ConcreteRArray + wrapped_tracedarray = T <: AbstractArray && ancestor(T) <: TracedRArray + + subParms = [] + for (i, SST) in enumerate(T.parameters) + if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive + TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) + push!(subParms, TrT) + elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber + TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers) + push!(subParms, TrT) + else + if SST isa Type + TrT = traced_type_inner(SST, seen, mode, track_numbers) + push!(subParms, TrT) + else + push!(subParms, SST) + end + end + end + + if !isempty(subParms) + TT2 = Core.apply_type(T.name.wrapper, subParms...) + else + TT2 = T + end + seen3 = copy(seen) + seen3[T] = TT2 + if fieldcount(T) == fieldcount(TT2) + legal = true + for f in 1:fieldcount(T) + subT = fieldtype(T, f) + subT2 = fieldtype(TT2, f) + subTT = traced_type_inner(subT, seen3, mode, track_numbers) + if subT2 != subTT + legal = false + break + end + end + if legal + for (k, v) in seen3 + seen[k] = v + end + return TT2 + end + end + + name = Symbol[] + throw(NoFieldMatchError(T, TT2)) +end + const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() # function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type)) diff --git a/test/tracing.jl b/test/tracing.jl index 342c2d4fc3..98bf1abcef 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -11,6 +11,11 @@ using Reactant: ReactantPrimitive using Test +struct Wrapper{A,B} + a::A + b::B +end + @testset "Tracing" begin @testset "trace_type" begin @testset "mode = ConcreteToTraced" begin @@ -138,6 +143,30 @@ using Test Base.Pairs{Symbol,Union{}}, Base.Pairs{Symbol,Union{}}, ), + ( + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + ), + (Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}), + ( + Wrapper{Float64,Vector{Float64}}, + Wrapper{Float64,Vector{Float64}}, + Wrapper{TracedRNumber{Float64},Vector{Float64}}, + ), + ( + Wrapper{Float64,ConcreteRArray{Float64,1}}, + Wrapper{Float64,TracedRArray{Float64,1}}, + Wrapper{TracedRNumber{Float64},TracedRArray{Float64,1}}, + ), + (Wrapper{Symbol}, Wrapper{Symbol}, Wrapper{Symbol}), + (Wrapper{Float64}, Wrapper{Float64}, Wrapper{TracedRNumber{Float64}}), + ( + Wrapper{ConcreteRArray{Float64,1}}, + Wrapper{TracedRArray{Float64,1}}, + Wrapper{TracedRArray{Float64,1}}, + ), + (Wrapper, Wrapper, Wrapper), ] tracedty = traced_type(origty, Val(ConcreteToTraced), Union{}) @test tracedty == targetty