-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Negative Binomial regression support (#238)
* minor testing * working on adding negative Binomial support, debugging now * debugging a division by 0 error * done with negative Binomial with dispersion parameter fixed. still need to construct more elaborate test cases because conversion may be an issue for non-traditional link functions. the next step is to estimate the dispersion parameter too * fixed the bug caused by mis-calculating the deviance for negative binomial * added a test case * a small test case * fixed typos * fixed typos * added RDatasets as a required package for testing * added another test to clarify the loglink vs the NB canonical link * added more details on the two test cases * done with estimating theta too * removed redundant struct * removed redundant println * took care of PR revew comments * took care of review comments * took care of PR review comments * minor * addressed more PR comments * changed 1.0 to 1 * fixed the reparameterization issue * minor * added oftype(v, NaN) for type safety * added oftype(v, NaN) for type safety * trying latex with md * adding theta in html * added negative binomial regression section * added negative binomial regression section * put a listing in alphabetical order * added NegativeBinomialLink to the list * added InverseGaussian to the doc because it is already supported
- Loading branch information
1 parent
886f4fc
commit fc5baf3
Showing
7 changed files
with
264 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
function mle_for_θ(y::AbstractVector, μ::AbstractVector, wts::AbstractVector; | ||
maxIter=30, convTol=1.e-6) | ||
function first_derivative(θ::Real) | ||
tmp(yi, μi) = (yi+θ)/(μi+θ) + log(μi+θ) - 1 - log(θ) - digamma(θ+yi) + digamma(θ) | ||
unit_weights ? sum(tmp(yi, μi) for (yi, μi) in zip(y, μ)) : | ||
sum(wti * tmp(yi, μi) for (wti, yi, μi) in zip(wts, y, μ)) | ||
end | ||
function second_derivative(θ::Real) | ||
tmp(yi, μi) = -(yi+θ)/(μi+θ)^2 + 2/(μi+θ) - 1/θ - trigamma(θ+yi) + trigamma(θ) | ||
unit_weights ? sum(tmp(yi, μi) for (yi, μi) in zip(y, μ)) : | ||
sum(wti * tmp(yi, μi) for (wti, yi, μi) in zip(wts, y, μ)) | ||
end | ||
|
||
unit_weights = length(wts) == 0 | ||
if unit_weights | ||
n = length(y) | ||
θ = n / sum((yi/μi - 1)^2 for (yi, μi) in zip(y, μ)) | ||
else | ||
n = sum(wts) | ||
θ = n / sum(wti * (yi/μi - 1)^2 for (wti, yi, μi) in zip(wts, y, μ)) | ||
end | ||
δ, converged = one(θ), false | ||
|
||
for t = 1:maxIter | ||
θ = abs(θ) | ||
δ = first_derivative(θ) / second_derivative(θ) | ||
if abs(δ) <= convTol | ||
converged = true | ||
break | ||
end | ||
θ = θ - δ | ||
end | ||
converged || throw(ConvergenceException(maxIter)) | ||
θ | ||
end | ||
|
||
function negbin(F, D, args...; | ||
initialθ::Real=Inf, maxIter::Integer=30, convTol::Real=1.e-6, | ||
verbose::Bool=false, kwargs...) | ||
maxIter >= 1 || throw(ArgumentError("maxIter must be positive")) | ||
convTol > 0 || throw(ArgumentError("convTol must be positive")) | ||
initialθ > 0 || throw(ArgumentError("initialθ must be positive")) | ||
|
||
# fit a Poisson regression model if the user does not specify an initial θ | ||
if isinf(initialθ) | ||
regmodel = glm(F, D, Poisson(), args...; | ||
maxIter=maxIter, convTol=convTol, verbose=verbose, kwargs...) | ||
else | ||
regmodel = glm(F, D, NegativeBinomial(initialθ), args...; | ||
maxIter=maxIter, convTol=convTol, verbose=verbose, kwargs...) | ||
end | ||
|
||
μ = regmodel.model.rr.mu | ||
y = regmodel.model.rr.y | ||
wts = regmodel.model.rr.wts | ||
lw, ly = length(wts), length(y) | ||
if lw != ly && lw != 0 | ||
throw(ArgumentError("length of wts must be either $ly or 0 but was $lw")) | ||
end | ||
|
||
θ = mle_for_θ(y, μ, wts) | ||
d = sqrt(2 * max(1, deviance(regmodel))) | ||
δ = one(θ) | ||
ll = loglikelihood(regmodel) | ||
ll0 = ll + 2 * d | ||
|
||
converged = false | ||
for i = 1:maxIter | ||
if abs(ll0 - ll)/d + abs(δ) <= convTol | ||
converged = true | ||
break | ||
end | ||
verbose && println("[ Alternating iteration ", i, ", θ = ", θ, " ]") | ||
regmodel = glm(F, D, NegativeBinomial(θ), args...; | ||
maxIter=maxIter, convTol=convTol, verbose=verbose, kwargs...) | ||
μ = regmodel.model.rr.mu | ||
prevθ = θ | ||
θ = mle_for_θ(y, μ, wts; maxIter=maxIter, convTol=convTol) | ||
δ = prevθ - θ | ||
ll0 = ll | ||
ll = loglikelihood(regmodel) | ||
end | ||
converged || throw(ConvergenceException(maxIter)) | ||
regmodel | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
CategoricalArrays 0.3.0 | ||
Compat 0.36.0 | ||
DataFrames 0.11.0 | ||
CSV 0.2.0 | ||
CSV 0.2.0 | ||
RDatasets 0.4.0 |
Oops, something went wrong.