In [1]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.tri as tri
import edward as ed
import numpy as np
import tensorflow as tf
from edward.models import Categorical, MultivariateNormalDiag, Dirichlet, OneHotCategorical, MultivariateNormalFullCovariance
from functools import reduce
import time
import distributions
from build_dataset import build_dataset_2dim_Kclass_gmm
from visualize_tools import circle, ellipse, covariance_ellipse, change_aspect_ratio

In [2]:
font = {'family': 'serif',
        'color':  'black',
        'weight': 'normal',
        'size': 16,
        }
plt.style.use('ggplot')

In [8]:
N = 3000
K = 3
mean_true = [[-3.0,-3.0], [-1.0,6.0] , [5.0,-1.0]]
covariance_true = [[[2.0,0.4],[0.4,1.0]],
       [[2.0,0.4],[0.4,1.0]],
       [[2.0,0.4],[0.4,1.0]]]
mix_true  = np.array([ 1.0/3.0 , 1.0/3.0, 1.0/3.0 ])
x_data, label,mix = build_dataset_2dim_Kclass_gmm(N,K,mean_true,covariance_true,mix_true)

In [9]:
%matplotlib nbagg
fig = plt.figure()
ax = plt.axes([0.1,0.1,0.8,0.8])
ax.set_xlim(-10.0,10.0)
ax.set_ylim(-10.0,10.0)
plt.gca().set_aspect('equal', adjustable='box')
for k in range(K):
    c1,c2,ca=covariance_ellipse(covariance_true[k])
    ex,ey=ellipse(c1,c2,ca)
    plt.plot(ex+mean_true[k][0],ey+mean_true[k][1],color="black")
color=["#FF0000","#00FF00","#0000FF"]
for n,x in enumerate(x_data):
    plt.scatter(x[0],x[1],s=1.0,color=color[label[n]])
plt.show()

<IPython.core.display.Javascript object>

In [7]:
def gmm_bbvi_z_pi(N,mu,cov,x_data,label,n_iter):
    start = time.time()

    mu = tf.constant(mu, dtype = tf.float32)
    sigma = tf.constant(cov, dtype = tf.float32)

    #generative model 
    alpha = tf.constant([1.0, 1.0, 1.0])
    pi = Dirichlet(concentration = alpha)
    z = OneHotCategorical(probs = pi, dtype = tf.float32) 
    x = MultivariateNormalFullCovariance(loc = sum([mu[k]*z[n][k] for k in range(K)]),covariance_matrix = sum([sigma[k]*z[n][k] for k in range(K)])) 
    print("generative model")
    #variational model
    lambda_pi = tf.nn.softplus(tf.Variable([0.0 , 0.0 , 0.0]))
    qpi = Dirichlet(concentration = lambda_pi)

    x_ph =  tf.placeholder(tf.float32,[2]) 

    y = /* TODO*/
    lambda_z = tf.nn.softmax(y[n]) 
    qz = OneHotCategorical(probs = lambda_z, dtype = tf.float32) 
    print("variational model")
    latent_vars = {z:qz,pi:qpi}
    data = {x:x_ph}
    inference = ed.KLqp(latent_vars=latent_vars,data=data)
    inference.initialize(n_iter=n_iter)
    print("inference")
    print(time.time()-start)
    sess = ed.get_session()
    tf.global_variables_initializer().run()
    loss =[]
    variational_parameter=[]
    variational_parameter.append(sess.run([lambda_pi,lambda_z]))
    for _ in range(inference.n_iter):
        info_dict = inference.update(feed_dict = {x_ph:x_data})
        loss.append(info_dict['loss'])
        variational_parameter.append(sess.run([lambda_pi,lambda_z]))
    print(time.time()-start)
    return loss,variational_parameter                             
                                 

In [16]:
loss,variational_parameter = gmm_bbvi_z_pi(N,mean_true,covariance_true,x_data,label,5000)

generative model
variational model
inference
55.76840400695801
369.28381538391113


In [17]:
%matplotlib nbagg
plt.title("loss(-ELBO)")
plt.plot(loss)
plt.xlabel("iteration number")
plt.show()

<IPython.core.display.Javascript object>

In [18]:
%matplotlib nbagg
import scipy.stats as ss
import matplotlib.tri as mtri
from matplotlib import cm
fig = plt.figure()
ax = plt.axes([0.1,0.1,0.8,0.8])
ax.set_xlim(0.0,1.0)
ax.set_ylim(0.0,1.0)
plt.gca().set_aspect('equal', adjustable='box')
X = []
Y = []

for x in np.arange(0.01,0.99,0.01):
    for y in  np.arange(0.01,0.99-x,0.01):
        X.append(x)
        Y.append(y)
triang = mtri.Triangulation(X, Y)

plt.plot(mix_true[0],mix_true[1],"bx")
artists = []
for t,vp in enumerate(variational_parameter):
    if t %100 == 0:
        dc = ss.dirichlet(np.array(vp[0]))
        
        Z = []
        for x,y in zip(X,Y):
            Z.append(dc.pdf((x,y)))

        sum_pi = vp[0][0]+vp[0][1]+vp[0][2]
        exp = vp[0]/sum_pi
            
        text = [plt.text(0.4,0.9,"iteration number : "+str(t)),
                    plt.text(0.4,0.8,"$\lambda_\pi$ : ["+"{0:.3f}".format(vp[0][0]) + " {0:.3f}".format(vp[0][1]) + " {0:.3f}]".format(vp[0][2])),
                    plt.text(0.4,0.7,"$E[\pi]$ : ["+"{0:.3f}".format(exp[0]) + " {0:.3f}".format(exp[1]) + " {0:.3f}]".format(exp[2]))]

        im_pi = ax.tricontourf(triang, Z,10)
        

        e = plt.plot(exp[0],exp[1],"rx")
        
        artists.append(im_pi.collections+text+e)
ani=animation.ArtistAnimation(fig,artists)
plt.show()
#ani.save("gmm_2dim_3class_pi.gif", writer='imagemagick', fps=4)

<IPython.core.display.Javascript object>

In [19]:
%matplotlib nbagg
fig = plt.figure()
ax = plt.axes([0.1,0.1,0.8,0.8])
ax.set_xlim(-10.0,10.0)
ax.set_ylim(-10.0,10.0)
plt.xticks( np.arange(-10.0, 10.0, 1.0) )
plt.yticks( np.arange(-10.0, 10.0, 1.0) )
plt.gca().set_aspect('equal', adjustable='box')
for k in range(K):
    c1,c2,ca=covariance_ellipse(covariance_true[k])
    ex,ey=ellipse(c1,c2,ca)
    plt.plot(ex+mean_true[k][0],ey+mean_true[k][1],color="black")
    plt.plot(ex*2+mean_true[k][0],ey*2+mean_true[k][1],color="black")
    plt.plot(ex*3+mean_true[k][0],ey*3+mean_true[k][1],color="black")
artists = []
for t,vp in enumerate(variational_parameter):
    text = [plt.text(5.0,9.0,"iteration number : "+str(t))]
    im_z=plt.scatter(x_data.T[0],x_data.T[1],color=vp[1])
    artists.append([im_z]+text)
ani=animation.ArtistAnimation(fig,artists)
plt.show()
#ani.save("gmm_2dim_3class_z.gif", writer='imagemagick', fps=4)

<IPython.core.display.Javascript object>