Skip to content

Commit

Permalink
Add predict method (#185)
Browse files Browse the repository at this point in the history
* Add predict for FixedEffectModel

* Add predict method and tests

* Whitespace fixes

* More whitespace fixes

* Final attempt at fixing whitespace
  • Loading branch information
nilshg committed Nov 17, 2021
1 parent d49a9e0 commit 85c5cb5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/FixedEffectModel.jl
Expand Up @@ -72,13 +72,33 @@ end

# predict, residuals, modelresponse
function StatsBase.predict(x::FixedEffectModel, df)
has_fe(x) && throw("predict is not defined for fixed effect models. To access the fixed effects, run `reg` with the option save = :fe, and access fixed effects with `fe()`")
# Require DataFrame input as we are using leftjoin and select from DataFrames here
df isa AbstractDataFrame || throw("Predict requires an input of type DataFrame")

fes = if has_fe(x)
# Make sure there are FEs saved
nrow(x.fe) > 0 || throw("No estimates for fixed effects found. Model needs to be estimated with save = :fe or :all for prediction to work.")

# Join FE estimates onto data and sum row-wise
combine(
leftjoin(select(df, x.fekeys), unique(x.fe);
on = x.fekeys, makeunique = true, matchmissing = :equal),
AsTable(Not(x.fekeys)) => sum)
else
nothing
end

df = StatsModels.columntable(df)
formula_schema = apply_schema(x.formula_predict, schema(x.formula_predict, df, x.contrasts), StatisticalModel)
cols, nonmissings = StatsModels.missing_omit(df, MatrixTerm(formula_schema.rhs))
cols, nonmissings = StatsModels.missing_omit(df, MatrixTerm(x.formula_predict.rhs))
formula_schema = apply_schema(x.formula_predict.rhs, schema(x.formula_predict.rhs, cols, x.contrasts), StatisticalModel)
new_x = modelmatrix(formula_schema, cols)
out = Vector{Union{Float64, Missing}}(missing, length(Tables.rows(df)))
out[nonmissings] = new_x * x.coef
out[nonmissings] = new_x * x.coef

if !isnothing(fes)
out[nonmissings] .+= fes[nonmissings, 1]
end

return out
end

Expand Down Expand Up @@ -107,7 +127,6 @@ function StatsBase.residuals(x::FixedEffectModel)
!has_fe(x) && throw("To access residuals, use residuals(x, df::AbstractDataFrame")
x.residuals
end


"""
fe(x::FixedEffectModel; keepkeys = false)
Expand Down
23 changes: 23 additions & 0 deletions test/predict.jl
Expand Up @@ -34,7 +34,30 @@ model = @formula Sales ~ CPI + (Price ~ Pimin) + fe(State)
result = reg(df, model)
show(result)

## Tests for predict method
# Test that only DataFrame is accepted
@test_throws "Predict requires an input of type DataFrame" predict(result, Matrix(df))

# Test that predicting from model without saved FE test throws
@test_throws "No estimates for fixed effects found. Model needs to be estimated with save = :fe or :all for prediction to work." predict(result, df)

# Test basic functionality - adding 1 to price should increase prediction by coef
model = @formula Sales ~ Price + fe(State)
result = reg(df, model, save = :fe)
x = predict(result, DataFrame(Price = [1.0, 2.0], State = [1, 1]))
@test last(x) - first(x) only(result.coef)

# Missing variables in covariates should yield missing prediction
x = predict(result, DataFrame(Price = [1.0, missing], State = [1, 1]))
@test ismissing(last(x))

# Missing variables in fixed effects should yield missing prediction
x = predict(result, DataFrame(Price = [1.0, 2.0], State = [1, missing]))
@test ismissing(last(x))

# Fixed effect levels not in the estimation data should yield missing prediction
x = predict(result, DataFrame(Price = [1.0, 2.0], State = [1, 111]))
@test ismissing(last(x))

##############################################################################
##
Expand Down

0 comments on commit 85c5cb5

Please sign in to comment.