Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparseL extractor and use it in condVar #492

Merged
merged 10 commits into from
May 8, 2021
2 changes: 2 additions & 0 deletions src/MixedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export @formula,
coeftable,
cond,
condVar,
condVartables,
describeblocks,
deviance,
dispersion,
Expand Down Expand Up @@ -118,6 +119,7 @@ export @formula,
setθ!,
simulate!,
sparse,
sparseL,
std,
stderror,
updateL!,
Expand Down
117 changes: 109 additions & 8 deletions src/linearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,48 @@ diagonal blocks from the conditional variance-covariance matrix,
s² Λ(Λ'Z'ZΛ + I)⁻¹Λ'
"""
function condVar(m::LinearMixedModel{T}) where {T}
retrms = m.reterms
t1 = first(retrms)
L11 = first(m.L)
if !isone(length(retrms)) || !isa(L11, Diagonal{T,Vector{T}})
throw(ArgumentError("code for multiple or vector-valued r.e. not yet written"))
L = m.L
s = sdest(m)
@static if VERSION < v"1.6.1"
spL = LowerTriangular(SparseMatrixCSC{T, Int}(sparseL(m)))
else
spL = LowerTriangular(sparseL(m))
end
nre = size(spL, 1)
val = Array{T,3}[]
offset = 0
for (i, re) in enumerate(m.reterms)
λt = s * transpose(re.λ)
vi = size(λt, 2)
ℓi = length(re.levels)
vali = Array{T}(undef, (vi, vi, ℓi))
scratch = Matrix{T}(undef, (size(spL, 1), vi))
for b in 1:ℓi
fill!(scratch, zero(T))
copyto!(view(scratch, (offset + (b - 1) * vi) .+ (1:vi), :), λt)
ldiv!(spL, scratch)
mul!(view(vali, :, :, b), scratch', scratch)
end
push!(val, vali)
offset += vi * ℓi
end
ll = first(t1.λ)
Ld = L11.diag
Array{T,3}[reshape(abs2.(ll ./ Ld) .* varest(m), (1, 1, length(Ld)))]
val
end

function _cvtbl(arr::Array{T,3}, trm) where {T}
merge(
NamedTuple{(fname(trm),)}((trm.levels,)),
columntable([NamedTuple{(:σ, :ρ)}(sdcorr(view(arr, :, :, i))) for i in axes(arr, 3)]),
)
end

"""
condVartables(m::LinearMixedModel)

Return the conditional covariance matrices of the random effects as a `NamedTuple` of columntables
"""
function condVartables(m::MixedModel{T}) where {T}
NamedTuple{fnames(m)}((map(_cvtbl, condVar(m), m.reterms)...,))
end

function pushALblock!(A, L, blk)
Expand Down Expand Up @@ -910,6 +943,74 @@ end

Base.show(io::IO, m::LinearMixedModel) = Base.show(io, MIME"text/plain"(), m)

"""
_coord(A::AbstractMatrix)

Return the positions and values of the nonzeros in `A` as a
`NamedTuple{(:i, :j, :v), Tuple{Vector{Int32}, Vector{Int32}, Vector{Float64}}}`
"""
function _coord(A::Diagonal)
(i = Int32.(axes(A,1)), j = Int32.(axes(A,2)), v = A.diag)
end

function _coord(A::UniformBlockDiagonal)
dat = A.data
r, c, k = size(dat)
blk = repeat(r .* (0:k-1), inner=r*c)
(
i = Int32.(repeat(1:r, outer=c*k) .+ blk),
j = Int32.(repeat(1:c, inner=r, outer=k) .+ blk),
v = vec(dat)
)
end

function _coord(A::SparseMatrixCSC{T,Int32}) where {T}
rv = rowvals(A)
cv = similar(rv)
for j in axes(A, 2), k in nzrange(A, j)
cv[k] = j
end
(i = rv, j = cv, v = nonzeros(A), )
end

function _coord(A::Matrix)
m, n = size(A)
(
i = Int32.(repeat(axes(A, 1), outer=n)),
j = Int32.(repeat(axes(A, 2), inner=m)),
v = vec(A),
)
end

"""
sparseL(m::LinearMixedModel{T}; full::Bool=false) where {T}

Return the lower Cholesky factor `L` as a `SparseMatrix{T,Int32}`.

`full` indicates whether the parts of `L` associated with the fixed-effects and response
are to be included.
"""
function sparseL(m::LinearMixedModel{T}; full::Bool=false) where {T}
L, reterms = m.L, m.reterms
nt = length(reterms) + full
rowoffset, coloffset = 0, 0
val = (i = Int32[], j = Int32[], v = T[])
for i in 1:nt, j in 1:i
Lblk = L[block(i, j)]
cblk = _coord(Lblk)
append!(val.i, cblk.i .+ Int32(rowoffset))
append!(val.j, cblk.j .+ Int32(coloffset))
append!(val.v, cblk.v)
if i == j
coloffset = 0
rowoffset += size(Lblk, 1)
else
coloffset += size(Lblk, 2)
end
end
dropzeros!(tril!(sparse(val...,)))
end


"""
ssqdenom(m::LinearMixedModel)
Expand Down
26 changes: 25 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end

Return the average of `a` and `b`
"""
average(a::T, b::T) where {T<:AbstractFloat} = (a + b) / 2
average(a::T, b::T) where {T<:AbstractFloat} = (a + b) / T(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run into trouble with this somewhere? As an integer 2 should be promoted to whatever floating point type (a+b) is.

Copy link
Collaborator Author

@dmbates dmbates May 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there would be an issue with Float32 in that the literal 2 is an Int which will get promoted to Float64 on a 64-bit system. But I was wrong

julia> (1.0f0 + 2.0f0) / 2
1.5f0

so can go back to the original expression.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I would leave that in as it doesn't seem to be harmful in any way and it could be protection against unforeseen circumstances.


"""
cpad(s::AbstractString, n::Integer)
Expand Down Expand Up @@ -137,3 +137,27 @@ function replicate(f::Function, n::Integer;
end
results
end

"""
sdcorr(A::AbstractMatrix{T}) where {T}

Transform a square matrix `A` with positive diagonals into an `NTuple{size(A,1), T}` of
standard deviations and a tuple of correlations.

`A` is assumed to be symmetric and only the lower triangle is used. The order of the
correlations is row-major ordering of the lower triangle (or, equivalently, column-major
in the upper triangle).
"""
function sdcorr(A::AbstractMatrix{T}) where {T}
m,n = size(A)
m == n || throw(ArgumentError("matrix A must be square"))
indpairs = checkindprsk(m)
rtdiag = sqrt.(NTuple{m,T}(diag(A)))
(
rtdiag,
ntuple(kchoose2(m)) do k
i,j = indpairs[k]
A[i,j]/(rtdiag[i] * rtdiag[j])
end,
)
end
4 changes: 1 addition & 3 deletions test/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using LinearAlgebra
using MixedModels
using Random
using StableRNGs
using Statistics
using Tables
using Test

Expand Down Expand Up @@ -34,7 +35,6 @@ include("modelcache.jl")
# restore the original state
refit!(fm, vec(float.(ds.yield)))
end

@testset "Poisson" begin
center(v::AbstractVector) = v .- (sum(v) / length(v))
grouseticks = DataFrame(dataset(:grouseticks))
Expand All @@ -50,7 +50,6 @@ include("modelcache.jl")
gm2sim = refit!(simulate!(StableRNG(42), deepcopy(gm2)), fast=true)
@test isapprox(gm2.β, gm2sim.β; atol=norm(stderror(gm2)))
end

@testset "_rand with dispersion" begin
@test_throws ArgumentError MixedModels._rand(StableRNG(42), Normal(), 1, 1, 1)
@test_throws ArgumentError MixedModels._rand(StableRNG(42), Gamma(), 1, 1, 1)
Expand Down Expand Up @@ -99,7 +98,6 @@ end
@test sum(issingular(bsamp)) == sum(issingular(bsamp_threaded))
end


@testset "Bernoulli simulate! and GLMM boostrap" begin
contra = dataset(:contra)
gm0 = fit(MixedModel, only(gfms[:contra]), contra, Bernoulli(), fast=true)
Expand Down
1 change: 0 additions & 1 deletion test/likelihoodratiotest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ end
fm1 = fit(MixedModel,@formula(reaction ~ 1 + days + (1+days|subj)),slp, REML=true);

@test_throws ArgumentError likelihoodratiotest(fm0,fm1)

contra = MixedModels.dataset(:contra);
# glm doesn't like categorical responses, so we convert it to numeric ourselves
# TODO: upstream fix
Expand Down
3 changes: 0 additions & 3 deletions test/mime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using MixedModels: pirls!, setβθ!, setθ!, updateL!

include("modelcache.jl")


# explicitly setting theta for these to so that we can do exact textual comparisons
βθ = [0.1955554704948119, 0.05755412761885973, 0.3207843518569843, -1.0582595252774376,
-2.1047524824609853, -1.0549789653925743, 1.339766125847893, 0.4953047709862237]
Expand All @@ -32,7 +31,6 @@ lrt = likelihoodratiotest(fm0, fm1)
mime = MIME"text/markdown"()
@test_logs (:warn, "Model has not been fit: results will be nonsense") sprint(show, mime, gm3)
gm3.optsum.feval = 1

@testset "lmm" begin
@test sprint(show, mime, fm0) == """
| | Est. | SE | z | p | σ_subj |
Expand Down Expand Up @@ -153,7 +151,6 @@ end
@testset "html" begin
# this is minimal since we're mostly testing that dispatch works
# the stdlib actually handles most of the conversion

@test sprint(show, MIME("text/html"), BlockDescription(gm3)) == """
<table><tr><th align="left">rows</th><th align="left">subj</th><th align="left">item</th><th align="left">fixed</th></tr><tr><td align="left">316</td><td align="left">Diagonal</td><td align="left"></td><td align="left"></td></tr><tr><td align="left">24</td><td align="left">Dense</td><td align="left">Diag/Dense</td><td align="left"></td></tr><tr><td align="left">7</td><td align="left">Dense</td><td align="left">Dense</td><td align="left">Dense</td></tr></table>
"""
Expand Down
41 changes: 40 additions & 1 deletion test/pls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@ end
@test varest(fm) ≈ 0.3024263987592062 atol=0.0001
@test logdet(fm) ≈ 95.74614821367786 atol=0.001

@test_throws ArgumentError condVar(fm)
cv = condVar(fm)
@test length(cv) == 2
@test size(first(cv)) == (1, 1, 24)
@test size(last(cv)) == (1, 1, 6)
@test first(first(cv)) ≈ 0.07331320237988301 rtol=1.e-4
@test last(last(cv)) ≈ 0.04051547211287544 rtol=1.e-4

rfu = ranef(fm, uscale=true)
@test length(rfu) == 2
Expand Down Expand Up @@ -200,6 +205,13 @@ end
@test varest(fm) ≈ 0.6780020742644107 atol=0.0001
@test logdet(fm) ≈ 101.0381339953986 atol=0.001

cv = condVar(fm)
@test length(cv) == 2
@test size(first(cv)) == (1, 1, 30)
@test first(first(cv)) ≈ 1.111873335663485 rtol=1.e-4
@test size(last(cv)) == (1, 1, 10)
@test last(last(cv)) ≈ 0.850428770978789 rtol=1.e-4

show(io, BlockDescription(fm))
@test countlines(seekstart(io)) == 4
tokens = Set(split(String(take!(io)), r"\s+"))
Expand All @@ -217,6 +229,10 @@ end
@test fm1.optsum.initial == ones(3)
@test lowerbd(fm1) == zeros(3)

spL = sparseL(fm1)
@test size(spL) == (4114, 4114)
@test 733090 < nnz(spL) < 733100
dmbates marked this conversation as resolved.
Show resolved Hide resolved

@test objective(fm1) ≈ 237721.7687745563 atol=0.001
ftd1 = fitted(fm1);
@test size(ftd1) == (73421, )
Expand Down Expand Up @@ -291,6 +307,29 @@ end
@test size(first(u3)) == (2, 18)
@test first(u3)[1, 1] ≈ 3.030300122575336 atol=0.001

cv = condVar(fm)
@test length(cv) == 1
@test size(first(cv)) == (2, 2, 18)
@test first(first(cv)) ≈ 140.96612241084617 rtol=1.e-4
@test last(last(cv)) ≈ 5.157750215432247 rtol=1.e-4
@test first(cv)[2] ≈ -20.60428045516186 rtol=1.e-4

cvt = condVartables(fm)
@test length(cvt) == 1
@test only(keys(cvt)) == :subj
cvtsubj = cvt.subj
@test only(cvt) === cvtsubj
@test keys(cvtsubj) == (:subj, :σ, :ρ)
@test Tables.istable(cvtsubj)
@test first(cvtsubj.subj) == "S308"
cvtsubjσ1 = first(cvtsubj.σ)
@test all(==(cvtsubjσ1), cvtsubj.σ)
@test first(cvtsubjσ1) ≈ 11.87291549750297 atol=1.0e-4
@test last(cvtsubjσ1) ≈ 2.271068078114843 atol=1.0e-4
cvtsubjρ = first(cvtsubj.ρ)
@test all(==(cvtsubjρ), cvtsubj.ρ)
@test only(cvtsubjρ) ≈ -0.7641347018831385 atol=1.0e-4

b3 = ranef(fm)
@test length(b3) == 1
@test size(first(b3)) == (2, 18)
Expand Down