In [9]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Variational Approximation

> 

In [10]:
#| default_exp VA

In [11]:
#| hide
#| export
import jax.numpy as jnp
import jax.lax as lax
import jax.scipy.linalg as jsl
from jaxtyping import Array, Float, Int
from KalmanPaper import simple as sp
from typing import Tuple

### `Ptt`: 推定誤差共分散行列 $\mathbf P_{t/t}$

$$\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\{2\lambda(\xi_t)\}^2}{1+\{2\lambda(\xi_t)\}^2\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\left(\mathbf P_{t/t-1}\mathbf x_t\right)\left(\mathbf P_{t/t-1}\mathbf x_t\right)^T$$

In [12]:
#| export
def Ptt(
    Ptm:  Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    x:    Float[Array, "N"], # $\mathbf x_t$
    xi:   Float[Array, ""], # $\xi_t$
) -> Float[Array, "N N"]: # $\mathbf P_{t/t}$
  dsigma = (2*sp.lam(xi))**2
  Ptmx = Ptm @ x
  return Ptm - (dsigma / (1 + dsigma * (x @ Ptmx))) * jnp.outer(Ptmx, Ptmx)

### `wtt`: 濾波推定値 $\hat{\mathbf w}_{t/t}$

$$\hat{\mathbf w}_{t/t}=\mathbf P_{t/t}\left(\mathbf P_{t/t-1}^{-1}\hat{\mathbf w}_{t/t-1}+(y_t-1/2)\mathbf x_t\right)$$

In [13]:
#| export
def wtt(
    Ptm: Float[Array, "N N"],  # $\mathbf P_{t/t-1}$
    Ptt_: Float[Array, "N N"], # $\mathbf P_{t/t}$
    w: Float[Array, "N"],      # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],      # $\mathbf x_t$
    y: Float[Array, "N"],      # $y_t$
) -> Float[Array, "N"]:        # $\hat{\mathbf w}_{t/t}$
  return Ptt_ @ (jsl.cho_solve(jsl.cho_factor(Ptm), w) + (y - 1/2) * x)

### `xipre`: $(\mathbf P_{t/t-1},\hat{\mathbf w}_{t/t-1})$ を使う変分パラメータ $\xi_t$

$$\xi_t=\sqrt{\mathbf x_t^T\left(\mathbf P_{t/t-1}+\hat{\mathbf w}_{t/t-1}\hat{\mathbf w}_{t/t-1}^T\right)\mathbf x_t}$$

In [14]:
#| export
def xipre(
    Ptm: Float[Array, "N N"], # $\mathbf P_{t/t-1}$
    w: Float[Array, "N"],   # $\hat{\mathbf w}_{t/t-1}$
    x: Float[Array, "N"],   # $\mathbf x_t$
) -> Float[Array, ""]: # $\xi_t$
  return jnp.sqrt(x @ (Ptm + jnp.outer(w,w)) @ x)

### `VApre`: `xipre` を使う Variational Approximation

$$N,T,\{ \mathbf x_t \}_{t=0,\ldots,T-1}, \{ y_t \}_{t=0,\ldots,T-1}, \boldsymbol\Gamma, \hat{\mathbf w}_{0/-1}, \mathbf P_{0/-1} $$
$$\to\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1},\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1},\{\xi_t\}_{t=0,\ldots,T-1}$$

In [15]:
#| export
def VApre(
    N: Int, # $N$
    T: Int, # $T$
    x: Float[Array, "{T} {N}"], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y: Float[Array, "{T} {N}"], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\hat{\mathbf w}_{0/-1}$
    P0: Float[Array, "{N} {N}"], # $\mathbf P_{0/-1}$
) -> Tuple[Float[Array, "{T} {N}"], Float[Array, "{T} {N} {N}"], Float[Array, "{T}"]]: # $\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1},\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1},\{\xi_t\}_{t=0,\ldots,T-1}$
    def step(carry, inputs):
        Ptm, wtm = carry
        xt, yt = inputs
        xit = xipre(Ptm, wtm, x)
        Ptt_ = Ptt(Ptm, xt, xit)
        wtt_ = wtt(Ptm, Ptt_, wtm, xt, yt)
        return (Ptt_ + G, wtt_), (wtt_, Ptt_, xit)
    
    _, (W, P, Xi) = lax.scan(
        step,
        (P0, w0),
        (x, y),
        length=T
    )
    return W, P, Xi

In [16]:
#| hide
import nbdev; nbdev.nbdev_export()