# Extended Kalman Filter (EKF) Demo

Demonstrates EKF with non-linear Range-Bearing tracking:
- Linear motion (constant velocity)
- Non-linear observations (range/bearing in polar coordinates)
- User-provided Jacobian functions

In [1]:
import sys
sys.path.append("/Users/chunchaoma/Desktop/ml_simulation_research")

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from ekf_ukf_pf.ekf import ExtendedKalmanFilter
import ekf_ukf_pf.utils as utils

tf.random.set_seed(42)
np.random.seed(42)

## 1. Generate Data

In [2]:
# Create model and generate trajectory using utils
dt = 1.0
model = utils.create_model(
    dt=dt, 
    process_noise_std_pos=0.1, 
    process_noise_std_vel=0.1,
    range_noise_std=50.0, 
    bearing_noise_std=0.005, 
    seed=42
)

T = 100
true_states, obs = utils.generate_trajectory(model, T=T)

print(f"Generated {T} observations")
print(f"True states: {true_states.shape}, Observations: {obs.shape}")

Generated 100 observations
True states: (4, 101), Observations: (2, 100)


## 2. Define Functions and Jacobians

In [3]:
# Get filter functions from utils
state_transition_fn = utils.create_state_transition_fn(model)  # x_t = A * x_{t-1}
observation_fn = utils.create_observation_fn() #  z_t = h(x_t) 
state_transition_jacobian_fn = utils.create_state_jacobian_fn(model) # A
observation_jacobian_fn = utils.create_observation_jacobian_fn(model)
#  H = [[cos(θ),      0,  sin(θ),     0],
#       [-sin(θ)/r,   0,  cos(θ)/r,   0]]

print("Functions and Jacobians defined")

Functions and Jacobians defined


## 3. Initialize and Run EKF

In [4]:
# Initialize EKF
x0 = true_states[:,1]
Sigma0 = tf.eye(4, dtype=tf.float32) * 100.0

ekf = ExtendedKalmanFilter(
    state_transition_fn=state_transition_fn,
    observation_fn=observation_fn,
    Q=model.Q, 
    R=model.R, 
    x0=x0, 
    Sigma0=Sigma0,
    state_transition_jacobian_fn=state_transition_jacobian_fn,
    observation_jacobian_fn=observation_jacobian_fn
)

In [5]:
observations = tf.convert_to_tensor(obs, dtype=tf.float32)
T = tf.shape(observations)[1]

In [6]:
# Initialize
x = ekf.x0
Sigma_post = ekf.Sigma0
u = None

In [8]:
# x_pred, Sigma_pred = ekf.predict(x, Sigma_post, u)
# self.f = state_transition_fn(x)
# x_{t|t-1} = f(x_{t-1|t-1}, u_t)

# F_t = create_state_jacobian_fn(x, u)
# Σ_{t|t-1} = F_t @ Σ_{t-1|t-1} @ F_t^T + Q_t


In [None]:
# x, Sigma_post = ekf.update(z, x_pred, Sigma_pred)
# H_t = ∂h/∂x|_{x_{t|t-1}}

# Then:
# r_t = z_t - h(x_{t|t-1})
# S_t = H_t @ Σ_{t|t-1} @ H_t^T + R_t
# K_t = Σ_{t|t-1} @ H_t^T @ S_t^{-1}
# x_{t|t} = x_{t|t-1} + K_t @ r_t
# Σ_{t|t} = (I - K_t @ H_t) @ Σ_{t|t-1}  [or Joseph form]

# self.h = observation_fn

# H = self.compute_observation_jacobian(x_pred)
# r = z - self.h(x_pred)
# S = tf.linalg.matmul(tf.linalg.matmul(H, Sigma_pred), H, transpose_b=True) + self.R 


In [7]:
t = 0
for t in range(T):
    print(t)
    z = tf.reshape(obs[:, t], [-1, 1])
    x_pred, Sigma_pred = ekf.predict(x, Sigma_post, u)
    print(f"Prediction at time {t}: x_pred = {x_pred.numpy().flatten()}, Sigma_pred = {Sigma_pred.numpy()}")
    x, Sigma_post = ekf.update(z, x_pred, Sigma_pred)
    print(f"Update at time {t}: x_post = {x.numpy().flatten()}, Sigma_post = {Sigma_post.numpy()}")
    print("-----")

0
Prediction at time 0: x_pred = [21.499626 -5.65718  21.597942  6.68472 ], Sigma_pred = [[200.01 100.     0.     0.  ]
 [100.   100.01   0.     0.  ]
 [  0.     0.   200.01 100.  ]
 [  0.     0.   100.   100.01]]
Update at time 0: x_post = [27.793411  -2.5104446 15.546616   3.659208 ], Sigma_post = [[92.18605  46.09072  92.584305 46.28984 ]
 [46.09072  73.05672  46.289837 23.143764]
 [92.584305 46.289837 93.03091  46.51313 ]
 [46.28984  23.143764 46.51313  73.26791 ]]
-----
1
Prediction at time 1: x_pred = [25.282967  -2.5104446 19.205824   3.659208 ], Sigma_pred = [[257.4342   119.14744  208.30775   69.43361 ]
 [119.14744   73.06672   69.4336    23.143764]
 [208.30774   69.4336   259.33508  119.78104 ]
 [ 69.43361   23.143764 119.78104   73.27791 ]]
Update at time 1: x_post = [22.035982  -5.535907  21.702368   6.3731327], Sigma_post = [[220.15123  105.13189  167.2367    51.739597]
 [105.13188   60.11399   79.85107   34.60699 ]
 [167.2367    79.85107  127.0802    39.331547]
 [ 51.7396