<a href="https://colab.research.google.com/github/a-yuto/numpyro-sandbox/blob/main/%E5%AD%A3%E7%AF%80%E6%80%A7%E3%81%AE%E3%82%B3%E3%83%BC%E3%83%89.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install numpyro
%pip install --upgrade jax


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random,ops
import numpyro
from numpyro import sample, plate
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from typing import Callable

# モデル定義
def multiplicative_model(seasonality : np.ndarray, y=None):
    num_seasons = len(np.unique(seasonality))
    mu = numpyro.sample('mu', dist.Normal(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(0.5))
    with numpyro.plate('seasonality_plate', num_seasons):
        season_effect = numpyro.sample('season_effect', dist.Normal(0, 1))

    # トレンド部分をAR(1)に
    phi = numpyro.sample('phi', dist.Normal(0., 1.))
    trend_prev = numpyro.sample('trend_0', dist.Normal(0., 1.))

    obs = []
    n_time = len(y) if y is not None else len(seasonality)
    for t in range(n_time):
        if t > 0:
            trend = numpyro.sample(f'trend_{t}', dist.Normal(phi*trend_prev, 1.))
            trend_prev = trend
        else:
            trend = trend_prev

        mean_function = jnp.exp(mu + trend + season_effect[seasonality[t]])
        obs.append(mean_function)

    with numpyro.plate('obs_plate', n_time):
        numpyro.sample('obs', dist.Normal(jnp.array(obs), sigma), obs=y)

def ar_model(seasonality : np.ndarray,  n = 1, y=None):
    num_seasons = len(np.unique(seasonality))
    mu = numpyro.sample('mu', dist.Normal(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(0.5))

    with numpyro.plate('seasonality_plate', num_seasons):
        season_effect = numpyro.sample('season_effect', dist.Normal(0, 1))

    # トレンド部分をAR(n)に
    phi = numpyro.sample('phi', dist.Normal(0., 1.), sample_shape=(n,))
    trend_prev = numpyro.sample('trend_init', dist.Normal(0., 1.), sample_shape=(n,))

    obs = []
    n_time = len(y) if y is not None else len(seasonality)
    for t in range(n_time):
        if t >= n:
            trend = numpyro.sample(f'trend_{t}', dist.Normal(jnp.dot(phi, trend_prev), 1.))
            trend_prev = jnp.roll(trend_prev, -1)
            trend_prev = trend_prev.at[-1].set(trend)
        elif t > 0 and t < n:
            trend = numpyro.sample(f'trend_{t}', dist.Normal(jnp.dot(phi[:t], trend_prev[:t]), 1.))
        else:
            trend = trend_prev[t]

        mean_function = jnp.exp(mu + trend + season_effect[seasonality[t]])
        obs.append(mean_function)

    with numpyro.plate('obs_plate', n_time):
        numpyro.sample('obs', dist.Normal(jnp.array(obs), sigma), obs=y)


def ma_model(seasonality : np.ndarray,n = 1, y=None):
    num_seasons = len(np.unique(seasonality))
    mu = numpyro.sample('mu', dist.Normal(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(0.5))

    with numpyro.plate('seasonality_plate', num_seasons):
        season_effect = numpyro.sample('season_effect', dist.Normal(0, 1))

    # トレンド部分をMA(n)に
    theta = numpyro.sample('theta', dist.Normal(0., 1.), sample_shape=(n,))
    error_prev = numpyro.sample('error_init', dist.Normal(0., 1.), sample_shape=(n,))

    obs = []
    n_time = len(y) if y is not None else len(seasonality)
    for t in range(n_time):
        error = numpyro.sample(f'error_{t}', dist.Normal(0., 1.))
        if t >= n:
            trend = jnp.dot(theta, error_prev) + error
            error_prev = jnp.roll(error_prev, -1)
            error_prev = error_prev.at[-1].set(error)
        elif t > 0 and t < n:
            trend = jnp.dot(theta[:t], error_prev[:t]) + error
        else:
            trend = error

        mean_function = jnp.exp(mu + trend + season_effect[seasonality[t]])
        obs.append(mean_function)

    with numpyro.plate('obs_plate', n_time):
        numpyro.sample('obs', dist.Normal(jnp.array(obs), sigma), obs=y)

def logistic_trend_model(seasonality : np.ndarray, y=None):
    num_seasons = len(np.unique(seasonality))
    mu = numpyro.sample('mu', dist.Normal(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(0.5))

    with numpyro.plate('seasonality_plate', num_seasons):
        season_effect = numpyro.sample('season_effect', dist.Normal(0, 1))

    # トレンド部分をロジスティック非線形トレンドに
    alpha = numpyro.sample('alpha', dist.Normal(0., 1.))
    beta = numpyro.sample('beta', dist.Normal(0., 1.))
    obs = []
    n_time = len(y) if y is not None else len(seasonality)
    for t in range(n_time):
        trend = 1. / (1. + jnp.exp(-alpha - beta * t))
        mean_function = jnp.exp(mu + trend + season_effect[seasonality[t]])
        obs.append(mean_function)

    with numpyro.plate('obs_plate', n_time):
        numpyro.sample('obs', dist.Normal(jnp.array(obs), sigma), obs=y)


def train_model(
    model: Callable, 
    seasonality: np.ndarray, 
    y: np.ndarray, 
    n: int = None, 
    num_samples: int = 1000, 
    num_warmup: int = 500
) -> MCMC:
    
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup)
    if n is not None:
        mcmc.run(random.PRNGKey(0), seasonality=seasonality, n=n, y=y)
    else:
        mcmc.run(random.PRNGKey(0), seasonality=seasonality, y=y)
    
    return mcmc


# データの長さ
n = 100

# 年間を通じての週ごとの周期性を模擬します。
# 例えば、週の何日目か（0〜6）を示すシーズン性を持つデータを作成します。
seasonality = np.arange(n) % 7

# 真のパラメータ
true_mu = 0.5
true_sigma = 0.1
true_season_effect = np.random.normal(0, 1, 7)  # 一週間の各日に対する影響

# 真のパラメータを用いてyを生成します。
epsilon = np.random.normal(0, true_sigma, n)
y = true_mu + true_season_effect[seasonality] + epsilon

# 使用例
ar_mcmc = train_model(ar_model, seasonality, y, n = 2)
ma_mcmc = train_model(ma_model, seasonality, y, n = 2)
logistic_mcmc = train_model(logistic_trend_model, seasonality, y)

# サンプルの取得
ar_samples = ar_mcmc.get_samples()
ma_samples = ma_mcmc.get_samples()
logistic_samples = logistic_mcmc.get_samples()


# データの長さ
n = 100

# 年間を通じての週ごとの周期性を模擬します。
# 例えば、週の何日目か（0〜6）を示すシーズン性を持つデータを作成します。
seasonality = np.arange(n) % 7

# 真のパラメータ
true_mu = 0.5
true_sigma = 0.1
true_season_effect = np.random.normal(0, 1, 7)  # 一週間の各日に対する影響

# 真のパラメータを用いてyを生成します。
epsilon = np.random.normal(0, true_sigma, n)
y = true_mu + true_season_effect[seasonality] + epsilon


# 使用例
ar_mcmc = train_model(ar_model, seasonality, y)
ma_mcmc = train_model(ma_model, seasonality, y)
logistic_mcmc = train_model(logistic_trend_model, seasonality, y)

# サンプルの取得
ar_samples = ar_mcmc.get_samples()
ma_samples = ma_mcmc.get_samples()
logistic_samples = logistic_mcmc.get_samples()


RuntimeError: ignored

In [None]:
samples = mcmc.get_samples()

In [None]:
import matplotlib.pyplot as plt

# サンプルから季節効果を取得
season_effect_samples = samples['season_effect']

# 各月の季節効果の平均と標準偏差を計算
season_effect_mean = np.mean(season_effect_samples, axis=0)
season_effect_std = np.std(season_effect_samples, axis=0)

# 各月の季節効果の平均と標準偏差をプロット
plt.figure(figsize=(10, 5))
plt.errorbar(range(1, 13), season_effect_mean, yerr=season_effect_std, fmt='o')
plt.xticks(range(1, 13), ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
plt.xlabel('Month')
plt.ylabel('Season effect')
plt.title('Season effect for each month')
plt.grid(True)
plt.show()


In [None]:
from numpyro.infer import MCMC, NUTS, Predictive
# 事後予測分布のサンプリング
predictive = Predictive(multiplicative_model, samples)
rng_key = random.PRNGKey(1)  # 別の乱数シードを使用
predictive_samples = predictive(rng_key, seasonality)

# 予測値（平均）
y_pred = predictive_samples['obs'].mean(axis=0)

# 予測区間（95%信頼区間）
y_lower = np.percentile(predictive_samples['obs'], 2.5, axis=0)
y_upper = np.percentile(predictive_samples['obs'], 97.5, axis=0)

# プロット
plt.figure(figsize=(10, 5))
plt.plot(y, label='True')
plt.plot(y_pred, label='Predicted')
plt.fill_between(range(len(y)), y_lower, y_upper, alpha=0.2)
plt.legend()
plt.show()



In [None]:
predictive_samples

In [None]:
import jax
print(jax.__version__)
