Skip to content

Commit

Permalink
Merge af5ba92 into e32bb71
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jul 26, 2023
2 parents e32bb71 + af5ba92 commit b6236aa
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 64 deletions.
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.26.6"
version = "0.27"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -56,6 +56,7 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "5, 6"
NamedArrays = "0.9"
Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.37.1"
Expand All @@ -66,3 +67,12 @@ StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
julia = "1.7"

[weakdeps]
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[extensions]
TuringOptimExt = "Optim"

[extras]
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
92 changes: 47 additions & 45 deletions src/modes/OptimInterface.jl → ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using Setfield
using DynamicPPL: DefaultContext, LikelihoodContext
using DynamicPPL: DynamicPPL
import .Optim
import .Optim: optimize
import ..ForwardDiff
import NamedArrays
import StatsBase
import Printf
import StatsAPI

module TuringOptimExt

if isdefined(Base, :get_extension)
import Turing
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import Optim
else
import ..Turing
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import ..Optim
end

"""
ModeResult{
Expand All @@ -23,7 +23,7 @@ A wrapper struct to store various results from a MAP or MLE estimation.
struct ModeResult{
V<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
M<:OptimLogDensity
M<:Turing.OptimLogDensity
} <: StatsBase.StatisticalModel
"A vector with the resulting point estimates."
values::V
Expand Down Expand Up @@ -57,10 +57,10 @@ function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
estimates = m.values.array[:, 1]
stderrors = StatsBase.stderror(m)
zscore = estimates ./ stderrors
p = map(z -> StatsAPI.pvalue(Normal(), z; tail=:both), zscore)
p = map(z -> StatsAPI.pvalue(Distributions.Normal(), z; tail=:both), zscore)

# Confidence interval (CI)
q = quantile(Normal(), (1 + level) / 2)
q = Statistics.quantile(Distributions.Normal(), (1 + level) / 2)
ci_low = estimates .- q .* stderrors
ci_high = estimates .+ q .* stderrors

Expand All @@ -80,7 +80,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff
# Hessian is computed with respect to the untransformed parameters.
linked = DynamicPPL.istrans(m.f.varinfo)
if linked
@set! m.f.varinfo = invlink!!(m.f.varinfo, m.f.model)
Setfield.@set! m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
end

# Calculate the Hessian.
Expand All @@ -90,7 +90,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff

# Link it back if we invlinked it.
if linked
@set! m.f.varinfo = link!!(m.f.varinfo, m.f.model)
Setfield.@set! m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
end

return NamedArrays.NamedArray(info, (varnames, varnames))
Expand Down Expand Up @@ -126,18 +126,18 @@ mle = optimize(model, MLE())
mle = optimize(model, MLE(), NelderMead())
```
"""
function Optim.optimize(model::Model, ::MLE, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, options::Optim.Options=Optim.Options(); kwargs...)
return _mle_optimize(model, options; kwargs...)
end
function Optim.optimize(model::Model, ::MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
return _mle_optimize(model, init_vals, options; kwargs...)
end
function Optim.optimize(model::Model, ::MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
return _mle_optimize(model, optimizer, options; kwargs...)
end
function Optim.optimize(
model::Model,
::MLE,
model::DynamicPPL.Model,
::Turing.MLE,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
Expand All @@ -146,9 +146,9 @@ function Optim.optimize(
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end

function _mle_optimize(model::Model, args...; kwargs...)
ctx = OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(model, OptimLogDensity(model, ctx), args...; kwargs...)
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...)
end

"""
Expand All @@ -172,18 +172,18 @@ map_est = optimize(model, MAP(), NelderMead())
```
"""

function Optim.optimize(model::Model, ::MAP, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, options::Optim.Options=Optim.Options(); kwargs...)
return _map_optimize(model, options; kwargs...)
end
function Optim.optimize(model::Model, ::MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
return _map_optimize(model, init_vals, options; kwargs...)
end
function Optim.optimize(model::Model, ::MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
return _map_optimize(model, optimizer, options; kwargs...)
end
function Optim.optimize(
model::Model,
::MAP,
model::DynamicPPL.Model,
::Turing.MAP,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
Expand All @@ -192,9 +192,9 @@ function Optim.optimize(
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end

function _map_optimize(model::Model, args...; kwargs...)
ctx = OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(model, OptimLogDensity(model, ctx), args...; kwargs...)
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...)
end

"""
Expand All @@ -203,8 +203,8 @@ end
Estimate a mode, i.e., compute a MLE or MAP estimate.
"""
function _optimize(
model::Model,
f::OptimLogDensity,
model::DynamicPPL.Model,
f::Turing.OptimLogDensity,
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
args...;
kwargs...
Expand All @@ -213,8 +213,8 @@ function _optimize(
end

function _optimize(
model::Model,
f::OptimLogDensity,
model::DynamicPPL.Model,
f::Turing.OptimLogDensity,
options::Optim.Options=Optim.Options(),
args...;
kwargs...
Expand All @@ -223,8 +223,8 @@ function _optimize(
end

function _optimize(
model::Model,
f::OptimLogDensity,
model::DynamicPPL.Model,
f::Turing.OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
options::Optim.Options=Optim.Options(),
args...;
Expand All @@ -234,8 +234,8 @@ function _optimize(
end

function _optimize(
model::Model,
f::OptimLogDensity,
model::DynamicPPL.Model,
f::Turing.OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
options::Optim.Options=Optim.Options(),
Expand All @@ -244,8 +244,8 @@ function _optimize(
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)

# Optimize!
Expand All @@ -258,10 +258,10 @@ function _optimize(

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
@set! f.varinfo = invlink!!(f.varinfo, model)
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
Setfield.@set! f.varinfo = DynamicPPL.invlink!!(f.varinfo, model)
vals = DynamicPPL.getparams(f)
@set! f.varinfo = link!!(f.varinfo, model)
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
Expand All @@ -275,3 +275,5 @@ function _optimize(

return ModeResult(vmat, M, -M.minimum, f)
end

end # module
41 changes: 24 additions & 17 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import AdvancedVI
using DynamicPPL: DynamicPPL, LogDensityFunction
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import NamedArrays
import Setfield
import StatsAPI
import StatsBase

import Printf
import Random

const PROGRESS = Ref(true)
Expand Down Expand Up @@ -48,26 +54,9 @@ using .Inference
include("variational/VariationalInference.jl")
using .Variational

@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
@eval Inference begin
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
end
end
end

include("modes/ModeEstimation.jl")
using .ModeEstimation

@init @require Optim="429524aa-4258-5aef-a3af-852621145aeb" @eval begin
include("modes/OptimInterface.jl")
export optimize
end

###########
# Exports #
###########
Expand Down Expand Up @@ -145,4 +134,22 @@ export @model, # modelling
optim_objective,
optim_function,
optim_problem

function __init__()
@static if !isdefined(Base, :get_extension)
@require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("../ext/TuringOptimExt.jl")
end
@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
@eval Inference begin
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
end
end
end
end

end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "5, 6"
NamedArrays = "0.9.4"
Optim = "0.22, 1.0"
Optim = "1"
Optimization = "3.5"
OptimizationOptimJL = "0.1"
PDMats = "0.10, 0.11"
Expand Down

0 comments on commit b6236aa

Please sign in to comment.