In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# w:EXP, P:VA

> $w$ の推定を拡張カルマンフィルタによって行い、 $P$ の推定を変分近似によって行う。

In [None]:
#| default_exp wEXP_PVA

In [None]:
#| hide
#| exporti
import jax.numpy as jnp
import jax.lax as lax
import jax.scipy.linalg as jsl
import jax
from jaxtyping import Array, Float, Int
from KalmanPaper import simple as sp
from KalmanPaper.EKF import wtt
from KalmanPaper.VA import Ptt, xit
from typing import Tuple, NamedTuple
from functools import partial

In [None]:
#| export
class wEXP_PVA_out(NamedTuple):
  r"""$\!$*
  `wEXP_PVA` 関数の返り値

  table
  W: $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  P: $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  Xi: $\{\xi_t\}_{t=0,\ldots,T-1}$

  *$\!$"""
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]
  Xi: Float[Array, "T"]

sp.rewrite_nt(wEXP_PVA_out)

In [None]:
#| hide
sp.reshow_doc(wEXP_PVA_out)

---

### wEXP_PVA_out

>      wEXP_PVA_out (W:jaxtyping.Float[Array,'TN'],
>                    P:jaxtyping.Float[Array,'TNN'],
>                    Xi:jaxtyping.Float[Array,'T'])

*$\!$*
`wEXP_PVA` 関数の返り値

| $\!$ | Type | Details |
|--|--|--|
| W | Float[Array, 'T N'] | $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
| P | Float[Array, 'T N N'] | $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
| Xi | Float[Array, 'T'] | $\{\xi_t\}_{t=0,\ldots,T-1}$

*$\!$*

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def wEXP_PVA(
    N: int, # $N$
    T: int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
) -> wEXP_PVA_out:
    r"""$\!$*
    $\mathbf w_{t/t}$ を EKF によって推論し、$\mathbf P_{t/t}$ を VA によって推論する手法
    $\xi_t$ には一段予測推定値 $\hat{\mathbf w}_{t/t-1}$ を使う
    $$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t-1}+\hat{\mathbf w}_{t/t-1}\hat{\mathbf w}_{t/t-1}^T\right)\mathbf x_t}$$
    *$\!$"""
    class Carry(NamedTuple):
        Ptm: Float[Array, "{N} {N}"]
        wtm: Float[Array, "{N}"]
    
    class Input(NamedTuple):
        xt: Float[Array, "{N}"]
        yt: Float[Array, ""]

    class Output(NamedTuple):
        wtt_: Float[Array, "{N}"]
        Ptt_: Float[Array, "{N} {N}"]
        xit_: Float[Array, ""]

    def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
        Ptm, wtm = carry
        xt, yt = inputs
        xit_ = xit(Ptm, wtm, xt)
        Ptt_ = Ptt(Ptm, xt, xit_)
        wtt_ = wtt(Ptm, wtm, xt, yt)
        return Carry(Ptt_ + G, wtt_), Output(wtt_, Ptt_, xit_)
    
    _, (W, P, Xi) = lax.scan(
        step,
        Carry(P0, w0),
        Input(x, y),
        length=T
    )
    return wEXP_PVA_out(W, P, Xi)

In [None]:
#| hide
sp.reshow_doc(wEXP_PVA)

---

### wEXP_PVA

>      wEXP_PVA (N:int, T:int, x:jaxtyping.Float[Array,'{T}{N}'],
>                y:jaxtyping.Float[Array,'{T}'],
>                G:jaxtyping.Float[Array,'{N}{N}'],
>                w0:jaxtyping.Float[Array,'{N}'],
>                P0:jaxtyping.Float[Array,'{N}{N}'])

*$\!$*
$\mathbf w_{t/t}$ を EKF によって推論し、$\mathbf P_{t/t}$ を VA によって推論する手法
$\xi_t$ には一段予測推定値 $\hat{\mathbf w}_{t/t-1}$ を使う
$$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t-1}+\hat{\mathbf w}_{t/t-1}\hat{\mathbf w}_{t/t-1}^T\right)\mathbf x_t}$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| N | int | $N$ |
| T | int | $T$ |
| x | Float[Array, '{T} {N}'] | $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$ |
| y | Float[Array, '{T}'] | $\{ y_t \}_{t=0,\ldots,T-1}$ |
| G | Float[Array, '{N} {N}'] | $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] | $\hat{\mathbf w}_{0/-1}$ |
| P0 | Float[Array, '{N} {N}'] | $\mathbf P_{0/-1}$ |
| **Returns** | **wEXP_PVA_out** |$\!$|