In [19]:
import jax.numpy as jnp
from jax import vmap
from utils_v2 import KalmanFilter, ExtendedKalmanFilter
from typing import Tuple

In [None]:
class EnsembleKalmanFilter(KalmanFilter):
    def __init__(self, x_0, ensemble_size, f, h, H, C, R, Q, Z, w_k, P_0):
        super().__init__(x_0, None, None, None, None, R, Q, Z, w_k, None)
        self.ensemble_size = ensemble_size

In [None]:
class EnsembleKalmanFilter2(ExtendedKalmanFilter):
    def __init__(
        self,
        x_0: jnp.ndarray,
        ensemble_size: int,
        f: callable,
        h: callable,
        R: jnp.ndarray,
        Q: jnp.ndarray,
        Z: jnp.ndarray,
        w_k: jnp.ndarray,
        jaccobian_h: callable,
    ) -> None:
        super().__init__(None, f, h, R, Q, Z, w_k, None, None, jaccobian_h)
        self.ensemble_size = ensemble_size
        self.x_0 = x_0
        self.x_k = jnp.stack([x_0] * self.ensemble_size, axis=1)

    def _step_estimation(self, u_k: jnp.ndarray) -> jnp.ndarray:
        try:
            # Ensemble prediction
            self.x_k = vmap(lambda x: self.f(x, u_k) + self.w_k, in_axes=1, out_axes=1)(self.x_k)

            return self.x_k
        except Exception as e:
            raise RuntimeError(
                f"Error in the EnKF step estimation function: {e}"
            ) from e

    def _process_covariance(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Update error covariance matrix using system dynamics and process noise.

        Returns:
            Updated covariance matrix (n x n)
        """

        try:
            # Ensemble mean
            x_k_mean = jnp.mean(self.x_k, axis=1)
            # Ensemble perturbation
            x_k_perturbation = self.x_k - x_k_mean[:, None]
            # Ensemble covariance
            p_k = x_k_perturbation @ x_k_perturbation.T / (self.ensemble_size - 1)
            return p_k, x_k_mean
        except Exception as e:
            raise RuntimeError(
                f"Error in the EnKF proccess covariance function: {e}"
            ) from e

    def _update_x_k(self, x: jnp.ndarray, measurements: jnp.ndarray) -> jnp.ndarray:
        x_k = x + self.K @ (measurements - self.H @ self.x_k)
        return x_k

    def _current_state_and_process(
        self, x_km: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Update state estimate using measurement and Kalman gain.

        Args:
            x_km: Noisy measurement vector (p x 1)

        Returns:
            Tuple containing:
            - Corrected state estimate (n x 1)
            - Updated error covariance (n x n)
        """
        try:
            measurements = self.h(x_km) + self.Z
            x_k = vmap(
                lambda x: self._update_x_k(x, measurements), in_axes=1, out_axes=1
            )(self.x_k)
            p_k = (jnp.eye(self.K.shape[0]) - self.K @ self.H) @ self.P
            return x_k, p_k
        except Exception as e:
            raise RuntimeError(
                f"Error in the EnKF new state and proccess calculation function, the error: {e}"
            ) from e

    def predict(self, u_k: jnp.ndarray) -> jnp.ndarray:
        """
        Predicts the next state based on the control input and process model.

        Parameters:
            u_k (ndarray): Control input vector of shape (m, 1).

        Returns:
            ndarray: Updated state estimate vector of shape (n,).
        """
        try:
            self.x_k = self._step_estimation(u_k).squeeze()
            self.P, _ = self._process_covariance()
            return self.x_k
        except Exception as e:
            raise RuntimeError(f"Error in the EnKF predict method: {e}") from e

In [None]:
arr = jnp.array([1, 2, 3])
N = 4

stacked = jnp.stack([arr] * N, axis=1)  # Stacks along a new axis

print(stacked + jnp.array([1, 2, 3]).reshape(-1, 1))

[[2 2 2 2]
 [4 4 4 4]
 [6 6 6 6]]


In [None]:
def f(x):
    return x**2


for i in range(4):
    stacked = stacked.at[:, i].set(f(stacked[:, i]))
print(stacked + jnp.array([1, 2, 3]).reshape(-1, 1))

[[   2    2    2    2]
 [ 258  258  258  258]
 [6564 6564 6564 6564]]


In [33]:
stacked = vmap(lambda x: f(x), in_axes=0, out_axes=0)(stacked)
stacked

Array([[1, 1, 1, 1],
       [4, 4, 4, 4],
       [9, 9, 9, 9]], dtype=int32)