diff --git a/src/linalg/rankUpdate.jl b/src/linalg/rankUpdate.jl index a75a8590a..03bc1bbb9 100644 --- a/src/linalg/rankUpdate.jl +++ b/src/linalg/rankUpdate.jl @@ -60,9 +60,11 @@ function rankUpdate!(α::T, A::SparseMatrixCSC{T}, C::Diagonal{T}) where T <: Nu rv = rowvals(A) for j in 1:n nzr = nzrange(A, j) - length(nzr) == 1 || throw(ArgumentError("A*A' has off-diagonal elements")) - k = nzr[1] - @inbounds dd[rv[k]] += α * abs2(nz[k]) + if !isempty(nzr) + length(nzr) == 1 || throw(ArgumentError("A*A' has off-diagonal elements")) + k = nzr[1] + @inbounds dd[rv[k]] += α * abs2(nz[k]) + end end C end diff --git a/src/modelterms.jl b/src/modelterms.jl index 85a3469a2..5ac3b2153 100644 --- a/src/modelterms.jl +++ b/src/modelterms.jl @@ -470,6 +470,24 @@ function Base.Ac_mul_B(A::ScalarFactorReTerm{T,V,R}, end end +function Base.Ac_mul_B(A::VectorFactorReTerm, B::ScalarFactorReTerm) + nzeros = copy(A.wtz) + k, n = size(nzeros) + rowind = Matrix{Int32}(k, n) + refs = A.f.refs + bwtz = B.wtz + for j in 1:n + bwtzj = bwtz[j] + offset = k * (refs[j] - 1) + for i in 1:k + rowind[i, j] = i + offset + nzeros[i, j] *= bwtzj + end + end + sparse(vec(rowind), Vector{Int32}(repeat(B.f.refs, inner=k)), vec(nzeros), + k * nlevs(A), nlevs(B)) +end + function Base.Ac_mul_B(A::VectorFactorReTerm{T}, B::VectorFactorReTerm{T}) where T if A === B l = vsize(A) diff --git a/src/pls.jl b/src/pls.jl index e8e1f6e30..3cf8a39fc 100644 --- a/src/pls.jl +++ b/src/pls.jl @@ -5,10 +5,12 @@ Convert sparse `S` to `Diagonal` if `S` is diagonal or to `full(S)` if the proportion of nonzeros exceeds `threshold`. """ function densify(S::SparseMatrixCSC, threshold::Real = 0.3) + dropzeros!(S) m, n = size(S) if m == n && isdiag(S) # convert diagonal sparse to Diagonal Diagonal(diag(S)) - elseif nnz(S)/(*(size(S)...)) ≤ threshold # very sparse matrices left as is + elseif nnz(S)/(*(size(S)...)) ≤ threshold || # very sparse matrices left as is + all(d -> iszero(d) || d == 1, diff(S.colptr)) S else full(S) diff --git a/test/FactorReTerm.jl b/test/FactorReTerm.jl index f7d5f3618..020728909 100644 --- a/test/FactorReTerm.jl +++ b/test/FactorReTerm.jl @@ -116,6 +116,10 @@ end @test isa(vrp, UniformBlockDiagonal{Float64}) @test size(vrp) == (36, 36) + scl = ScalarFactorReTerm(slp[:G], Array(slp[:U]), Array(slp[:U]), :G, ["U"], 1.0) + + @test sparse(corr)'sparse(scl) == corr'scl + @test MixedModels.Λ_mul_B!(Vector{Float64}(36), corr, ones(36)) == repeat([0.5, 1.0], outer=18) @testset "reweight!" begin