Skip to content

Commit

Permalink
Fix broadcast with RowVectors of matrices (#20980)
Browse files Browse the repository at this point in the history
Fix #20979. Amusingly, this bug is a direct result of `transpose` being recursive.
  • Loading branch information
mbauman committed Mar 12, 2017
1 parent f0aedc6 commit 3777aab
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
9 changes: 5 additions & 4 deletions base/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,17 @@ end
@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false

# helper function for below
@inline to_vec(rowvec::RowVector) = transpose(rowvec)
@inline to_vec(rowvec::RowVector) = map(transpose, transpose(rowvec))
@inline to_vec(x::Number) = x
@inline to_vecs(rowvecs...) = (map(to_vec, rowvecs)...)

# map
@inline map(f, rowvecs::RowVector...) = RowVector(map(f, to_vecs(rowvecs...)...))
# map: Preserve the RowVector by un-wrapping and re-wrapping, but note that `f`
# expects to operate within the transposed domain, so to_vec transposes the elements
@inline map(f, rowvecs::RowVector...) = RowVector(map(transposef, to_vecs(rowvecs...)...))

# broacast (other combinations default to higher-dimensional array)
@inline broadcast(f, rowvecs::Union{Number,RowVector}...) =
RowVector(broadcast(f, to_vecs(rowvecs...)...))
RowVector(broadcast(transposef, to_vecs(rowvecs...)...))

# Horizontal concatenation #

Expand Down
17 changes: 16 additions & 1 deletion test/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,21 @@ end
end
end

@testset "issue #20979" begin
f20979(z::Complex) = [z.re -z.im; z.im z.re]
v = [1+2im]'
@test (f20979.(v))[1] == f20979(v[1])
@test f20979.(v) == f20979.(collect(v))

w = rand(Complex128, 3)
@test f20979.(v') == f20979.(collect(v')) == (f20979.(v))'

g20979(x, y) = [x[2,1] x[1,2]; y[1,2] y[2,1]]
v = [rand(2,2), rand(2,2), rand(2,2)]
@test g20979.(v', v') == g20979.(collect(v'), collect(v')) ==
map(g20979, v', v') == map(g20979, collect(v'), collect(v'))
end

@testset "ambiguity between * methods with RowVectors and ConjRowVectors (#20971)" begin
@test RowVector(ConjArray(ones(4))) * ones(4) == 4
end
end

0 comments on commit 3777aab

Please sign in to comment.