### **ガウス分布の平均パラメタ$\mu$の推定**
- 尤度関数(ガウス分布):<br>
$p(x|\mu,\lambda)=\mathcal{N}(x|\mu,\lambda^{-1})$
- 事前分布(ガウス分布):<br>
$p(\mu)=\mathcal{N}(x|m,\lambda_{\mu}^{-1})$
- 観測データ(ガウス分布から発生):<br>
$\mathcal{D}=\left\{x_1,x_2,x_3,...,x_N\right\},\ x_n\sim\mathcal{N}(x|\mu,\lambda^{-1})$

In [39]:
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.animation import ArtistAnimation
from IPython.display import HTML
from tqdm import tqdm
from scipy.stats import norm
from math import gamma
from scipy.special import loggamma

In [40]:
def t_pdf(x,mean,lam,nu):
    A = np.exp(loggamma((nu+1)/2) - loggamma(nu/2))
    B = ((lam/(np.pi*nu))**(1/2))
    C = (1+(lam/nu)*((x-mean)**2))**(-(nu+1)/2)
    
    return A*B*C    


def KLdiv(x,p,q):
    
    width = x[1] - x[0]
    
    p = np.array(p)
    q = np.array(q)
    
    out = np.sum(p*np.log(p/q))*width
    
    return out

In [44]:
# true - gaussian distribution
true_dist = {"mean":5,
             "lambda":1}

#noise
noise_var = 0.2

# prior - gamma distribution
prior = {"shape":0.1,
         "lambda":0.001}

In [45]:
#%matplotlib nbagg

artist = []
fig, ax = plt.subplots(2,2,figsize=(10,8))
plt.subplots_adjust(wspace=0.4, hspace=0.6)

frame_num = 40
N_list = 10**np.linspace(0,6,frame_num)

# 左上
ax[0,0].grid()
ax[0,0].set_title("observed data $\mathcal{D}$")
ax[0,0].set_xlabel("$x$",fontsize=15)
Ucb = true_dist["mean"] + 3*true_dist["lambda"]**(-1)
Lcb = true_dist["mean"] - 3*true_dist["lambda"]**(-1)
xlim = ax[0,0].set_xlim(Lcb,Ucb)

# 右上
ax[0,1].grid()
ax[0,1].set_title("posterior $p(\lambda|\mathcal{D})$")
ax[0,1].set_xlabel("$\mu$",fontsize=15)
ax[0,1].set_xlim([0,1.5])

# 左下
ax[1,0].grid()
ax[1,0].set_title("predictive distribution $\int p(x_*|\lambda)p(\lambda|\mathcal{D})d\lambda$")
ax[1,0].set_xlim(xlim)
ax[1,0].set_ylabel("$p(x_*)$",fontsize=15)
ax[1,0].set_xlabel("$x_*$",fontsize=15)

# 右下
ax[1,1].grid()
ax[1,1].set_title("KL-div between true and pred")
ax[1,1].set_xlim(min(N_list),max(N_list))
ax[1,1].set_xlabel("#data",fontsize=15)
ax[1,1].set_xscale("log")
ax[1,1].set_ylabel("KL-divergence $D(p_{true}||p_{pred})$")

est = []
for N in tqdm(N_list):
    
    x = np.linspace(xlim[0],xlim[1],100)
    
    # observe data
    np.random.seed(0)
    data = np.random.normal(loc=true_dist["mean"],
                            scale=true_dist["lambda"]**(-1),
                            size=int(N))
    
    noise = np.random.normal(loc=0,
                             scale=noise_var,
                             size=int(N))
    data += noise
    
    _,_,data_plot = ax[0,0].hist(data,bins=100,color="black")
    
    
    # likelihood
    likelihood = {"mean":true_dist["mean"],
                  "lambda":(data.var())**(-1)}
    
    # posterior - gamma distribution
    posterior = {}
    posterior["shape"] = int(N)*0.5 + prior["shape"]
    posterior["lambda"] = 0.5*np.sum((data-likelihood["mean"])**2)+prior["lambda"]
    
    # visualize
    posterior_obs = np.random.gamma(shape=posterior["shape"],
                                    scale=posterior["lambda"]**(-1),
                                    size=10000)
    

    _,_,post_plot = ax[0,1].hist(posterior_obs,bins=100,color="red",density=True,range=(0,1.5))
    
    
    # predictive - Student t distribution
    pred_dist = {}
    pred_dist["mean"] = likelihood["mean"]
    pred_dist["lambda"] = posterior["shape"]/posterior["lambda"]
    pred_dist["nu"] = 2*posterior["shape"]
    
    
    pred = t_pdf(x,pred_dist["mean"],pred_dist["lambda"],pred_dist["nu"])
    true = norm.pdf(x,true_dist["mean"],true_dist["lambda"]**(-1))
    
    hikaku_plot = ax[1,0].plot(x,pred,color="blue",label="predict")
    hikaku_plot += ax[1,0].plot(x,true,color="black",label="true")
    if N == min(N_list):
        ax[1,0].legend()
    
    # error load
    error = KLdiv(x,true,pred)
    #print(error)
    est.append([N,error])
    est_load = np.array(est)
    error_plot = ax[1,1].plot(est_load[:,0],est_load[:,1],color="black")
    
    frame = data_plot+post_plot+hikaku_plot+error_plot
    artist.append(frame)

plt.close()
anim = ArtistAnimation(fig, artist, interval=100)
HTML(anim.to_jshtml())
    

100%|██████████| 40/40 [00:13<00:00,  2.99it/s]


In [46]:
anim.save('gaussian_var_estimation_noise02.gif', writer="imagemagick")

MovieWriter imagemagick unavailable; trying to use <class 'matplotlib.animation.PillowWriter'> instead.
