In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


[autoreload of KalmanPaper.simple failed: Traceback (most recent call last):
  File "c:\Users\suzun\Dev\Paper\KalmanPaper\venv\Lib\site-packages\IPython\extensions\autoreload.py", line 322, in check
    elif self.deduper_reloader.maybe_reload_module(m):
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\suzun\Dev\Paper\KalmanPaper\venv\Lib\site-packages\IPython\extensions\deduperreload\deduperreload.py", line 545, in maybe_reload_module
    new_source_code = f.read()
                      ^^^^^^^^
UnicodeDecodeError: 'cp932' codec can't decode byte 0x86 in position 524: illegal multibyte sequence
]


# 00_Gen

> 乱数によってデータを生成する。生成過程: $(\mathbf w_t,y_t)\to\mathbf x_t$

In [None]:
#| default_exp gen00

In [None]:
#| hide
#| exporti
import jax.numpy as jnp
import jax.random as jrd
import jax.lax as lax
import jax
from jaxtyping import Array, Float, PRNGKeyArray
from typing import Tuple, NamedTuple
from functools import partial
from KalmanPaper import simple as sp

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def gen_w(
    key: PRNGKeyArray, # PRNGKeyArray
    N: int, # $N$
    T: int, # $T$
    G: Float[Array, "{N} {N}"], # $\boldsymbol\Gamma$
    w0: Float[Array, "{N}"], # $\mathbf w_{-1}$
) -> Float[Array, "{T} {N}"]: # $\{\mathbf w_t\}_{t=0,\ldots,T-1}$
    r"""$\!$*
    潜在変数 $\{\mathbf w_t\}_{t=0,\ldots,T-1}$ の生成
    $$\mathbf w_{t}\sim\mathcal N(\mathbf w_t\mid\mathbf w_{t-1},\boldsymbol\Gamma)$$
    *$\!$"""
    keys = jrd.split(key, T)

    def step(wtpre, key):
        wt = jrd.multivariate_normal(key, wtpre, G)
        return wt, wt
    _, W = lax.scan(step, w0, keys, length=T)
    return W

In [None]:
#| hide
sp.reshow_doc(gen_w)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/gen.py#L20){target="_blank" style="float:right; font-size:smaller"}

### gen_w

>      gen_w (key:Union[jaxKey[Array,''],jaxUInt32[Array,'2']], N:int, T:int,
>             G:jaxtyping.Float[Array,'{N}{N}'],
>             w0:jaxtyping.Float[Array,'{N}'])

*$\!$*
潜在変数 $\{\mathbf w_t\}_{t=0,\ldots,T-1}$ の生成
$$\mathbf w_{t}\sim\mathcal N(\mathbf w_t\mid\mathbf w_{t-1},\boldsymbol\Gamma)$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| key | Union | PRNGKeyArray |
| N | int | $N$ |
| T | int | $T$ |
| G | Float[Array, '{N} {N}'] | $\boldsymbol\Gamma$ |
| w0 | Float[Array, '{N}'] | $\mathbf w_{-1}$ |
| **Returns** | **Float[Array, '{T} {N}']** | **$\{\mathbf w_t\}_{t=0,\ldots,T-1}$** |

In [None]:
#| export
@partial(jax.jit, static_argnames=['N', 'T'])
def gen_xy(
    key: PRNGKeyArray, # RPNGKeyArray
    N: int, # $N$
    T: int, # $T$
    Sigma: Float[Array, "{N} {N}"], # $\boldsymbol\Sigma$
    W: Float[Array, "{T} {N}"], # $\{\mathbf w_t\}_{t=0,\ldots,T-1}$
) -> Tuple[Float[Array, "{T} {N}"], Float[Array, "{T}"]]: # $\{\mathbf x_t\}_{t=0,\ldots,T-1}, \{y_t\}_{t=0,\ldots,T-1}$
    r"""$\!$*
    観測変数 $\{\mathbf x_t\}_{t=0,\ldots,T-1}, \{y_t\}_{t=0,\ldots,T-1}$ の生成
    $$y_t\sim\text{Bern}(y_t\mid 1/2)$$
    $$\boldsymbol\Sigma\mathbf w_t=2\boldsymbol\mu_{1,t}$$
    $$\boldsymbol\mu_{2,t}=-\boldsymbol\mu_{1,t}$$
    $$
    \mathbf x_t\sim
    \begin{cases}
    \displaystyle\mathcal N\left(\boldsymbol\mu_{1,t},\boldsymbol\Sigma\right) & (y_t=1) \\
    \displaystyle\mathcal N\left(\boldsymbol\mu_{2,t},\boldsymbol\Sigma\right) & (y_t=0)
    \end{cases}
    $$
    *$\!$"""
    # split key for independent draws
    key_y, key_z = jrd.split(key, 2)

    # Bernoulli draws (returns bool) -> convert to float
    Y = jrd.bernoulli(key_y, p=0.5, shape=(T,)).astype(jnp.float32)  # shape (T,)

    # Cholesky of Sigma (assumes positive-definite). Sigma shape (N,N)
    L = jnp.linalg.cholesky(Sigma)  # lower-triangular, (N,N)

    # Compute per-time means: mu_t = 0.5 * Sigma @ W[t]
    # Vectorized: W @ Sigma.T yields (T,N) where row t is W[t] @ Sigma.T == (Sigma @ W[t])^T
    sign = 2 * Y - 1
    mu = 0.5 * ((W * sign[:, None]) @ Sigma.T)  # shape (T, N)

    # Standard normal samples z_t ~ N(0, I) stacked to shape (T, N)
    z = jrd.normal(key_z, shape=(T, N))

    # Transform: X = mu + L @ z_t  <=>  mu + z @ L.T
    X = mu + (z @ L.T)  # shape (T, N)

    return X, Y

In [None]:
#| hide
sp.reshow_doc(gen_xy)

---

[source](https://github.com/SuzuSys/KalmanPaper/blob/main/KalmanPaper/gen.py#L41){target="_blank" style="float:right; font-size:smaller"}

### gen_xy

>      gen_xy (key:Union[jaxKey[Array,''],jaxUInt32[Array,'2']], N:int, T:int,
>              Sigma:jaxtyping.Float[Array,'{N}{N}'],
>              W:jaxtyping.Float[Array,'{T}{N}'])

*$\!$*
観測変数 $\{\mathbf x_t\}_{t=0,\ldots,T-1}, \{y_t\}_{t=0,\ldots,T-1}$ の生成
$$y_t\sim\text{Bern}(y_t\mid 1/2)$$
$$\boldsymbol\Sigma\mathbf w_t=2\boldsymbol\mu_{1,t}$$
$$\boldsymbol\mu_{2,t}=-\boldsymbol\mu_{1,t}$$
$$
\mathbf x_t\sim
\begin{cases}
\displaystyle\mathcal N\left(\boldsymbol\mu_{1,t},\boldsymbol\Sigma\right) & (y_t=1) \\
\displaystyle\mathcal N\left(\boldsymbol\mu_{2,t},\boldsymbol\Sigma\right) & (y_t=0)
\end{cases}
$$
*$\!$*

|$\!$| **Type** | **Details** |
| -- | -------- | ----------- |
| key | Union | RPNGKeyArray |
| N | int | $N$ |
| T | int | $T$ |
| Sigma | Float[Array, '{N} {N}'] | $\boldsymbol\Sigma$ |
| W | Float[Array, '{T} {N}'] | $\{\mathbf w_t\}_{t=0,\ldots,T-1}$ |
| **Returns** | **Tuple** | **$\{\mathbf x_t\}_{t=0,\ldots,T-1}, \{y_t\}_{t=0,\ldots,T-1}$** |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()