Skip to content

Commit

Permalink
fail on broadcast addition with scalar
Browse files Browse the repository at this point in the history
closes #416
closes #417
  • Loading branch information
baggepinnen committed Jul 6, 2022
1 parent 5146ac0 commit d366c69
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
3 changes: 1 addition & 2 deletions src/matrix_comps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ end
"""
`sysr, G, T = balreal(sys::StateSpace)`
Calculates a balanced realization of the system sys, such that the observability and reachability gramians of the balanced system are equal and `diagm(G)`. `T` is the similarity transform between the old state `x` and the new state `z` such that `Tz = x`.
Calculates a balanced realization of the system sys, such that the observability and reachability gramians of the balanced system are equal and diagonal `diagm(G)`. `T` is the similarity transform between the old state `x` and the new state `z` such that `Tz = x`.
See also `gram`, `baltrunc`
Expand Down Expand Up @@ -548,7 +548,6 @@ For more advanced model reduction, see [RobustAndOptimalControl.jl - Model Reduc
"""
function baltrunc(sys::ST; atol = sqrt(eps()), rtol = 1e-3, n = nothing, residual=false) where ST <: AbstractStateSpace
sysbal, S, T = balreal(sys)
S = diag(S)
if n === nothing
S = S[S .>= atol]
S = S[S .>= S[1]*rtol]
Expand Down
19 changes: 11 additions & 8 deletions src/types/StateSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ function +(s1::ST, s2::ST) where {ST <: AbstractStateSpace}
end


+(sys::ST, n::Number) where ST <: AbstractStateSpace = basetype(ST)(sys.A, sys.B, sys.C, sys.D .+ n, sys.timeevol)
function +(sys::ST, n::Number) where ST <: AbstractStateSpace
issiso(sys) || throw(DimensionMismatch("Numbers and systems can only be added for SISO systems."))
basetype(ST)(sys.A, sys.B, sys.C, sys.D .+ n, sys.timeevol)
end
+(n::Number, sys::ST) where ST <: AbstractStateSpace = +(sys, n)

## SUBTRACTION ##
Expand All @@ -265,9 +268,9 @@ function *(sys1::ST, sys2::ST) where {ST <: AbstractStateSpace}
#Check dimension alignment
#Note: sys1*sys2 = y <- sys1 <- sys2 <- u
if (sys1.nu != sys2.ny) && (sys1.nu == 1 || sys2.ny == 1)
error("sys1*sys2: sys1 must have same number of inputs as sys2 has outputs. If you want to broadcast a scalar system to a diagonal system, use broadcasted multiplication sys1 .* sys2")
throw(DimensionMismatch("sys1*sys2: sys1 must have same number of inputs as sys2 has outputs. If you want to broadcast a scalar system to a diagonal system, use broadcasted multiplication sys1 .* sys2"))
end
sys1.nu == sys2.ny || error("sys1*sys2: sys1 must have same number of inputs as sys2 has outputs")
sys1.nu == sys2.ny || throw(DimensionMismatch("sys1*sys2: sys1 must have same number of inputs as sys2 has outputs"))
timeevol = common_timeevol(sys1,sys2)
T = promote_type(numeric_type(sys1), numeric_type(sys2))

Expand All @@ -280,7 +283,7 @@ function *(sys1::ST, sys2::ST) where {ST <: AbstractStateSpace}
end

function Base.Broadcast.broadcasted(::typeof(*), sys1::AbstractStateSpace, sys2::AbstractStateSpace)
issiso(sys1) || issiso(sys2) || error("Only SISO statespace systems can be broadcasted")
issiso(sys1) || issiso(sys2) || throw(DimensionMismatch("Only SISO statespace systems can be broadcasted"))
if issiso(sys1) && !issiso(sys2) # Check !issiso(sys2) to avoid calling fill if both are siso
sys1 = append(sys1 for i in 1:sys2.ny)
elseif issiso(sys2)
Expand All @@ -290,18 +293,18 @@ function Base.Broadcast.broadcasted(::typeof(*), sys1::AbstractStateSpace, sys2:
end

function Base.Broadcast.broadcasted(::typeof(*), sys1::ST, M::AbstractArray) where {ST <: AbstractStateSpace}
LinearAlgebra.isdiag(M) || error("Broadcasting multiplication of an LTI system with an array is only supported for diagonal arrays. If you want the system to behave like a scalar and multiply each element of the array, wrap the system in a `Ref` to indicate this, i.e., `Ref(sys) .* array`. See also function `array2mimo`.")
LinearAlgebra.isdiag(M) || throw(DimensionMismatch("Broadcasting multiplication of an LTI system with an array is only supported for diagonal arrays. If you want the system to behave like a scalar and multiply each element of the array, wrap the system in a `Ref` to indicate this, i.e., `Ref(sys) .* array`. See also function `array2mimo`."))
sys1 .* ss(M, sys1.timeevol) # If diagonal, broadcast by replicating input channels
end

function Base.Broadcast.broadcasted(::typeof(*), M::AbstractArray, sys1::ST) where {ST <: AbstractStateSpace}
LinearAlgebra.isdiag(M) || error("Broadcasting multiplication of an LTI system with an array is only supported for diagonal arrays. If you want the system to behave like a scalar and multiply each element of the array, wrap the system in a `Ref` to indicate this, i.e., `array .* Ref(sys)`. See also function `array2mimo`.")
LinearAlgebra.isdiag(M) || throw(DimensionMismatch("Broadcasting multiplication of an LTI system with an array is only supported for diagonal arrays. If you want the system to behave like a scalar and multiply each element of the array, wrap the system in a `Ref` to indicate this, i.e., `array .* Ref(sys)`. See also function `array2mimo`."))
ss(M, sys1.timeevol) .* sys1 # If diagonal, broadcast by replicating output channels
end

function Base.Broadcast.broadcasted(::typeof(*), sys1::Base.RefValue{ST}, M::AbstractArray) where {ST <: AbstractStateSpace}
sys1 = sys1[]
issiso(sys1) || error("Only SISO statespace systems can be broadcasted")
issiso(sys1) || throw(DimensionMismatch("Only SISO statespace systems can be broadcasted"))
T = promote_type(numeric_type(sys1), eltype(M))
A,B,C,D = ssdata(sys1)
nx = sys1.nx
Expand All @@ -320,7 +323,7 @@ end

function Base.Broadcast.broadcasted(::typeof(*), M::AbstractArray, sys1::Base.RefValue{ST}) where {ST <: AbstractStateSpace}
sys1 = sys1[]
issiso(sys1) || error("Only SISO statespace systems can be broadcasted")
issiso(sys1) || throw(DimensionMismatch("Only SISO statespace systems can be broadcasted"))
T = promote_type(numeric_type(sys1), eltype(M))
A,B,C,D = ssdata(sys1)
nx = sys1.nx
Expand Down
4 changes: 2 additions & 2 deletions test/test_matrix_comps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ D = 0
sys = ss(A,B,C,D)
sysr, G = balreal(sys)

@test gram(sysr, :c) G
@test gram(sysr, :o) G
@test gram(sysr, :c) diagm(G)
@test gram(sysr, :o) diagm(G)
@test sort(poles(sysr)) sort(poles(sys))

sysb,T = ControlSystems.balance_statespace(sys)
Expand Down
13 changes: 8 additions & 5 deletions test/test_statespace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
@test C_111 + C_111 == SS([-5 0; 0 -5],[2; 2],[3 3],[0])
@test C_222 + C_222 == SS([-5 -3 0 0; 2 -9 0 0; 0 0 -5 -3;
0 0 2 -9],[1 0; 0 2; 1 0; 0 2], [1 0 1 0; 0 1 0 1],[0 0; 0 0])
@test C_222 + 1 == SS([-5 -3; 2 -9],[1 0; 0 2],[1 0; 0 1],[1 1; 1 1])
@test_throws DimensionMismatch C_222 + 1
@test D_111 + D_111 == SS([-0.5 0; 0 -0.5],[2; 2],[3 3],[0], 0.005)

@inferred C_111 + C_111
Expand All @@ -63,12 +63,15 @@

@inferred C_111 + false

@test C_222 + 1.5 == 1.0C_222 + 1.5 # C_222 has eltype Int
@test 1.5 + C_222 == 1.0C_222 + 1.5
@test_throws DimensionMismatch C_222 + 1.5
@test_throws DimensionMismatch 1.5 + C_222

@test_throws DimensionMismatch ss(-ones(2,2), ones(2,2), ones(2,2), zeros(2,2)) + 1
@test ss(-ones(2,2), ones(2,2), ones(2,2), zeros(2,2)) + ones(2, 2) == ss(-ones(2,2), ones(2,2), ones(2,2), ones(2,2))

# Subtraction
@test C_111 - C_211 == SS([-5 0 0; 0 -5 -3; 0 2 -9],[2; 1; 2],[3 -1 -0],[0])
@test 1 - C_222 == SS([-5 -3; 2 -9],[1 0; 0 2],[-1 -0; -0 -1],[1 1; 1 1])
@test_throws DimensionMismatch 1 - C_222
@test D_111 - D_211 == SS([-0.5 0 0; 0 0.2 -0.8; 0 -0.8 0.07],[2; 1; 2],
[3 -1 -0],[0], 0.005)

Expand All @@ -91,7 +94,7 @@
C_111_d = ssrand(1,1,2)
M = ones(2,2)

@test_throws ErrorException C_111_d.*M # We do not allow broadcasting with non-diagonal matrices https://github.com/JuliaControl/ControlSystems.jl/issues/416
@test_throws DimensionMismatch C_111_d.*M # We do not allow broadcasting with non-diagonal matrices https://github.com/JuliaControl/ControlSystems.jl/issues/416
# Unless we wrap the system in a Ref to indicate that we really want it to broadcast like a scalar

@test Ref(C_111_d).*M == [C_111_d C_111_d; C_111_d C_111_d]
Expand Down

0 comments on commit d366c69

Please sign in to comment.