-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add Grouping contrasts * export and add a getproperty patch to work around statsmodels * warn users to use Grouping when OOM during schema call * Update src/grouping.jl Co-authored-by: Phillip Alday <phillip.alday@mpi.nl>
- Loading branch information
1 parent
6fabb48
commit 2f9d423
Showing
5 changed files
with
67 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
|
||
""" | ||
struct Grouping <: StatsModels.AbstractContrasts end | ||
A placeholder type to indicate that a categorical variable is only used for | ||
grouping and not for contrasts. When creating a [`CategoricalTerm`](@ref), this | ||
skips constructing the contrasts matrix which makes it robust to large numbers | ||
of levels, while still holding onto the vector of levels and constructing the | ||
level-to-index mapping (`invindex` field of the [`ContrastsMatrix`](@ref).). | ||
Note that calling `modelcols` on a `CategoricalTerm{Grouping}` is an error. | ||
# Examples | ||
```julia | ||
julia> schema((; grp = string.(1:100_000))) | ||
# out-of-memory error | ||
julia> schema((; grp = string.(1:100_000)), Dict(:grp => Grouping())) | ||
""" | ||
struct Grouping <: StatsModels.AbstractContrasts | ||
end | ||
|
||
# return an empty matrix | ||
StatsModels.contrasts_matrix(::Grouping, baseind, n) = zeros(0,0) | ||
StatsModels.termnames(::Grouping, levels::AbstractVector, baseind::Integer) = levels | ||
|
||
# this is needed until StatsModels stops assuming all contrasts have a .levels field | ||
Base.getproperty(g::Grouping, prop::Symbol) = | ||
prop == :levels ? nothing : getfield(g, prop) | ||
|
||
# special-case categorical terms with Grouping contrasts. | ||
StatsModels.modelcols(::CategoricalTerm{Grouping}, d::NamedTuple) = | ||
error("can't create model columns directly from a Grouping term") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
using Test | ||
using StatsModels | ||
|
||
@testset "Grouping pseudo-contrasts" begin | ||
d = (y = rand(2_000_000), grp=string.([1:1_000_000; 1:1_000_000])) | ||
## OOM seems to result in the process being killed on Mac so this messes up CI | ||
# @test_throws OutOfMemoryError schema(d) | ||
sch = schema(d, Dict(:grp => Grouping())) | ||
t = sch[term(:grp)] | ||
@test t isa CategoricalTerm{Grouping} | ||
@test size(t.contrasts.matrix) == (0,0) | ||
@test length(t.contrasts.levels) == 1_000_000 | ||
|
||
levs = sort(string.(1:1_000_000)) | ||
|
||
@test all(t.contrasts.invindex[lev] == i for (i,lev) in enumerate(levs)) | ||
@test all(t.contrasts.levels[i] == lev for (i,lev) in enumerate(levs)) | ||
|
||
# @test_throws OutOfMemoryError fit(MixedModel, @formula(y ~ 1 + (1 | grp)), d) | ||
@test fit(MixedModel, @formula(y ~ 1 + (1 | grp)), d, contrasts=Dict(:grp => Grouping())) isa LinearMixedModel | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters