Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extension of the GLM interface #27

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ version = "0.2.5"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
8 changes: 4 additions & 4 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import ..DecisionTree # strange syntax b/s we are lazy-loading
"""
DecisionTreeClassifer(; kwargs...)

A variationn on the CART decision tree classifier from
[https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md](https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md).
A variation on the CART decision tree classifier from
[https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md](https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md).

Instead of predicting the mode class at each leaf, a UnivariateFinite
distribution is fit to the leaf training classes, with smoothing
Expand Down Expand Up @@ -120,7 +120,7 @@ function MLJBase.fit(model::DecisionTreeClassifier

yplain = identity.(y) # y as plain not abstact vector
classes_seen = unique(yplain)

tree = DecisionTree.build_tree(yplain,
Xmatrix,
model.n_subfeatures,
Expand Down Expand Up @@ -271,7 +271,7 @@ end

## METADATA

DTTypes=Union{DecisionTreeClassifier,DecisionTreeRegressor}
const DTTypes = Union{DecisionTreeClassifier,DecisionTreeRegressor}

MLJBase.package_name(::Type{<:DTTypes}) = "DecisionTree"
MLJBase.package_uuid(::Type{<:DTTypes}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Expand Down
279 changes: 196 additions & 83 deletions src/GLM.jl
Original file line number Diff line number Diff line change
@@ -1,132 +1,245 @@
module GLM_

import MLJBase
# -------------------------------------------------------------------
# TODO
# - return feature names in the report
# - handle binomial case properly, needs MLJ API change for weighted
# samples (y/N ~ Be(p) with weights N)
# - handle levels properly (see GLM.jl/issues/240); if feed something
# with levels, the fit will fail.
# - revisit and test Poisson and Negbin regrssion once there's a clear
# example we can test on (requires handling levels which deps upon GLM)
# - test Logit, Probit etc on Binomial once binomial case is handled
# -------------------------------------------------------------------

export OLSRegressor, OLS,
GLMCountRegressor, GLMCount
import MLJBase
import Distributions
using DocStringExtensions

export GLMRegressor,
OLSRegressor,
PoissonRegressor,
LogitRegressor, LogisticRegressor,
ProbitRegressor,
CauchitRegressor,
CloglogRegressor,
NegativeBinomialRegressor

import ..GLM

####
#### HELPER FUNCTIONS
####

"""
$SIGNATURES

Augment the matrix X with a column of ones if the intercept is to be fitted.
"""
augment_X(m::MLJBase.Model, X::Matrix)::Matrix =
m.fit_intercept ? hcat(X, ones(Int, size(X, 1), 1)) : X

## TODO: add feature importance curve to report using `features`
"""
$SIGNATURES

Report based on the fitresult of a GLM model.
"""
glm_report(fitresult) = ( deviance = GLM.deviance(fitresult),
dof_residual = GLM.dof_residual(fitresult),
stderror = GLM.stderror(fitresult),
vcov = GLM.vcov(fitresult) )

####
#### REGRESSION TYPES
####

mutable struct OLSRegressor <: MLJBase.Probabilistic
fit_intercept::Bool
# allowrankdeficient::Bool
end
"""
GLMRegressor

Generalized Linear model corresponding to:

``y ∼ D(θ),``
``g(θ) = Xβ``

OLSRegressor(;fit_intercept=true) = OLSRegressor(fit_intercept)
where

mutable struct GLMCountRegressor <: MLJBase.Probabilistic
* `y` is the observed response vector
* `D(θ)` is a parametric distribution (typically in the exponential family)
* `g` is the link function
* `X` is the design matrix (possibly with a column of one if an intercept is to be fitted)
* `β` is the vector of coefficients
"""
mutable struct GLMRegressor <: MLJBase.Probabilistic
distr::Distributions.Distribution
link::GLM.Link
fit_intercept::Bool
# link
allowrankdeficient::Bool # OLS only
end

GLMCountRegressor(;fit_intercept=true) = GLMCountRegressor(fit_intercept)
"""
$SIGNATURES

# synonyms
const OLS = OLSRegressor
const GLMCount = GLMCountRegressor
Model for continuous regression with

####
#### FIT FUNCTIONS
####
``y ∼ N(μ, σ²)``

function MLJBase.fit(model::OLS, verbosity::Int, X, y)
where `μ=Xβ` and `N` denotes a normal distribution.
See also [`GLMRegressor`](@ref).
"""
OLSRegressor(; fit_intercept=true, allowrankdeficient=false) =
GLMRegressor(Distributions.Normal(), GLM.IdentityLink(), fit_intercept, allowrankdeficient)

Xmatrix = MLJBase.matrix(X)
features = MLJBase.schema(X).names
model.fit_intercept && (Xmatrix = hcat(Xmatrix, ones(eltype(Xmatrix), size(Xmatrix, 1), 1)))
"""
$SIGNATURES

fitresult = GLM.lm(Xmatrix, y)
Model for count regression with

## TODO: add feature importance curve to report using `features`
report = (deviance=GLM.deviance(fitresult)
, dof_residual=GLM.dof_residual(fitresult)
, stderror=GLM.stderror(fitresult)
, vcov=GLM.vcov(fitresult))
cache = nothing
``y ∼ Poi(λ)``

return fitresult, cache, report
end
where `log(λ) = Xβ` and `Poi` denotes a Poisson distribution.
See also [`GLMRegressor`](@ref).
"""
PoissonRegressor(; fit_intercept=true) =
GLMRegressor(Distributions.Poisson(), GLM.LogLink(), fit_intercept, false)

function MLJBase.fitted_params(model::OLS, fitresult)
coefs = GLM.coef(fitresult)
return (coef=coefs[1:end-Int(model.fit_intercept)],
intercept=ifelse(model.fit_intercept, coefs[end], nothing))
end
"""
$SIGNATURES

Model for count regression with

function MLJBase.fit(model::GLMCount, verbosity::Int, X, y)
``y ∼ NB(λ)``

Xmatrix = MLJBase.matrix(X)
features = MLJBase.schema(X).names
model.fit_intercept && (Xmatrix = hcat(Xmatrix, ones(eltype(Xmatrix), size(Xmatrix, 1), 1)))
where `λ = Xβ` and `NB` denotes a Negative Binomial distribution.
See also [`GLMRegressor`](@ref).
"""
NegativeBinomialRegressor(; fit_intercept=true, r=1.) =
GLMRegressor(Distributions.NegativeBinomial(r), GLM.LogLink(), fit_intercept, false)

fitresult = GLM.glm(Xmatrix, y, GLM.Poisson()) # Log link
"""
$SIGNATURES

## TODO: add feature importance curve to report using `features`
report = (deviance=GLM.deviance(fitresult)
, dof_residual=GLM.dof_residual(fitresult)
, stderror=GLM.stderror(fitresult)
, vcov=GLM.vcov(fitresult))
cache = nothing
Model for bernoulli (binary) and binomial regression with logit link:

``y ∼ Be(p)`` or ``y ∼ Bin(n, p)``

where `logit(p) = Xβ`.
See also [`GLMRegressor`](@ref), [`ProbitRegressor`](@ref), [`CauchitRegressor`](@ref),
[`CloglogRegressor`](@ref).
"""
LogitRegressor(; fit_intercept=true, distr=Distributions.Bernoulli()) =
GLMRegressor(distr, GLM.LogitLink(), fit_intercept, false)
LogisticRegressor = LogitRegressor

"""
$SIGNATURES

Model for bernoulli (binary) and binomial regression with probit link:

``y ∼ Be(p)`` or ``y ∼ Bin(n, p)``

where `probit(p) = Xβ`.
See also [`GLMRegressor`](@ref), [`LogitRegressor`](@ref), [`CauchitRegressor`](@ref),
[`CloglogRegressor`](@ref).
"""
ProbitRegressor(; fit_intercept=true, distr=Distributions.Bernoulli()) =
GLMRegressor(distr, GLM.ProbitLink(), fit_intercept, false)

"""
$SIGNATURES

Model for bernoulli (binary) and binomial regression with cauchit link:

``y ∼ Be(p)`` or ``y ∼ Bin(n, p)``

where `cauchit(p) = Xβ`.
See also [`GLMRegressor`](@ref), [`ProbitRegressor`](@ref), [`LogitRegressor`](@ref),
[`CloglogRegressor`](@ref).
"""
CauchitRegressor(; fit_intercept=true, distr=Distributions.Bernoulli()) =
GLMRegressor(distr, GLM.CauchitLink(), fit_intercept, false)

"""
$SIGNATURES

Model for bernoulli (binary) and binomial regression with complentary log log link:

``y ∼ Be(p)`` or ``y ∼ Bin(n, p)``

where `cloglog(p) = Xβ`.
See also [`GLMRegressor`](@ref), [`ProbitRegressor`](@ref), [`LogitRegressor`](@ref),
[`CauchitRegressor`](@ref).
"""
CloglogRegressor(; fit_intercept=true, distr=Distributions.Bernoulli()) =
GLMRegressor(distr, GLM.CloglogLink(), fit_intercept, false)

####
#### FIT
####

function MLJBase.fit(model::GLMRegressor, verbosity::Int, X, y)
features = MLJBase.schema(X).names
Xmatrix = augment_X(model, MLJBase.matrix(X))
fitresult = GLM.glm(Xmatrix, y, model.distr, model.link)
report = glm_report(fitresult)
cache = nothing
return fitresult, cache, report
end

function MLJBase.fitted_params(model::GLMCount, fitresult)
function MLJBase.fitted_params(model::GLMRegressor, fitresult)
coefs = GLM.coef(fitresult)
return (coef=coefs[1:end-Int(model.fit_intercept)],
intercept=ifelse(model.fit_intercept, coefs[end], nothing))
return (coef = coefs[1:end-Int(model.fit_intercept)],
intercept = ifelse(model.fit_intercept, coefs[end],
nothing))
end

####
#### PREDICT FUNCTIONS
#### PREDICT
####

function MLJBase.predict_mean(model::Union{OLS, GLMCount}
, fitresult
, Xnew)
Xmatrix = MLJBase.matrix(Xnew)
model.fit_intercept && (Xmatrix = hcat(Xmatrix, ones(eltype(Xmatrix), size(Xmatrix, 1), 1)))
function MLJBase.predict_mean(model::GLMRegressor, fitresult, Xnew)
Xmatrix = augment_X(model, MLJBase.matrix(Xnew))
return GLM.predict(fitresult, Xmatrix)
end

function MLJBase.predict(model::OLS, fitresult, Xnew)
Xmatrix = MLJBase.matrix(Xnew)
model.fit_intercept && (Xmatrix = hcat(Xmatrix, ones(eltype(Xmatrix), size(Xmatrix, 1), 1)))
μ = GLM.predict(fitresult, Xmatrix)
σ̂ = GLM.dispersion(fitresult, false)
return [GLM.Normal(μᵢ, σ̂) for μᵢ ∈ μ]
end

function MLJBase.predict(model::GLMCount, fitresult, Xnew)
Xmatrix = MLJBase.matrix(Xnew)
model.fit_intercept && (Xmatrix = hcat(Xmatrix, ones(eltype(Xmatrix), size(Xmatrix, 1), 1)))
λ = GLM.predict(fitresult, Xmatrix)
return [GLM.Poisson(λᵢ) for λᵢ ∈ λ]
function MLJBase.predict(model::GLMRegressor, fitresult, Xnew)
Xmatrix = augment_X(model, MLJBase.matrix(Xnew))

# Ordinary least squares
if isa(model.distr, Distributions.Normal) && isa(model.link, GLM.IdentityLink)
μ = GLM.predict(fitresult, Xmatrix)
σ̂ = GLM.dispersion(fitresult, false)
return GLM.Normal.(μ, σ̂)

# Poisson regression
elseif isa(model.distr, Distributions.Poisson) && isa(model.link, GLM.LogLink)
λ = GLM.predict(fitresult, Xmatrix)
return GLM.Poisson.(λ)

elseif isa(model.distr, Union{Distributions.Bernoulli,Distributions.Binomial}) &&
isa(model.link, Union{GLM.LogitLink,GLM.ProbitLink,GLM.CauchitLink,GLM.CloglogLink})
π = GLM.predict(fitresult, Xmatrix)
if isa(model.distr, Distributions.Bernoulli)
return GLM.Bernoulli.(π)
else
error("Binomial regression not yet supported")
end
end
end

####
#### METADATA
####

# shared metadata
const GLM_REGS = Union{Type{<:OLS}, Type{<:GLMCount}}
MLJBase.package_name(::GLM_REGS) = "GLM"
MLJBase.package_uuid(::GLM_REGS) = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
MLJBase.package_url(::GLM_REGS) = "https://github.com/JuliaStats/GLM.jl"
MLJBase.is_pure_julia(::GLM_REGS) = true

MLJBase.load_path(::Type{<:OLS}) = "MLJModels.GLM_.OLSRegressor"
MLJBase.input_scitype_union(::Type{<:OLS}) = MLJBase.Continuous
MLJBase.target_scitype_union(::Type{<:OLS}) = MLJBase.Continuous
MLJBase.input_is_multivariate(::Type{<:OLS}) = true

MLJBase.load_path(::Type{<:GLMCount}) = "MLJModels.GLM_.GLMCountRegressor"
MLJBase.input_scitype_union(::Type{<:GLMCount}) = MLJBase.Continuous
MLJBase.target_scitype_union(::Type{<:GLMCount}) = MLJBase.Count
MLJBase.input_is_multivariate(::Type{<:GLMCount}) = true
MLJBase.package_name(::Type{<:GLMRegressor}) = "GLM"
MLJBase.package_uuid(::Type{<:GLMRegressor}) = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
MLJBase.package_url(::Type{<:GLMRegressor}) = "https://github.com/JuliaStats/GLM.jl"
MLJBase.is_pure_julia(::Type{<:GLMRegressor}) = true

MLJBase.load_path(::Type{<:GLMRegressor}) = "MLJModels.GLM_.GLMRegressor"
MLJBase.input_scitype_union(::Type{<:GLMRegressor}) = Union{MLJBase.Continuous, MLJBase.Count}
MLJBase.target_scitype_union(::Type{<:GLMRegressor}) = Union{MLJBase.Continuous, MLJBase.Count}
MLJBase.input_is_multivariate(::Type{<:GLMRegressor}) = true

end # module