In [1]:
import jax.numpy as jnp
from jax import grad, jacfwd
from utils_v2 import KalmanFilter
from typing import Tuple
import jax.random as jrandom

In [6]:
class ExtendedKalmanFilter(KalmanFilter):
    def __init__(
        self,
        x_0: jnp.ndarray | float | int,
        f: callable,
        h: callable,
        R: jnp.ndarray,
        Q: jnp.ndarray,
        Z: jnp.ndarray,
        w_k: jnp.ndarray,
        P_0: jnp.ndarray,
    ) -> None:
        if not isinstance(x_0, (jnp.ndarray, float, int)):
            raise Exception(
                "The State input must be of type jnp.ndarray, float, or int"
            )
        # Expand scalar state into array if necessary
        x_0 = jnp.expand_dims(x_0, axis=-1) if type(x_0) in (int, float) else x_0
        super().__init__(x_0, None, None, None, None, R, Q, Z, w_k, P_0)
        self.f = f  # Nonlinear state transition function: f(x, u)
        self.h = h  # Nonlinear measurement function: h(x)

    def _derivative(
        self, f: callable, x: float | jnp.ndarray, h: float = 1e-5
    ) -> jnp.ndarray:
        """
        Compute finite-difference derivative for a scalar function.
        If x is a jnp.ndarray, assumes single element.
        """
        if isinstance(x, jnp.ndarray):
            x = x[0]
        return jnp.array((f(x + h) - f(x - h)) / (2 * h))

    def _gradient(
        self, f: callable, x: jnp.ndarray, u: jnp.ndarray = None
    ) -> jnp.ndarray:
        """
        Compute the gradient of a scalar-valued function f at x.
        Returns the gradient as a diagonal matrix.
        """
        grad_f = grad(f)
        # Wrap gradient vector as a diagonal matrix
        return (
            jnp.diag(jnp.array(grad_f(x)))
            if u is None
            else jnp.diag(jnp.array(grad_f(x, u)))
        )

    def _jacobian(
        self, f: callable, x: jnp.ndarray, u: jnp.ndarray = None
    ) -> jnp.ndarray:
        """
        Compute the full Jacobian of a vector-valued function f at x.
        """
        jac_F = jacfwd(f)
        return jnp.array(jac_F(x)) if u is None else jnp.array(jac_F(x, u))

    def _set_matrices(self, u_k: jnp.ndarray) -> None:
        """
        Set the matrices A and H by choosing the appropriate differentiation method
        based on the initial state shape. Also stores the chosen function in _function.
        """
        # Check if state is scalar-like: either a float/int or a 0-dim jnp.ndarray.
        if type(self.x_k) in (float, int) or (
            type(self.x_k) is jnp.ndarray and jnp.array(self.x_k).squeeze().shape == ()
        ):
            self.A = self._derivative(self.f, self.x_k)
            self.H = self._derivative(self.h, self.x_k)
            self._function = self._derivative
        # If state is a vector (1-dimensional array)
        elif self.x_k.squeeze().ndim == 1:
            self.A = self._gradient(self.f, self.x_k, u_k)
            self.H = self._gradient(self.h, self.x_k)
            self._function = self._gradient
        # Otherwise assume state is higher-dimensional and use the Jacobian
        else:
            self.A = self._jacobian(self.f, self.x_k, u_k)
            self.H = self._jacobian(self.h, self.x_k)
            self._function = self._jacobian

    def _step_estimation(self, u_k: jnp.ndarray) -> jnp.ndarray:
        """
        Predicts the next state using the nonlinear system dynamics f and control input u_k.

        Args:
            u_k: Control input vector (m x 1)

        Returns:
            Predicted state vector (n x 1)

        Raises:
            RuntimeError: If an error occurs during the prediction.
        """

        try:
            new_x_k = self.f(self.x_k, u_k) + self.w_k
            return new_x_k
        except Exception as e:
            raise RuntimeError(f"Error in EKF step estimation: {e}") from e

    def _current_state_and_process(
        self, x_km: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Updates the state estimate using a measurement and the Kalman gain.

        This function recomputes the linearization matrices A and H based on the current state.

        Args:
            x_km: Noisy measurement vector (p x 1)

        Returns:
            A tuple containing:
            - Corrected state estimate (n x 1)
            - Updated error covariance matrix (n x n)

        Raises:
            RuntimeError: If an error occurs during state update.
        """
        try:
            # Recompute linearization at current state
            self.A = self._function(self.f, self.x_k)
            self.H = self._function(self.h, self.x_k)
            measurements = self.h(x_km) + self.Z
            x_k = self.x_k + self.K @ (measurements - self.H @ self.x_k)
            p_k = (jnp.eye(self.K.shape[0]) - self.K @ self.H) @ self.P
            return x_k, p_k
        except Exception as e:
            raise RuntimeError(f"Error in EKF state and process update: {e}") from e


In [10]:
def f_nonlinear(x, u):
    return x + u


def h_nonlinear(x):
    return x**2


x0 = jnp.array([1.0, 2.0,3.1,4.8])
delta_t = 0.1
m = 0.1
M = 1
g = 9.8
l = 0.5
A = jnp.array(
    [
        [0, 1, 0, 0],
        [0, 0, m * g * delta_t / M, 0],
        [0, 0, 0, 1],
        [0, 0, (m + M) * g * delta_t / (M * l), 0],
    ]
)

# A=jnp.array([[1,delta_t,0,0],
#             [0,1,m*g*delta_t/M,0],
#             [0,0,1,delta_t],
#             [0,0,(m+M)*g*delta_t/(M*l),1]])

B = jnp.array([[0], [delta_t / M], [0], [delta_t / (M * l)]])

H = jnp.array([[0, 1, 0, 0], [0, 0, 0, 1]])

R = jnp.eye(H.shape[0]) * 0.005
C = H
Q = jnp.eye(A.shape[0]) * 0.005
mean = 0
std_dev = 0.005
# Q=jrandom.normal(key, shape=(A.shape[0],1)) * std_dev + mean     #FAILED
P_0 = jnp.ones(A.shape)
key = jrandom.PRNGKey(42)
Z = jrandom.normal(key, shape=(C.shape[0], 1)) * std_dev + mean
w_k = jnp.ones((A.shape[0], 1)) * 0.005
# w_k=jrandom.normal(key, shape=(A.shape[0],1)) * std_dev + mean    #FAILED

In [11]:
ekf = ExtendedKalmanFilter(x0, f_nonlinear, h_nonlinear, R, Q, Z, w_k, P_0)

In [12]:
ekf.predict(jnp.array([0.5, -0.3,0,0.5]))

Error in the proccess covariance function, the error:  unsupported operand type(s) for @: 'NoneType' and 'jaxlib.xla_extension.ArrayImpl'


Array([[1.505    , 0.705    , 1.005    , 1.505    ],
       [2.505    , 1.705    , 2.005    , 2.505    ],
       [3.605    , 2.805    , 3.105    , 3.605    ],
       [5.3050003, 4.505    , 4.8050003, 5.3050003]], dtype=float32)

In [None]:
def F(x):
    return jnp.array([x[0] ** 2 + x[1], x[0] + x[1] ** 2])


grad_f = grad(F)  # Computes the gradient function
x0 = jnp.array([2.0, 3.0])
print(grad_f(x0))  # Output: [4. 6.]

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).