Skip to content

Commit 688c5c3

Browse files
committed
Got rid of mult and div forwarding. Fixes #33.
1 parent 9f55130 commit 688c5c3

File tree

4 files changed

+69
-62
lines changed

4 files changed

+69
-62
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"]
4-
version = "0.6.3"
4+
version = "0.6.4"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/broadcasting.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ getdata(x::Adjoint) = getdata(x.parent)'
8888
getdata(x::Transpose) = transpose(getdata(x.parent))
8989

9090

91-
# function Base.similar(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}, args...) where {InnerStyle, Axes, N}
92-
# return ComponentArray{Axes}(similar(BC.Broadcasted{InnerStyle}(bc.f, map(getdata, bc.args), bc.axes), args...))
93-
# end
9491
function Base.similar(bc::BC.Broadcasted{<:CAStyle{InnerStyle, Axes, N}}, args...) where {InnerStyle, Axes, N}
9592
return ComponentArray{Axes}(similar(BC.Broadcasted{InnerStyle}(bc.f, bc.args, bc.axes), args...))
9693
end

src/math.jl

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,6 @@ Base.pointer(x::ComponentArray{T,N,<:DenseArray,Axes}) where {T,N,Axes} = pointe
33

44
Base.unsafe_convert(::Type{Ptr{T}}, x::ComponentArray{T,N,<:DenseArray,Axes}) where {T,N,Axes} = Base.unsafe_convert(Ptr{T}, getdata(x))
55

6-
# Avoid slower fallback
7-
for f in [:(*), :(/), :(\)]
8-
@eval begin
9-
# The normal stuff
10-
Base.$f(x::ComponentArray, y::AbstractArray) = $f(getdata(x), y)
11-
Base.$f(x::AbstractArray, y::ComponentArray) = $f(x, getdata(y))
12-
Base.$f(x::ComponentArray, y::ComponentArray) = $f(getdata(x), getdata(y))
13-
14-
# A bunch of special cases to avoid ambiguous method errors
15-
Base.$f(x::ComponentArray, y::AbstractMatrix) = $f(getdata(x), y)
16-
Base.$f(x::AbstractMatrix, y::ComponentArray) = $f(x, getdata(y))
17-
18-
Base.$f(x::ComponentArray, y::AbstractVector) = $f(getdata(x), y)
19-
Base.$f(x::AbstractVector, y::ComponentArray) = $f(x, getdata(y))
20-
end
21-
end
22-
23-
# Adjoint/transpose special cases
24-
for f in [:(*), :(/)]
25-
@eval begin
26-
Base.$f(x::Adjoint, y::ComponentArray) = $f(getdata(x), getdata(y))
27-
Base.$f(x::Transpose, y::ComponentArray) = $f(getdata(x), getdata(y))
28-
29-
Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
30-
Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
31-
32-
Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
33-
Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
34-
35-
Base.$f(x::Adjoint{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
36-
Base.$f(x::Transpose{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
37-
38-
Base.$f(x::ComponentArray, y::Adjoint{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
39-
Base.$f(x::ComponentArray, y::Transpose{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
40-
41-
Base.$f(x::ComponentArray, y::Adjoint{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
42-
Base.$f(x::ComponentArray, y::Transpose{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
43-
44-
# There seems to be a new method in Julia > v.1.4 that specializes on this
45-
Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
46-
Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
47-
48-
Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
49-
Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
50-
end
51-
end
526

537
#TODO: All this stuff
548
LinearAlgebra.ldiv!(Y::Union{AbstractMatrix, AbstractVector}, A::Factorization, b::Union{ComponentMatrix, ComponentVector}) =
@@ -69,4 +23,56 @@ ArrayInterface.lu_instance(jac_prototype::ComponentArray) = ArrayInterface.lu_in
6923

7024
## Vector to matrix concatenation
7125
Base.hcat(x::ComponentVector...) = ComponentArray(hcat(getdata.(x)...), getaxes(x[1])[1], FlatAxis())
72-
Base.vcat(x::AdjOrTransComponentArray...) = ComponentArray(vcat(map(y->getdata(y.parent)', x)...), getaxes(x[1]))
26+
Base.vcat(x::AdjOrTransComponentArray...) = ComponentArray(vcat(map(y->getdata(y.parent)', x)...), getaxes(x[1]))
27+
28+
29+
# While there are some cases where these were faster, it is going to be almost impossible to
30+
# to keep up with method ambiguity errors due to other array types overloading *, /, and \.
31+
# Leaving these here and commented out for now, but will delete them later.
32+
33+
# # Avoid slower fallback
34+
# for f in [:(*), :(/), :(\)]
35+
# @eval begin
36+
# # The normal stuff
37+
# Base.$f(x::ComponentArray, y::AbstractArray) = $f(getdata(x), y)
38+
# Base.$f(x::AbstractArray, y::ComponentArray) = $f(x, getdata(y))
39+
# Base.$f(x::ComponentArray, y::ComponentArray) = $f(getdata(x), getdata(y))
40+
41+
# # A bunch of special cases to avoid ambiguous method errors
42+
# Base.$f(x::ComponentArray, y::AbstractMatrix) = $f(getdata(x), y)
43+
# Base.$f(x::AbstractMatrix, y::ComponentArray) = $f(x, getdata(y))
44+
45+
# Base.$f(x::ComponentArray, y::AbstractVector) = $f(getdata(x), y)
46+
# Base.$f(x::AbstractVector, y::ComponentArray) = $f(x, getdata(y))
47+
# end
48+
# end
49+
50+
# # Adjoint/transpose special cases
51+
# for f in [:(*), :(/)]
52+
# @eval begin
53+
# Base.$f(x::Adjoint, y::ComponentArray) = $f(getdata(x), getdata(y))
54+
# Base.$f(x::Transpose, y::ComponentArray) = $f(getdata(x), getdata(y))
55+
56+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
57+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector) where T = $f(x, getdata(y))
58+
59+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
60+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentMatrix{T,A,Axes}) where {T,A,Axes} = $f(x, getdata(y))
61+
62+
# Base.$f(x::Adjoint{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
63+
# Base.$f(x::Transpose{T,<:AbstractMatrix{T}}, y::ComponentVector) where {T} = $f(x, getdata(y))
64+
65+
# Base.$f(x::ComponentArray, y::Adjoint{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
66+
# Base.$f(x::ComponentArray, y::Transpose{T,<:AbstractVector{T}}) where T = $f(getdata(x), y)
67+
68+
# Base.$f(x::ComponentArray, y::Adjoint{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
69+
# Base.$f(x::ComponentArray, y::Transpose{T,<:ComponentVector}) where T = $f(getdata(x), getdata(y))
70+
71+
# # There seems to be a new method in Julia > v.1.4 that specializes on this
72+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
73+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Number,A,Axes} = $f(x, getdata(y))
74+
75+
# Base.$f(x::Adjoint{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
76+
# Base.$f(x::Transpose{T,<:AbstractVector{T}}, y::ComponentVector{T,A,Axes}) where {T<:Real,A,Axes} = $f(getdata(x), getdata(y))
77+
# end
78+
# end

test/runtests.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,27 +227,26 @@ end
227227
a_t = collect(a')
228228

229229
@test all(zeros(cmat) * ca .== zeros(ca))
230-
@test ones(cmat) * ca isa Vector
230+
@test ones(cmat) * ca isa AbstractVector
231231
@test ca * ca' == collect(cmat)
232232
@test ca * ca' == a * a'
233233
@test ca' * ca == a' * a
234234
@test cmat * ca == cmat * a
235-
@test cmat' * ca isa Array
235+
@test cmat' * ca isa AbstractArray
236236
@test a' * ca isa Number
237237
@test cmat'' == cmat
238238
@test ca'' == ca
239239
@test ca.c' * cmat[:c,:c] * ca.c isa Number
240240
@test ca * 1 isa ComponentVector
241241
@test size(ca' * 1) == size(ca')
242242
@test a' * ca isa Number
243-
@test a_t * ca isa Array
243+
@test a_t * ca isa AbstractArray
244244
@test a' * cmat isa Adjoint
245-
@test a_t * cmat isa Array
246-
@test cmat * ca isa Vector
245+
@test a_t * cmat isa AbstractArray
246+
@test cmat * ca isa AbstractVector
247247
@test ca + ca + ca isa typeof(ca)
248248
@test a + ca + ca isa typeof(ca)
249249
@test a*ca' isa AbstractMatrix
250-
251250

252251
@test ca * transpose(ca) == collect(cmat)
253252
@test ca * transpose(ca) == a * transpose(a)
@@ -265,22 +264,27 @@ end
265264
temp .= (cmat+I) \ ca
266265
@test temp isa ComponentArray
267266
@test (ca' / (cmat'+I))' == (cmat+I) \ ca
268-
@test cmat * ((cmat+I) \ ca) isa Array
269-
@test inv(cmat+I) isa Array
267+
@test cmat * ((cmat+I) \ ca) isa AbstractArray
268+
@test inv(cmat+I) isa AbstractArray
270269

271270
tempmat = deepcopy(cmat)
272271

273272
@test ldiv!(temp, lu(cmat+I), ca) isa ComponentVector
274-
@test ldiv!(getdata(temp), lu(cmat+I), ca) isa Vector
273+
@test ldiv!(getdata(temp), lu(cmat+I), ca) isa AbstractVector
275274
@test ldiv!(tempmat, lu(cmat+I), cmat) isa ComponentMatrix
276-
@test ldiv!(getdata(tempmat), lu(cmat+I), cmat) isa Matrix
275+
@test ldiv!(getdata(tempmat), lu(cmat+I), cmat) isa AbstractMatrix
277276

278277
vca2 = vcat(ca2', ca2')
279278
hca2 = hcat(ca2, ca2)
280279
@test all(vca2[1,:] .== ca2)
281280
@test all(hca2[:,1] .== ca2)
282281
@test all(vca2' .== hca2)
283282
@test hca2[:a,:] == vca2[:,:a]
283+
284+
# Issue #33
285+
smat = @SMatrix [1 2; 3 4]
286+
b = ComponentArray(a = 1, b = 2)
287+
@test smat*b isa StaticArray
284288
end
285289

286290
@testset "Plot Utilities" begin
@@ -304,7 +308,7 @@ end
304308
@test label2index(ca2, "c.b") == collect(11:14)
305309
end
306310

307-
@testset "Issues" begin
311+
@testset "Uncategorized Issues" begin
308312
# Issue #25
309313
@test sum(abs2, cmat) == sum(abs2, getdata(cmat))
310314
end

0 commit comments

Comments
 (0)