In [1]:
%matplotlib nbagg
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
import matplotlib.animation as animation
import scipy.stats as stats
from tqdm import trange

In [2]:
def counting(datas, size, out=None):
    if out is None:
        cnt = np.zeros(size, dtype=int)
    else:
        cnt = out
        cnt[:] = 0
    if datas is None:
        return cnt
    for i in range(size):
        cnt[i] = (datas == i).sum()
    return cnt

def rad2deg(rad):
    return rad * 180.0 / np.pi

def plot_mean(axes, mean, color=None, mark="o"):
    return axes.scatter(mean[0], mean[1], marker=mark, c=color, zorder=1)

def plot_covariance(axes, mean, covariance, scale=3.0, color=None):
    la, v = np.linalg.eig(covariance)
    std = np.sqrt(la)
    angle = rad2deg(np.arctan2(v[1,0], v[0,0]))
    e = patches.Ellipse((mean[0], mean[1]), 2*std[0]*scale, 2*std[1]*scale, angle=angle, linewidth=1, fill=False, color=color, zorder=2)
    axes.add_artist(e)
    return e

def plot_datas(axes, datas, color=None, mark="+"):
    return axes.scatter(datas[:,0], datas[:,1], marker=mark, c=color, zorder=1)

In [3]:
#%% Initialize truth parameters.
K_truth = 2
D = 2
N = 300
ITERATION = 100
pi_truth = np.array([0.7, 0.3])
mean_truth = np.array([[2.5, 0.0], [-2.5, 0.0]])
cov_truth = np.empty((K_truth, D, D))
for i in range(K_truth):
    cov_truth[i] = np.identity(D)

In [4]:
#%% Plot color initializeton.
colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:cyan"]

In [5]:
#%% Generate datas.
datas = np.empty((N, D))
hidden_state_truth = np.random.choice(K_truth, size=N, p=pi_truth)
for n in range(N):
    k = hidden_state_truth[n]
    datas[n] = stats.multivariate_normal.rvs(mean=mean_truth[k], cov=cov_truth[k])

In [6]:
#%% Plot datas with truth label.
ax = plt.gca()
for k in range(K_truth):
    plot_datas(ax, datas[hidden_state_truth == k], color=colors[k])
    plot_mean(ax, mean_truth[k], color=colors[k])
    plot_covariance(ax, mean_truth[k], cov_truth[k], color=colors[k])
plt.title("Truth datas.")
xlim = ax.get_xlim()
ylim = ax.get_ylim()
width = xlim[1] - xlim[0]
height = ylim[1] - ylim[0]
ax.set_xlim(0.2 * -width + xlim[0], 0.2 * width + xlim[1])
ax.set_ylim(0.2 * -height + ylim[0], 0.2 * height + ylim[1])
plt.show()

<IPython.core.display.Javascript object>

In [7]:
#%% Initialize hyper parameters.
K = 2
alpha_0 = np.ones(K)
mu_0 = np.zeros((K, D))
Sigma_0 = np.empty((K, D, D))
for k in range(K):
    Sigma_0[k] = np.identity(D)
nu_0 = np.ones(K) * D
Lambda_0 = np.empty((K, D, D))
for k in range(K):
    Lambda_0[k] = np.identity(D)

In [8]:
#%% Initialize local parameters.
pi = stats.dirichlet.rvs(alpha_0)[0]
mu = np.empty((K, D))
Sigma = np.empty((K, D, D))
for k in range(K):
    mu[k] = stats.multivariate_normal.rvs(mean=mu_0[k], cov=Sigma_0[k])
    Sigma[k] = stats.invwishart.rvs(df=nu_0[k], scale=Lambda_0[k])


In [9]:
k_vec = np.zeros(K)
D_zeros = np.matrix(np.zeros(D))
p_yi = np.empty((N, K))
hidden_state = np.zeros(N, dtype=int)
hidden_state_count = np.zeros(K, dtype=int)

In [10]:
#%% Define history objects.
fig = plt.figure()
ax = plt.gca()
width = xlim[1] - xlim[0]
height = ylim[1] - ylim[0]
ax.set_xlim(0.2 * -width + xlim[0], 0.2 * width + xlim[1])
ax.set_ylim(0.2 * -height + ylim[0], 0.2 * height + ylim[1])
ims = []

<IPython.core.display.Javascript object>

In [11]:
#%% Start Gibbs sampling.
for t in trange(ITERATION):

    # Calculate phase.
    for k in range(K):
        p_yi[:,k] = stats.multivariate_normal.pdf(datas, mean=mu[k], cov=Sigma[k])
    p_yi *= pi

    # Resampling labels.
    for n in range(N):
        k_vec[:] = p_yi[n]
        k_vec /= k_vec.sum()
        hidden_state[n] = np.random.choice(K, p=k_vec)
    counting(hidden_state, K, out=hidden_state_count)

    # Resampling parameters phase.
    # Resampling pi.
    alpha_hat = alpha_0 + hidden_state_count
    pi[:] = stats.dirichlet.rvs(alpha_hat)

    # Resampling mu_k and Sigma_k
    for k in range(K):
        data_k = np.matrix(datas[hidden_state == k])
        if data_k.shape[0] is 0:
            data_k = D_zeros
        M_k = hidden_state_count[k]
        sum_data_k = data_k.sum(axis=0)

        # Resampling Sigma_k.
        tmp = np.matrix(data_k - mu[k])
        Lambda_hat = tmp.T.dot(tmp) + Lambda_0[k]
        nu_hat = M_k + nu_0[k]
        Sigma[k] = stats.invwishart.rvs(df=nu_hat, scale=Lambda_hat)

        # Resampling mu_k.
        Sigma_hat = np.array((M_k * np.matrix(Sigma[k]).I + np.matrix(Sigma_0[k]).I).I)
        mu_hat = np.array((sum_data_k.dot(np.matrix(Sigma[k]).I) + mu_0[k].dot(np.matrix(Sigma_0[k]).I)).dot(Sigma_hat))[0]
        mu[k] = stats.multivariate_normal.rvs(mean=mu_hat, cov=Sigma_hat)

    # Save animation phase.
    tmp_im = []
    #ax.cla()
    tmp_im.append(ax.annotate("ITERATION : %2d" % (t,), xy=(0.6,0.1), xycoords='axes fraction', fontsize=15, bbox=dict(boxstyle="round", fc="white"), zorder=3))
    for k in range(K):
        if hidden_state_count[k] != 0:
            data_k = datas[hidden_state == k]
            tmp_im.append(plot_datas(ax, data_k, color=colors[k]))
        tmp_im.append(plot_mean(ax, mu[k], color=colors[k]))
        tmp_im.append(plot_covariance(ax, mu[k], Sigma[k], color=colors[k]))
    ims.append(tmp_im)

100%|██████████| 100/100 [00:05<00:00, 17.28it/s]


In [12]:
#%% Plot phase.
anim = animation.ArtistAnimation(fig, ims, interval=100, repeat_delay=0, blit=True)
plt.show()