Composite nodes
=====================

In [1]:
using ForneyLab

# Define new node type called StateTransition, with exposed variables called (x, x_prev, y):
@composite StateTransition (x, x_prev, y) begin
    x ~ GaussianMeanVariance(x_prev, constant(1.0))
    y ~ GaussianMeanVariance(x, constant(1.0))
end

g = FactorGraph()

# Build factor graph
x_prev ~ GaussianMeanVariance(constant(0.0), constant(1.0))
y = Variable(id=:y)
x = Variable(id=:x)
StateTransition(x, x_prev, y)
placeholder(y, :y)

# Assign id for easy lookup
x.id = :x
;

In [2]:
# Define custom rule for sum-product message towards x
@sumProductRule(:node_type     => StateTransition,
                :outbound_type => Message{Gaussian},
                :inbound_types => (Void, Message{Gaussian}, Message{PointMass}),
                :name          => SPStateTransitionVGP)

# Build algorithm
schedule = sumProductSchedule(x)

ForneyLab.draw(g, schedule=schedule)

In [3]:
# Implement rule to send some dummy message for now
ruleSPStateTransitionVGP(::Void, ::Message{Gaussian}, ::Message{PointMass}) = Message(Gaussian, m=2.0, v=3.0)

algo = messagePassingAlgorithm(schedule, x)

# Inspect Julia code of algorithm
println(algo)

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(2))

messages[1] = ruleSPGaussianMeanVariancePPV(Message(PointMass, m=0.0), Message(PointMass, m=1.0), nothing)
messages[2] = ruleSPStateTransitionVGP(nothing, messages[1], Message(PointMass, m=data[:y]))

marginals[:x] = messages[2].dist

return marginals

end


In [4]:
eval(parse(algo))

data = Dict(:y => 1.0)
marginals = step!(data)

println(marginals[:x])

ForneyLab.ProbabilityDistribution{ForneyLab.Gaussian}(Dict(:m=>2.0,:v=>3.0))
