In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using MacroTools
using LinearAlgebra

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1278
┌ Info: Precompiling GraphPPL [b3f8163a-e979-4e85-b43e-1f63d8c8b42c]
└ @ Base loading.jl:1278
│ - If you have GraphPPL checked out for development and have
│   added ReactiveMP as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with GraphPPL


In [53]:
@model function transition_model(n)
    
    A ~ MatrixDirichlet(ones(3, 3)) 
    B ~ MatrixDirichlet([ 10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0 ])
    
    s_0 ~ Categorical(fill(1.0 / 3.0, 3))
    
    s = randomvar(n)
    x = datavar(Vector{Float64}, n)
    
    s_prev = s_0
    
    for t in 1:n
        s[t] ~ Transition(s_prev, A) where { q = q(out, in)q(a) }
        x[t] ~ Transition(s[t], B) where { q = MeanField() }
        s_prev = s[t]
    end
    
    return s, x, A, B
end

transition_model (generic function with 1 method)

In [70]:
function inference(data, n_its)
    n = length(data)
    
    model, (s, x, A, B) = transition_model(n)
    
    sbuffer = Vector{Vector{Marginal}}()
    Abuffer = Vector{Marginal}()
    Bbuffer = Vector{Marginal}()
    
    # fe_scheduler = PendingScheduler()
    
    ssub = subscribe!(collectLatest(getmarginals(s)), (ms) -> push!(sbuffer, ms))
    Asub = subscribe!(getmarginal(A), (mA) -> push!(Abuffer, mA))
    Bsub = subscribe!(getmarginal(B), (mB) -> push!(Bbuffer, mB))
    
    setmarginal!(A, vague(MatrixDirichlet, 3, 3))
    setmarginal!(B, vague(MatrixDirichlet, 3, 3))
    
#     foreach(s) do svar
#         setmarginal!(svar, vague(Categorical, 3))
#     end
    # setmarginal!(γ, Gamma(0.01, 100.0))
    
    for i in 1:n_its
        update!(x, data)
        # release!(fe_scheduler)
    end
    
    unsubscribe!(ssub)
    unsubscribe!(Asub)
    unsubscribe!(Bsub)
    
    
    return sbuffer, Abuffer, Bbuffer
end

inference (generic function with 1 method)

In [71]:
import ForneyLab

function generate_data()
    n_samples = 100
    A_data = [0.9 0.0 0.1; 0.1 0.9 0.0; 0.0 0.1 0.9] # Transition probabilities (some transitions are impossible)
    B_data = [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9] # Observation noise
    s_0_data = [1.0, 0.0, 0.0] # Initial state
    # Generate some data
    s_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the states
    x_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the observations
    s_t_min_data = s_0_data
    for t = 1:n_samples
        a = A_data*s_t_min_data
        s_data[t] = ForneyLab.sample(ForneyLab.ProbabilityDistribution(ForneyLab.Categorical, p=a./sum(a))) # Simulate state transition
        b = B_data*s_data[t]
        x_data[t] = ForneyLab.sample(ForneyLab.ProbabilityDistribution(ForneyLab.Categorical, p=b./sum(b))) # Simulate observation
        s_t_min_data = s_data[t]
    end
    ;
    return x_data, s_data
end

generate_data (generic function with 1 method)

In [78]:
x_data, s_data = generate_data();

In [79]:
@time sbuffer, Abuffer, Bbuffer = inference(x_data, 20);

  0.072084 seconds (534.74 k allocations: 38.806 MiB)


In [74]:
using Plots

In [77]:
mean.(Bbuffer)[end]

3×3 Array{Float64,2}:
 0.863823   0.0481804  0.0239232
 0.049749   0.882761   0.0668449
 0.0864279  0.0690589  0.909232

In [66]:
length(sbuffer)

11

In [59]:
sum.(ReactiveMP.probvec.(sbuffer[end]) .- s_data)

100-element Array{Float64,1}:
 -6.591949208711867e-17
  1.7780915628762273e-17
 -5.160802341031001e-17
  7.958043946043603e-17
 -1.8127860323957634e-16
  8.794912497930851e-17
 -2.865343053971847e-17
  5.107608671943431e-17
  1.183593018516324e-16
  1.5597264690740686e-16
 -5.411862906597176e-17
  1.969859822134601e-17
  1.2262665383989957e-16
  ⋮
  9.8879238130678e-17
  1.6501557065229377e-16
 -4.2392304944183223e-17
 -3.3237572850258745e-17
 -1.2835937282691667e-17
  2.4508051295855926e-17
 -1.5670109524204556e-18
 -8.679716017104266e-17
 -8.497434526855141e-17
  3.5778671692021646e-17
  5.346471963069144e-17
 -5.149960319306146e-17

In [29]:
coefs

1-element Array{Float64,1}:
 -0.269936031369128

In [30]:
@show mean.(θbuffer)[end]

mean.(θbuffer)[end] = [-0.2537372243225593]


1-element Array{Float64,1}:
 -0.2537372243225593

In [31]:
@show mean.(γbuffer)

mean.(γbuffer) = Any[]


Any[]

In [32]:
fe

10-element Array{Float64,1}:
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581
 140.0493209580581