# Infinite Data Stream

In [7]:
# Activate local environment, see `Project.toml`
import Pkg; Pkg.activate("."); Pkg.instantiate();

[32m[1m  Activating[22m[39m project at `~/.julia/dev/RxInfer/examples`


In [121]:
using RxInfer, Plots

We assume that we don't know the shape of our signal in advance. So we try to fit a simple gaussian random walk with unknown observation noise:

In [122]:
mutable struct DataGenerationProcess
    previous :: Float64
    process_noise :: Float64
    observation_noise :: Float64
    history :: Vector{Float64}
    observations :: Vector{Float64}
end

function getnext!(process::DataGenerationProcess)
    process.previous = process.previous + 1.0
    next = 10sin(0.1 * process.previous)
    observation = next + rand(Normal(0.0, process.observation_noise))
    push!(process.history, next)
    push!(process.observations, observation)
    return observation
end

function gethistory(process::DataGenerationProcess)
    return process.history
end

function getobservations(process::DataGenerationProcess)
    return process.observations
end

getobservations (generic function with 1 method)

In [136]:
@model function kalman_filter()
    
    # Reactive inputs
    x_t_min_mean = datavar(Float64)
    x_t_min_var  = datavar(Float64)
    τ_shape = datavar(Float64)
    τ_rate  = datavar(Float64)

    τ ~ Gamma(shape = τ_shape, rate = τ_rate)
    
    x_t_min ~ Normal(mean = x_t_min_mean, variance = x_t_min_var)
    x_t     ~ Normal(mean = x_t_min, precision = 1.0)
    
    y = datavar(Float64)
    y ~ Normal(mean = x_t, precision = τ)
    
end

@model function simple_kalman_filter()
    
    # Reactive inputs
    x_t_min_mean = datavar(Float64)
    x_t_min_var  = datavar(Float64)
    
    x_t_min ~ Normal(mean = x_t_min_mean, variance = x_t_min_var)
    x_t     ~ Normal(mean = x_t_min, precision = 1.0)
    
    y = datavar(Float64)
    y ~ Normal(mean = x_t, precision = 10.0)
    
end

@constraints function filter_constraints()
    q(x_t, τ) = q(x_t)q(τ)
end

filter_constraints (generic function with 1 method)

In [169]:
# We force stop after n data points
n = 1_000_000

1000000

In [170]:
process = DataGenerationProcess(0.0, 1.0, 10.0, Float64[], Float64[])
# stream = timer(100, 100) |> map_to(process) |> map(Float64, getnext!) |> take(n)
stream = from(1:n) |> map_to(process) |> map(Float64, getnext!)

keystream = stream |> map(NamedTuple{(:y,), Tuple{Float64}}, (d) -> (y = d, ))

static = keep(eltype(keystream))
subscribe!(keystream, static)
static = collect(static);

In [171]:
function testme(datastream)
    autoupdates = (
        RxInfer.RxInferenceAutoUpdateSpecification((:x_t_min_mean, :x_t_min_var), RxInfer.FromMarginalAutoUpdate(), mean_var, :x_t),
        RxInfer.RxInferenceAutoUpdateSpecification((:τ_shape, ), RxInfer.FromMarginalAutoUpdate(), shape, :τ),
        RxInfer.RxInferenceAutoUpdateSpecification((:τ_rate, ), RxInfer.FromMarginalAutoUpdate(), rate, :τ)
    );
    
    engine = rxinference(
        model = simple_kalman_filter(),
        constraints = filter_constraints(),
        datastream = datastream,
        autoupdates = autoupdates,
        returnvars = (:x_t, :τ),
        historyvars = (
            x_t = KeepLast(),
        ),
        initmarginals = (
            x_t = NormalMeanVariance(0.0, 1e3),
            τ = GammaShapeRate(1.0, 1.0)
        ),

        iterations = 1,
        free_energy = true,
        autostart = false
    );

    qxt = []
    qτ = []

    subscription1 = subscribe!(engine.posteriors[:x_t], (q) -> push!(qxt, q))
    subscription2 = subscribe!(engine.posteriors[:τ], (q) -> push!(qτ, q))

    RxInfer.start(engine)

    unsubscribe!(subscription1)
    unsubscribe!(subscription2)
    
    return qxt, qτ, engine
end

testme (generic function with 2 methods)

In [None]:
@time qxt, qτ, engine = testme(from(static))

# p1 = plot(mean.(qxt), ribbon = std.(qxt))
# p2 = plot(getvalues(engine.fe_actor))

# plot(p1, p2, size = (800, 200))

 11.973101 seconds (127.00 M allocations: 5.349 GiB, 21.52% gc time)


(Any[NormalWeightedMeanPrecision{Float64}(xi=-133.85154512098273, w=10.000999000999), NormalWeightedMeanPrecision{Float64}(xi=66.04013153390288, w=10.909099164547767), NormalWeightedMeanPrecision{Float64}(xi=-38.75283341487092, w=10.916030592559267), NormalWeightedMeanPrecision{Float64}(xi=66.28009919335321, w=10.916079436668749), NormalWeightedMeanPrecision{Float64}(xi=-35.76367635705593, w=10.91607978065984), NormalWeightedMeanPrecision{Float64}(xi=94.85872470384422, w=10.916079783082434), NormalWeightedMeanPrecision{Float64}(xi=294.4390626181073, w=10.916079783099494), NormalWeightedMeanPrecision{Float64}(xi=109.32617120474714, w=10.916079783099615), NormalWeightedMeanPrecision{Float64}(xi=178.2081675890655, w=10.916079783099615), NormalWeightedMeanPrecision{Float64}(xi=58.59389656177452, w=10.916079783099615)  …  NormalWeightedMeanPrecision{Float64}(xi=180.35364238454898, w=10.916079783099615), NormalWeightedMeanPrecision{Float64}(xi=78.52147834667902, w=10.916079783099615), Normal

In [155]:
using BenchmarkTools

In [168]:
@benchmark testme(from($static))

BenchmarkTools.Trial: 104 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m37.116 ms[22m[39m … [35m85.153 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% …  0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m46.316 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m48.538 ms[22m[39m ± [32m 9.030 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.72% ± 10.98%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m▃[39m▃[39m▂[39m [39m [39m [39m [34m▅[39m[39m [39m▃[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▇[39m▇[39m▄[39m▄[39m▅[39m▇[3

In [124]:
plot_callback = (posteriors) -> begin
    IJulia.clear_output(true)

    p = plot(mean.(posteriors), ribbon = var.(posteriors), label = "Estimation")
    p = plot!(gethistory(process), label = "Real states")    
    p = scatter!(getobservations(process), ms = 2, label = "Observations")
    p = plot(p, size = (1000, 400), legend = :bottomright)
        
    display(p)
end

#127 (generic function with 1 method)