From 89d4611456ad6f0ddcfb4aceb8ccedcc30db5aaf Mon Sep 17 00:00:00 2001 From: Abel Soares Siqueira Date: Thu, 5 Sep 2019 09:14:29 -0300 Subject: [PATCH 1/2] Fix hcat and vcat for variables arguments Closes #97 --- src/LinearOperators.jl | 92 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 8 deletions(-) diff --git a/src/LinearOperators.jl b/src/LinearOperators.jl index 38fea2c6..e2c0a1bf 100644 --- a/src/LinearOperators.jl +++ b/src/LinearOperators.jl @@ -576,7 +576,6 @@ function opDiagonal(nrow :: Int, ncol :: Int, d :: AbstractVector{T}) where T LinearOperator{T,F1,F2,F3}(nrow, ncol, false, false, prod, tprod, ctprod) end - hcat(A :: AbstractLinearOperator, B :: AbstractMatrix) = hcat(A, LinearOperator(B)) hcat(A :: AbstractMatrix, B :: AbstractLinearOperator) = hcat(LinearOperator(A), B) @@ -600,12 +599,50 @@ end function hcat(ops :: OperatorOrMatrix...) op = ops[1] - for i = 2:length(ops) - op = [op ops[i]] + nops = length(ops) + nrow = size(op, 1) + for i = 2:nops + size(ops[i], 1) == nrow || throw(LinearOperatorException("hcat: inconsistent row sizes")) + end + ncol = sum(size(ops[i], 2) for i = 1:nops) + S = promote_type([eltype(op) for op in ops]...) + + function prod(v) + Av = zeros(S, nrow) + k = 0 + for i = 1:nops + s = size(ops[i], 2) + Av .+= ops[i] * v[k+1:k+s] + k += s + end + return Av + end + function tprod(v) + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 2) + Atv[k+1:k+s] .= transpose(ops[i]) * v + k += s + end + return Atv + end + function ctprod(v) + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 2) + Atv[k+1:k+s] .= ops[i]' * v + k += s + end + return Atv end - return op -end + F1 = typeof(prod) + F2 = typeof(tprod) + F3 = typeof(ctprod) + return LinearOperator{S,F1,F2,F3}(nrow, ncol, false, false, prod, tprod, ctprod) +end vcat(A :: AbstractLinearOperator, B :: AbstractMatrix) = vcat(A, LinearOperator(B)) @@ -630,10 +667,49 @@ end function vcat(ops :: OperatorOrMatrix...) op = ops[1] - for i = 2:length(ops) - op = [op; ops[i]] + nops = length(ops) + ncol = size(op, 2) + for i = 2:nops + size(ops[i], 2) == ncol || throw(LinearOperatorException("hcat: inconsistent row sizes")) + end + nrow = sum(size(ops[i], 1) for i = 1:nops) + S = promote_type([eltype(op) for op in ops]...) + + function prod(v) + Av = zeros(S, nrow) + k = 0 + for i = 1:nops + s = size(ops[i], 1) + Av[k+1:k+s] .= ops[i] * v + k += s + end + return Av + end + function tprod(v) + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 1) + Atv .+= transpose(ops[i]) * v[k+1:k+s] + k += s + end + return Atv end - return op + function ctprod(v) + Atv = zeros(S, ncol) + k = 0 + for i = 1:nops + s = size(ops[i], 1) + Atv .+= transpose(ops[i]) * v[k+1:k+s] + k += s + end + return Atv + end + + F1 = typeof(prod) + F2 = typeof(tprod) + F3 = typeof(ctprod) + return LinearOperator{S,F1,F2,F3}(nrow, ncol, false, false, prod, tprod, ctprod) end # Removed by https://github.com/JuliaLang/julia/pull/24017 From 4052165bc8139c5ed74aed6dda50abd6482ab43e Mon Sep 17 00:00:00 2001 From: Abel Soares Siqueira Date: Tue, 17 Sep 2019 17:43:16 -0300 Subject: [PATCH 2/2] Improve type stability of hcat(ops...) --- src/LinearOperators.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/LinearOperators.jl b/src/LinearOperators.jl index e2c0a1bf..ee225af6 100644 --- a/src/LinearOperators.jl +++ b/src/LinearOperators.jl @@ -348,7 +348,6 @@ end -(op :: AbstractLinearOperator, x :: Number) = op + (-x) -(x :: Number, op :: AbstractLinearOperator) = x + (-op) - # Utility functions. """ @@ -598,16 +597,18 @@ function hcat(A :: AbstractLinearOperator, B :: AbstractLinearOperator) end function hcat(ops :: OperatorOrMatrix...) - op = ops[1] nops = length(ops) - nrow = size(op, 1) + nrow = size(ops[1], 1) for i = 2:nops size(ops[i], 1) == nrow || throw(LinearOperatorException("hcat: inconsistent row sizes")) end ncol = sum(size(ops[i], 2) for i = 1:nops) - S = promote_type([eltype(op) for op in ops]...) + S = eltype(ops[1]) + for i = 2:nops + S = promote_type(S, eltype(ops[i])) + end - function prod(v) + prod = @closure v -> begin Av = zeros(S, nrow) k = 0 for i = 1:nops @@ -685,7 +686,7 @@ function vcat(ops :: OperatorOrMatrix...) end return Av end - function tprod(v) + tprod = @closure v -> begin Atv = zeros(S, ncol) k = 0 for i = 1:nops @@ -695,7 +696,7 @@ function vcat(ops :: OperatorOrMatrix...) end return Atv end - function ctprod(v) + ctprod = @closure v -> begin Atv = zeros(S, ncol) k = 0 for i = 1:nops