Skip to content

Commit

Permalink
Fix MLE with ConditionContext (#2022)
Browse files Browse the repository at this point in the history
* Fix MLE with ConditionContext

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

* Update ModeEstimation.jl

* Update tests

* Improve tests

Improve tests

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>

---------

Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
  • Loading branch information
devmotion and torfjelde committed Jun 21, 2023
1 parent a0b8999 commit 3e8d97f
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.26.1"
version = "0.26.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
45 changes: 26 additions & 19 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,41 +39,48 @@ intended to allow an optimizer to sample in R^n freely.
"""
struct OptimizationContext{C<:AbstractContext} <: AbstractContext
context::C

function OptimizationContext{C}(context::C) where {C<:AbstractContext}
if !(context isa Union{DefaultContext,LikelihoodContext})
throw(ArgumentError("`OptimizationContext` supports only leaf contexts of type `DynamicPPL.DefaultContext` and `DynamicPPL.LikelihoodContext` (given: `$(typeof(context)))`"))
end
return new{C}(context)
end
end

DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::OptimizationContext) = context.context
DynamicPPL.setchildcontext(::OptimizationContext, child) = OptimizationContext(child)
OptimizationContext(context::AbstractContext) = OptimizationContext{typeof(context)}(context)

# assume
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, dist, vn, vi)
r = vi[vn, dist]
return r, 0, vi
end
DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsLeaf()

# assume
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
r = vi[vn, dist]
return r, Distributions.logpdf(dist, r), vi
lp = if ctx.context isa DefaultContext
# MAP
Distributions.logpdf(dist, r)
else
# MLE
0
end
return r, lp, vi
end

# dot assume
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, right, left, vns, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
# TODO: Stop using `get_and_set_val!`.
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
return r, 0, vi
end

_loglikelihood(dist::Distribution, x) = loglikelihood(dist, x)
_loglikelihood(dists::AbstractArray{<:Distribution}, x) = loglikelihood(arraydist(dists), x)

function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
# TODO: Stop using `get_and_set_val!`.
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
return r, _loglikelihood(right, r), vi
lp = if ctx.context isa DefaultContext
# MAP
_loglikelihood(right, r)
else
# MLE
0
end
return r, lp, vi
end

"""
Expand Down
150 changes: 120 additions & 30 deletions test/modes/OptimInterface.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,60 @@
function find_map(model::DynamicPPL.TestUtils.DemoModels)
# Set up.
true_values = rand(NamedTuple, model)
d = length(true_values.s)
s_size, m_size = size(true_values.s), size(true_values.m)
s_isunivariate = true_values.s isa Real
m_isunivariate = true_values.m isa Real

# Cosntruct callable.
function f_wrapped(x)
s = s_isunivariate ? x[1] : reshape(x[1:d], s_size)
m = m_isunivariate ? x[2] : reshape(x[d + 1:end], m_size)
return -DynamicPPL.TestUtils.logjoint_true(model, s, m)
end
# TODO: Remove these once the equivalent is present in `DynamicPPL.TestUtils.
function likelihood_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
return (s=1/16, m=7/4)
end
function posterior_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
# TODO: Figure out exact for `s`.
return (s=0.907407, m=7/6)
end

function likelihood_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1] = 1e-32
vals.s[2] = 1e-32

vals.m[1] = 1.5
vals.m[2] = 2.0

return vals
end
function posterior_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1] = 0.890625
vals.s[2] = 1
vals.m[1] = 3/4
vals.m[2] = 1

# Optimize.
lbs = vcat(fill(0, d), fill(-Inf, d))
ubs = fill(Inf, 2d)
result = optimize(f_wrapped, lbs, ubs, rand(2d), Fminbox(NelderMead()))
@assert Optim.converged(result) "optimization didn't converge"

# Extract the result.
x = Optim.minimizer(result)
s = s_isunivariate ? x[1] : reshape(x[1:d], s_size)
m = m_isunivariate ? x[2] : reshape(x[d + 1:end], m_size)
return -Optim.minimum(result), (s = s, m = m)
return vals
end

# Used for testing how well it works with nested contexts.
struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext
context::C
logprior_weight::T1
loglikelihood_weight::T2
end
DynamicPPL.NodeTrait(::OverrideContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(parent::OverrideContext) = parent.context
DynamicPPL.setchildcontext(parent::OverrideContext, child) = OverrideContext(
child,
parent.logprior_weight,
parent.loglikelihood_weight
)

# Only implement what we need for the models above.
function DynamicPPL.tilde_assume(context::OverrideContext, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, context.logprior_weight, vi
end
function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
return context.loglikelihood_weight, vi
end

@testset "OptimInterface.jl" begin
Expand Down Expand Up @@ -126,20 +157,79 @@ end
# FIXME: Some models doesn't work for Tracker and ReverseDiff.
if Turing.Essential.ADBACKEND[] === :forwarddiff
@testset "MAP for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
maximum_true, maximizer_true = find_map(model)
result_true = posterior_optima(model)

@testset "$(optimizer)" for optimizer in [LBFGS(), NelderMead()]
result = optimize(model, MAP(), optimizer)
vals = result.values

for vn in DynamicPPL.TestUtils.varnames(model)
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(maximizer_true, vn))
sym = DynamicPPL.AbstractPPL.getsym(vn_leaf)
true_value_vn = get(maximizer_true, vn_leaf)
@test vals[Symbol(vn_leaf)] true_value_vn rtol = 0.05
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
@test get(result_true, vn_leaf) vals[Symbol(vn_leaf)] atol=0.05
end
end
end
end
@testset "MLE for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
result_true = likelihood_optima(model)

# `NelderMead` seems to struggle with convergence here, so we exclude it.
@testset "$(optimizer)" for optimizer in [LBFGS(),]
result = optimize(model, MLE(), optimizer)
vals = result.values

for vn in DynamicPPL.TestUtils.varnames(model)
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
@test get(result_true, vn_leaf) vals[Symbol(vn_leaf)] atol=0.05
end
end
end
end
end

# Issue: https://discourse.julialang.org/t/two-equivalent-conditioning-syntaxes-giving-different-likelihood-values/100320
@testset "OptimizationContext" begin
@model function model1(x)
μ ~ Uniform(0, 2)
x ~ LogNormal(μ, 1)
end

@model function model2()
μ ~ Uniform(0, 2)
x ~ LogNormal(μ, 1)
end

x = 1.0
w = [1.0]

@testset "With ConditionContext" begin
m1 = model1(x)
m2 = model2() | (x = x,)
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
end

@testset "With prefixes" begin
function prefix_μ(model)
return DynamicPPL.contextualize(model, DynamicPPL.PrefixContext{:inner}(model.context))
end
m1 = prefix_μ(model1(x))
m2 = prefix_μ(model2() | (var"inner.x" = x,))
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
end

@testset "Weighted" begin
function override(model)
return DynamicPPL.contextualize(
model,
OverrideContext(model.context, 100, 1)
)
end
m1 = override(model1(x))
m2 = override(model2() | (x = x,))
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
end
end
end

2 comments on commit 3e8d97f

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85983

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.26.2 -m "<description of version>" 3e8d97f6bf90352cd30624619e5f24c0227c7295
git push origin v0.26.2

Please sign in to comment.