In [31]:
import jax.numpy as jnp
from jax import grad, jacfwd
from utils_v2 import KalmanFilter
from typing import Tuple

In [33]:
class ExtendedKalmanFilter(KalmanFilter):
    def __init__(
        self,
        x_0: jnp.ndarray | float,
        f: callable,
        h: callable,
        R: jnp.ndarray,
        Q: jnp.ndarray,
        Z: jnp.ndarray,
        w_k: jnp.ndarray,
        P_0: jnp.ndarray,
    ):
        super().__init__(x_0, None, None, None, None, R, Q, Z, w_k, P_0)
        self.f = f  # f is the non linear function in EKF and A and B are the linear functions in KF
        self.h = h  # h is the non linear function in EKF and C and H are the linear functions in KF

        self.set_matrices()

    def derivative(self, f: callable, x: float | jnp.ndarray, h: float = 1e-5) -> float:
        if type(x) is jnp.ndarray:
            x = x[0]
        return (f(x + h) - f(x - h)) / (2 * h)

    def gradient(self, f: callable, x: jnp.ndarray) -> jnp.ndarray:
        grad_f = grad(f)
        return jnp.array(grad_f(x))

    def jacobian(self, f: callable, x: jnp.ndarray) -> jnp.ndarray:
        jac_F = jacfwd(f)
        return jac_F(x)

    def set_matrices(self) -> None:
        if type(self.x_k) is float or (
            type(self.x_k) is jnp.ndarray and self.x_k.flatten().shape == (1,)
        ):
            self.A = self.derivative(self.f, self.x_k)
            self.H = self.derivative(self.h, self.x_k)

        elif self.x_k.squeeze().shape == (1,):
            self.A = self.gradient(self.f, self.x_k)
            self.H = self.gradient(self.h, self.x_k)
        else:
            self.A = self.jacobian(self.f, self.x_k)
            self.H = self.jacobian(self.h, self.x_k)

    def step_estimation(self, u_k: jnp.ndarray) -> jnp.ndarray:
        """
        Predict next state using system dynamics and control input.

        Args:
            u_k: Control input vector (m x 1)

        Returns:
            Predicted state vector (n x 1)
        """

        try:
            new_x_k = self.f(self.x_k, u_k) + self.w_k
            return new_x_k
        except Exception as e:
            print(
                "Error in the Extended Kalman Filter step estimation function, the error: ",
                e,
            )

    def current_state_and_process(
        self, x_km: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Update state estimate using measurement and Kalman gain.

        Args:
            x_km: Noisy measurement vector (p x 1)

        Returns:
            Tuple containing:
            - Corrected state estimate (n x 1)
            - Updated error covariance (n x n)
        """
        try:
            Y = self.h(x_km) + self.Z
            x_k = self.x_k + self.K @ (Y - self.H @ self.x_k)
            p_k = (jnp.eye(self.K.shape[0]) - self.K @ self.H) @ self.P
            return x_k, p_k
        except Exception as e:
            print(
                "Error in the Extended Kalman Filter new state and proccess calculation function, the error: ",
                e,
            )


In [None]:
def F(x):
    return x[0] + x[1]

In [None]:
def F(x):
    return jnp.array([x[0] ** 2 + x[1], x[0] + x[1] ** 2])


grad_f = grad(F)  # Computes the gradient function
x0 = jnp.array([2.0, 3.0])
print(grad_f(x0))  # Output: [4. 6.]

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).

In [26]:
f

<function __main__.f(x)>