Skip to content

Commit

Permalink
Fix predict for drop_intercept models (#76)
Browse files Browse the repository at this point in the history
* - fixed predict with drop_intercept model and categorical variabels bug
- add tests that used to fail but now pass

* reverted end of line changes
  • Loading branch information
AsafManela authored and kleinschmidt committed Jan 8, 2019
1 parent f9c247d commit a45ad63
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/statsmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@ StatsBase.r2(mm::DataFrameRegressionModel, variant::Symbol) = r2(mm.model, varia
StatsBase.adjr2(mm::DataFrameRegressionModel, variant::Symbol) = adjr2(mm.model, variant)

# Predict function that takes data frame as predictor instead of matrix
function StatsBase.predict(mm::DataFrameRegressionModel, df::AbstractDataFrame; kwargs...)
function StatsBase.predict(mm::DataFrameRegressionModel{T}, df::AbstractDataFrame; kwargs...) where T
# copy terms, removing outcome if present (ModelFrame will complain if a
# term is not found in the DataFrame and we don't want to remove elements with missing y)
newTerms = dropresponse!(mm.mf.terms)
# create new model frame/matrix
drop_intercept(T) && (newTerms.intercept = true)
mf = ModelFrame(newTerms, df; contrasts = mm.mf.contrasts)
drop_intercept(T) && (mf.terms.intercept = false)
newX = ModelMatrix(mf).m
yp = predict(mm, newX; kwargs...)
out = missings(eltype(yp), size(df, 1))
Expand Down
8 changes: 8 additions & 0 deletions test/statsmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ StatsBase.coeftable(mod::DummyModNoIntercept) =
["'beta' value"],
["" for n in 1:size(mod.x,2)],
0)
StatsBase.predict(mod::DummyModNoIntercept) = mod.x * mod.beta
StatsBase.predict(mod::DummyModNoIntercept, newX::Matrix) = newX * mod.beta

## Another dummy model type to test fall-through show method
struct DummyModTwo <: RegressionModel
Expand Down Expand Up @@ -122,14 +124,20 @@ Base.show(io::IO, m::DummyModTwo) = println(io, m.msg)
ct1 = coeftable(m1)
ct2 = coeftable(m2)
@test ct1.rownms == ct2.rownms == ["x1", "x2", "x1 & x2"]
@test predict(m1, d[2:4, :]) == predict(m1)[2:4]
@test predict(m2, d[2:4, :]) == predict(m2)[2:4]

f1 = @formula(y ~ 1 + x1p)
f2 = @formula(y ~ 0 + x1p)
m1 = fit(DummyModNoIntercept, f1, d)
m2 = fit(DummyModNoIntercept, f2, d)
m3 = fit(DummyModNoIntercept, f3, d, contrasts = Dict(:x1p => EffectsCoding()))
ct1 = coeftable(m1)
ct2 = coeftable(m2)
@test ct1.rownms == ct2.rownms == ["x1p: 6", "x1p: 7", "x1p: 8"]
@test predict(m1, d[2:4, :]) == predict(m1)[2:4]
@test predict(m2, d[2:4, :]) == predict(m2)[2:4]
@test predict(m3, d[2:4, :]) == predict(m3)[2:4]

m2 = fit(DummyModTwo, f, d)
show(io, m2)
Expand Down

0 comments on commit a45ad63

Please sign in to comment.