diff --git a/src/LinearOperators.jl b/src/LinearOperators.jl index 38fea2c6..ee225af6 100644 --- a/src/LinearOperators.jl +++ b/src/LinearOperators.jl @@ -348,7 +348,6 @@ end -(op :: AbstractLinearOperator, x :: Number) = op + (-x) -(x :: Number, op :: AbstractLinearOperator) = x + (-op) - # Utility functions. """ @@ -576,7 +575,6 @@ function opDiagonal(nrow :: Int, ncol :: Int, d :: AbstractVector{T}) where T LinearOperator{T,F1,F2,F3}(nrow, ncol, false, false, prod, tprod, ctprod) end - hcat(A :: AbstractLinearOperator, B :: AbstractMatrix) = hcat(A, LinearOperator(B)) hcat(A :: AbstractMatrix, B :: AbstractLinearOperator) = hcat(LinearOperator(A), B) @@ -599,13 +597,53 @@ function hcat(A :: AbstractLinearOperator, B :: AbstractLinearOperator) end function hcat(ops :: OperatorOrMatrix...) - op = ops[1] - for i = 2:length(ops) - op = [op ops[i]] + nops = length(ops) + nrow = size(ops[1], 1) + for i = 2:nops + size(ops[i], 1) == nrow || throw(LinearOperatorException("hcat: inconsistent row sizes")) + end + ncol = sum(size(ops[i], 2) for i = 1:nops) + S = eltype(ops[1]) + for i = 2:nops + S = promote_type(S, eltype(ops[i])) end - return op -end + prod = @closure v -> begin + Av = zeros(S, nrow) + k = 0 + for i = 1:nops + s = size(ops[i], 2) + Av .+= ops[i] * v[k+1:k+s] + k += s + end + return Av + end + function tprod(v) + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 2) + Atv[k+1:k+s] .= transpose(ops[i]) * v + k += s + end + return Atv + end + function ctprod(v) + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 2) + Atv[k+1:k+s] .= ops[i]' * v + k += s + end + return Atv + end + + F1 = typeof(prod) + F2 = typeof(tprod) + F3 = typeof(ctprod) + return LinearOperator{S,F1,F2,F3}(nrow, ncol, false, false, prod, tprod, ctprod) +end vcat(A :: AbstractLinearOperator, B :: AbstractMatrix) = vcat(A, LinearOperator(B)) @@ -630,10 +668,49 @@ end function vcat(ops :: OperatorOrMatrix...) op = ops[1] - for i = 2:length(ops) - op = [op; ops[i]] + nops = length(ops) + ncol = size(op, 2) + for i = 2:nops + size(ops[i], 2) == ncol || throw(LinearOperatorException("hcat: inconsistent row sizes")) + end + nrow = sum(size(ops[i], 1) for i = 1:nops) + S = promote_type([eltype(op) for op in ops]...) + + function prod(v) + Av = zeros(S, nrow) + k = 0 + for i = 1:nops + s = size(ops[i], 1) + Av[k+1:k+s] .= ops[i] * v + k += s + end + return Av + end + tprod = @closure v -> begin + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 1) + Atv .+= transpose(ops[i]) * v[k+1:k+s] + k += s + end + return Atv + end + ctprod = @closure v -> begin + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 1) + Atv .+= transpose(ops[i]) * v[k+1:k+s] + k += s + end + return Atv end - return op + + F1 = typeof(prod) + F2 = typeof(tprod) + F3 = typeof(ctprod) + return LinearOperator{S,F1,F2,F3}(nrow, ncol, false, false, prod, tprod, ctprod) end # Removed by https://github.com/JuliaLang/julia/pull/24017