In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| default_exp EKF

In [None]:
#| hide
#| export
import jax.numpy as jnp
import jax.lax as lax
import jax
from jaxtyping import Array, Float, Int
from KalmanPaper import simple as sp
from typing import Tuple
from functools import partial

In [None]:
#| export
@jax.jit
def Ptt(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
) -> Float[Array, "N N"]:   # $\mathbf P_{t/t}$
  r"""$\!$*
  推定誤差共分散行列 $\mathbf P_{t/t}$
  $$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
  $$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
  *$\!$"""
  dsigma = sp.dxlosi(w @ x)
  Ptmx = Ptm @ x
  return Ptm - (dsigma / (1 + dsigma * (x @ Ptmx))) * jnp.outer(Ptmx, Ptmx)

In [None]:
#| hide
sp.reshow_doc(Ptt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L17){target="_blank" style="float:right; font-size:smaller"}

### Ptt

>      Ptt (Ptm:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'])

*$\!$*
推定誤差共分散行列 $\mathbf P_{t/t}$
$$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
$$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| **Returns** | **Float[Array, 'N N']** | **$\mathbf P_{t/t}$** |

In [None]:
#| export
@jax.jit
def wtt(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
    y: Float[Array, "N"],   # $y_t$
) -> Float[Array, "N"]: # $\hat{\mathbf w}_{t/t}$
  r"""$\!$*
  濾波推定値 $\hat{\mathbf w}_{t/t}$
  $$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$
  *$\!$"""
  dsigma = sp.dxlosi(w @ x)
  Ptmx = Ptm @ x
  return w + ((y - sp.losi(w @ x)) / (1 + dsigma * (x @ Ptmx))) * Ptmx

In [None]:
#| hide
sp.reshow_doc(wtt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L33){target="_blank" style="float:right; font-size:smaller"}

### wtt

>      wtt (Ptm:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'], y:jaxtyping.Float[Array,'N'])

*$\!$*
濾波推定値 $\hat{\mathbf w}_{t/t}$
$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| y | Float[Array, 'N'] | $y_t$ |
| **Returns** | **Float[Array, 'N']** | **$\hat{\mathbf w}_{t/t}$** |

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def EKF(
    N: int, # $N$
    T: int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T} {N}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
) -> Tuple[Float[Array, "{T} {N}"], Float[Array, "{T} {N} {N}"]]: # $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1},\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
    r"""$\!$*
    拡張カルマンフィルタ
    *$\!$"""
    def step(carry, inputs):
        Ptm, wtm = carry
        xt, yt = inputs
        Ptt_ = Ptt(Ptm, wtm, xt)
        wtt_ = wtt(Ptm, wtm, xt, yt)
        return (Ptt_ + G, wtt_), (wtt_, Ptt_)
    
    _, (W, P) = lax.scan(
        step,
        (P0, w0),
        (x, y),
        length=T
    )
    return W, P

In [None]:
#| hide
sp.reshow_doc(EKF)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L45){target="_blank" style="float:right; font-size:smaller"}

### EKF

>      EKF (N:int, T:int, x:jaxtyping.Float[Array,'{T}{N}'],
>           y:jaxtyping.Float[Array,'{T}{N}'],
>           G:jaxtyping.Float[Array,'{N}{N}'], w0:jaxtyping.Float[Array,'{N}'],
>           P0:jaxtyping.Float[Array,'{N}{N}'])

*$\!$*
拡張カルマンフィルタ
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| N | int | $N$ |
| T | int | $T$ |
| x | Float[Array, '{T} {N}'] | $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$ |
| y | Float[Array, '{T} {N}'] | $\{ y_t \}_{t=0,\ldots,T-1}$ |
| G | Float[Array, '{N} {N}'] | $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] | $\hat{\mathbf w}_{0/-1}$ |
| P0 | Float[Array, '{N} {N}'] | $\mathbf P_{0/-1}$ |
| **Returns** | **Tuple** | **$\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1},\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$** |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()