Skip to content

Commit

Permalink
Merge pull request #40 from JuliaOpt/bl/promote_adjoint_transpose
Browse files Browse the repository at this point in the history
Fix promotion with adjoint and transpose
  • Loading branch information
blegat committed Feb 21, 2020
2 parents 82a6d9b + 739bcc5 commit 57cf8d4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
40 changes: 27 additions & 13 deletions src/Test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,35 @@ function matrix_vector_division_test(x)
end

function _xAx_test(x::AbstractVector, A::AbstractMatrix)
@test_rewrite(x' * A)
# Complex expression
@test_rewrite(x' * ones(Int, size(A)...))
@test_rewrite(x' * A * x)
# Complex expression
@test_rewrite(x' * ones(Int, size(A)...) * x)
@test_rewrite reshape(x, (1, length(x))) * A * x .- 1
@test_rewrite x' * A * x .- 1
@test_rewrite x' * A * x - 1
for t in [transpose, adjoint]
@test_rewrite(t(x) * A)
# Complex expression
@test_rewrite(t(x) * ones(Int, size(A)...))
@test_rewrite(t(x) * A * x)
# Complex expression
@test_rewrite(t(x) * ones(Int, size(A)...) * x)
@test_rewrite reshape(x, (1, length(x))) * A * x .- 1
@test_rewrite t(x) * A * x .- 1
@test_rewrite t(x) * A * x - 1
@test_rewrite t(x) * x + t(x) * A * x
@test_rewrite t(x) * x - t(x) * A * x
@test MA.promote_operation(*, typeof(t(x)), typeof(A), typeof(x)) == typeof(t(x) * A * x)
@test MA.promote_operation(*, typeof(t(x)), typeof(x)) == typeof(t(x) * x)
@test_rewrite t(x) * x + 2 * t(x) * A * x
@test_rewrite t(x) * x - 2 * t(x) * A * x
@test_rewrite t(x) * A * x + 2 * t(x) * x
@test_rewrite t(x) * A * x - 2 * t(x) * x
@test MA.promote_operation(*, Int, typeof(t(x)), typeof(A), typeof(x)) == typeof(2 * t(x) * A * x)
@test MA.promote_operation(*, Int, typeof(t(x)), typeof(x)) == typeof(2 * t(x) * x)
end
end
function _xABx_test(x::AbstractVector, A::AbstractMatrix, B::AbstractMatrix)
@test_rewrite (x'A)' + 2B * x
@test_rewrite (x'A)' + 2B * x .- 1
@test_rewrite (x'A)' + 2B * x .- [length(x):-1:1;]
@test_rewrite (x'A)' + 2B * x - [length(x):-1:1;]
for t in [transpose, adjoint]
@test_rewrite t(t(x) * A) + 2B * x
@test_rewrite t(t(x) * A) + 2B * x .- 1
@test_rewrite t(t(x) * A) + 2B * x .- [length(x):-1:1;]
@test_rewrite t(t(x) * A) + 2B * x - [length(x):-1:1;]
end
end

function _matrix_vector_test(x::AbstractVector, A::AbstractMatrix)
Expand Down
7 changes: 7 additions & 0 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,17 @@ end
const TransposeOrAdjoint{T, MT} = Union{LinearAlgebra.Transpose{T, MT}, LinearAlgebra.Adjoint{T, MT}}
_mirror_transpose_or_adjoint(x, ::LinearAlgebra.Transpose) = LinearAlgebra.transpose(x)
_mirror_transpose_or_adjoint(x, ::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(x)
_mirror_transpose_or_adjoint(A::Type{<:AbstractArray{T}}, ::Type{<:LinearAlgebra.Transpose}) where {T} = LinearAlgebra.Transpose{T, A}
_mirror_transpose_or_adjoint(A::Type{<:AbstractArray{T}}, ::Type{<:LinearAlgebra.Adjoint}) where {T} = LinearAlgebra.Adjoint{T, A}
similar_array_type(TA::Type{<:TransposeOrAdjoint{T, A}}, ::Type{S}) where {S, T, A} = _mirror_transpose_or_adjoint(similar_array_type(A, S), TA)
# dot product
function promote_array_mul(::Type{<:TransposeOrAdjoint{S, <:AbstractVector}}, ::Type{<:AbstractVector{T}}) where {S, T}
return promote_sum_mul(S, T)
end
function promote_array_mul(A::Type{<:TransposeOrAdjoint{S, V}}, M::Type{<:AbstractMatrix{T}}) where {S, T, V <: AbstractVector}
B = promote_array_mul(_mirror_transpose_or_adjoint(M, A), V)
return _mirror_transpose_or_adjoint(B, A)
end
function operate(::typeof(*), x::LinearAlgebra.Adjoint{<:Any, <:AbstractVector}, y::AbstractVector)
return operate(LinearAlgebra.dot, parent(x), y)
end
Expand Down
4 changes: 2 additions & 2 deletions test/dummy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ MA.scaling(x::DummyBigInt) = x

MA.mutable_operate_to!(x::DummyBigInt, op::Function, args::Union{MA.Scaling, DummyBigInt}...) = DummyBigInt(MA.mutable_operate_to!(x.data, op, _data.(args)...))
# Called for instance if `args` is `(v', v)` for a vector `v`.
MA.mutable_operate_to!(output::DummyBigInt, op::typeof(MA.add_mul), x::Union{MA.Scaling, DummyBigInt}, y::Union{MA.Scaling, DummyBigInt}, z::Union{MA.Scaling, DummyBigInt}, args::Union{MA.Scaling, DummyBigInt}...) = MA.mutable_operate_to!(output, +, x, *(y, z, args...))
MA.mutable_operate_to!(output::DummyBigInt, op::typeof(MA.add_mul), x, y, z, args...) = MA.mutable_operate_to!(output, +, x, *(y, z, args...))
MA.mutable_operate_to!(output::DummyBigInt, op::MA.AddSubMul, x::Union{MA.Scaling, DummyBigInt}, y::Union{MA.Scaling, DummyBigInt}, z::Union{MA.Scaling, DummyBigInt}, args::Union{MA.Scaling, DummyBigInt}...) = MA.mutable_operate_to!(output, MA.add_sub_op(op), x, *(y, z, args...))
MA.mutable_operate_to!(output::DummyBigInt, op::MA.AddSubMul, x, y, z, args...) = MA.mutable_operate_to!(output, MA.add_sub_op(op), x, *(y, z, args...))
function MA.mutable_operate!(op::Function, x::DummyBigInt, args::Vararg{Any, N}) where N
MA.mutable_operate_to!(x, op, x, args...)
end
Expand Down

0 comments on commit 57cf8d4

Please sign in to comment.