Skip to content

Commit

Permalink
Merge pull request #1993 from JuliaOpt/bl/uniform_scaling
Browse files Browse the repository at this point in the history
Implement algebra with UniformScaling
  • Loading branch information
blegat committed Jun 28, 2019
2 parents 4fdad2b + 425f007 commit 8e7b117
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 68 deletions.
84 changes: 50 additions & 34 deletions src/operators.jl
Expand Up @@ -9,52 +9,58 @@
#############################################################################

const _JuMPTypes = Union{AbstractJuMPScalar, NonlinearExpression}
const Constant = Union{Number, UniformScaling}
_float(x::Number) = convert(Float64, x)
_float(J::UniformScaling) = _float(J.λ)

# Overloads
#
# Different objects that must all interact:
# 1. Number
# 1. Constant
# 2. AbstractVariableRef
# 4. GenericAffExpr
# 5. GenericQuadExpr

# Number
# Number--Number obviously already taken care of!
# Number--VariableRef
Base.:+(lhs::Number, rhs::AbstractVariableRef) = GenericAffExpr(convert(Float64, lhs), rhs => 1.0)
Base.:-(lhs::Number, rhs::AbstractVariableRef) = GenericAffExpr(convert(Float64, lhs), rhs => -1.0)
Base.:*(lhs::Number, rhs::AbstractVariableRef) = GenericAffExpr(0.0, rhs => convert(Float64,lhs))
# Number--GenericAffExpr
function Base.:+(lhs::Number, rhs::GenericAffExpr)
# Constant
# Constant--Constant obviously already taken care of!
# Constant--VariableRef
Base.:+(lhs::Constant, rhs::AbstractVariableRef) = GenericAffExpr(_float(lhs), rhs => 1.0)
Base.:-(lhs::Constant, rhs::AbstractVariableRef) = GenericAffExpr(_float(lhs), rhs => -1.0)
Base.:*(lhs::Constant, rhs::AbstractVariableRef) = GenericAffExpr(0.0, rhs => _float(lhs))
# Constant--GenericAffExpr
function Base.:+(lhs::Constant, rhs::GenericAffExpr)
result = copy(rhs)
result.constant += lhs
return result
end
function Base.:-(lhs::Number, rhs::GenericAffExpr)
function Base.:-(lhs::Constant, rhs::GenericAffExpr)
result = -rhs
result.constant += lhs
return result
end
Base.:*(lhs::Number, rhs::GenericAffExpr) = map_coefficients(c -> lhs * c, rhs)
# Number--QuadExpr
Base.:+(lhs::Number, rhs::GenericQuadExpr) = GenericQuadExpr(lhs+rhs.aff, copy(rhs.terms))
function Base.:-(lhs::Number, rhs::GenericQuadExpr)
function Base.:*(lhs::Constant, rhs::GenericAffExpr)
f = _float(lhs)
return map_coefficients(c -> f * c, rhs)
end
# Constant--QuadExpr
Base.:+(lhs::Constant, rhs::GenericQuadExpr) = GenericQuadExpr(lhs+rhs.aff, copy(rhs.terms))
function Base.:-(lhs::Constant, rhs::GenericQuadExpr)
result = -rhs
result.aff.constant += lhs
return result
end
Base.:*(lhs::Number, rhs::GenericQuadExpr) = map_coefficients(c -> lhs * c, rhs)
Base.:*(lhs::Constant, rhs::GenericQuadExpr) = map_coefficients(c -> lhs * c, rhs)

# AbstractVariableRef (or, AbstractJuMPScalar)
# TODO: What is the role of AbstractJuMPScalar??
Base.:+(lhs::AbstractJuMPScalar) = lhs
Base.:-(lhs::AbstractVariableRef) = GenericAffExpr(0.0, lhs => -1.0)
Base.:*(lhs::AbstractJuMPScalar) = lhs # make this more generic so extensions don't have to define unary multiplication for our macros
# AbstractVariableRef--Number
Base.:+(lhs::AbstractVariableRef, rhs::Number) = (+)( rhs,lhs)
Base.:-(lhs::AbstractVariableRef, rhs::Number) = (+)(-rhs,lhs)
Base.:*(lhs::AbstractVariableRef, rhs::Number) = (*)(rhs,lhs)
Base.:/(lhs::AbstractVariableRef, rhs::Number) = (*)(1.0/rhs,lhs)
# AbstractVariableRef--Constant
Base.:+(lhs::AbstractVariableRef, rhs::Constant) = (+)( rhs,lhs)
Base.:-(lhs::AbstractVariableRef, rhs::Constant) = (+)(-rhs,lhs)
Base.:*(lhs::AbstractVariableRef, rhs::Constant) = (*)(rhs,lhs)
Base.:/(lhs::AbstractVariableRef, rhs::Constant) = (*)(1.0/rhs,lhs)
# AbstractVariableRef--AbstractVariableRef
Base.:+(lhs::V, rhs::V) where {V <: AbstractVariableRef} = GenericAffExpr(0.0, lhs => 1.0, rhs => 1.0)
Base.:-(lhs::V, rhs::V) where {V <: AbstractVariableRef} = GenericAffExpr(0.0, lhs => 1.0, rhs => -1.0)
Expand Down Expand Up @@ -111,11 +117,11 @@ end
# GenericAffExpr
Base.:+(lhs::GenericAffExpr) = lhs
Base.:-(lhs::GenericAffExpr) = map_coefficients(-, lhs)
# GenericAffExpr--Number
Base.:+(lhs::GenericAffExpr, rhs::Number) = (+)(+rhs,lhs)
Base.:-(lhs::GenericAffExpr, rhs::Number) = (+)(-rhs,lhs)
Base.:*(lhs::GenericAffExpr, rhs::Number) = (*)(rhs,lhs)
Base.:/(lhs::GenericAffExpr, rhs::Number) = map_coefficients(c -> c/rhs, lhs)
# GenericAffExpr--Constant
Base.:+(lhs::GenericAffExpr, rhs::Constant) = (+)(rhs,lhs)
Base.:-(lhs::GenericAffExpr, rhs::Constant) = (+)(-rhs,lhs)
Base.:*(lhs::GenericAffExpr, rhs::Constant) = (*)(rhs,lhs)
Base.:/(lhs::GenericAffExpr, rhs::Constant) = map_coefficients(c -> c/rhs, lhs)
function Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Integer)
if rhs == 2
return lhs*lhs
Expand All @@ -127,7 +133,7 @@ function Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Integer)
error("Only exponents of 0, 1, or 2 are currently supported. Are you trying to build a nonlinear problem? Make sure you use @NLconstraint/@NLobjective.")
end
end
Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Number) = error("Only exponents of 0, 1, or 2 are currently supported. Are you trying to build a nonlinear problem? Make sure you use @NLconstraint/@NLobjective.")
Base.:^(lhs::Union{AbstractVariableRef, GenericAffExpr}, rhs::Constant) = error("Only exponents of 0, 1, or 2 are currently supported. Are you trying to build a nonlinear problem? Make sure you use @NLconstraint/@NLobjective.")
# GenericAffExpr--AbstractVariableRef
function Base.:+(lhs::GenericAffExpr{C, V}, rhs::V) where {C, V <: AbstractVariableRef}
return add_to_expression!(copy(lhs), one(C), rhs)
Expand Down Expand Up @@ -187,11 +193,11 @@ end
# GenericQuadExpr
Base.:+(lhs::GenericQuadExpr) = lhs
Base.:-(lhs::GenericQuadExpr) = map_coefficients(-, lhs)
# GenericQuadExpr--Number
Base.:+(lhs::GenericQuadExpr, rhs::Number) = (+)(+rhs,lhs)
Base.:-(lhs::GenericQuadExpr, rhs::Number) = (+)(-rhs,lhs)
Base.:*(lhs::GenericQuadExpr, rhs::Number) = (*)(rhs,lhs)
Base.:/(lhs::GenericQuadExpr, rhs::Number) = (*)(inv(rhs),lhs)
# GenericQuadExpr--Constant
Base.:+(lhs::GenericQuadExpr, rhs::Constant) = (+)(+rhs,lhs)
Base.:-(lhs::GenericQuadExpr, rhs::Constant) = (+)(-rhs,lhs)
Base.:*(lhs::GenericQuadExpr, rhs::Constant) = (*)(rhs,lhs)
Base.:/(lhs::GenericQuadExpr, rhs::Constant) = (*)(inv(rhs),lhs)
# GenericQuadExpr--AbstractVariableRef
Base.:+(q::GenericQuadExpr, v::AbstractVariableRef) = GenericQuadExpr(q.aff+v, copy(q.terms))
Base.:-(q::GenericQuadExpr, v::AbstractVariableRef) = GenericQuadExpr(q.aff-v, copy(q.terms))
Expand Down Expand Up @@ -271,8 +277,8 @@ end
# for scalars, so instead of defining them one-by-one, we will
# fallback to the multiplication operator
LinearAlgebra.dot(lhs::_JuMPTypes, rhs::_JuMPTypes) = lhs*rhs
LinearAlgebra.dot(lhs::_JuMPTypes, rhs::Number) = lhs*rhs
LinearAlgebra.dot(lhs::Number, rhs::_JuMPTypes) = lhs*rhs
LinearAlgebra.dot(lhs::_JuMPTypes, rhs::Constant) = lhs*rhs
LinearAlgebra.dot(lhs::Constant, rhs::_JuMPTypes) = lhs*rhs

LinearAlgebra.dot(lhs::AbstractVector{T}, rhs::AbstractVector{S}) where {T <: _JuMPTypes, S <: _JuMPTypes} = _dot(lhs,rhs)
LinearAlgebra.dot(lhs::AbstractVector{T}, rhs::AbstractVector{S}) where {T <: _JuMPTypes, S} = _dot(lhs,rhs)
Expand Down Expand Up @@ -606,6 +612,16 @@ function Base.:-(x::AbstractArray{T}) where {T <: _JuMPTypes}
return ret
end

# Fix https://github.com/JuliaLang/julia/issues/32374 as done in
# https://github.com/JuliaLang/julia/pull/32375. This hack should
# be removed once we drop Julia v1.0.
function Base.:-(A::Symmetric{<:JuMP.AbstractVariableRef})
return Symmetric(-A.data, LinearAlgebra.sym_uplo(A.uplo))
end
function Base.:-(A::Hermitian{<:JuMP.AbstractVariableRef})
return Hermitian(-A.data, LinearAlgebra.sym_uplo(A.uplo))
end

###############################################################################
# nonlinear function fallbacks for JuMP built-in types
###############################################################################
Expand All @@ -625,6 +641,6 @@ Base.:*(lhs::GenericQuadExpr, rhs::GenericQuadExpr) =
Base.:*(::S, ::T) where {T <: GenericQuadExpr,
S <: Union{AbstractVariableRef, GenericAffExpr, GenericQuadExpr}} =
error( "*(::$S,::$T) is not defined. $op_hint")
Base.:/(::S, ::T) where {S <: Union{Number, AbstractVariableRef, GenericAffExpr, GenericQuadExpr},
Base.:/(::S, ::T) where {S <: Union{Constant, AbstractVariableRef, GenericAffExpr, GenericQuadExpr},
T <: Union{AbstractVariableRef, GenericAffExpr, GenericQuadExpr}} =
error( "/(::$S,::$T) is not defined. $op_hint")

0 comments on commit 8e7b117

Please sign in to comment.