Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ function __lookup_unique_name_in_module(mod, name)
new_name = i == 0 ? name : name * "_" * string(i)
MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, new_name)) && return new_name
end
return error("Could not find unique name for $name")
modstr = string(mod)
return error("Mod\n$modstr\nCould not find unique name for $name")
end

function __take_region(compiled_fn)
Expand Down
63 changes: 56 additions & 7 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,61 @@ Base.@nospecializeinfer function traced_type_inner(
throw("XLA $T array cannot be traced")
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(A::Type{AbstractArray}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
)
return A
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(A::Type{AbstractArray{T}}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
) where {T}
if mode == ConcreteToTraced
return AbstractArray{TracedRNumber{T}}
else
return A
end
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(A::Type{AbstractArray{T,N}}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
) where {T,N}
if mode == ConcreteToTraced
return AbstractArray{TracedRNumber{T},N}
else
return A
end
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(A::Type{<:Array}),
seen,
mode::TraceMode,
@nospecialize(track_numbers::Type)
)
T = eltype(A)
N = ndims(A)
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
return ConcreteRArray{T,N}
if A isa UnionAll
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
return ConcreteRArray{T}
else
return Array{traced_type_inner(T, seen, mode, track_numbers)}
end
else
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
N = ndims(A)
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
return ConcreteRArray{T,N}
else
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
end
end
end

Expand Down Expand Up @@ -365,6 +408,7 @@ Base.@nospecializeinfer function traced_type_inner(
if isnothing(Base.datatype_fieldcount(aT))
throw(TracedTypeError("Unhandled type $T"))
end
return T
end

if T isa Union
Expand Down Expand Up @@ -457,7 +501,7 @@ Base.@nospecializeinfer function traced_type_inner(
end

name = Symbol[]
throw(NoFieldMatchError(T, TT2))
throw(NoFieldMatchError(T, TT2, subTys))
end

const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}()
Expand Down Expand Up @@ -580,13 +624,18 @@ end
struct NoFieldMatchError <: TracedTypeException
origty
besteffort
subTys
end
function Base.showerror(io::IO, err::NoFieldMatchError)
print(io, "NoFieldMatchError: ")
return print(
println(io, "NoFieldMatchError: ")
println(
io,
"Cannot convert type $(err.origty), best attempt $(err.besteffort) failed.\nThis could be because the type does not capture the fieldtypes that should be converted in its type parameters.",
)
for (i, subty) in zip(1:fieldcount(err.origty), err.subTys)
origty = fieldtype(err.origty, i)
println(io, "idx=", i, " Derived: ", subty, " Existing: ", origty)
end
end

function make_tracer(
Expand Down
Loading