In [1]:
import jax.numpy as jnp
from jax import grad, jacfwd
from utils_v2 import KalmanFilter

In [None]:
class ExtendedKalmanFilter(KalmanFilter):
    def __init__(
        self,
        x_0: jnp.ndarray | float,
        f: callable,
        h: callable,
        R: jnp.ndarray,
        Q: jnp.ndarray,
        Z: jnp.ndarray,
        w_k: jnp.ndarray,
        P_0: jnp.ndarray,
    ):
        super().__init__(x_0, None, None, None, R, None, Q, Z, w_k, P_0)
        self.f = f  # f is the non linear function in EKF and A and B are the linear functions in KF
        self.h = h  # h is the non linear function in EKF and C and H are the linear functions in KF

    def derivative(self, x: float, h: float = 1e-5):
        return (self.f(x + h) - self.f(x - h)) / (2 * h)

    def gradient(self, x: jnp.ndarray):
        grad_f = grad(self.f)
        return grad_f(x)

    def jacobian(self, x: jnp.ndarray):
        jac_F = jacfwd(self.f)
        return jac_F(x)

    def predict(self):
        if type(self.x_k)==float or (type(self.x_k)==jnp.ndarray and self.x_k.flatten().shape==(1,)):
            
        self.x_k = self.f(self.x_k) + self.w_k
        self.F = self.jacobian(self.x_k)
        self.P_k = self.F @ self.P_k @ self.F.T + self.Q
        self.P = self.F @ self.P @ self.F.T + self.Q

    def update(self, z):
        y = z - self.h(self.x)
        H = self.jacobian(self.x)
        S = H @ self.P @ H.T + self.R
        K = self.P @ H.T @ jnp.linalg.inv(S)
        self.x = self.x + K @ y
        self.P = self.P - K @ H @ self.P


In [None]:
def F(x):
    return jnp.array([x[0] ** 2 + x[1], x[0] + x[1] ** 2])


print(F(jnp.array([1, 2])))

[3 5]


In [16]:
x=jnp.array([[1,2]])
print(x.flatten().shape==(1,))

False
