Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 168 additions & 135 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,128 +7,7 @@
NoStopTracedTrack = 6
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type)
)
if T === Any
return T
end

if T === Union{}
return T
end

if Enzyme.Compiler.isghostty(T) || Core.Compiler.isconstType(T)
return T
end

if T == Type || T == DataType
return T
end

# unknown number of fields
if T isa UnionAll
aT = Base.argument_datatype(T)
if isnothing(aT)
throw(TracedTypeError("Unhandled type $T"))
end
if isnothing(Base.datatype_fieldcount(aT))
throw(TracedTypeError("Unhandled type $T"))
end
end

if T isa Union
return Union{
traced_type_inner(T.a, seen, mode, track_numbers),
traced_type_inner(T.b, seen, mode, track_numbers),
}
end

# if abstract it must be by reference
if Base.isabstracttype(T)
if !(T isa UnionAll) && length(T.parameters) == 0
return T
end
throw(TracedTypeError("Unhandled abstract type $T"))
end

if !(Base.isconcretetype(T) || T isa UnionAll)
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
end

if haskey(seen, T)
return seen[T]
end

seen2 = copy(seen)
seen2[T] = T

changed = false
subTys = Type[]
for f in 1:fieldcount(T)
subT = fieldtype(T, f)
subTT = traced_type_inner(subT, seen2, mode, track_numbers)
changed |= subT != subTT
push!(subTys, subTT)
end

if !changed
for (k, v) in seen2
seen[k] = v
end
return T
end

wrapped_carray = T <: AbstractArray && ancestor(T) <: ConcreteRArray
wrapped_tracedarray = T <: AbstractArray && ancestor(T) <: TracedRArray

subParms = []
for (i, SST) in enumerate(T.parameters)
if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive
TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers)
push!(subParms, TrT)
elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber
TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers)
push!(subParms, TrT)
else
if SST isa Type
TrT = traced_type_inner(SST, seen, mode, track_numbers)
push!(subParms, TrT)
else
push!(subParms, SST)
end
end
end

if !isempty(subParms)
TT2 = Core.apply_type(T.name.wrapper, subParms...)
else
TT2 = T
end
seen3 = copy(seen)
seen3[T] = TT2
if fieldcount(T) == fieldcount(TT2)
legal = true
for f in 1:fieldcount(T)
subT = fieldtype(T, f)
subT2 = fieldtype(TT2, f)
subTT = traced_type_inner(subT, seen3, mode, track_numbers)
if subT2 != subTT
legal = false
break
end
end
if legal
for (k, v) in seen3
seen[k] = v
end
return TT2
end
end

name = Symbol[]
throw(NoFieldMatchError(T, TT2))
end
function traced_type_inner end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{Union{}}),
Expand Down Expand Up @@ -216,28 +95,56 @@ Base.@nospecializeinfer function traced_type_inner(
return Core.apply_type(T.name.wrapper, traced_fieldtypes...)
end

@inline is_concrete_tuple(x::T2) where {T2} =
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)

Base.@nospecializeinfer function traced_type_inner(
Base.@nospecializeinfer function traced_tuple_type_inner(
@nospecialize(T::Type{<:Tuple}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
)
if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll
if T === Tuple
return T
end
if T isa UnionAll
if T.var.lb === Union{} && T.var.ub === Any
return UnionAll(T.var, traced_type_inner(T.body, seen, mode, track_numbers))
end
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters)
# Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...}
throw(AssertionError("Type tuple of vararg $T is not supported"))
end
TT = [
traced_type_inner(T.parameters[i], seen, mode, track_numbers) for
i in 1:length(T.parameters)
]
end
TT = Union{Type,Core.TypeofVararg}[]
for i in 1:length(T.parameters)
st = traced_type_inner(T.parameters[i], seen, mode, track_numbers)
push!(TT, st)
end
return Tuple{TT...}
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Core.TypeofVararg),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
)
return Vararg{traced_type_inner(T.T, seen, mode, track_numbers),T.N}
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::TypeVar), seen, mode::TraceMode, @nospecialize(track_numbers::Type)
)
if T.lb === Union{} && T.ub === Any
return T
end
throw(AssertionError("Unsupported Typevar $T lb=$(T.lb) ub=$(T.ub)"))
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:Tuple}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
)
return traced_tuple_type_inner(T, seen, mode, track_numbers)
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:NamedTuple}),
seen,
Expand Down Expand Up @@ -427,6 +334,132 @@ Base.@nospecializeinfer function traced_type_inner(
throw("Val type $(Val{T}) cannot be traced")
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type)
)
if T === Any
return T
end

if T === Union{}
return T
end

if Enzyme.Compiler.isghostty(T) || Core.Compiler.isconstType(T)
return T
end

if T == Type || T == DataType
return T
end

# unknown number of fields
if Base.inferencebarrier(T) isa UnionAll
if T.var.lb === Union{} && T.var.ub === Any
return UnionAll(T.var, traced_type_inner(T.body, seen, mode, track_numbers))
end
aT = Base.argument_datatype(T)
if isnothing(aT)
throw(TracedTypeError("Unhandled type $T"))
end
if isnothing(Base.datatype_fieldcount(aT))
throw(TracedTypeError("Unhandled type $T"))
end
end

if T isa Union
return Union{
traced_type_inner(T.a, seen, mode, track_numbers),
traced_type_inner(T.b, seen, mode, track_numbers),
}
end

# if abstract it must be by reference
if Base.isabstracttype(T)
if !(T isa UnionAll) && length(T.parameters) == 0
return T
end
throw(TracedTypeError("Unhandled abstract type $T"))
end

if T <: Tuple
return traced_tuple_type_inner(T, seen, mode, track_numbers)
end

if haskey(seen, T)
return seen[T]
end

seen2 = copy(seen)
seen2[T] = T

changed = false
subTys = Union{Type,TypeVar}[]
for f in 1:fieldcount(T)
subT = fieldtype(T, f)
subTT = traced_type_inner(subT, seen2, mode, track_numbers)
changed |= subT != subTT
push!(subTys, subTT)
end

if !changed
for (k, v) in seen2
seen[k] = v
end
return T
end

wrapped_carray = T <: AbstractArray && ancestor(T) <: ConcreteRArray
wrapped_tracedarray = T <: AbstractArray && ancestor(T) <: TracedRArray

subParms = []
for (i, SST) in enumerate(T.parameters)
if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive
TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers)
push!(subParms, TrT)
elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber
TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers)
push!(subParms, TrT)
else
if SST isa Type
TrT = traced_type_inner(SST, seen, mode, track_numbers)
push!(subParms, TrT)
else
push!(subParms, SST)
end
end
end

if !isempty(subParms)
TT2 = Core.apply_type(T.name.wrapper, subParms...)
else
TT2 = T
end
seen3 = copy(seen)
seen3[T] = TT2
if fieldcount(T) == fieldcount(TT2)
legal = true
for f in 1:fieldcount(T)
subT = fieldtype(T, f)
subT2 = fieldtype(TT2, f)
subTT = traced_type_inner(subT, seen3, mode, track_numbers)
if subT2 != subTT
legal = false
break
end
end
if legal
for (k, v) in seen3
seen[k] = v
end
return TT2
end
end

name = Symbol[]
throw(NoFieldMatchError(T, TT2))
end

const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}()

# function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type))
Expand Down
29 changes: 29 additions & 0 deletions test/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ using Reactant:
ReactantPrimitive
using Test

struct Wrapper{A,B}
a::A
b::B
end

@testset "Tracing" begin
@testset "trace_type" begin
@testset "mode = ConcreteToTraced" begin
Expand Down Expand Up @@ -138,6 +143,30 @@ using Test
Base.Pairs{Symbol,Union{}},
Base.Pairs{Symbol,Union{}},
),
(
NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D},
NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D},
NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D},
),
(Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}),
(
Wrapper{Float64,Vector{Float64}},
Wrapper{Float64,Vector{Float64}},
Wrapper{TracedRNumber{Float64},Vector{Float64}},
),
(
Wrapper{Float64,ConcreteRArray{Float64,1}},
Wrapper{Float64,TracedRArray{Float64,1}},
Wrapper{TracedRNumber{Float64},TracedRArray{Float64,1}},
),
(Wrapper{Symbol}, Wrapper{Symbol}, Wrapper{Symbol}),
(Wrapper{Float64}, Wrapper{Float64}, Wrapper{TracedRNumber{Float64}}),
(
Wrapper{ConcreteRArray{Float64,1}},
Wrapper{TracedRArray{Float64,1}},
Wrapper{TracedRArray{Float64,1}},
),
(Wrapper, Wrapper, Wrapper),
]
tracedty = traced_type(origty, Val(ConcreteToTraced), Union{})
@test tracedty == targetty
Expand Down
Loading