### **ガウス分布の平均パラメタ$\mu$の推定**
- 尤度関数(ガウス分布):<br>
$p(x|\mu,\lambda)=\mathcal{N}(x|\mu,\lambda^{-1})$
- 事前分布(ガウス分布):<br>
$p(\mu)=\mathcal{N}(x|m,\lambda_{\mu}^{-1})$
- 観測データ(ガウス分布から発生):<br>
$\mathcal{D}=\left\{x_1,x_2,x_3,...,x_N\right\},\ x_n\sim\mathcal{N}(x|\mu,\lambda^{-1})$

In [21]:
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.animation import ArtistAnimation
from IPython.display import HTML
from tqdm import tqdm
from scipy.stats import norm

In [22]:
def gaussian_KLdiv(dist1,dist2):
    m1 = dist1["mean"]
    sig1 = dist1["lambda"]**(-1/2)
    
    m2 = dist2["mean"]
    sig2 = dist2["lambda"]**(-1/2)
    
    KLdiv = np.log(sig1/sig2) + ((sig1**2) + (m1-m2)**2)/(2*(sig2**2))-(1/2)
    
    return KLdiv

In [25]:
# true
true_dist = {"mean":5,
             "lambda":1}

#noise
noise_var = 4

# prior distribution
prior = {"mean":4,
         "lambda":0.0001}

In [30]:
#%matplotlib nbagg

artist = []
fig, ax = plt.subplots(2,2,figsize=(10,8))
plt.subplots_adjust(wspace=0.4, hspace=0.6)

frame_num = 40
N_list = 10**np.linspace(0,4,frame_num)

# 左上
ax[0,0].grid()
ax[0,0].set_title("observed data $\mathcal{D}$")
ax[0,0].set_xlabel("$x$",fontsize=15)
Ucb = true_dist["mean"] + 3*true_dist["lambda"]**(-1)
Lcb = true_dist["mean"] - 3*true_dist["lambda"]**(-1)
xlim = ax[0,0].set_xlim(Lcb,Ucb)

# 右上
ax[0,1].grid()
ax[0,1].set_title("posterior $p(\mu|\mathcal{D})$")
ax[0,1].set_xlabel("$\mu$",fontsize=15)
ax[0,1].set_xlim(xlim)

# 左下
ax[1,0].grid()
ax[1,0].set_title("predictive distribution $\int p(x_*|\mu)p(\mu|\mathcal{D})d\mu$")
ax[1,0].set_xlim(xlim)
ax[1,0].set_ylabel("$p(x_*)$",fontsize=15)
ax[1,0].set_xlabel("$x_*$",fontsize=15)

# 右下
ax[1,1].grid()
ax[1,1].set_title("KL-div between true and pred")
ax[1,1].set_xlim(min(N_list),max(N_list))
ax[1,1].set_xlabel("#data",fontsize=15)
ax[1,1].set_xscale("log")
ax[1,1].set_ylabel("KL-divergence $D(p_{true}||p_{pred})$")

est = []
for N in tqdm(N_list):
    
    x = np.linspace(xlim[0],xlim[1],100)
    
    # observe data
    np.random.seed(0)
    data = np.random.normal(loc=true_dist["mean"],
                            scale=true_dist["lambda"]**(-1),
                            size=int(N))
    
    noise = np.random.normal(loc=0,
                             scale=noise_var,
                             size=int(N))
    data += noise
    
    _,_,data_plot = ax[0,0].hist(data,bins=100,color="black")
    
    
    # likelihood
    likelihood = {"mean":data.mean(),
                  "lambda":true_dist["lambda"]}
    
    # posterior distribution
    posterior = {}
    posterior["lambda"] = prior["lambda"] + N*likelihood["lambda"]
    posterior["mean"] = (likelihood["lambda"]*data.sum()+prior["lambda"]*prior["mean"])/posterior["lambda"]
    
    # visualize
    posterior_obs = np.random.normal(loc=posterior["mean"],
                                    scale=posterior["lambda"]**(-1),
                                    size=10000)
    

    _,_,post_plot = ax[0,1].hist(posterior_obs,bins=100,color="red",density=True,range=(Lcb,Ucb))
    
    
    # predictive distribution
    pred_dist = {}
    pred_dist["mean"] = posterior["mean"]
    pred_dist["lambda"] = ((likelihood["lambda"]**(-1))+(posterior["lambda"]**(-1)))**(-1)
    
    
    pred = norm.pdf(x,pred_dist["mean"],pred_dist["lambda"]**(-1))
    true = norm.pdf(x,true_dist["mean"],true_dist["lambda"]**(-1))
    
    hikaku_plot = ax[1,0].plot(x,pred,color="blue",label="predict")
    hikaku_plot += ax[1,0].plot(x,true,color="black",label="true")
    if N == min(N_list):
        ax[1,0].legend()
    
    
    # error load
    KLdiv = gaussian_KLdiv(pred_dist,true_dist)
    est.append([N,KLdiv])
    est_load = np.array(est)
    error_plot = ax[1,1].plot(est_load[:,0],est_load[:,1],color="black")
    
    frame = data_plot+post_plot+error_plot+hikaku_plot
    artist.append(frame)

plt.close()
anim = ArtistAnimation(fig, artist, interval=100)
HTML(anim.to_jshtml())
    

  return n/db/n.sum(), bin_edges
100%|██████████| 40/40 [00:18<00:00,  2.18it/s]


In [None]:
anim.save('gaussian_mean_estimation.gif', writer="imagemagick")

MovieWriter imagemagick unavailable; trying to use <class 'matplotlib.animation.PillowWriter'> instead.
