diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index 39027a92..ccc78377 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -4,7 +4,9 @@ const CombinedCombinedAnyDims = Tuple{<:CombinedAxis, <:CombinedAxis, Vararg{<:C # Similar Base.similar(x::ComponentArray) = ComponentArray(similar(getdata(x)), getaxes(x)...) -Base.similar(x::ComponentArray, ::Type{T}) where T = ComponentArray(similar(getdata(x), T), getaxes(x)...) +Base.similar(x::ComponentArray, ::Type{T}) where {T} = ComponentArray(similar(getdata(x), T), getaxes(x)...) +Base.similar(x::ComponentArray, dims::Vararg{Int}) = similar(getdata(x), dims...) +Base.similar(x::ComponentArray, ::Type{T}, dims::Vararg{Int}) where {T} = similar(getdata(x), T, dims...) Base.similar(x::AbstractArray, dims::CombinedAnyDims) = _similar(x, dims) Base.similar(x::AbstractArray, dims::AnyCombinedAnyDims) = _similar(x, dims) Base.similar(x::AbstractArray, dims::CombinedCombinedAnyDims) = _similar(x, dims) diff --git a/test/Project.toml b/test/Project.toml index 3d5d8bc1..7ac6831f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index 5a1a67f9..3188eafd 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -21,4 +21,7 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4)) @test mapreduce(abs2, +, jlca) == 30 @test all(map(sin, jlca) .== sin.(jlca) .== sin.(jla) .≈ sin.(1:4)) + + # Issue #179 + @test similar(jlca, 5) isa typeof(jla) end diff --git a/test/runtests.jl b/test/runtests.jl index c1b58c49..4531fb86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -344,10 +344,13 @@ end end @testset "Similar" begin - @test typeof(similar(ca)) == typeof(ca) - @test typeof(similar(ca2)) == typeof(ca2) - @test typeof(similar(ca, Float32)) == typeof(ca_Float32) + @test similar(ca) isa typeof(ca) + @test similar(ca2) isa typeof(ca2) + @test similar(ca, Float32) isa typeof(ca_Float32) @test eltype(similar(ca, ForwardDiff.Dual)) == ForwardDiff.Dual + @test similar(ca, 5) isa typeof(getdata(ca)) + @test similar(ca, Float32, 5) isa typeof(getdata(ca_Float32)) + @test similar(cmat, 5, 5) isa typeof(getdata(cmat)) end @testset "Copy" begin