Skip to content

Commit

Permalink
Fix #988
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Mar 8, 2017
1 parent 078a040 commit 197c519
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ end
_multiply!{T<:JuMPTypes}(ret::AbstractArray{T}, lhs::SparseMatrixCSC, rhs::SparseMatrixCSC) = _multiply!(ret, lhs, full(rhs))
_multiplyt!{T<:JuMPTypes}(ret::AbstractArray{T}, lhs::SparseMatrixCSC, rhs::SparseMatrixCSC) = _multiplyt!(ret, lhs, full(rhs))

_multiply!(ret, lhs, rhs) = A_mul_B!(ret, lhs, ret)
_multiply!(ret, lhs, rhs) = A_mul_B!(ret, lhs, rhs)
_multiplyt!(ret, lhs, rhs) = At_mul_B!(ret, lhs, rhs)

(*){T<:JuMPTypes}( A::Union{Matrix{T},SparseMatrixCSC{T}}, x::Union{Matrix, Vector, SparseMatrixCSC}) = _matmul(A, x)
(*){T<:JuMPTypes,R<:JuMPTypes}(A::Union{Matrix{T},SparseMatrixCSC{T}}, x::Union{Matrix{R},Vector{R},SparseMatrixCSC{R}}) = _matmul(A, x)
Expand Down Expand Up @@ -479,14 +480,14 @@ function _matmult(A, x)
ret
end

_multiply_type(R,S) = typeof(one(R) * one(S))

# See https://github.com/JuliaLang/julia/pull/18218
_matprod_type(R, S) = typeof(one(R) * one(S) + one(R) * one(S))
# Don't do size checks here in _return_array, defer that to (*)
_return_array{R,S}(A::AbstractMatrix{R}, x::AbstractVector{S}) = _fillwithzeros(Array{_multiply_type(R,S)}(size(A,1)))
_return_array{R,S}(A::AbstractMatrix{R}, x::AbstractMatrix{S}) = _fillwithzeros(Array{_multiply_type(R,S)}(size(A,1), size(x,2)))
_return_array{R,S}(A::AbstractMatrix{R}, x::AbstractVector{S}) = _fillwithzeros(Array{_matprod_type(R,S)}(size(A,1)))
_return_array{R,S}(A::AbstractMatrix{R}, x::AbstractMatrix{S}) = _fillwithzeros(Array{_matprod_type(R,S)}(size(A,1), size(x,2)))
# these are for transpose return matrices
_return_arrayt{R,S}(A::AbstractMatrix{R}, x::AbstractVector{S}) = _fillwithzeros(Array{_multiply_type(R,S)}(size(A,2)))
_return_arrayt{R,S}(A::AbstractMatrix{R}, x::AbstractMatrix{S}) = _fillwithzeros(Array{_multiply_type(R,S)}(size(A,2), size(x, 2)))
_return_arrayt{R,S}(A::AbstractMatrix{R}, x::AbstractVector{S}) = _fillwithzeros(Array{_matprod_type(R,S)}(size(A,2)))
_return_arrayt{R,S}(A::AbstractMatrix{R}, x::AbstractMatrix{S}) = _fillwithzeros(Array{_matprod_type(R,S)}(size(A,2), size(x, 2)))

# helper so we don't fill the buffer array with the same object
function _fillwithzeros{T}(arr::AbstractArray{T})
Expand Down
39 changes: 39 additions & 0 deletions test/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ const eq = JuMP.repl[:eq]
const Vert = JuMP.repl[:Vert]
const sub2 = JuMP.repl[:sub2]

# For "DimensionMismatch when performing vector-matrix multiplication with custom types #988"
import Base: +, *, one, zero, transpose
immutable MyType{T}
a::T
end
immutable MySumType{T}
a::T
end
Base.one{T}(::Type{MyType{T}}) = MyType(one(T))
Base.zero{T}(::Type{MySumType{T}}) = MySumType(zero(T))
Base.zero{T}(::MySumType{T}) = MySumType(zero(T))
Base.transpose(t::MyType) = MyType(t.a)
Base.transpose(t::MySumType) = MySumType(t.a)
+{MyT<:Union{MyType, MySumType}, MyS<:Union{MyType, MySumType}}(t1::MyT, t2::MyS) = MySumType(t1.a+t2.a)
*{S, T}(t1::MyType{S}, t2::T) = MyType(t1.a*t2)
*{S, T}(t1::S, t2::MyType{T}) = MyType(t1*t2.a)
*{S, T}(t1::MyType{S}, t2::MyType{T}) = MyType(t1.a*t2.a)


@testset "Operator overloads" begin
Expand Down Expand Up @@ -858,4 +875,26 @@ const sub2 = JuMP.repl[:sub2]
@test norm(x).terms == norm(x2).terms
end
end

@testset "DimensionMismatch when performing vector-matrix multiplication with custom types #988" begin
m = Model()
@variable m Q[1:3, 1:3] SDP

x = [MyType(1), MyType(2), MyType(3)]
y = Q * x
z = x' * Q
@test typeof(y) == Vector{MySumType{JuMP.GenericAffExpr{Float64,JuMP.Variable}}}
@test size(y) == (3,)
@test typeof(z) == Matrix{MySumType{JuMP.GenericAffExpr{Float64,JuMP.Variable}}}
@test size(z) == (1, 3)
for i in 1:3
# Q is symmetric
a = zero(JuMP.GenericAffExpr{Float64,JuMP.Variable})
a += Q[1,i]
a += 2Q[2,i]
a += 3Q[3,i]
# Q[1,i] + 2Q[2,i] + 3Q[3,i] is rearranged as 2 Q[2,3] + Q[1,3] + 3 Q[3,3]
@test z[i].a == y[i].a == a
end
end
end

0 comments on commit 197c519

Please sign in to comment.