In [1]:

using Revise
using Rocket
using ReactiveMP
using GraphPPL
using Distributions
using LinearAlgebra
import ProgressMeter
using PyCall
using HDF5



In [2]:
@model function lar_model_multivariate(order, c, stype)
    mx_prev = datavar(Vector{Float64})
    vx_prev = datavar(Matrix{Float64})
    γ_a = datavar(Float64)
    γ_b = datavar(Float64)
    m_θ = datavar(Vector{Float64})
    v_θ = datavar(Matrix{Float64})
    y = datavar(Float64)
    
    θ  ~ MvNormalMeanCovariance(m_θ, v_θ) where { q = MeanField() }
    x_prev ~ MvNormalMeanCovariance(mx_prev, vx_prev) where { q = MeanField() }
    
    γ  ~ GammaShapeRate(γ_a, γ_b) where { q = MeanField() }

    γ_y = constvar(1.0)
    ct  = constvar(c)

    meta = ARMeta(Multivariate, order, stype)

    ar_node, x ~ AR(x_prev, θ, γ) where { q = q(y, x_prev)q(γ)q(θ), meta = meta }
    y ~ NormalMeanPrecision(dot(ct, x), γ_y) where { q = MeanField() }

    return x, mx_prev, vx_prev, x_prev, y, θ,m_θ,v_θ, γ, γ_a,γ_b,ar_node
end


using BenchmarkTools

lar_model(::Type{ Multivariate }, order, c, stype) = lar_model_multivariate(order, c, stype, options = (limit_stack_depth = 50, ))
lar_model(::Type{ Univariate }, order, c, stype)   = lar_model_univariate(order, c, stype, options = (limit_stack_depth = 50, ))



lar_model (generic function with 2 methods)

In [3]:


# setup inference
function start_inference(data,mx_min,vx_min,mθ,vθ,γa,γb, order, niter, artype=Multivariate, stype=ARsafe())

    c = ReactiveMP.ar_unit(artype, order)

    model, (x_t, mx_t_min, vx_t_min, x_t_min, y_t, θ, m_θ,v_θ,γ,γ_a,γ_b,ar_node) = lar_model(artype, order, c, stype)
    
    x_t_current = MvNormalMeanCovariance(zeros(order),diageye(order))
    θ_current = MvNormalMeanCovariance(mθ, vθ)
    γ_current = GammaShapeRate(γa, γb)
    
    x_t_stream = keep(Marginal)
    θ_stream = keep(Marginal)
    γ_stream = keep(Marginal)
    
    x_t_subscribtion = subscribe!(getmarginal(x_t), (x_t_posterior) -> next!(x_t_stream, x_t_posterior))
    γ_subscription = subscribe!(getmarginal(γ), (γ_posterior) -> next!(γ_stream, γ_posterior))
    θ_subscription = subscribe!(getmarginal(θ), (θ_posterior) -> next!(θ_stream, θ_posterior))

    setmarginal!(x_t, x_t_current)
    setmarginal!(γ, γ_current)
    setmarginal!(θ, θ_current)
    setmarginal!(ar_node, :y_x, MvNormalMeanPrecision(zeros(2*order), Matrix{Float64}(I, 2*order, 2*order)))
   
    
#     update!(mx_t_min, mx_min)
#     update!(vx_t_min, vx_min)
#     update!(γ_a, γa)
#     update!(γ_b, γb)
#     update!(m_θ, mθ)
#     update!(v_θ, vθ)
    
    for _ in 1:niter
        update!(y_t, data)
        update!(mx_t_min, mx_min)
        update!(vx_t_min, vx_min)
        update!(γ_a, γa)
        update!(γ_b, γb)
        update!(m_θ, mθ)
        update!(v_θ, vθ)
    end
    
#     for _ in 1:niter
#         update!(y_t, data)
#         update!(mx_t_min, mean(x_t_current))
#         update!(vx_t_min, cov(x_t_current))
#         update!(γ_a, shape(γ_current))
#         update!(γ_b, rate(γ_current))
#         update!(m_θ, mean(θ_current))
#         update!(v_θ, cov(θ_current))
#     end

#     x_t_stream ,θ_stream ,γ_stream

return mean(x_t_stream[end]), cov(x_t_stream[end]), mean(θ_stream[end]), cov(θ_stream[end]), shape(γ_stream[end]), rate(γ_stream[end])
# return mean(x_t_current), cov(x_t_current), mean(θ_current), cov(θ_current), shape(γ_current), rate(γ_current)
# return mx_t_min, vx_t_min, γ_a, γ_b, m_θ, v_θ
end

# 

start_inference (generic function with 3 methods)

In [7]:
ar_order = 3
mx_min = zeros(ar_order)
vx_min = diageye(ar_order)
mθ = zeros(ar_order)
vθ = diageye(ar_order)
γa = 0.1
γb = 0.1

for i = 1.0:20.0
    mx_min, vx_min, mθ, vθ, γa, γb = start_inference(i,mx_min,vx_min,mθ,vθ,γa,γb,ar_order,20);
    print(i, mx_min,'\n')
end

1.0[0.7326603259265834, 0.0, 0.0]
2.0[1.5918260603345191, 0.7025377880486924, 0.0]
3.0[2.621322950030165, 1.6334367250781978, 0.6988639424503474]
4.0[3.754583033494254, 2.6800063357701074, 1.6368536109472904]
5.0[4.886646472864992, 3.7721048516891784, 2.6641594243141515]
6.0[5.977183065382432, 4.866835123315578, 3.7360779924333496]
7.0[7.03306714761773, 5.935644233796427, 4.819678940792302]
8.0[8.066551461535472, 6.979864045040783, 5.884449048807563]
9.0[9.086653145161849, 8.00683457693946, 6.927985031772058]
10.0[10.099285683677854, 9.022842586885705, 7.955423793905819]
11.0[11.10782216819495, 10.032488121494534, 8.972141567955521]
12.0[12.114122448381313, 11.038537105606588, 9.982377723767732]
13.0[13.119187526245776, 12.042559895678954, 10.98880544839012]
14.0[14.123554281215482, 13.045436819965998, 11.992975559699572]
15.0[15.127390281203994, 14.047739767810258, 12.995835754579318]
16.0[16.130873153753992, 15.049589354782023, 13.9979934889443]
17.0[17.13414341639399, 16.05111040790