Skip to content

Commit

Permalink
Compute working weights and residuals more carefully to reduce underflow
Browse files Browse the repository at this point in the history
and get the limits right when overflow is unavoidable. To do this, change
inverselink to return μ, 1-μ, dμdη, instead of μ, dμdη, μ*(1-μ) for Link01
in order to have access to accurate μ or 1-μ.

Introduce an absolute tolerance criterion to avoid convergence issues when
deviance is almost zero and rename tol to rtol.
  • Loading branch information
andreasnoack committed Jun 14, 2019
1 parent 0926a95 commit b1ab908
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 161 deletions.
145 changes: 115 additions & 30 deletions src/glmfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,79 @@ function updateμ!(r::GlmResp{V,D,L}) where {V<:FPVector,D,L}
end
end

function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::LogitLink)
# LogitLink is the canonical link function for Binomial so only wrkresᵢ can
# possibly fail when dμdη==0 in which case it evaluates to ±1.
if iszero(dμdηᵢ)
wrkresᵢ = ifelse(yᵢ == 1, one(μᵢ), -one(μᵢ))
else
wrkresᵢ = ifelse(yᵢ == 1, omμᵢ, yᵢ - μᵢ) / dμdηᵢ
end
wrkwtᵢ = μᵢ*omμᵢ

return wrkresᵢ, wrkwtᵢ
end

function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::ProbitLink)
# Since μomμ will underflow before dμdη for Probit, we can just check the
# former to decide when to evaluate with the tail approximation.
μomμᵢ = μᵢ*omμᵢ
if iszero(μomμᵢ)
wrkresᵢ = 1/abs(ηᵢ)
wrkwtᵢ = dμdηᵢ
else
wrkresᵢ = ifelse(yᵢ == 1, omμᵢ, yᵢ - μᵢ) / dμdηᵢ
wrkwtᵢ = abs2(dμdηᵢ)/μomμᵢ
end

return wrkresᵢ, wrkwtᵢ
end

function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::CloglogLink)
if yᵢ == 1
wrkresᵢ = exp(-ηᵢ)
else
emη = exp(-ηᵢ)
if iszero(emη)
# Diverges to -∞
wrkresᵢ = -typeof(wrkresᵢ)(Inf)
elseif isinf(emη)
# converges to -1
wrkresᵢ = -one(emη)
else
wrkresᵢ = (yᵢ - μᵢ)/omμᵢ*emη
end
end

wrkwtᵢ = exp(2*ηᵢ)/expm1(exp(ηᵢ))
# We know that both limits are zero so we'll convert NaNs
wrkwtᵢ = ifelse(isnan(wrkwtᵢ), zero(wrkwtᵢ), wrkwtᵢ)

return wrkresᵢ, wrkwtᵢ
end

# Fallback for remaining link functions
function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::Link01)
wrkresᵢ = ifelse(yᵢ == 1, omμᵢ, yᵢ - μᵢ)/dμdηᵢ
wrkwtᵢ = abs2(dμdηᵢ)/(μᵢ*omμᵢ)

return wrkresᵢ, wrkwtᵢ
end

function updateμ!(r::GlmResp{V,D,L}) where {V<:FPVector,D<:Union{Bernoulli,Binomial},L<:Link01}
y, η, μ, wrkres, wrkwt, dres = r.y, r.eta, r.mu, r.wrkresid, r.wrkwt, r.devresid

@inbounds for i in eachindex(y, η, μ, wrkres, wrkwt, dres)
μi, dμdη, μomμ = inverselink(L(), η[i])
μ[i] = μi
yi = y[i]
wrkres[i] = (yi - μi) / dμdη
wrkwt[i] = cancancel(r) ? dμdη : abs2(dμdη) / μomμ
dres[i] = devresid(r.d, yi, μi)
yᵢ, ηᵢ = y[i], η[i]
μᵢ, omμᵢ, dμdηᵢ = inverselink(L(), ηᵢ)
μ[i] = μᵢ
# For large values of ηᵢ the quantities dμdη and μomμ will underflow.
# The ratios defining (yᵢ - μᵢ)/dμdη and dμdη^2/μomμ have fairly stable
# tail behavior so we can switch algorithm to avoid 0/0. The behavior
# is specific to the link function so _weights_residuals dispatches to
# robust versions for LogitLink and ProbitLink
wrkres[i], wrkwt[i] = _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, L())
dres[i] = devresid(r.d, yᵢ, μᵢ)
end
end

Expand Down Expand Up @@ -200,7 +263,7 @@ end
dof(x::GeneralizedLinearModel) = dispersion_parameter(x.rr.d) ? length(coef(x)) + 1 : length(coef(x))

function _fit!(m::AbstractGLM, verbose::Bool, maxiter::Integer, minstepfac::Real,
tol::Real, start)
atol::Real, rtol::Real, start)

# Return early if model has the fit flag set
m.fit && return m
Expand Down Expand Up @@ -246,9 +309,9 @@ function _fit!(m::AbstractGLM, verbose::Bool, maxiter::Integer, minstepfac::Real

# Line search
## If the deviance isn't declining then half the step size
## The tol*dev term is to avoid failure when deviance
## The rtol*dev term is to avoid failure when deviance
## is unchanged except for rouding errors.
while dev > devold + tol*dev
while dev > devold + rtol*dev
f /= 2
f > minstepfac || error("step-halving failed at beta0 = $(p.beta0)")
try
Expand All @@ -261,22 +324,26 @@ function _fit!(m::AbstractGLM, verbose::Bool, maxiter::Integer, minstepfac::Real
installbeta!(p, f)

# Test for convergence
crit = (devold - dev)/dev
verbose && println("$i: $dev, $crit")
if crit < tol || dev == 0
verbose && println("Iteration: $i, deviance: $dev, diff.dev.:$(devold - dev)")
if devold - dev < max(rtol*devold, atol)
cvg = true
break
end
@assert isfinite(crit)
@assert isfinite(dev)
devold = dev
end
cvg || throw(ConvergenceException(maxiter))
m.fit = true
m
end

function StatsBase.fit!(m::AbstractGLM; verbose::Bool=false, maxiter::Integer=30,
minstepfac::Real=0.001, tol::Real=1e-6, start=nothing,
function StatsBase.fit!(m::AbstractGLM;
verbose::Bool=false,
maxiter::Integer=30,
minstepfac::Real=0.001,
atol::Real=1e-6,
rtol::Real=1e-6,
start=nothing,
kwargs...)
if haskey(kwargs, :maxIter)
Base.depwarn("'maxIter' argument is deprecated, use 'maxiter' instead", :fit!)
Expand All @@ -287,19 +354,32 @@ function StatsBase.fit!(m::AbstractGLM; verbose::Bool=false, maxiter::Integer=30
minstepfac = kwargs[:minStepFac]
end
if haskey(kwargs, :convTol)
Base.depwarn("'convTol' argument is deprecated, use 'tol' instead", :fit!)
Base.depwarn("'convTol' argument is deprecated, use `atol` and `rtol` instead", :fit!)
tol = kwargs[:convTol]
end
if !issubset(keys(kwargs), (:maxIter, :minStepFac, :convTol))
throw(ArgumentError("unsupported keyword argument"))
end
if haskey(kwargs, :tol)
Base.depwarn("`tol` argument is deprecated, use `atol` and `rtol` instead", :fit!)
tol = kwargs[:tol]
end

_fit!(m, verbose, maxiter, minstepfac, tol, start)
_fit!(m, verbose, maxiter, minstepfac, atol, rtol, start)
end

function StatsBase.fit!(m::AbstractGLM, y; wts=nothing, offset=nothing, dofit::Bool=true,
verbose::Bool=false, maxiter::Integer=30, minstepfac::Real=0.001,
tol::Real=1e-6, start=nothing, kwargs...)
function StatsBase.fit!(m::AbstractGLM,
y;
wts=nothing,
offset=nothing,
dofit::Bool=true,
verbose::Bool=false,
maxiter::Integer=30,
minstepfac::Real=0.001,
atol::Real=1e-6,
rtol::Real=1e-6,
start=nothing,
kwargs...)
if haskey(kwargs, :maxIter)
Base.depwarn("'maxIter' argument is deprecated, use 'maxiter' instead", :fit!)
maxiter = kwargs[:maxIter]
Expand All @@ -309,12 +389,16 @@ function StatsBase.fit!(m::AbstractGLM, y; wts=nothing, offset=nothing, dofit::B
minstepfac = kwargs[:minStepFac]
end
if haskey(kwargs, :convTol)
Base.depwarn("'convTol' argument is deprecated, use 'tol' instead", :fit!)
Base.depwarn("'convTol' argument is deprecated, use `atol` and `rtol` instead", :fit!)
tol = kwargs[:convTol]
end
if !issubset(keys(kwargs), (:maxIter, :minStepFac, :convTol))
throw(ArgumentError("unsupported keyword argument"))
end
if haskey(kwargs, :tol)
Base.depwarn("`tol` argument is deprecated, use `atol` and `rtol` instead", :fit!)
tol = kwargs[:tol]
end

r = m.rr
V = typeof(r.y)
Expand All @@ -326,7 +410,7 @@ function StatsBase.fit!(m::AbstractGLM, y; wts=nothing, offset=nothing, dofit::B
fill!(m.pp.beta0, 0)
m.fit = false
if dofit
_fit!(m, verbose, maxiter, minstepfac, tol, start)
_fit!(m, verbose, maxiter, minstepfac, atol, rtol, start)
else
m
end
Expand All @@ -346,8 +430,10 @@ vector, respectively, or a formula and a data frame. `d` must be a
length 0
- `verbose::Bool=false`: Display convergence information for each iteration
- `maxiter::Integer=30`: Maximum number of iterations allowed to achieve convergence
- `tol::Real=1e-6`: Convergence is achieved when the relative change in
deviance is less than this
- `atol::Real=1e-6`: Convergence is achieved when the relative change in
deviance is less than `max(rtol*dev, atol)`.
- `rtol::Real=1e-6`: Convergence is achieved when the relative change in
deviance is less than `max(rtol*dev, atol)`.
- `minstepfac::Real=0.001`: Minimum line step fraction. Must be between 0 and 1.
- `start::AbstractVector=nothing`: Starting values for beta. Should have the
same length as the number of columns in the model matrix.
Expand All @@ -373,11 +459,11 @@ function fit(::Type{M},
end

fit(::Type{M},
X::Union{Matrix,SparseMatrixCSC},
y::AbstractVector,
d::UnivariateDistribution,
l::Link=canonicallink(d); kwargs...) where {M<:AbstractGLM} =
fit(M, float(X), float(y), d, l; kwargs...)
X::Union{Matrix,SparseMatrixCSC},
y::AbstractVector,
d::UnivariateDistribution,
l::Link=canonicallink(d); kwargs...) where {M<:AbstractGLM} =
fit(M, float(X), float(y), d, l; kwargs...)

"""
glm(F, D, args...; kwargs...)
Expand Down Expand Up @@ -485,4 +571,3 @@ function checky(y, d::Binomial)
end
return nothing
end

0 comments on commit b1ab908

Please sign in to comment.