# Gaussian Linear Dynamical System

In [None]:
# Activate local environment, see `Project.toml`
import Pkg; Pkg.activate("."); Pkg.instantiate();

In this example the goal is to estimate hidden states of a Linear Dynamical process where all hidden states are Gaussians. A simple multivariate Linear Gaussian State Space Model can be described with the following equations:

$$\begin{aligned}
 p(x_i|x_{i - 1}) & = \mathcal{N}(x_i|A * x_{i - 1}, \mathcal{P}),\\
 p(y_i|x_i) & = \mathcal{N}(y_i|B * x_i, \mathcal{Q}),
\end{aligned}$$

where $x_i$ are hidden states, $y_i$ are noisy observations, $A$, $B$ are state transition and observational matrices, $\mathcal{P}$ and $\mathcal{Q}$ are state transition noise and observation noise covariance matrices. For a more rigorous introduction to Linear Gaussian Dynamical systems we refer to [Simo Sarkka, Bayesian Filtering and Smoothing](https://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf) book.

To model this process in `RxInfer`, first, we start with importing all needed packages:

In [None]:
using RxInfer, BenchmarkTools, Random, LinearAlgebra, Plots

Next step, is to generate some synthetic data:

In [None]:
function generate_data(rng, A, B, Q, P)
    x_prev = [ 10.0, -10.0 ]

    x = Vector{Vector{Float64}}(undef, n)
    y = Vector{Vector{Float64}}(undef, n)

    for i in 1:n
        x[i] = rand(rng, MvNormal(A * x_prev, Q))
        y[i] = rand(rng, MvNormal(B * x[i], P))
        x_prev = x[i]
    end
    
    return x, y
end

In [None]:
# Seed for reproducibility
seed = 1234

rng = MersenneTwister(1234)

# We will model 2-dimensional observations with rotation matrix `A`
# To avoid clutter we also assume that matrices `A`, `B`, `P` and `Q`
# are known and fixed for all time-steps
θ = π / 35
A = [ cos(θ) -sin(θ); sin(θ) cos(θ) ]
B = diageye(2)
Q = diageye(2)
P = 25.0 .* diageye(2)

# Number of observations
n = 300;

In [None]:
x, y = generate_data(rng, A, B, Q, P);

Let's plot our synthetic dataset. Lines represent our hidden states we want to estimate using noisy observations, which are represented as dots.

In [None]:
px = plot()

px = plot!(px, getindex.(x, 1), label = "Hidden Signal (dim-1)", color = :orange)
px = scatter!(px, getindex.(y, 1), label = false, markersize = 2, color = :orange)
px = plot!(px, getindex.(x, 2), label = "Hidden Signal (dim-2)", color = :green)
px = scatter!(px, getindex.(y, 2), label = false, markersize = 2, color = :green)

plot(px)

To create a model we use `GraphPPL` package and `@model` macro:

In [None]:
@model function rotate_ssm(n, x0, A, B, Q, P)
    
    # We create constvar references for better efficiency
    cA = constvar(A)
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)
    
    # `x` is a sequence of hidden states
    x = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)
    
    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior
    
    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(cA * x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end

end

To run inference we also specify prior for out first hidden state:

In [None]:
x0 = MvNormalMeanCovariance(zeros(2), 100.0 * diageye(2));

In [None]:
# For large number of observations you need to use `limit_stack_depth = 100` option during model creation, e.g. 
# inference(..., options = (limit_stack_depth = 500, ))`
result = inference(
    model = rotate_ssm(length(y), x0, A, B, Q, P), 
    data = (y = y,),
    free_energy = true
);

xmarginals = result.posteriors[:x]
bfe        = result.free_energy;

In [None]:
px = plot()

px = plot!(px, getindex.(x, 1), label = "Hidden Signal (dim-1)", color = :orange)
px = plot!(px, getindex.(x, 2), label = "Hidden Signal (dim-2)", color = :green)

px = plot!(px, getindex.(mean.(xmarginals), 1), ribbon = getindex.(var.(xmarginals), 1) .|> sqrt, fillalpha = 0.5, label = "Estimated Signal (dim-1)", color = :teal)
px = plot!(px, getindex.(mean.(xmarginals), 2), ribbon = getindex.(var.(xmarginals), 2) .|> sqrt, fillalpha = 0.5, label = "Estimated Signal (dim-1)", color = :violet)

plot(px)

As we can see from our plot, estimated signal resembles closely to the real hidden states with small variance. We maybe also interested in the value for minus log evidence:

In [None]:
bfe

We may be also interested in performance of our resulting Belief Propagation algorithm:

In [None]:
@benchmark inference(
    model = rotate_ssm(length($y), $x0, $A, $B, $Q, $P), 
    data = (y = $y,)
)

└ @ Revise /Users/bvdmitri/.julia/packages/Revise/do2nH/src/packagedef.jl:570
