In [1]:
# pip install pymc arviz numpy
import numpy as np
import pymc as pm
import arviz as az

# -------------------------
# Data (example)
# X: [N, D], y: [N] in {0,1}
# -------------------------
rng = np.random.default_rng(42)
N, D = 200, 3
X = rng.normal(size=(N, D))

# 例: 真の係数で生成（デモ用）
w_true = np.array([0.8, -1.2, 0.5])
b_true = -0.3
p_true = 1 / (1 + np.exp(-(b_true + X @ w_true)))
y = rng.binomial(n=1, p=p_true, size=N)

# -------------------------
# Bayesian Logistic Regression
# -------------------------
with pm.Model() as model:
    # 事前分布（正則化の役割）
    w = pm.Normal("w", mu=0.0, sigma=1.0, shape=D)
    b = pm.Normal("b", mu=0.0, sigma=1.0)

    # 線形予測子
    eta = b + pm.math.dot(X, w)

    # ロジスティック（= sigmoid）
    p = pm.Deterministic("p", pm.math.sigmoid(eta))

    # 尤度（y ~ Bernoulli(p)）
    y_obs = pm.Bernoulli("y_obs", p=p, observed=y)

    # サンプリング（NUTS/HMC）
    idata = pm.sample(
        draws=2000,
        tune=1000,
        chains=4,
        target_accept=0.9,
        random_seed=42
    )

# 係数の事後要約
print(az.summary(idata, var_names=["b", "w"], round_to=3))

# 予測（事後平均での確率例）
w_post = idata.posterior["w"].mean(("chain", "draw")).values
b_post = idata.posterior["b"].mean(("chain", "draw")).values
p_hat = 1 / (1 + np.exp(-(b_post + X @ w_post)))
print("p_hat (first 5):", p_hat[:5])


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [w, b]


Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2315 seconds.


       mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
b    -0.373  0.168  -0.683   -0.054      0.002    0.002  7619.778  6196.065   
w[0]  0.967  0.188   0.610    1.314      0.002    0.002  7387.036  5346.261   
w[1] -1.114  0.205  -1.499   -0.731      0.002    0.002  7494.160  6473.567   
w[2]  0.296  0.185  -0.051    0.635      0.002    0.002  6777.328  5952.392   

      r_hat  
b     1.001  
w[0]  1.002  
w[1]  1.000  
w[2]  1.001  
p_hat (first 5): [0.7863313  0.91087033 0.52452685 0.12491824 0.19367207]
