From 49fe437537545afb17b90983a0275d892e0d047f Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Tue, 11 Sep 2018 13:05:20 +0200 Subject: [PATCH 1/2] Fix predict with confidence intervals Use keyword arguments, which are passed through by the DataFrameRegressionModel method. --- docs/src/index.md | 9 +++++---- src/GLM.jl | 1 + src/deprecated.jl | 1 + src/lm.jl | 20 +++++++++++++------- test/runtests.jl | 4 ++-- 5 files changed, 22 insertions(+), 13 deletions(-) create mode 100644 src/deprecated.jl diff --git a/docs/src/index.md b/docs/src/index.md index cf07c567..2d7b86d3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -81,16 +81,17 @@ julia> predict(ols) 4.33333 6.83333 -``` - +``` + +The columns of the matrix are prediction, 95% lower and upper confidence bounds +. ### Probit Regression: ```jldoctest diff --git a/src/GLM.jl b/src/GLM.jl index 4cf9bfb3..807019fd 100644 --- a/src/GLM.jl +++ b/src/GLM.jl @@ -99,5 +99,6 @@ module GLM include("glmfit.jl") include("ftest.jl") include("negbinfit.jl") + include("deprecated.jl") end # module diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 00000000..5f81a29f --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1 @@ +@Base.deprecate predict(mm::LinearModel, newx::AbstractMatrix, interval::Symbol, level::Real = 0.95) predict(mm, newx; interval=interval, level=level) \ No newline at end of file diff --git a/src/lm.jl b/src/lm.jl index 812d5110..edf0e19e 100644 --- a/src/lm.jl +++ b/src/lm.jl @@ -186,19 +186,25 @@ function coeftable(mm::LinearModel) ["x$i" for i = 1:size(mm.pp.X, 2)], 4) end -predict(mm::LinearModel, newx::AbstractMatrix) = newx * coef(mm) - """ - predict(mm::LinearModel, newx::AbstractMatrix, interval_type::Symbol, level::Real = 0.95) + predict(mm::LinearModel, newx::AbstractMatrix; + interval::Union{Symbol,Nothing} = nothing, level::Real = 0.95) -Specifying `interval_type` will return a 3-column matrix with the prediction and +If `interval` is `nothing` (the default), return a vector with the predicted values +for model `mm` and new data `newx`. +Otherwise, return a 3-column matrix with the prediction and the lower and upper confidence bounds for a given `level` (0.95 equates alpha = 0.05). -Valid values of `interval_type` are `:confint` delimiting the uncertainty of the +Valid values of `interval` are `:confint` delimiting the uncertainty of the predicted relationship, and `:predint` delimiting estimated bounds for new data points. """ -function predict(mm::LinearModel, newx::AbstractMatrix, interval_type::Symbol, level::Real = 0.95) +function predict(mm::LinearModel, newx::AbstractMatrix; + interval::Union{Symbol,Nothing}=nothing, level::Real = 0.95) retmean = newx * coef(mm) - interval_type == :confint || error("only :confint is currently implemented") #:predint will be implemented + if interval === nothing + return retmean + elseif interval !== :confint + error("only :confint is currently implemented") #:predint will be implemented + end length(mm.rr.wts) == 0 || error("prediction with confidence intervals not yet implemented for weighted regression") R = cholesky!(mm.pp).U #get the R matrix from the QR factorization diff --git a/test/runtests.jl b/test/runtests.jl index 89319505..48dcb84e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -438,7 +438,7 @@ end @test isapprox(predict(gm12, newX, offset=newoff), logistic.(newX * coef(gm12) .+ newoff)) - # Prediction from DataFrames + # Prediction from DataFrames d = convert(DataFrame, X) d[:y] = Y @@ -451,7 +451,7 @@ end Ylm = X * [0.8, 1.6] + 0.8randn(10) mm = fit(LinearModel, X, Ylm) - pred = predict(mm, newX, :confint) + pred = predict(mm, newX, interval=:confint) @test isapprox(pred[1,2], 0.6122189104014528) @test isapprox(pred[2,2], -0.33530477814532056) From d7cc9c2ef75ff1315801427eaee62a20b8cad9b3 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Tue, 11 Sep 2018 13:27:16 +0200 Subject: [PATCH 2/2] Add test for predicted values --- test/runtests.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 48dcb84e..497d5ab3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -451,18 +451,18 @@ end Ylm = X * [0.8, 1.6] + 0.8randn(10) mm = fit(LinearModel, X, Ylm) - pred = predict(mm, newX, interval=:confint) - - @test isapprox(pred[1,2], 0.6122189104014528) - @test isapprox(pred[2,2], -0.33530477814532056) - @test isapprox(pred[3,2], 1.340413688904295) - @test isapprox(pred[4,2], 0.02118806218116165) - @test isapprox(pred[5,2], 0.8543142404183606) - @test isapprox(pred[1,3], 2.6853964084909836) - @test isapprox(pred[2,3], 1.2766396685055916) - @test isapprox(pred[3,3], 3.6617479283005894) - @test isapprox(pred[4,3], 0.6477623101170038) - @test isapprox(pred[5,3], 2.564532433982956) + pred1 = predict(mm, newX) + pred2 = predict(mm, newX, interval=:confint) + + @test pred1 == pred2[:, 1] ≈ + [1.6488076594462182, 0.4706674451801356, 2.5010808086024423, + 0.3344751861490827, 1.7094233372006582] + @test pred2[:, 2:3] ≈ [ 0.6122189104014528 2.6853964084909836 + -0.33530477814532056 1.2766396685055916 + 1.340413688904295 3.6617479283005894 + 0.02118806218116165 0.6477623101170038 + 0.8543142404183606 2.564532433982956] + end @testset "F test for model comparison" begin