# Extended Kalman Filter (EKF) Demo

Demonstrates EKF with non-linear Range-Bearing tracking:
- Linear motion (constant velocity)
- Non-linear observations (range/bearing in polar coordinates)
- User-provided Jacobian functions

In [1]:
import sys
sys.path.append("/Users/chunchaoma/Desktop/ml_simulation_research")

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from ekf_ukf_pf.ekf import ExtendedKalmanFilter
import ekf_ukf_pf.utils as utils

tf.random.set_seed(42)
np.random.seed(42)

## 1. Generate Data

In [2]:
# Create model and generate trajectory using utils
dt = 1.0
model = utils.create_model(
    dt=dt, 
    process_noise_std_pos=0.1, 
    process_noise_std_vel=0.1,
    range_noise_std=50.0, 
    bearing_noise_std=0.005, 
    seed=42
)

T = 100
true_states, obs = utils.generate_trajectory(model, T=T)

print(f"Generated {T} observations")
print(f"True states: {true_states.shape}, Observations: {obs.shape}")

Generated 100 observations
True states: (4, 101), Observations: (2, 100)


## 2. Define Functions and Jacobians

In [3]:
# Get filter functions from utils
state_transition_fn = utils.create_state_transition_fn(model)
observation_fn = utils.create_observation_fn()
state_transition_jacobian_fn = utils.create_state_jacobian_fn(model)
observation_jacobian_fn = utils.create_observation_jacobian_fn(model)

print("Functions and Jacobians defined")

Functions and Jacobians defined


## 3. Initialize and Run EKF

In [4]:
# Initialize EKF
x0 = true_states[:,1]
Sigma0 = tf.eye(4, dtype=tf.float32) * 100.0

ekf = ExtendedKalmanFilter(
    state_transition_fn=state_transition_fn,
    observation_fn=observation_fn,
    Q=model.Q, 
    R=model.R, 
    x0=x0, 
    Sigma0=Sigma0,
    state_transition_jacobian_fn=state_transition_jacobian_fn,
    observation_jacobian_fn=observation_jacobian_fn
)

In [8]:
observations = tf.convert_to_tensor(obs, dtype=tf.float32)
T = tf.shape(observations)[1]

In [11]:
# Initialize
x = ekf.x0
Sigma_post = ekf.Sigma0
u = None

In [17]:
t = 0
for t in range(T):
    print(t)
    z = tf.reshape(obs[:, t], [-1, 1])
    x_pred, Sigma_pred = ekf.predict(x, Sigma_post, u)
    print(f"Prediction at time {t}: x_pred = {x_pred.numpy().flatten()}, Sigma_pred = {Sigma_pred.numpy()}")
    x, Sigma_post = ekf.update(z, x_pred, Sigma_pred)
    print(f"Update at time {t}: x_post = {x.numpy().flatten()}, Sigma_post = {Sigma_post.numpy()}")
    print("-----")

0
Prediction at time 0: x_pred = [24.191133  -0.8620117 12.575122  -0.5397022], Sigma_pred = [[490.539    165.24588  333.43436   86.34659 ]
 [165.24588   60.12399  114.45805   34.60699 ]
 [333.4344   114.45805  227.92838   61.506626]
 [ 86.3466    34.606983  61.506638  22.18508 ]]
Update at time 0: x_post = [28.000277    0.5209162  15.224242    0.26226825], Sigma_post = [[45.20725     3.4127266  23.630184   -7.614843  ]
 [ 3.4125168   1.2255965   1.822333    0.34506458]
 [23.63014     1.8224441  12.374841   -3.9294996 ]
 [-7.614899    0.34508106 -3.9295201   2.2065802 ]]
-----
1
Prediction at time 1: x_pred = [28.521193    0.5209162  15.48651     0.26226825], Sigma_pred = [[53.26809     4.6383233  18.18274    -7.2697783 ]
 [ 4.6381135   1.2355965   2.1673975   0.34506458]
 [18.182766    2.167525    6.732401   -1.7229195 ]
 [-7.2698183   0.34508106 -1.72294     2.2165802 ]]
Update at time 1: x_post = [-11.709375    -0.78261507   3.7336273    8.608602  ], Sigma_post = [[10.881441    3.23