Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wts keyword to lm #341

Merged
merged 22 commits into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
232 changes: 232 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# This file is machine-generated - editing it directly is not advised
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this file.


[[Arpack]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6"
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.3.1"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "SHA"]
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.8"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.2.0"

[[DataAPI]]
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.1.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.6"

[[DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
version = "1.0.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "51d184211a807ddba5f48d06098e93986759ec43"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.21.9"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.3"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.10"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "1af46bf083b9630a5b27d4fd94f496c5fca642a8"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.1.1"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"

[[Rmath]]
deps = ["BinaryProvider", "Libdl", "Random", "Statistics"]
git-tree-sha1 = "2bbddcb984a1d08612d0c4abb5b4774883f6fa98"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.6.0"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[ShiftedArrays]]
git-tree-sha1 = "22395afdcf37d6709a5a0766cc4a5ca52cb85ea0"
uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a"
version = "1.0.0"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl"]
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.8.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.32.0"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions", "Test"]
git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.8.0"

[[StatsModels]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "ShiftedArrays", "SparseArrays", "StatsBase", "Tables"]
git-tree-sha1 = "f4a967704cae9426654cf4649561f263a6dc8419"
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
version = "0.6.7"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e"
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.0"

[[Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"]
git-tree-sha1 = "aaed7b3b00248ff6a794375ad6adf30f30ca5591"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "0.2.11"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
25 changes: 16 additions & 9 deletions src/lm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ mutable struct LmResp{V<:FPVector} <: ModResp # response in a linear model
new{V}(mu, off, wts, y)
end
end
LmResp(y::V) where {V<:FPVector} = LmResp{V}(fill!(similar(y), 0), similar(y, 0), similar(y, 0), y)

LmResp(y::AbstractVector{T}) where T<:Real = LmResp(float(y))
LmResp(y::FPVector; wts::FPVector = similar(y, 0)) = LmResp{typeof(y)}(fill!(similar(y), 0), similar(y, 0), wts, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make wts a positional argument for consistency with GlmResp.

Also, no spaces around = for arguments, and break lines at 92 chars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to make wts a positional argument because there are 4 fields in LmResp and wts is like the 4th argument. So it seemed weird to have the arguments taken be different from the fields of the object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe but then GlmResp should be changed too... Not sure it's worth it.


function updateμ!(r::LmResp{V}, linPr::V) where V<:FPVector
LmResp(y::AbstractVector{T}; wts::AbstractVector{T} = similar(y, 0)) where T<:Real = LmResp(float(y), float(wts))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LmResp(y::AbstractVector{T}; wts::AbstractVector{T} = similar(y, 0)) where T<:Real = LmResp(float(y), float(wts))
LmResp(y::AbstractVector{<:Real}; wts::AbstractVector{<:Real}=similar(y, 0)) = LmResp(float(y), float(wts))


function updateμ!(r::LmResp{V}, linPr::V) where {V<:FPVector}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't needed.

n = length(linPr)
length(r.y) == n || error("length(linPr) is $n, should be $(length(r.y))")
length(r.offset) == 0 ? copyto!(r.mu, linPr) : broadcast!(+, r.mu, linPr, r.offset)
Expand Down Expand Up @@ -126,24 +127,30 @@ end
LinearAlgebra.cholesky(x::LinearModel) = cholesky(x.pp)

function StatsBase.fit!(obj::LinearModel)
installbeta!(delbeta!(obj.pp, obj.rr.y))
if length(obj.rr.wts) == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if length(obj.rr.wts) == 0
if isempty(obj.rr.wts)

installbeta!(delbeta!(obj.pp, obj.rr.y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

installbeta!(obj.pp) can be moved outside of the if AFAICT.

else
installbeta!(delbeta!(obj.pp, obj.rr.y, obj.rr.wts))
end
updateμ!(obj.rr, linpred(obj.pp, zero(eltype(obj.rr.y))))
return obj
end

function fit(::Type{LinearModel}, X::AbstractMatrix, y::AbstractVector,
allowrankdeficient::Bool=false)
fit!(LinearModel(LmResp(y), cholpred(X, allowrankdeficient)))
function fit(::Type{LinearModel}, X::AbstractMatrix, y::V,
allowrankdeficient::Bool=false; wts::V = similar(y, 0)) where {V<:FPVector}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make this signature more restrictive?

fit!(LinearModel(LmResp(y, wts = wts), cholpred(X, allowrankdeficient)))
end

"""
lm(X, y, allowrankdeficient::Bool=false)
lm(X, y, allowrankdeficient::Bool=false; wts = nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lm(X, y, allowrankdeficient::Bool=false; wts = nothing)
lm(X, y, allowrankdeficient::Bool=false; wts=similar(y, 0))


An alias for `fit(LinearModel, X, y, allowrankdeficient)`

The arguments `X` and `y` can be a `Matrix` and a `Vector` or a `Formula` and a `DataFrame`.

The keyword argument `wts` can be a `Vector{Float64}`.
pdeffebach marked this conversation as resolved.
Show resolved Hide resolved
"""
lm(X, y, allowrankdeficient::Bool=false) = fit(LinearModel, X, y, allowrankdeficient)
lm(X, y, allowrankdeficient::Bool=false; kwargs...) = fit(LinearModel, X, y, allowrankdeficient; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lm(X, y, allowrankdeficient::Bool=false; kwargs...) = fit(LinearModel, X, y, allowrankdeficient; kwargs...)
lm(X, y, allowrankdeficient::Bool=false; kwargs...) =
fit(LinearModel, X, y, allowrankdeficient; kwargs...)


dof(x::LinearModel) = length(coef(x)) + 1

Expand Down