Skip to content

Commit

Permalink
add Grouping contrasts (#339)
Browse files Browse the repository at this point in the history
* add Grouping contrasts

* export and add a getproperty patch to work around statsmodels

* warn users to use Grouping when OOM during schema call

* Update src/grouping.jl

Co-authored-by: Phillip Alday <phillip.alday@mpi.nl>
  • Loading branch information
kleinschmidt and palday committed Jul 15, 2020
1 parent 6fabb48 commit 2f9d423
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/MixedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export @formula,
BlockedSparse,
DummyCoding,
EffectsCoding,
Grouping,
Gamma,
GeneralizedLinearMixedModel,
HelmertCoding,
Expand Down Expand Up @@ -155,5 +156,6 @@ include("linalg.jl")
include("simulate.jl")
include("bootstrap.jl")
include("blockdescription.jl")
include("grouping.jl")

end # module
34 changes: 34 additions & 0 deletions src/grouping.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

"""
struct Grouping <: StatsModels.AbstractContrasts end
A placeholder type to indicate that a categorical variable is only used for
grouping and not for contrasts. When creating a [`CategoricalTerm`](@ref), this
skips constructing the contrasts matrix which makes it robust to large numbers
of levels, while still holding onto the vector of levels and constructing the
level-to-index mapping (`invindex` field of the [`ContrastsMatrix`](@ref).).
Note that calling `modelcols` on a `CategoricalTerm{Grouping}` is an error.
# Examples
```julia
julia> schema((; grp = string.(1:100_000)))
# out-of-memory error
julia> schema((; grp = string.(1:100_000)), Dict(:grp => Grouping()))
"""
struct Grouping <: StatsModels.AbstractContrasts
end

# return an empty matrix
StatsModels.contrasts_matrix(::Grouping, baseind, n) = zeros(0,0)
StatsModels.termnames(::Grouping, levels::AbstractVector, baseind::Integer) = levels

# this is needed until StatsModels stops assuming all contrasts have a .levels field
Base.getproperty(g::Grouping, prop::Symbol) =
prop == :levels ? nothing : getfield(g, prop)

# special-case categorical terms with Grouping contrasts.
StatsModels.modelcols(::CategoricalTerm{Grouping}, d::NamedTuple) =
error("can't create model columns directly from a Grouping term")
10 changes: 9 additions & 1 deletion src/linearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@ function LinearMixedModel(
# TODO: perform missing_omit() after apply_schema() when improved
# missing support is in a StatsModels release
tbl, _ = StatsModels.missing_omit(tbl, f)
form = apply_schema(f, schema(f, tbl, contrasts), LinearMixedModel)
sch = try
schema(f, tbl, contrasts)
catch e
if isa(e, OutOfMemoryError)
@warn "Random effects grouping variables with many levels can cause out-of-memory errors. Try manually specifying `Grouping()` contrasts for those variables."
end
rethrow(e)
end
form = apply_schema(f, sch, LinearMixedModel)
# tbl, _ = StatsModels.missing_omit(tbl, form)

y, Xs = modelcols(form, tbl)
Expand Down
21 changes: 21 additions & 0 deletions test/grouping.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Test
using StatsModels

@testset "Grouping pseudo-contrasts" begin
d = (y = rand(2_000_000), grp=string.([1:1_000_000; 1:1_000_000]))
## OOM seems to result in the process being killed on Mac so this messes up CI
# @test_throws OutOfMemoryError schema(d)
sch = schema(d, Dict(:grp => Grouping()))
t = sch[term(:grp)]
@test t isa CategoricalTerm{Grouping}
@test size(t.contrasts.matrix) == (0,0)
@test length(t.contrasts.levels) == 1_000_000

levs = sort(string.(1:1_000_000))

@test all(t.contrasts.invindex[lev] == i for (i,lev) in enumerate(levs))
@test all(t.contrasts.levels[i] == lev for (i,lev) in enumerate(levs))

# @test_throws OutOfMemoryError fit(MixedModel, @formula(y ~ 1 + (1 | grp)), d)
@test fit(MixedModel, @formula(y ~ 1 + (1 | grp)), d, contrasts=Dict(:grp => Grouping())) isa LinearMixedModel
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ include("gausshermite.jl")
include("fit.jl")
include("missing.jl")
include("likelihoodratiotest.jl")
include("grouping.jl")

0 comments on commit 2f9d423

Please sign in to comment.