diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 0a4e218e0a..08d962a58f 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -299,7 +299,8 @@ function __lookup_unique_name_in_module(mod, name) new_name = i == 0 ? name : name * "_" * string(i) MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, new_name)) && return new_name end - return error("Could not find unique name for $name") + modstr = string(mod) + return error("Mod\n$modstr\nCould not find unique name for $name") end function __take_region(compiled_fn) diff --git a/src/Tracing.jl b/src/Tracing.jl index a439aa93f7..9ffa30aaf4 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -291,6 +291,41 @@ Base.@nospecializeinfer function traced_type_inner( throw("XLA $T array cannot be traced") end +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{AbstractArray}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) + return A +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{AbstractArray{T}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) where {T} + if mode == ConcreteToTraced + return AbstractArray{TracedRNumber{T}} + else + return A + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(A::Type{AbstractArray{T,N}}), + seen, + mode::TraceMode, + @nospecialize(track_numbers::Type) +) where {T,N} + if mode == ConcreteToTraced + return AbstractArray{TracedRNumber{T},N} + else + return A + end +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(A::Type{<:Array}), seen, @@ -298,11 +333,19 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(track_numbers::Type) ) T = eltype(A) - N = ndims(A) - if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive - return ConcreteRArray{T,N} + if A isa UnionAll + if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive + return ConcreteRArray{T} + else + return Array{traced_type_inner(T, seen, mode, track_numbers)} + end else - return Array{traced_type_inner(T, seen, mode, track_numbers),N} + N = ndims(A) + if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive + return ConcreteRArray{T,N} + else + return Array{traced_type_inner(T, seen, mode, track_numbers),N} + end end end @@ -365,6 +408,7 @@ Base.@nospecializeinfer function traced_type_inner( if isnothing(Base.datatype_fieldcount(aT)) throw(TracedTypeError("Unhandled type $T")) end + return T end if T isa Union @@ -457,7 +501,7 @@ Base.@nospecializeinfer function traced_type_inner( end name = Symbol[] - throw(NoFieldMatchError(T, TT2)) + throw(NoFieldMatchError(T, TT2, subTys)) end const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() @@ -580,13 +624,18 @@ end struct NoFieldMatchError <: TracedTypeException origty besteffort + subTys end function Base.showerror(io::IO, err::NoFieldMatchError) - print(io, "NoFieldMatchError: ") - return print( + println(io, "NoFieldMatchError: ") + println( io, "Cannot convert type $(err.origty), best attempt $(err.besteffort) failed.\nThis could be because the type does not capture the fieldtypes that should be converted in its type parameters.", ) + for (i, subty) in zip(1:fieldcount(err.origty), err.subTys) + origty = fieldtype(err.origty, i) + println(io, "idx=", i, " Derived: ", subty, " Existing: ", origty) + end end function make_tracer(