Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix predict with confidence intervals #253

Merged
merged 3 commits into from
Nov 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,17 @@ julia> predict(ols)
4.33333
6.83333

```
<!-- Andreas Noack: As of 9 May 2018 this example doesn't work so I've temporarily commented it out
julia> newX = DataFrame(X=[2,3,4]);

julia> predict(ols, newX, :confint)
julia> predict(ols, newX, interval=:confidence)
3×3 Array{Float64,2}:
4.33333 1.33845 7.32821
6.83333 2.09801 11.5687
9.33333 1.40962 17.257
The columns of the matrix are prediction, 95% lower and upper confidence bounds -->
```

The columns of the matrix are prediction, 95% lower and upper confidence bounds
.

### Probit Regression:
```jldoctest
Expand Down
1 change: 1 addition & 0 deletions src/GLM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,6 @@ module GLM
include("glmfit.jl")
include("ftest.jl")
include("negbinfit.jl")
include("deprecated.jl")

end # module
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@Base.deprecate predict(mm::LinearModel, newx::AbstractMatrix, interval::Symbol, level::Real = 0.95) predict(mm, newx; interval=interval, level=level)
26 changes: 18 additions & 8 deletions src/lm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,29 @@ function coeftable(mm::LinearModel)
["x$i" for i = 1:size(mm.pp.X, 2)], 4)
end

predict(mm::LinearModel, newx::AbstractMatrix) = newx * coef(mm)

"""
predict(mm::LinearModel, newx::AbstractMatrix, interval_type::Symbol, level::Real = 0.95)
predict(mm::LinearModel, newx::AbstractMatrix;
interval::Union{Symbol,Nothing} = nothing, level::Real = 0.95)

Specifying `interval_type` will return a 3-column matrix with the prediction and
If `interval` is `nothing` (the default), return a vector with the predicted values
for model `mm` and new data `newx`.
Otherwise, return a 3-column matrix with the prediction and
the lower and upper confidence bounds for a given `level` (0.95 equates alpha = 0.05).
Valid values of `interval_type` are `:confint` delimiting the uncertainty of the
predicted relationship, and `:predint` delimiting estimated bounds for new data points.
Valid values of `interval` are `:confidence` delimiting the uncertainty of the
predicted relationship, and `:prediction` delimiting estimated bounds for new data points.
"""
function predict(mm::LinearModel, newx::AbstractMatrix, interval_type::Symbol, level::Real = 0.95)
function predict(mm::LinearModel, newx::AbstractMatrix;
interval::Union{Symbol,Nothing}=nothing, level::Real = 0.95)
retmean = newx * coef(mm)
interval_type == :confint || error("only :confint is currently implemented") #:predint will be implemented
if interval === :confint
Base.depwarn("interval=:confint is deprecated in favor of interval=:confidence")
interval = :confidence
end
if interval === nothing
return retmean
elseif interval !== :confidence
error("only interval=:confidence is currently implemented") #:predint will be implemented
end
length(mm.rr.wts) == 0 || error("prediction with confidence intervals not yet implemented for weighted regression")

R = cholesky!(mm.pp).U #get the R matrix from the QR factorization
Expand Down
26 changes: 13 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ end
@test isapprox(predict(gm12, newX, offset=newoff),
logistic.(newX * coef(gm12) .+ newoff))

# Prediction from DataFrames
# Prediction from DataFrames
d = convert(DataFrame, X)
d[:y] = Y

Expand All @@ -451,18 +451,18 @@ end

Ylm = X * [0.8, 1.6] + 0.8randn(10)
mm = fit(LinearModel, X, Ylm)
pred = predict(mm, newX, :confint)

@test isapprox(pred[1,2], 0.6122189104014528)
@test isapprox(pred[2,2], -0.33530477814532056)
@test isapprox(pred[3,2], 1.340413688904295)
@test isapprox(pred[4,2], 0.02118806218116165)
@test isapprox(pred[5,2], 0.8543142404183606)
@test isapprox(pred[1,3], 2.6853964084909836)
@test isapprox(pred[2,3], 1.2766396685055916)
@test isapprox(pred[3,3], 3.6617479283005894)
@test isapprox(pred[4,3], 0.6477623101170038)
@test isapprox(pred[5,3], 2.564532433982956)
pred1 = predict(mm, newX)
pred2 = predict(mm, newX, interval=:confidence)

@test pred1 == pred2[:, 1] ≈
[1.6488076594462182, 0.4706674451801356, 2.5010808086024423,
0.3344751861490827, 1.7094233372006582]
@test pred2[:, 2:3] ≈ [ 0.6122189104014528 2.6853964084909836
-0.33530477814532056 1.2766396685055916
1.340413688904295 3.6617479283005894
0.02118806218116165 0.6477623101170038
0.8543142404183606 2.564532433982956]

end

@testset "F test for model comparison" begin
Expand Down