Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions Example/Using_Custom_VI/AdvancedPS_SSM_Container.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
## Some function which make the model easier to define.


# A very light weight container for a state space model
# The state space is one dimensional
# Note that this is only for a very simple demonstration.

module AdvancedPS_SSM_Container
using Libtask
using AdvancedPS: current_trace
using NamedTupleTools
# This is a very shallow container solely for testing puropose
abstract type AbstractPFContainer end
export AbstractPFContainer

const initialize = current_trace

mutable struct Container{M<:AbstractMatrix, I<:Integer} <: AbstractPFContainer
x::M
marked::Vector{Bool}
produced_at::Vector{I}
num_produce::I
end



function Base.copy(vi::Container)
return Container(deepcopy(vi.x),deepcopy(vi.marked),deepcopy(vi.produced_at),deepcopy(vi.num_produce))
end


@inline function set_retained_vns_del_by_spl!(container::Container)
for i in 1:length(container.marked)
if container.marked[i]
if container.produced_at[i] > container.num_produce
container.marked[i] = false
end
end
end
container
end



# This is important for initalizaiton

@inline function report_observation!(trace, logp::Real)
produce(logp)
current_trace()
end

# logγ corresponds to the proposal distributoin we are sampling from.
@inline function report_transition!(trace ,logp::Real,logγ::Real)
trace.taskinfo.logp += logp - logγ
trace.taskinfo.logpseq += logp
end

@inline function update_var!(trace, vn::Int, r::Vector{<:Real})
if !trace.vi.marked[vn]
trace.vi.x[vn,:] = r # We assume that it is already vectorized!!
trace.vi.marked[vn] = true
trace.vi.produced_at[vn] = trace.vi.num_produce
return r
end
return trace.vi.x[vn,:]
end



@inline function is_marked(trace, vn::Int)::Bool
if trace.vi.marked[vn]
return true
end
return false
end

@inline function get_vn(trace, vn::Int)
return trace.vi.x[vn,:]
end

@inline function get_traj(trace)
return trace.vi.x[1:trace.vi.num_produce+1,:]
end

# The reason for this is that we need to pass it!
@inline function Base.copy(vi::Container)
Container(deepcopy(vi.x),deepcopy(vi.marked),deepcopy(vi.produced_at),copy(vi.num_produce))
end

@inline function tonamedtuple(vi::Container)
tnames = Tuple([Symbol("x$i") for i in 1:size(vi.x)[1]])
tvalues = Tuple([vi.x[i,:] for i in 1:size(vi.x)[1]])
return namedtuple(tnames, tvalues)
end

function Base.empty!(vi::Container)
for i in 1:length(vi.marked)
vi.marked[i] = false
end
vi
end

export Container,
tonamedtuple,
get_traj,
get_vn,
is_marked,
report_observation!,
report_transition!,
set_retained_vns_del_by_spl,
update_var!,
initialize


end
62 changes: 62 additions & 0 deletions Example/Using_Custom_VI/demonstarte.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

## It is not yet a package...


using Distributions
using AdvancedPS
using BenchmarkTools

dir = splitdir(splitdir(pathof(AdvancedPS))[1])[1]
push!(LOAD_PATH,dir*"/Example/Using_Custom_VI/" )
using AdvancedPS_SSM_Container
const APSCont = AdvancedPS_SSM_Container


# Define a short model.
# The syntax is rather simple. Observations need to be reported with report_observation.
# Transitions must be reported using report_transition.
# The trace contains the variables which we want to infer using particle gibbs.
# Thats all!
n = 200


y = Array{Float64,2}(undef,n-1,1)
for i =1:n-1
y[i,1] = 0
end



function task_f(y)
var = initialize()
x = zeros(n,1)
r = rand(MultivariateNormal(1,1.0))
x[1,:] = update_var!(var, 1, r)
report_transition!(var,0.0,0.0)
for i = 2:n
# Sampling
r = rand(MultivariateNormal(1,1.0))
x[i,:] = update_var!(var, i, r)
logγ = logpdf(MultivariateNormal(1,1.0),x[i,:]) #γ(x_t|x_t-1)
logp = logpdf(MultivariateNormal(1,1.0),x[i,:]) # p(x_t|x_t-1)
report_transition!(var,logp,logγ)
#Proposal and Resampling
logpy = logpdf(MultivariateNormal(x[i,:],1.0), y[i-1,:])
var = report_observation!(var,logpy)
end
end


tcontainer = Container(zeros(n,1),Vector{Bool}(undef,n),Vector{Int}(zeros(n)),0)
model = PFModel(task_f, (y=y,))



alg = AdvancedPS.SMCAlgorithm()
uf = AdvancedPS.SMCUtilityFunctions(APSCont.set_retained_vns_del_by_spl!, tonamedtuple)
@btime sample(model, alg, uf, tcontainer, 10)


alg = AdvancedPS.PGAlgorithm(AdvancedPS.resample_systematic, 1.0, 10)
uf = AdvancedPS.PGUtilityFunctions( APSCont.set_retained_vns_del_by_spl!, APSCont.tonamedtuple)
@btime chn2 =sample(model, alg, uf, tcontainer, 5)
210 changes: 210 additions & 0 deletions Example/Using_Turing_VI/AdvancedPS_Turing_Container.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
module AdvancedPS_Turing_Container

using Turing.Core.RandomVariables
import Turing.Core: @varname
import Turing.Utilities: vectorize
using Turing
using AdvancedPS: current_trace

# This is important for initalizaiton
const TypedVarInfo = VarInfo{<:NamedTuple}
const Selector = Turing.Selector
const BASE_SELECTOR = Selector(:PS)
const initialize = current_trace


Base.copy(vi::VarInfo) = deepcopy(vi)
tonamedtuple(vi::TypedVarInfo) = Turing.tonamedtuple(vi)
tonamedtuple(vi::UntypedVarInfo) = tonamedtuple(TypedVarInfo(vi))


function report_observation!(trace, logp::Float64)
produce(logp)
trace = current_trace()
end

# logγ corresponds to the proposal distributoin we are sampling from.
function report_transition!(trace, logp::Float64, logγ::Float64)
trace.taskinfo.logp += logp - logγ
trace.taskinfo.logpseq += logp
end


# We obtain the new value for our variable
# If the distribution is not specified, we simply set it to be Normal
function update_var!(trace, vn, val, dist= Normal())
# check if the symbol is contained in the varinfo...
if haskey(trace.vi,vn)
if is_flagged(trace.vi, vn, "del")
unset_flag!(trace.vi, vn, "del")
trace.vi[vn] = vectorize(dist,val)
setgid!(trace.vi, BASE_SELECTOR, vn)
setorder!(trace.vi, vn, trace.vi.num_produce)
return val
else
updategid!(trace.vi, BASE_SELECTOR, vn)
val = trace.vi[vn]
end
else
#We do not specify the distribution... Thats why we set it to be Normal()
push!(trace.vi, vn, val, dist)
end
return val
end

#########################################################################
# This is copied from turing compiler3.0, but we need to extract the #
# sampler form set_retained_vns_del_by_spl! #
#########################################################################


# Get all indices of variables belonging to SampleFromPrior:

# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to

# the SampleFromPrior sampler

"""
`getidx(vi::UntypedVarInfo, vn::VarName)`

Returns the index of `vn` in `vi.metadata.vns`.
"""
getidx(vi::UntypedVarInfo, vn::VarName) = vi.idcs[vn]
"""
`getidx(vi::TypedVarInfo, vn::VarName{sym})`

Returns the index of `vn` in `getfield(vi.metadata, sym).vns`.
"""
function getidx(vi::TypedVarInfo, vn::VarName{sym}) where sym
getfield(vi.metadata, sym).idcs[vn]
end

"""
`setgid!(vi::VarInfo, gid::Selector, vn::VarName)`

Adds `gid` to the set of sampler selectors associated with `vn` in `vi`.
"""
setgid!(vi::UntypedVarInfo, gid::Selector, vn::VarName) = push!(vi.gids[getidx(vi, vn)], gid)
function setgid!(vi::TypedVarInfo, gid::Selector, vn::VarName{sym}) where sym
push!(getfield(vi.metadata, sym).gids[getidx(vi, vn)], gid)
end
"""
`updategid!(vi::VarInfo, vn::VarName, spl::Sampler)`

If `vn` doesn't have a sampler selector linked and `vn`'s symbol is in the space of
`spl`, this function will set `vn`'s `gid` to `Set([spl.selector])`.
"""
function updategid!(vi::AbstractVarInfo, sel::Selector, vn::VarName)
setgid!(vi, sel, vn)
end

@generated function _getidcs(metadata::NamedTuple{names}) where {names}
exprs = []
for f in names
push!(exprs, :($f = findinds(metadata.$f)))
end
length(exprs) == 0 && return :(NamedTuple())
return :($(exprs...),)
end

# Get all indices of variables belonging to a given sampler

@inline function _getidcs(vi::UntypedVarInfo, s::Selector)
findinds(vi, s)
end

@inline function _getidcs(vi::TypedVarInfo, s::Selector)
return _getidcs(vi.metadata, s)
end
# Get a NamedTuple for all the indices belonging to a given selector for each symbol
@generated function _getidcs(metadata::NamedTuple{names}, s::Selector) where {names}
exprs = []
# Iterate through each varname in metadata.
for f in names
# If the varname is in the sampler space
# or the sample space is empty (all variables)
# then return the indices for that variable.
push!(exprs, :($f = findinds(metadata.$f, s)))
end
length(exprs) == 0 && return :(NamedTuple())
return :($(exprs...),)
end

@inline function findinds(f_meta, s::Selector)

# Get all the idcs of the vns in `space` and that belong to the selector `s`
return filter((i) ->
(s in f_meta.gids[i] || isempty(f_meta.gids[i])), 1:length(f_meta.gids))
end

@inline function findinds(f_meta)
# Get all the idcs of the vns
return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids))
end

"""
`set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)`

Sets the `"del"` flag of variables in `vi` with `order > vi.num_produce` to `true`.
"""
function set_retained_vns_del_by_spl!(vi::AbstractVarInfo)
return set_retained_vns_del_by_spl!(vi, BASE_SELECTOR)
end

function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, sel::Selector)
# Get the indices of `vns` that belong to `spl` as a vector
gidcs = _getidcs(vi, sel)
if vi.num_produce == 0
for i = length(gidcs):-1:1
vi.flags["del"][gidcs[i]] = true
end
else
for i in 1:length(vi.orders)
if i in gidcs && vi.orders[i] > vi.num_produce
vi.flags["del"][i] = true
end
end
end
return nothing
end

function set_retained_vns_del_by_spl!(vi::TypedVarInfo, sel::Selector)
# Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol
gidcs = _getidcs(vi, sel)
return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, vi.num_produce)
end

@generated function _set_retained_vns_del_by_spl!(metadata, gidcs::NamedTuple{names}, num_produce) where {names}
expr = Expr(:block)
for f in names
f_gidcs = :(gidcs.$f)
f_orders = :(metadata.$f.orders)
f_flags = :(metadata.$f.flags)
push!(expr.args, quote
# Set the flag for variables with symbol `f`
if num_produce == 0
for i = length($f_gidcs):-1:1
$f_flags["del"][$f_gidcs[i]] = true
end
else
for i in 1:length($f_orders)
if i in $f_gidcs && $f_orders[i] > num_produce
$f_flags["del"][i] = true
end
end
end
end)
end
return expr
end

export VarInfo,
UntypedVarInfo,
TypedVarInfo,
tonamedtuple,
report_observation!,
report_transition!,
set_retained_vns_del_by_spl,
update_var!,
initialize
end
Loading