## GMM高斯混合模型示例

In [2]:
import numpy as np
from sklearn import mixture
import matplotlib.pyplot as plt

In [None]:
#生成数据
def gendata():
    obs = np.concatenate((1.6*np.random.randn(300, 2), 6 + 1.3*np.random.randn(300, 2), np.array([-5, 5]) + 1.3*np.random.randn(200, 2), np.array([2, 7]) + 1.1*np.random.randn(200, 2)))
    return obs

In [9]:
def gaussian_2d(x, y, x0, y0, xsig, ysig):
    return 1/(2*np.pi*xsig*ysig) * np.exp(-0.5*(((x-x0) / xsig)**2 + ((y-y0) / ysig)**2))

In [14]:
# 生成GMM模型并fit到我们给定的数据集上
def gengmm(nc=4, n_iter = 2):
    g = mixture.GMM(n_components=nc)  # 多少个高斯函数
    g.init_params = ""  # 初始化
    g.n_iter = n_iter   # EM迭代次数
    return g

In [15]:
def plotGMM(g, n, pt):
    delta = 0.025
    x = np.arange(-10, 10, delta)
    y = np.arange(-6, 12, delta)
    X, Y = np.meshgrid(x, y)
 
    if pt == 1:
        for i in xrange(n):
            Z1 = gaussian_2d(X, Y, g.means_[i, 0], g.means_[i, 1], g.covars_[i, 0], g.covars_[i, 1])
            plt.contour(X, Y, Z1, linewidths=0.5)
 
    #绘出高斯分布的中心点
    plt.plot(g.means_[0][0],g.means_[0][1], '+', markersize=13, mew=3)
    plt.plot(g.means_[1][0],g.means_[1][1], '+', markersize=13, mew=3)
    plt.plot(g.means_[2][0],g.means_[2][1], '+', markersize=13, mew=3)
    plt.plot(g.means_[3][0],g.means_[3][1], '+', markersize=13, mew=3)

In [16]:
obs = gendata()
fig = plt.figure(1)
g = gengmm(4, 100)
g.fit(obs)
plt.plot(obs[:, 0], obs[:, 1], '.', markersize=3)
plotGMM(g, 4, 0)
plt.title('Gaussian Mixture Model')
plt.show()

<img src='GMM.png'>

In [17]:
g = gengmm(4, 1)
g.fit(obs)
plt.plot(obs[:, 0], obs[:, 1], '.', markersize=3)
plotGMM(g, 4, 1)
plt.title('Gaussian Models (Iter = 1)')
plt.show()
 
g = gengmm(4, 5)
g.fit(obs)
plt.plot(obs[:, 0], obs[:, 1], '.', markersize=3)
plotGMM(g, 4, 1)
plt.title('Gaussian Models (Iter = 5)')
plt.show()
 
g = gengmm(4, 20)
g.fit(obs)
plt.plot(obs[:, 0], obs[:, 1], '.', markersize=3)
plotGMM(g, 4, 1)
plt.title('Gaussian Models (Iter = 20)')
plt.show()
 
g = gengmm(4, 100)
g.fit(obs)
plt.plot(obs[:, 0], obs[:, 1], '.', markersize=3)
plotGMM(g, 4, 1)
plt.title('Gaussian Models (Iter = 100)')
plt.show()

![iter1](http://2.bp.blogspot.com/-di2iGvjV5hk/VGt45lYoy3I/AAAAAAAAA1M/q0K3arW_j38/s1600/iter_1.png)

![iter2](http://4.bp.blogspot.com/-5Gr6YWHlcpQ/VGt46YCR1oI/AAAAAAAAA1Y/x2lfEKEBBMI/s1600/iter_5.png)

![iter20](http://3.bp.blogspot.com/-HD3sjla6GuI/VGt451OzzoI/AAAAAAAAA1Q/JEmENkHlS5U/s1600/iter_20.png)

![iter200](http://3.bp.blogspot.com/-v4yqPX-ESew/VGt45hsEStI/AAAAAAAAA1U/wy9eCtLLVos/s1600/iter_100.png)

![GMM](http://4.bp.blogspot.com/-zuCQBrN8990/VGt45PZHXhI/AAAAAAAAA1E/jtQQaAj-PMc/s1600/gmm2.png)