Skip to content

Commit

Permalink
Remove usage of DynamicPPL.tonamedtuple (#2071)
Browse files Browse the repository at this point in the history
* use getparams also take model as an argument, make use of recently
introduced immutable invlink, and make use of getparams in transition constructions

* remove unnecessary custom ad rules for Tracker for Bijectors.link and Bijectors.invlink

* bump patch version

* remove unnused nt keyword from Transtiion and fixed incorrect calls to
different transition types

* Update src/inference/emcee.jl

* Apply suggestions from code review

* fixed more Transition calls missing model

* fixed IS

* fixed DynamicHMC

* fixed Emcee

* another attempt at fixing Emcee

* more fixes

* more fixes

* fixed the strange NamedTuple bundle_samples

* fixed syntax error

* hopefully last change

* fix transition_to_turing

* fixed GibbsConditional

* `FlattenIterator` is replaced by `varname_and_value_leaves` (#2072)

* Update Utilities.jl

* remove obsolete functionality.

* remove obsolete utility module

* removes more references to utilities.

* Update Project.toml

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* `Tracker` imports no longer needed.  (#2076)

* Update Turing.jl

* Update Project.toml

* Update Turing.jl

* Update Essential.jl

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
torfjelde and yebai committed Sep 4, 2023
1 parent 9423562 commit e22b77c
Show file tree
Hide file tree
Showing 20 changed files with 102 additions and 203 deletions.
6 changes: 2 additions & 4 deletions Project.toml
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.28.3"
version = "0.29"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -34,7 +34,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[weakdeps]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Expand All @@ -57,7 +56,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.23"
DynamicPPL = "0.23.15"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand All @@ -74,7 +73,6 @@ SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
StatsAPI = "1.6"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
julia = "1.7"

[extras]
Expand Down
6 changes: 3 additions & 3 deletions ext/TuringDynamicHMCExt.jl
Expand Up @@ -89,7 +89,7 @@ function DynamicPPL.initialstep(
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create first sample and state.
sample = Turing.Inference.Transition(vi)
sample = Turing.Inference.Transition(model, vi)
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)

return sample, state
Expand Down Expand Up @@ -119,10 +119,10 @@ function AbstractMCMC.step(
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create next sample and state.
sample = Turing.Inference.Transition(vi)
sample = Turing.Inference.Transition(model, vi)
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)

return sample, newstate
end

end
end
4 changes: 2 additions & 2 deletions ext/TuringOptimExt.jl
Expand Up @@ -250,10 +250,10 @@ function _optimize(

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
DynamicPPL.tonamedtuple(f.varinfo),
Turing.Inference.getparams(model, f.varinfo),
DynamicPPL.getlogp(f.varinfo)
)]
varnames, _ = Turing.Inference._params_to_array(ts)
varnames, _ = Turing.Inference._params_to_array(model, ts)

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)
Expand Down
5 changes: 1 addition & 4 deletions src/Turing.jl
@@ -1,11 +1,10 @@
module Turing

using Requires, Reexport, ForwardDiff
using Reexport, ForwardDiff
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
using Statistics, LinearAlgebra
using Libtask
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC, Bijectors
using Tracker: Tracker

import AdvancedVI
using DynamicPPL: DynamicPPL, LogDensityFunction
Expand Down Expand Up @@ -44,8 +43,6 @@ ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(L
# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")
include("utilities/Utilities.jl")
using .Utilities
include("essential/Essential.jl")
Base.@deprecate_binding Core Essential false
using .Essential
Expand Down
6 changes: 3 additions & 3 deletions src/contrib/inference/abstractmcmc.jl
Expand Up @@ -5,11 +5,11 @@ end

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
# TODO: We should probably rename this `getparams` since it returns something
# very different from `Turing.Inference.getparams`.
θ = getparams(transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
# TODO: `deepcopy` is overkill; make more efficient.
varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model)
return Transition(varinfo, transition)
return Transition(f.model, varinfo, transition)
end

# NOTE: Only thing that depends on the underlying sampler.
Expand Down
12 changes: 6 additions & 6 deletions src/contrib/inference/sghmc.jl
Expand Up @@ -61,7 +61,7 @@ function DynamicPPL.initialstep(
end

# Compute initial sample and state.
sample = Transition(vi)
sample = Transition(model, vi)
= LogDensityProblemsAD.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
state = SGHMCState(ℓ, vi, zero(vi[spl]))

Expand Down Expand Up @@ -94,7 +94,7 @@ function AbstractMCMC.step(
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))

# Compute next sample and state.
sample = Transition(vi)
sample = Transition(model, vi)
newstate = SGHMCState(ℓ, vi, newv)

return sample, newstate
Expand Down Expand Up @@ -184,8 +184,8 @@ struct SGLDTransition{T,F<:Real}
stepsize::F
end

function SGLDTransition(vi::AbstractVarInfo, stepsize)
theta = tonamedtuple(vi)
function SGLDTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, stepsize)
theta = getparams(model, vi)
lp = getlogp(vi)
return SGLDTransition(theta, lp, stepsize)
end
Expand Down Expand Up @@ -214,7 +214,7 @@ function DynamicPPL.initialstep(
end

# Create first sample and state.
sample = SGLDTransition(vi, zero(spl.alg.stepsize(0)))
sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0)))
= LogDensityProblemsAD.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
state = SGLDState(ℓ, vi, 1)

Expand Down Expand Up @@ -242,7 +242,7 @@ function AbstractMCMC.step(
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))

# Compute next sample and state.
sample = SGLDTransition(vi, stepsize)
sample = SGLDTransition(model, vi, stepsize)
newstate = SGLDState(ℓ, vi, state.step + 1)

return sample, newstate
Expand Down
4 changes: 1 addition & 3 deletions src/essential/Essential.jl
Expand Up @@ -3,13 +3,11 @@ module Essential
using DistributionsAD, Bijectors
using Libtask, ForwardDiff, Random
using Distributions, LinearAlgebra
using ..Utilities, Reexport
using Tracker: Tracker
using Reexport
using ..Turing: Turing
using DynamicPPL: Model, AbstractSampler, Sampler, SampleFromPrior
using LinearAlgebra: copytri!
using Bijectors: PDMatDistribution
import Bijectors: link, invlink
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
Expand Down
16 changes: 0 additions & 16 deletions src/essential/ad.jl
Expand Up @@ -132,19 +132,3 @@ function verifygrad(grad::AbstractVector{<:Real})
return true
end
end

# These still seem necessary
for F in (:link, :invlink)
@eval begin
$F(dist::PDMatDistribution, x::Tracker.TrackedArray) = Tracker.track($F, dist, x)
Tracker.@grad function $F(dist::PDMatDistribution, x::Tracker.TrackedArray)
x_data = Tracker.data(x)
T = eltype(x_data)
y = $F(dist, x_data)
return y, Δ -> begin
out = reshape((ForwardDiff.jacobian(x -> $F(dist, x), x_data)::Matrix{T})' * vec(Δ), size(Δ))
return (nothing, out)
end
end
end
end
16 changes: 8 additions & 8 deletions src/inference/AdvancedSMC.jl
Expand Up @@ -52,8 +52,8 @@ struct SMCTransition{T,F<:AbstractFloat}
weight::F
end

function SMCTransition(vi::AbstractVarInfo, weight)
theta = tonamedtuple(vi)
function SMCTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, weight)
theta = getparams(model, vi)

# This is pretty useless since we reset the log probability continuously in the
# particle sweep.
Expand Down Expand Up @@ -128,7 +128,7 @@ function DynamicPPL.initialstep(
weight = AdvancedPS.getweight(particles, 1)

# Compute the first transition and the first state.
transition = SMCTransition(particle.model.f.varinfo, weight)
transition = SMCTransition(model, particle.model.f.varinfo, weight)
state = SMCState(particles, 2, logevidence)

return transition, state
Expand All @@ -150,7 +150,7 @@ function AbstractMCMC.step(
weight = AdvancedPS.getweight(particles, index)

# Compute the transition and the next state.
transition = SMCTransition(particle.model.f.varinfo, weight)
transition = SMCTransition(model, particle.model.f.varinfo, weight)
nextstate = SMCState(state.particles, index + 1, state.average_logevidence)

return transition, nextstate
Expand Down Expand Up @@ -225,8 +225,8 @@ struct PGState
rng::Random.AbstractRNG
end

function PGTransition(vi::AbstractVarInfo, logevidence)
theta = tonamedtuple(vi)
function PGTransition(model::DynamicPPL.Model, vi::AbstractVarInfo, logevidence)
theta = getparams(model, vi)

# This is pretty useless since we reset the log probability continuously in the
# particle sweep.
Expand Down Expand Up @@ -273,7 +273,7 @@ function DynamicPPL.initialstep(

# Compute the first transition.
_vi = reference.model.f.varinfo
transition = PGTransition(_vi, logevidence)
transition = PGTransition(model, _vi, logevidence)

return transition, PGState(_vi, reference.rng)
end
Expand Down Expand Up @@ -317,7 +317,7 @@ function AbstractMCMC.step(

# Compute the transition.
_vi = newreference.model.f.varinfo
transition = PGTransition(_vi, logevidence)
transition = PGTransition(model, _vi, logevidence)

return transition, PGState(_vi, newreference.rng)
end
Expand Down
87 changes: 59 additions & 28 deletions src/inference/Inference.jl
@@ -1,12 +1,11 @@
module Inference

using ..Essential
using ..Utilities
using DynamicPPL: Metadata, VarInfo, TypedVarInfo,
islinked, invlink!, link!,
setindex!!, push!!,
setlogp!!, getlogp,
tonamedtuple, VarName, getsym, vectorize,
VarName, getsym, vectorize,
_getvns, getdist,
Model, Sampler, SampleFromPrior, SampleFromUniform,
DefaultContext, PriorContext,
Expand Down Expand Up @@ -152,9 +151,8 @@ struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
end

Transition(θ, lp) = Transition(θ, lp, nothing)

function Transition(vi::AbstractVarInfo, t=nothing; nt::NamedTuple=NamedTuple())
θ = merge(tonamedtuple(vi), nt)
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
θ = getparams(model, vi)
lp = getlogp(vi)
return Transition(θ, lp, getstats(t))
end
Expand Down Expand Up @@ -291,18 +289,34 @@ end
##########################

"""
getparams(t)
getparams(model, t)
Return a named tuple of parameters.
"""
getparams(t) = t.θ
getparams(t::VarInfo) = tonamedtuple(TypedVarInfo(t))
getparams(model, t) = t.θ
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# Want the end-user to receive parameters in constrained space, so we `link`.
vi = DynamicPPL.invlink(vi, model)

# Extract parameter values in a simple form from the `VarInfo`.
vals = DynamicPPL.values_as(vi, OrderedDict)

# Obtain an iterator over the flattened parameter names and values.
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))

# Materialize the iterators and concatenate.
return mapreduce(collect, vcat, iters)
end


function _params_to_array(ts::Vector)
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
# TODO: Do we really need to use `Symbol` here?
names_set = OrderedSet{Symbol}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms, vs = flatten_namedtuple(getparams(t))
nms_and_vs = getparams(model, t)
nms = map(Symbol first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
push!(names_set, nm)
end
Expand All @@ -313,21 +327,8 @@ function _params_to_array(ts::Vector)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
(j, key) in enumerate(names)]

return names, vals
end

function flatten_namedtuple(nt::NamedTuple)
names_vals = mapreduce(vcat, keys(nt)) do k
v = nt[k]
if length(v) == 1
return [(Symbol(k), v)]
else
return mapreduce(vcat, zip(v[1], v[2])) do (vnval, vn)
return collect(FlattenIterator(vn, vnval))
end
end
end
return [vn[1] for vn in names_vals], [vn[2] for vn in names_vals]
return names, vals
end

function get_transition_extras(ts::AbstractVector{<:VarInfo})
Expand Down Expand Up @@ -384,7 +385,7 @@ function AbstractMCMC.bundle_samples(
)
# Convert transitions to array format.
# Also retrieve the variable names.
nms, vals = _params_to_array(ts)
nms, vals = _params_to_array(model, ts)

# Get the values of the extra parameters in each transition.
extra_params, extra_values = get_transition_extras(ts)
Expand Down Expand Up @@ -435,9 +436,39 @@ function AbstractMCMC.bundle_samples(
kwargs...
)
return map(ts) do t
params = map(first, getparams(t))
return merge(params, metadata(t))
# Construct a dictionary of pairs `vn => value`.
params = OrderedDict(getparams(model, t))
# Group the variable names by their symbol.
sym_to_vns = group_varnames_by_symbol(keys(params))
# Convert the values to a vector.
vals = map(values(sym_to_vns)) do vns
map(Base.Fix1(getindex, params), vns)
end
return merge(NamedTuple(zip(keys(sym_to_vns), vals)), metadata(t))
end
end

"""
group_varnames_by_symbol(vns)
Group the varnames by their symbol.
# Arguments
- `vns`: Iterable of `VarName`.
# Returns
- `OrderedDict{Symbol, Vector{VarName}}`: A dictionary mapping symbol to a vector of varnames.
"""
function group_varnames_by_symbol(vns)
d = OrderedDict{Symbol,Vector{VarName}}()
for vn in vns
sym = DynamicPPL.getsym(vn)
if !haskey(d, sym)
d[sym] = VarName[]
end
push!(d[sym], vn)
end
return d
end

function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
Expand Down Expand Up @@ -685,7 +716,7 @@ function transitions_from_chain(
model(rng, vi, sampler)

# Convert `VarInfo` into `NamedTuple` and save.
Transition(vi)
Transition(model, vi)
end

return transitions
Expand Down

0 comments on commit e22b77c

Please sign in to comment.