In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Extended Kalman Filter

> Dynamic Logistic Regression で導出された拡張カルマンフィルタ

In [None]:
#| default_exp EKF

In [None]:
#| hide
#| exporti
import jax.numpy as jnp
import jax.lax as lax
import jax
from jaxtyping import Array, Float
from KalmanPaper import simple as sp
from typing import Tuple, NamedTuple
from functools import partial

## 概要

濾波推定値 $\hat{\mathbf w}_{t/t}$
$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$

推定誤差共分散行列 $\mathbf P_{t/t}$
  $$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
  $$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
  $$\mathbf P_{t/t}^{-1}=\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T$$

各変数の情報の利用

| $\!$ | $\mathbf P_{t/t}$ | $\hat{\mathbf w}_{t/t}$ | $\mathbf P_{t/t-1}$ | $\hat{\mathbf w}_{t/t-1}$ | $\mathbf x_t$ | $y_t$ |
| -- | -- | -- | -- | -- | -- | -- |
| $\mathbf P_{t/t}$ | $\!$ | $\!$ | $\bigcirc$ | $\bigcirc$ | $\bigcirc$ | $\!$ |
| $\hat{\mathbf w}_{t/t}$ | $\!$ | $\!$ | $\bigcirc$ | $\bigcirc$ | $\bigcirc$ | $\bigcirc$ |


## ラプラス近似による導出

### 1. $p(\mathbf w_t\mid\mathbf Y_t)$

$$
p(\mathbf w_t\mid \mathbf Y_t)=\frac{p(y_t\mid\mathbf w_t)p(\mathbf w_t\mid\mathbf Y_{t-1})}{p(y_t\mid\mathbf Y_{t-1})}
$$

$$p(\mathbf w_t\mid\mathbf Y_{t-1})=\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})$$

よって、

$$
\begin{equation*}
p(\mathbf w_t\mid \mathbf Y_t)=
\begin{cases}
\sigma(\mathbf w_t^T\mathbf x_t)\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) &( y_t=1) \\
\{1-\sigma(\mathbf w_t^T\mathbf x_t)\}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) & (y_t=0)
\end{cases}
\end{equation*}
$$

### 2. $\hat{\mathbf w}_{t/t}$

$\mathbf w_{t}$ をMAP推定する。 $\ln p(\mathbf w_t\mid\mathbf Y_t)$ の微分を $\mathbf 0$ と置く。

$$
\begin{split}
&\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\
&= \frac{\partial}{\partial \mathbf w_t}\ln\left[\sigma(\mathbf w_t^T\mathbf x_t)^{y_t}\{1-\sigma(\mathbf w_t^T\mathbf x_t)\}^{1-y_t}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})\right] \\
&= \frac{\partial}{\partial \mathbf w_t}y_t\ln\sigma(\mathbf w_t^T\mathbf x_t)+\frac{\partial}{\partial\mathbf w_t}(1-y_t)\ln\{1-\sigma(\mathbf w_t^T\mathbf x_t)\} \\
&\phantom{=} +\frac{\partial}{\partial\mathbf w_t}\ln\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) \\
&= \left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})
\end{split}
$$

ここで、$\sigma(\mathbf w_t^T\mathbf x_t)$ をテイラー展開で一次近似する。

$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$ とする。

$$
\begin{split}
\sigma(\mathbf w_t^T\mathbf x_t) &\simeq \sigma_t+\left.\frac{\partial\sigma(\mathbf w_t^T\mathbf x_t)}{\partial \mathbf w_t}\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\
&\phantom{00}=\sigma_t+\sigma_t\{1-\sigma_t\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t
\end{split}
$$

$\sigma(\mathbf w_t^T\mathbf x_t)$ を $\sigma_t$ に置き換え、$\mathbf 0$ とおく。

$$
\begin{split}
&\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\
&=\left[ y_t-\sigma_t-\sigma_t\left\{1-\sigma_t\right\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t \right]\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\
&=(y_t-\sigma_t)\mathbf x_t-\left[\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T\right](\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\
&=\mathbf 0
\end{split}
$$

$$
\begin{split}
\mathbf w_t &= \hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}^{-1}+\sigma_t\left\{1-\sigma_t\right\}\mathbf x_t\mathbf x_t^T\right]^{-1}\mathbf x_t(y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf x_t(y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}\mathbf x_t-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right](y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\left[1-\frac{\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)
\end{split}
$$

よって $\mathbf w_t$ のMAP推定値 $\hat{\mathbf w}_{t/t}$ は

$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$

となる。

### 3. $\mathbf P_{t/t}$

$\hat{\mathbf w}_{t/t-1}$ を $p(\mathbf w_t\mid\mathbf Y_t)$ のピークとすると $\mathbf P_{t/t}$ が得られる。 

（ $\hat{\mathbf w}_{t/t}$ をピークとするのが本来のラプラス近似）

$$
\begin{split}
\mathbf P_{t/t}^{-1} &= \left.-\frac{\partial^2}{\partial\mathbf w_t^2} \ln p(\mathbf w_t\mid\mathbf Y_t)\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\
&= \left. -\frac{\partial}{\partial\mathbf w_t}\left[\left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})\right]\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\
&= \mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T
\end{split}
$$

Sherman-Morrison の公式によって、 $\mathbf P_{t/t}$ が得られる。

$$
\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T
$$

## 関数等

In [None]:
#| export
@jax.jit
def Ptt(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
) -> Float[Array, "N N"]:   # $\mathbf P_{t/t}$
  r"""$\!$*
  推定誤差共分散行列 $\mathbf P_{t/t}$
  $$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
  $$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
  *$\!$"""
  dsigma = sp.dxlosi(w @ x)
  Ptmx = Ptm @ x
  return Ptm - (dsigma / (1 + dsigma * (x @ Ptmx))) * jnp.outer(Ptmx, Ptmx)

In [None]:
#| hide
sp.reshow_doc(Ptt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L17){target="_blank" style="float:right; font-size:smaller"}

### Ptt

>      Ptt (Ptm:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'])

*$\!$*
推定誤差共分散行列 $\mathbf P_{t/t}$
$$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
$$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| **Returns** | **Float[Array, 'N N']** | **$\mathbf P_{t/t}$** |

In [None]:
#| export
@jax.jit
def wtt(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
    y: Float[Array, "N"],   # $y_t$
) -> Float[Array, "N"]: # $\hat{\mathbf w}_{t/t}$
  r"""$\!$*
  濾波推定値 $\hat{\mathbf w}_{t/t}$
  $$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$
  *$\!$"""
  dsigma = sp.dxlosi(w @ x)
  Ptmx = Ptm @ x
  return w + ((y - sp.losi(w @ x)) / (1 + dsigma * (x @ Ptmx))) * Ptmx

In [None]:
#| hide
sp.reshow_doc(wtt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L33){target="_blank" style="float:right; font-size:smaller"}

### wtt

>      wtt (Ptm:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'], y:jaxtyping.Float[Array,'N'])

*$\!$*
濾波推定値 $\hat{\mathbf w}_{t/t}$
$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| y | Float[Array, 'N'] | $y_t$ |
| **Returns** | **Float[Array, 'N']** | **$\hat{\mathbf w}_{t/t}$** |

In [None]:
#| export
class EKF_out(NamedTuple):
  r"""$\!$*
  `EKF` 関数の返り値

  table
  W: $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  P: $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  
  *$\!$"""
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]

sp.rewrite_nt(EKF_out)

In [None]:
#| hide
sp.reshow_doc(EKF_out)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L48){target="_blank" style="float:right; font-size:smaller"}

### EKF_out

>      EKF_out (W:jaxtyping.Float[Array,'TN'], P:jaxtyping.Float[Array,'TNN'])

*$\!$*
  `EKF` 関数の返り値

  | $\!$ | Type | Details |
|--|--|--|
  | W | Float[Array, 'T N'] | $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  | P | Float[Array, 'T N N'] | $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$

  *$\!$*

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def EKF(
    N: int, # $N$
    T: int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T} {N}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
) -> EKF_out:
    r"""$\!$*
    拡張カルマンフィルタ
    *$\!$"""

    class Carry(NamedTuple):
        Ptm: Float[Array, "{N} {N}"]
        wtm: Float[Array, "{N}"]
    
    class Input(NamedTuple):
        xt: Float[Array, "{N}"]
        yt: Float[Array, ""]

    class Output(NamedTuple):
        wtt_: Float[Array, "{N}"]
        Ptt_: Float[Array, "{N} {N}"]

    def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
        Ptm, wtm = carry
        xt, yt = inputs
        Ptt_ = Ptt(Ptm, wtm, xt)
        wtt_ = wtt(Ptm, wtm, xt, yt)
        return Carry(Ptt_ + G, wtt_), Output(wtt_, Ptt_)
    
    _, (W, P) = lax.scan(
        step,
        Carry(P0, w0),
        Input(x, y),
        length=T
    )
    return EKF_out(W, P)

In [None]:
#| hide
sp.reshow_doc(EKF)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L64){target="_blank" style="float:right; font-size:smaller"}

### EKF

>      EKF (N:int, T:int, x:jaxtyping.Float[Array,'{T}{N}'],
>           y:jaxtyping.Float[Array,'{T}{N}'],
>           G:jaxtyping.Float[Array,'{N}{N}'], w0:jaxtyping.Float[Array,'{N}'],
>           P0:jaxtyping.Float[Array,'{N}{N}'])

*$\!$*
拡張カルマンフィルタ
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| N | int | $N$ |
| T | int | $T$ |
| x | Float[Array, '{T} {N}'] | $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$ |
| y | Float[Array, '{T} {N}'] | $\{ y_t \}_{t=0,\ldots,T-1}$ |
| G | Float[Array, '{N} {N}'] | $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] | $\hat{\mathbf w}_{0/-1}$ |
| P0 | Float[Array, '{N} {N}'] | $\mathbf P_{0/-1}$ |
| **Returns** | **EKF_out** |$\!$|

In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import NamedTuple, Tuple
# assume sp.losi, sp.dxlosi exist (sigmoid and its derivative)

# --- helper: Mackay's K(s) and derivative K'(s) (eq.9 and eq.68 in paper) ---
@jax.jit
def K_of_s(s: Float[Array, ""]):
  # K(s) = (1 + s^2 / 8)^(-1/2)
  return (1.0 + (s**2) / 8.0) ** (-0.5)

@jax.jit
def Kprime_of_s(s: Float[Array, ""]):
  # K'(s) = - (s / 8) * (1 + s^2 / 8)^(-3/2)
  return -(s / 8.0) * (1.0 + (s**2) / 8.0) ** (-1.5)

# --- Nonstationary EKF runner ---
class NS_EKF_Out(NamedTuple):
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]
  q: Float[Array, "T"]   # tracked q_t per time step

@partial(jax.jit, static_argnames=['N','T','Nw'])
def EKF_nonstationary(
    N: int,
    T: int,
    x: Float[Array, "{T} {N}"],  # inputs sequence
    y: Float[Array, "{T}"],      # scalar class labels (0/1) per t
    w0: Float[Array, "{N}"],     # initial w_{0/-1}
    P0: Float[Array, "{N} {N}"], # initial P_{0/-1}
    q0: float = 1e-6,            # initial state-noise variance
    eta_q: float = 1e-3,         # learning rate for q updates
    Nw: int = 50,                # window size for q gradient estimation
    q_min: float = 0.0,
    q_max: float = 1.0
) -> NS_EKF_Out:
  r"""
  Nonstationary EKF:
  - uses Q_t = q_t * I
  - updates q_t by gradient-ascent (average gradient over window of length Nw)
    based on Appendix E (eq.70) in the provided paper. See: dynamic logistic regression.pdf. :contentReference[oaicite:1]{index=1}
  """

  I = jnp.eye(N)

  class Carry(NamedTuple):
    Ptm: Float[Array, "{N} {N}"]  # P_{t/t-1}
    wtm: Float[Array, "{N}"]      # w_{t/t-1}
    q_t: Float[Array, ""]         # current scalar q_t
    buf_Ptm: Float[Array, "{Nw} {N} {N}"]  # circular buffer of past Ptm (for gradient)
    buf_x: Float[Array, "{Nw} {N}"]        # buffer of past x
    buf_y: Float[Array, "{Nw}"]            # buffer of past y
    buf_idx: int                          # next write index (0..Nw-1)
    buf_count: int                        # how many entries filled (<= Nw)

  class Input(NamedTuple):
    xt: Float[Array, "{N}"]
    yt: Float[Array, ""]

  class Output(NamedTuple):
    wtt_: Float[Array, "{N}"]
    Ptt_: Float[Array, "{N} {N}"]
    q_: Float[Array, ""]

  # initialize buffers with zeros
  init_buf_Ptm = jnp.zeros((Nw, N, N))
  init_buf_x = jnp.zeros((Nw, N))
  init_buf_y = jnp.zeros((Nw,))

  def compute_q_gradient_from_buffer(buf_Ptm, buf_x, buf_y, buf_count, q_current, w_current):
    """
    Compute average gradient of log-evidence w.r.t q over buffer entries (use eq.70-like form).
    We follow the derivation in appendix E: grad ≈ (z - ~y) * a * K'(s) * (x^T x) / (2 s^2)
    where s^2 = x^T (Ptm + q I) x, ~y = sigmoid(K(s) * a), a = w^T x (activation using previous w).
    """
    def per_sample_grad(carry, elems):
      # elems: (Ptm_i, x_i, y_i)
      Ptm_i, x_i, y_i = elems
      Pprior = Ptm_i + q_current * I      # P_{t/t-1} + Q_t
      s2 = x_i @ (Pprior @ x_i)           # scalar
      # prevent tiny s2 -> numerical issues
      s2_safe = jnp.maximum(s2, 1e-12)
      s = jnp.sqrt(s2_safe)
      K = K_of_s(s)
      Kp = Kprime_of_s(s)
      a = w_current @ x_i                 # activation using current w (approx)
      y_tilde = sp.losi(K * a)            # moderated prediction ~y
      xTx = x_i @ x_i
      grad = (y_i - y_tilde) * a * Kp * xTx / (2.0 * s2_safe)
      return carry, grad

    # only iterate over the first buf_count entries
    elems = (buf_Ptm[:buf_count], buf_x[:buf_count], buf_y[:buf_count])
    _, grads = lax.scan(per_sample_grad, None, elems)
    # mean gradient
    mean_grad = jnp.mean(grads) if buf_count > 0 else 0.0
    return mean_grad

  def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
    Ptm, wtm, q_t, buf_Ptm, buf_x, buf_y, buf_idx, buf_count = carry
    xt, yt = inputs

    # EKF update using current Ptm and wtm (same functions as user's originals)
    Ptt_ = Ptt(Ptm, wtm, xt)   # P_{t/t} (uses sp.dxlosi etc.)
    wtt_ = wtt(Ptm, wtm, xt, yt)

    # next predicted P_{(t+1)/t} = P_{t/t} + Q_t where Q_t = q_t * I
    Pnext = Ptt_ + q_t * I

    # update circular buffers: write current Ptm, xt, yt
    buf_Ptm = buf_Ptm.at[buf_idx].set(Ptm)
    buf_x   = buf_x.at[buf_idx].set(xt)
    buf_y   = buf_y.at[buf_idx].set(yt)
    buf_idx_next = (buf_idx + 1) % buf_Ptm.shape[0]
    buf_count_next = jnp.minimum(buf_count + 1, buf_Ptm.shape[0])

    # compute q gradient and update q every step (could be done every M steps)
    grad_q = compute_q_gradient_from_buffer(buf_Ptm, buf_x, buf_y, buf_count_next, q_t, wtm)
    q_new = q_t + eta_q * grad_q
    q_new = jnp.clip(q_new, q_min, q_max)

    new_carry = Carry(Pnext, wtt_, q_new, buf_Ptm, buf_x, buf_y, buf_idx_next, buf_count_next)
    out = Output(wtt_, Ptt_, q_new)
    return new_carry, out

  # initial carry: P0 is prior P_{0/-1} already; w0 is w_{0/-1}
  init_carry = Carry(P0, w0, jnp.array(q0), init_buf_Ptm, init_buf_x, init_buf_y, 0, 0)

  _, outputs = lax.scan(step, init_carry, Input(x, y), length=T)

  W = outputs.wtt_
  P = outputs.Ptt_
  q_seq = outputs.q_
  return NS_EKF_Out(W, P, q_seq)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()