In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Variational Approximation

> 変分近似

In [None]:
#| default_exp VA

In [None]:
#| hide
#| exporti
import jax.numpy as jnp
import jax.lax as lax
import jax.scipy.linalg as jsl
import jax
from jaxtyping import Array, Float, Int
from KalmanPaper import simple as sp
from typing import Tuple, NamedTuple
from functools import partial

In [None]:
#| export
def lam(
    x: Float[Array, ""] # $x$
) -> Float[Array, ""]: # $\lambda(x)$
  r"""$\!$*
  Lambda 関数
  $$\lambda(x)=\frac{1}{2x}\left[\sigma(x)-\frac{1}{2}\right]$$
  *$\!$"""
  return (sp.losi(x) - 0.5)/(2*x)

In [None]:
#| hide
sp.reshow_doc(lam)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/VA.py#L19){target="_blank" style="float:right; font-size:smaller"}

### lam

>      lam (x:jaxtyping.Float[Array,''])

*$\!$*
Lambda 関数
$$\lambda(x)=\frac{1}{2x}\left[\sigma(x)-\frac{1}{2}\right]$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| x | Float[Array, ''] | $x$ |
| **Returns** | **Float[Array, '']** | **$\lambda(x)$** |

In [None]:
#| export
@jax.jit
def Ptt(
    Ptm:  Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    x:    Float[Array, "N"], # $\mathbf x_t$
    xi:   Float[Array, ""], # $\xi_t$
) -> Float[Array, "N N"]: # $\mathbf P_{t/t}$
  r"""$\!$*
  推定誤差共分散行列 $\mathbf P_{t/t}$ 

  $$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{2\lambda(\xi_t)}{1+2\lambda(\xi_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\left(\mathbf P_{t/t-1}\mathbf x_t\right)\left(\mathbf P_{t/t-1}\mathbf x_t\right)^T $$

  $$\mathbf P_{t/t}^{-1}=\mathbf P_{t/t-1}^{-1}+2\lambda(\xi_t)\mathbf x_t\mathbf x_t^T$$
  *$\!$"""
  dsigma = 2*lam(xi)
  Ptmx = Ptm @ x
  return Ptm - (dsigma / (1 + dsigma * (x @ Ptmx))) * jnp.outer(Ptmx, Ptmx)

In [None]:
#| hide
sp.reshow_doc(Ptt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L19){target="_blank" style="float:right; font-size:smaller"}

### Ptt

>      Ptt (Ptm:jaxtyping.Float[Array,'NN'], x:jaxtyping.Float[Array,'N'],
>           xi:jaxtyping.Float[Array,''])

*$\!$*
推定誤差共分散行列 $\mathbf P_{t/t}$ 

$$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{2\lambda(\xi_t)}{1+2\lambda(\xi_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\left(\mathbf P_{t/t-1}\mathbf x_t\right)\left(\mathbf P_{t/t-1}\mathbf x_t\right)^T $$

$$\mathbf P_{t/t}^{-1}=\mathbf P_{t/t-1}^{-1}+2\lambda(\xi_t)\mathbf x_t\mathbf x_t^T$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| xi | Float[Array, ''] | $\xi_t$ |
| **Returns** | **Float[Array, 'N N']** | **$\mathbf P_{t/t}$** |

In [None]:
#| export
@jax.jit
def wtt(
    Ptm: Float[Array, "N N"],  # $\mathbf P_{t/t-1}$
    Ptt_: Float[Array, "N N"], # $\mathbf P_{t/t}$
    w: Float[Array, "N"],      # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],      # $\mathbf x_t$
    y: Float[Array, ""],      # $y_t$
) -> Float[Array, "N"]:        # $\hat{\mathbf w}_{t/t}$
  r"""$\!$*
  濾波推定値 $\hat{\mathbf w}_{t/t}$
  $$\hat{\mathbf w}_{t/t}=\mathbf P_{t/t}\left(\mathbf P_{t/t-1}^{-1}\hat{\mathbf w}_{t/t-1}+(y_t-1/2)\mathbf x_t\right)$$
  *$\!$"""
  return Ptt_ @ (jsl.cho_solve(jsl.cho_factor(Ptm), w) + (y - 1/2) * x)

In [None]:
#| hide
sp.reshow_doc(wtt)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/EKF.py#L35){target="_blank" style="float:right; font-size:smaller"}

### wtt

>      wtt (Ptm:jaxtyping.Float[Array,'NN'], Ptt_:jaxtyping.Float[Array,'NN'],
>           w:jaxtyping.Float[Array,'N'], x:jaxtyping.Float[Array,'N'],
>           y:jaxtyping.Float[Array,''])

*$\!$*
濾波推定値 $\hat{\mathbf w}_{t/t}$
$$\hat{\mathbf w}_{t/t}=\mathbf P_{t/t}\left(\mathbf P_{t/t-1}^{-1}\hat{\mathbf w}_{t/t-1}+(y_t-1/2)\mathbf x_t\right)$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Ptm | Float[Array, 'N N'] | $\mathbf P_{t/t-1}$ |
| Ptt_ | Float[Array, 'N N'] | $\mathbf P_{t/t}$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}_{t/t-1}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| y | Float[Array, ''] | $y_t$ |
| **Returns** | **Float[Array, 'N']** | **$\hat{\mathbf w}_{t/t}$** |

In [None]:
#| export
@jax.jit
def xit(
    Cov: Float[Array, "N N"], # $\boldsymbol\Sigma$
    w: Float[Array, "N"],   # $\hat{\mathbf w}$
    x: Float[Array, "N"],   # $\mathbf x_t$
) -> Float[Array, ""]: # $\xi_t$
  r"""$\!$*
  変分パラメータ $\xi_t$
  $$\xi_t=\sqrt{\mathbf x_t^T\mathbb E[\mathbf w\mathbf w^T]\mathbf x_t}=\sqrt{\mathbf x_t^T\left(\boldsymbol \Sigma+\hat{\mathbf w}\hat{\mathbf w}^T\right)\mathbf x_t}$$
  *$\!$"""
  return jnp.sqrt(x @ (Cov + jnp.outer(w,w)) @ x)

In [None]:
#| hide
sp.reshow_doc(xit)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/VA.py#L63){target="_blank" style="float:right; font-size:smaller"}

### xit

>      xit (Cov:jaxtyping.Float[Array,'NN'], w:jaxtyping.Float[Array,'N'],
>           x:jaxtyping.Float[Array,'N'])

*$\!$*
変分パラメータ $\xi_t$
$$\xi_t=\sqrt{\mathbf x_t^T\mathbb E[\mathbf w\mathbf w^T]\mathbf x_t}=\sqrt{\mathbf x_t^T\left(\boldsymbol \Sigma+\hat{\mathbf w}\hat{\mathbf w}^T\right)\mathbf x_t}$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| Cov | Float[Array, 'N N'] | $\boldsymbol\Sigma$ |
| w | Float[Array, 'N'] | $\hat{\mathbf w}$ |
| x | Float[Array, 'N'] | $\mathbf x_t$ |
| **Returns** | **Float[Array, '']** | **$\xi_t$** |

In [None]:
#| export
class VApre_out(NamedTuple):
  r"""$\!$*
  `VApre` 関数の返り値

  table
  W: $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  P: $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  Xi: $\{\xi_t\}_{t=0,\ldots,T-1}$

  *$\!$"""
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]
  Xi: Float[Array, "T"]

sp.rewrite_nt(VApre_out)

In [None]:
#| hide
sp.reshow_doc(VApre_out)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/VA.py#L75){target="_blank" style="float:right; font-size:smaller"}

### VApre_out

>      VApre_out (W:jaxtyping.Float[Array,'TN'], P:jaxtyping.Float[Array,'TNN'],
>                 Xi:jaxtyping.Float[Array,'T'])

*$\!$*
  `VApre` 関数の返り値

  | $\!$ | Type | Details |
|--|--|--|
  | W | Float[Array, 'T N'] | $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  | P | Float[Array, 'T N N'] | $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  | Xi | Float[Array, 'T'] | $\{\xi_t\}_{t=0,\ldots,T-1}$

  *$\!$*

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def VApre(
    N: int, # $N$
    T: int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
) -> VApre_out:
    r"""$\!$*
    一段予測推定値 $\hat{\mathbf w}_{t/t-1}$ を使う変分近似法
    $$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t-1}+\hat{\mathbf w}_{t/t-1}\hat{\mathbf w}_{t/t-1}^T\right)\mathbf x_t}$$
    *$\!$"""
    class Carry(NamedTuple):
        Ptm: Float[Array, "{N} {N}"]
        wtm: Float[Array, "{N}"]
    
    class Input(NamedTuple):
        xt: Float[Array, "{N}"]
        yt: Float[Array, ""]

    class Output(NamedTuple):
        wtt_: Float[Array, "{N}"]
        Ptt_: Float[Array, "{N} {N}"]
        xit_: Float[Array, ""]

    def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
        Ptm, wtm = carry
        xt, yt = inputs
        xit_ = xit(Ptm, wtm, xt)
        Ptt_ = Ptt(Ptm, xt, xit_)
        wtt_ = wtt(Ptm, Ptt_, wtm, xt, yt)
        return Carry(Ptt_ + G, wtt_), Output(wtt_, Ptt_, xit_)
    
    _, (W, P, Xi) = lax.scan(
        step,
        Carry(P0, w0),
        Input(x, y),
        length=T
    )
    return VApre_out(W, P, Xi)

In [None]:
#| hide
sp.reshow_doc(VApre)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/VA.py#L93){target="_blank" style="float:right; font-size:smaller"}

### VApre

>      VApre (N:int, T:int, x:jaxtyping.Float[Array,'{T}{N}'],
>             y:jaxtyping.Float[Array,'{T}'], G:jaxtyping.Float[Array,'{N}{N}'],
>             w0:jaxtyping.Float[Array,'{N}'],
>             P0:jaxtyping.Float[Array,'{N}{N}'])

*$\!$*
一段予測推定値 $\hat{\mathbf w}_{t/t-1}$ を使う変分近似法
$$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t-1}+\hat{\mathbf w}_{t/t-1}\hat{\mathbf w}_{t/t-1}^T\right)\mathbf x_t}$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| N | int | $N$ |
| T | int | $T$ |
| x | Float[Array, '{T} {N}'] | $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$ |
| y | Float[Array, '{T}'] | $\{ y_t \}_{t=0,\ldots,T-1}$ |
| G | Float[Array, '{N} {N}'] | $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] | $\hat{\mathbf w}_{0/-1}$ |
| P0 | Float[Array, '{N} {N}'] | $\mathbf P_{0/-1}$ |
| **Returns** | **VApre_out** |$\!$|

In [None]:
#| export
class VAEM_out(NamedTuple):
  r"""$\!$*
  `VAEM` 関数の返り値

  table
  W: $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  P: $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  Xi: $\{\xi_t\}_{t=0,\ldots,T-1}$
  Iters: $\{\mathrm{Iter}_t\}_{t=0,\ldots,T-1}$
  
  *$\!$"""
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]
  Xi: Float[Array, "T"]
  Iters: Int[Array, "T"]

sp.rewrite_nt(VAEM_out)


In [None]:
#| hide
sp.reshow_doc(VAEM_out)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/VA.py#L136){target="_blank" style="float:right; font-size:smaller"}

### VAEM_out

>      VAEM_out (W:jaxtyping.Float[Array,'TN'], P:jaxtyping.Float[Array,'TNN'],
>                Xi:jaxtyping.Float[Array,'T'], Iters:jaxtyping.Int[Array,'T'])

*$\!$*
  `VAEM` 関数の返り値

  | $\!$ | Type | Details |
|--|--|--|
  | W | Float[Array, 'T N'] | $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}$
  | P | Float[Array, 'T N N'] | $\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}$
  | Xi | Float[Array, 'T'] | $\{\xi_t\}_{t=0,\ldots,T-1}$
  | Iters | Int[Array, 'T'] | $\{\mathrm{Iter}_t\}_{t=0,\ldots,T-1}$

  *$\!$*

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def VAEM(
    N: int, # $N$
    T: int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
    epsilon: Float[Array, ""], # $\epsilon\ge |\xi\\^{\text{new}}_t-\xi\\^{\text{old}}_t|$
    max_iter: int = 100, # 繰り返し回数の上限
) -> VAEM_out:
    r"""$\!$*
    濾波推定値 $\hat{\mathbf w}_{t/t}$ を使う変分近似法。EMアルゴリズムを使う。
    $$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t}+\hat{\mathbf w}_{t/t}\hat{\mathbf w}_{t/t}^T\right)\mathbf x_t}$$
    *$\!$"""

    class State(NamedTuple):
        xit_pr: Float[Array, ""]
        xit_af: Float[Array, ""]
        Ptt_: Float[Array, "{N} {N}"]
        wtt_: Float[Array, "{N}"]
        Ptm: Float[Array, "{N} {N}"]
        wtm: Float[Array, "{N}"]
        xt: Float[Array, "{N}"]
        yt: Float[Array, ""]
        i: Int[Array, ""]

    def inner_iter(s: State) -> State:
        xit_pr = s.xit_af
        xit_af = xit(s.Ptt_, s.wtt_, s.xt)
        Ptt_ = Ptt(s.Ptm, s.xt, s.xit_af)
        wtt_ = wtt(s.Ptm, s.Ptt_, s.wtm, s.xt, s.yt)
        return State(xit_pr, xit_af, Ptt_, wtt_, s.Ptm, s.wtm, s.xt, s.yt, s.i + 1)

    def cond_fun(s: State):
        return jnp.logical_and(
            jnp.abs(s.xit_pr - s.xit_af) >= epsilon,
            s.i < max_iter
        )

    class Carry(NamedTuple):
        Ptm: Float[Array, "{N} {N}"]
        wtm: Float[Array, "{N}"]
    
    class Input(NamedTuple):
        xt: Float[Array, "{N}"]
        yt: Float[Array, ""]

    class Output(NamedTuple):
        wtt_: Float[Array, "{N}"]
        Ptt_: Float[Array, "{N} {N}"]
        xit_: Float[Array, ""]
        i: Int[Array, ""]

    def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
        Ptm, wtm = carry
        xt, yt = inputs
        xit_pr = xit(Ptm, wtm, xt)
        Ptt_ = Ptt(Ptm, xt, xit_pr)
        wtt_ = wtt(Ptm, Ptt_, wtm, xt, yt)
        xit_af = xit(Ptt_, wtt_, xt)

        init_state = State(
            xit_pr,
            xit_af,
            Ptt_,
            wtt_,
            Ptm,
            wtm,
            xt,
            yt,
            jnp.array(0, dtype=jnp.int32),
        )

        s = lax.while_loop(cond_fun, inner_iter, init_state)

        return Carry(s.Ptt_ + G, s.wtt_), Output(s.wtt_, s.Ptt_, s.xit_pr, s.i)

    _, (W, P, Xi, Iters) = lax.scan(step, Carry(P0, w0), Input(x, y), length=T)
    return VAEM_out(W, P, Xi, Iters)

In [None]:
#| hide
sp.reshow_doc(VAEM)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/VA.py#L157){target="_blank" style="float:right; font-size:smaller"}

### VAEM

>      VAEM (N:int, T:int, x:jaxtyping.Float[Array,'{T}{N}'],
>            y:jaxtyping.Float[Array,'{T}'], G:jaxtyping.Float[Array,'{N}{N}'],
>            w0:jaxtyping.Float[Array,'{N}'],
>            P0:jaxtyping.Float[Array,'{N}{N}'],
>            epsilon:jaxtyping.Float[Array,''], max_iter:int=100)

*$\!$*
濾波推定値 $\hat{\mathbf w}_{t/t}$ を使う変分近似法。EMアルゴリズムを使う。
$$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t}+\hat{\mathbf w}_{t/t}\hat{\mathbf w}_{t/t}^T\right)\mathbf x_t}$$
*$\!$*

|$\!$| **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| N | int |$\!$| $N$ |
| T | int |$\!$| $T$ |
| x | Float[Array, '{T} {N}'] |$\!$| $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$ |
| y | Float[Array, '{T}'] |$\!$| $\{ y_t \}_{t=0,\ldots,T-1}$ |
| G | Float[Array, '{N} {N}'] |$\!$| $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] |$\!$| $\hat{\mathbf w}_{0/-1}$ |
| P0 | Float[Array, '{N} {N}'] |$\!$| $\mathbf P_{0/-1}$ |
| epsilon | Float[Array, ''] |$\!$| $\epsilon\ge \|\xi\\^{\text{new}}_t-\xi\\^{\text{old}}_t\|$ |
| max_iter | int | 100 | 繰り返し回数の上限 |
| **Returns** | **VAEM_out** |$\!$|$\!$|

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()