In [57]:
using Rocket
using ReactiveMP
using GraphPPL
using Distributions
using Random
using BenchmarkTools

In [58]:
@model [ outbound_message_portal = EmptyPortal() ] function gaussian_mixture_model(n)
    
    s ~ Beta(1.0, 1.0)
    
    m1 ~ NormalMeanVariance(0.0, 100.0)
    w1 ~ Gamma(0.01, 100.0)
    
    m2 ~ NormalMeanVariance(0.0, 100.0)
    w2 ~ Gamma(0.01, 100.0)
    
    z = randomvar(n)
    y = datavar(Float64, n)
    
    for i in 1:n
        z[i] ~ Bernoulli(s) where { q = MeanField() }
        y[i] ~ NormalMixture(z[i], (m1, m2), (w1, w2)) where { q = MeanField() }
    end
    
    return s, m1, w1, m2, w2, z, y
end

gaussian_mixture_model (generic function with 1 method)

In [59]:
using Random

In [71]:
function inference(data)
    n = length(data)
    model, (s, m1, w1, m2, w2, z, y) = gaussian_mixture_model(n);
    
    buffer = Vector{Marginal}()
    
    switchsub = subscribe!(getmarginal(s), (ms) -> push!(buffer, ms))
    
    setmarginal!(s, vague(Beta))
    setmarginal!(m1, NormalMeanVariance(-1.0, 1e4))
    setmarginal!(m2, NormalMeanVariance(2.0, 1e4))
    setmarginal!(w1, vague(Gamma))
    setmarginal!(w2, vague(Gamma))
    
    for i in 1:25
        update!(y, data)
    end
    
    unsubscribe!(switchsub)
    
    return buffer
end

inference (generic function with 1 method)

In [72]:
n = 1000

Random.seed!(124)

switch = 0.23
z      = rand(n) .< switch
y      = Vector{Float64}(undef, n)

d1 = Normal(-2.0, 0.75)
d2 = Normal(+2.0, 0.75)

for i in 1:n
    y[i] = z[i] ? rand(d1) : rand(d2)
end

In [73]:
@time marginals = inference(y); marginals[end] |> mean

  1.237079 seconds (6.19 M allocations: 299.565 MiB, 25.85% gc time)


0.24442274899442687

In [74]:
n = 1000

Random.seed!(421)

switch = 0.8604
z      = rand(n) .< switch
y      = Vector{Float64}(undef, n)

d1 = Normal(-2.0, 0.75)
d2 = Normal(+2.0, 0.75)

for i in 1:n
    y[i] = z[i] ? rand(d1) : rand(d2)
end

In [75]:
@time marginals = inference(y); marginals[end] |> mean

  1.061293 seconds (5.92 M allocations: 285.658 MiB, 23.60% gc time)


0.866239770265741