In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Extended Kalman Filter

> Dynamic Logistic Regression で導出された拡張カルマンフィルタ

In [None]:
#| default_exp EKF

In [None]:
#| hide
#| exporti
import jax.numpy as jnp
import jax.lax as lax
import jax
from jaxtyping import Array, Float
from KalmanPaper import simple as sp
from typing import Tuple, NamedTuple
from functools import partial

## 概要

濾波推定値 $\hat{\mathbf w}_{t/t}$
$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$

推定誤差共分散行列 $\mathbf P_{t/t}$
  $$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
  $$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$

各変数の情報の利用

| $\!$ | $\mathbf P_{t/t}$ | $\hat{\mathbf w}_{t/t}$ | $\mathbf P_{t/t-1}$ | $\hat{\mathbf w}_{t/t-1}$ | $\mathbf x_t$ | $y_t$ |
| -- | -- | -- | -- | -- | -- | -- |
| $\mathbf P_{t/t}$ | $\!$ | $\!$ | $\bigcirc$ | $\bigcirc$ | $\bigcirc$ | $\!$ |
| $\hat{\mathbf w}_{t/t}$ | $\!$ | $\!$ | $\bigcirc$ | $\bigcirc$ | $\bigcirc$ | $\bigcirc$ |


## ラプラス近似による導出

### 1. $p(\mathbf w_t\mid\mathbf Y_t)$

$$
p(\mathbf w_t\mid \mathbf Y_t)=\frac{p(y_t\mid\mathbf w_t)p(\mathbf w_t\mid\mathbf Y_{t-1})}{p(y_t\mid\mathbf Y_{t-1})}
$$

$$p(\mathbf w_t\mid\mathbf Y_{t-1})=\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})$$

よって、

$$
\begin{equation*}
p(\mathbf w_t\mid \mathbf Y_t)=
\begin{cases}
\sigma(\mathbf w_t^T\mathbf x_t)\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) &( y_t=1) \\
\{1-\sigma(\mathbf w_t^T\mathbf x_t)\}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) & (y_t=0)
\end{cases}
\end{equation*}
$$

### 2. $\hat{\mathbf w}_{t/t}$

$\mathbf w_{t}$ をMAP推定する。 $\ln p(\mathbf w_t\mid\mathbf Y_t)$ の微分を $\mathbf 0$ と置く。

$$
\begin{split}
&\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\
&= \frac{\partial}{\partial \mathbf w_t}\ln\left[\sigma(\mathbf w_t^T\mathbf x_t)^{y_t}\{1-\sigma(\mathbf w_t^T\mathbf x_t)\}^{1-y_t}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})\right] \\
&= \frac{\partial}{\partial \mathbf w_t}y_t\ln\sigma(\mathbf w_t^T\mathbf x_t)+\frac{\partial}{\partial\mathbf w_t}(1-y_t)\ln\{1-\sigma(\mathbf w_t^T\mathbf x_t)\} \\
&\phantom{=} +\frac{\partial}{\partial\mathbf w_t}\ln\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) \\
&= \left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})
\end{split}
$$

ここで、$\sigma(\mathbf w_t^T\mathbf x_t)$ をテイラー展開で一次近似する。

$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$ とする。

$$
\begin{split}
\sigma(\mathbf w_t^T\mathbf x_t) &\simeq \sigma_t+\left.\frac{\partial\sigma(\mathbf w_t^T\mathbf x_t)}{\partial \mathbf w_t}\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\
&\phantom{00}=\sigma_t+\sigma_t\{1-\sigma_t\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t
\end{split}
$$

$\sigma(\mathbf w_t^T\mathbf x_t)$ を $\sigma_t$ に置き換え、$\mathbf 0$ とおく。

$$
\begin{split}
&\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\
&=\left[ y_t-\sigma_t-\sigma_t\left\{1-\sigma_t\right\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t \right]\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\
&=(y_t-\sigma_t)\mathbf x_t-\left[\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T\right](\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\
&=\mathbf 0
\end{split}
$$

$$
\begin{split}
\mathbf w_t &= \hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}^{-1}+\sigma_t\left\{1-\sigma_t\right\}\mathbf x_t\mathbf x_t^T\right]^{-1}\mathbf x_t(y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf x_t(y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}\mathbf x_t-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right](y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\left[1-\frac{\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t) \\
&=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)
\end{split}
$$

よって $\mathbf w_t$ のMAP推定値 $\hat{\mathbf w}_{t/t}$ は

$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$

となる。

### 3. $\mathbf P_{t/t}$

$\hat{\mathbf w}_{t/t-1}$ を $p(\mathbf w_t\mid\mathbf Y_t)$ のピークとすると $\mathbf P_{t/t}$ が得られる。 

（ $\hat{\mathbf w}_{t/t}$ をピークとするのが本来のラプラス近似）

$$
\begin{split}
\mathbf P_{t/t}^{-1} &= \left.-\frac{\partial^2}{\partial\mathbf w_t^2} \ln p(\mathbf w_t\mid\mathbf Y_t)\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\
&= \left. -\frac{\partial}{\partial\mathbf w_t}\left[\left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})\right]\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\
&= \mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T
\end{split}
$$

Sherman-Morrison の公式によって、 $\mathbf P_{t/t}$ が得られる。

$$
\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T
$$

## 関数等

In [None]:
#| export
@jax.jit
def Ptt(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
) -> Float[Array, "N N"]:   # $\mathbf P_{t/t}$
  r"""$\!$*
  推定誤差共分散行列 $\mathbf P_{t/t}$
  $$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
  $$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
  *$\!$"""
  dsigma = sp.dxlosi(w @ x)
  Ptmx = Ptm @ x
  return Ptm - (dsigma / (1 + dsigma * (x @ Ptmx))) * jnp.outer(Ptmx, Ptmx)

In [None]:
#| hide
sp.reshow_doc(Ptt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L17){target="_blank" style="float:right; font-size:smaller"}

### Ptt

>      Ptt (Ptm:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'])

*$\!$*
推定誤差共分散行列 $\mathbf P_{t/t}$
$$\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)$$
$$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| **Returns** | **Float[Array, 'N N']** | **$\mathbf P_{t/t}$** |

In [None]:
#| export
@jax.jit
def wtt(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
    y: Float[Array, "N"],   # $y_t$
) -> Float[Array, "N"]: # $\hat{\mathbf w}_{t/t}$
  r"""$\!$*
  濾波推定値 $\hat{\mathbf w}_{t/t}$
  $$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$
  *$\!$"""
  dsigma = sp.dxlosi(w @ x)
  Ptmx = Ptm @ x
  return w + ((y - sp.losi(w @ x)) / (1 + dsigma * (x @ Ptmx))) * Ptmx

In [None]:
#| hide
sp.reshow_doc(wtt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L33){target="_blank" style="float:right; font-size:smaller"}

### wtt

>      wtt (Ptm:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'], y:jaxtyping.Float[Array,'N'])

*$\!$*
濾波推定値 $\hat{\mathbf w}_{t/t}$
$$\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| y | Float[Array, 'N'] | $y_t$ |
| **Returns** | **Float[Array, 'N']** | **$\hat{\mathbf w}_{t/t}$** |

In [None]:
#| export
class EKF_out(NamedTuple):
  r"""$\!$*
  `EKF` 関数の返り値

  table
  W: $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  P: $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  
  *$\!$"""
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]

sp.rewrite_nt(EKF_out)

In [None]:
#| hide
sp.reshow_doc(EKF_out)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L48){target="_blank" style="float:right; font-size:smaller"}

### EKF_out

>      EKF_out (W:jaxtyping.Float[Array,'TN'], P:jaxtyping.Float[Array,'TNN'])

*$\!$*
  `EKF` 関数の返り値

  | $\!$ | Type | Details |
|--|--|--|
  | W | Float[Array, 'T N'] | $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  | P | Float[Array, 'T N N'] | $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$

  *$\!$*

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def EKF(
    N: int, # $N$
    T: int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T} {N}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
) -> EKF_out:
    r"""$\!$*
    拡張カルマンフィルタ
    *$\!$"""

    class Carry(NamedTuple):
        Ptm: Float[Array, "{N} {N}"]
        wtm: Float[Array, "{N}"]
    
    class Input(NamedTuple):
        xt: Float[Array, "{N}"]
        yt: Float[Array, ""]

    class Output(NamedTuple):
        wtt_: Float[Array, "{N}"]
        Ptt_: Float[Array, "{N} {N}"]

    def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
        Ptm, wtm = carry
        xt, yt = inputs
        Ptt_ = Ptt(Ptm, wtm, xt)
        wtt_ = wtt(Ptm, wtm, xt, yt)
        return Carry(Ptt_ + G, wtt_), Output(wtt_, Ptt_)
    
    _, (W, P) = lax.scan(
        step,
        Carry(P0, w0),
        Input(x, y),
        length=T
    )
    return EKF_out(W, P)

In [None]:
#| hide
sp.reshow_doc(EKF)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L64){target="_blank" style="float:right; font-size:smaller"}

### EKF

>      EKF (N:int, T:int, x:jaxtyping.Float[Array,'{T}{N}'],
>           y:jaxtyping.Float[Array,'{T}{N}'],
>           G:jaxtyping.Float[Array,'{N}{N}'], w0:jaxtyping.Float[Array,'{N}'],
>           P0:jaxtyping.Float[Array,'{N}{N}'])

*$\!$*
拡張カルマンフィルタ
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| N | int | $N$ |
| T | int | $T$ |
| x | Float[Array, '{T} {N}'] | $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$ |
| y | Float[Array, '{T} {N}'] | $\{ y_t \}_{t=0,\ldots,T-1}$ |
| G | Float[Array, '{N} {N}'] | $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] | $\hat{\mathbf w}_{0/-1}$ |
| P0 | Float[Array, '{N} {N}'] | $\mathbf P_{0/-1}$ |
| **Returns** | **EKF_out** |$\!$|

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()