diff --git a/src/Tracing.jl b/src/Tracing.jl index 79e22a6511..45dad6426c 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -446,7 +446,7 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(A::Type{AbstractArray{T}}), + A::Type{AbstractArray{T}}, seen, mode::TraceMode, @nospecialize(track_numbers::Type), @@ -455,7 +455,7 @@ Base.@nospecializeinfer function traced_type_inner( ) where {T} if mode == ConcreteToTraced return AbstractArray{ - traced_type_inner(T, seen, mode, track_numbers, sharding, runtime) + traced_type_inner(eltype(A), seen, mode, track_numbers, sharding, runtime) } else return A @@ -463,7 +463,7 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(A::Type{AbstractArray{T,N}}), + A::Type{AbstractArray{T,N}}, seen, mode::TraceMode, @nospecialize(track_numbers::Type), @@ -472,7 +472,8 @@ Base.@nospecializeinfer function traced_type_inner( ) where {T,N} if mode == ConcreteToTraced return AbstractArray{ - traced_type_inner(T, seen, mode, track_numbers, sharding, runtime),N + traced_type_inner(eltype(A), seen, mode, track_numbers, sharding, runtime), + ndims(A), } else return A