In [17]:
import filterpy.kalman as kf
import numpy as np
import plotly.graph_objects as go

In [None]:
f = kf.KalmanFilter(dim_x=1, dim_z=1)
P_0 = np.array([[1.]])  # initial uncertainty
F = np.array([[1.]])    # state transition matrix
H = np.array([[1.]])    # Measurement function
R = np.array([[5.]])    # measurement uncertainty
Q = np.array([[0.1]])  # process uncertainty

f.F = F
f.H = H
f.R = R
f.Q = Q

def sample_traj(x_0 = np.array([[0.]]),L = 10):
    
    # Get real states
    x = np.zeros((L, 1))
    x[0] = x_0
    z = np.zeros((L, 1))
    for i in range(1, L):
        x[i] = F @ x[i-1] + np.random.normal(0, np.sqrt(Q[0, 0]), (1, 1))
        z[i] = H @ x[i] + np.random.normal(0, np.sqrt(R[0, 0]), (1, 1))
        
    return x, z

def rts(x_pred, P_est):
    return f.rts_smoother(x_pred, P_est)

x,z= sample_traj(L=10)
# print("Real states:\n", x)
# print("Measurements:\n", z)
# print("Kalman Filter estimates:\n", x_est)



In [None]:
L = 10
x, z, x_pred, x_est, P_est = sample_traj(L=L)
(xs, P, K, Pp)= f.rts_smoother(x_pred, P_est)

print(x_est[-1])
print(x_pred[-1])
print(xs[-1])


fig = go.Figure()
fig.add_trace(go.Scatter(y=x.flatten(), mode='lines', line=dict(color="black"), name='Real State'))
fig.add_trace(go.Scatter(x=list(range(L)), y=z.flatten(), mode='markers', name='Measurements', marker=dict(color='red')))
fig.add_trace(go.Scatter(y=x_est.flatten(), mode='lines', line=dict(color="green"), name='Kalman Filter Estimate'))
fig.add_trace(go.Scatter(y=xs.flatten(), mode='markers', line=dict(color="blue", dash='dash'), name='RTS Smoother Estimate'))
# fig.add_trace(go.Scatter(y=x_pred.flatten(), mode='lines+markers', line=dict(color="orange", dash='dash'), name='Kalman Filter Prediction'))

# Add propagation lines: from x_est[i-1] to x_pred[i] for i=1..L-1
for i in range(1, L):
    fig.add_trace(go.Scatter(
        x=[i-1, i],
        y=[x_est[i-1, 0, 0], x_pred[i, 0, 0]],
        mode='lines',
        line=dict(color='orange', dash='dot'),
        showlegend=(i==1),
        name='Prediction Propagation'
    ))

fig.update_layout(
    title='Kalman Filter State Estimation',
    xaxis_title='Time Step',
    yaxis_title='State Value',
    legend=dict(x=0, y=1)
)
fig.show()



[[-0.96067627]]
[[-0.91172999]]
[[-0.91172999]]
