# 高斯混合模型实践

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.mixture import GaussianMixture

In [3]:
def generate_data(n_samples1=200, n_samples2=200, variance1=1, variance2=1, random_state=42):
    np.random.seed(random_state)
    cov1 = np.identity(2) * variance1
    cov2 = np.identity(2) * variance2
    cluster1 = np.random.multivariate_normal(mean=[-2, 2], cov=cov1, size=n_samples1)
    cluster2 = np.random.multivariate_normal(mean=[2, -2], cov=cov2, size=n_samples2)
    X = np.vstack((cluster1, cluster2))
    return X

In [4]:
def compute_gauss(X, mu, Sigma):
    n, d = X.shape
    diff = X - mu
    return 1 / np.sqrt(np.linalg.det(Sigma) * (2 * np.pi) ** d) * np.exp(-0.5 * np.sum(diff @ np.linalg.inv(Sigma) * diff, axis=1))

In [5]:
def apply_gmm(X, pi, mu, Sigma, max_iter=100):
    n, d = X.shape
    k = len(pi)
    r = np.zeros((n, k))
    coef, rvals = [], []
    # store the initial values
    coef.append({'pi': pi.copy(),'mu': mu.copy(), 'Sigma': Sigma.copy()})
    rvals.append(np.ones((n, k)) / k)
    
    for i in range(max_iter):
        # E-step: compute responsibilities
        for j in range(k):
            r[:, j] = compute_gauss(X, mu[j], Sigma[j])
        r /= np.sum(r, axis=1, keepdims=True)
        # M-step: update parameters
        pi = np.mean(r, axis=0)
        mu = np.dot(r.T, X) / np.sum(r, axis=0, keepdims=True)
        Sigma = np.zeros((k, d, d))
        for j in range(k):
            diff = X - mu[j]
            Sigma[j] = np.dot(r[:, j] * diff.T, diff) / np.sum(r[:, j])
        # store the current values
        coef.append({'pi': pi.copy(),'mu': mu.copy(), 'Sigma': Sigma.copy()})
        rvals.append(r.copy())
    return {'coef': coef, 'rvals': rvals}

In [6]:
X = generate_data(n_samples1=200, n_samples2=200, variance1=3, variance2=1)
mu1, mu2 = np.array([0.1, 0]), np.array([0, 0])
Sigma1, Sigma2 = np.identity(2) * 0.1, np.identity(2) * 0.1
pi = [0.1, 0.9]
mu, Sigma = [mu1, mu2], [Sigma1, Sigma2]
res = apply_gmm(X, pi, mu, Sigma)

In [None]:
hist_index = [0, 1, 5, 10, 25, 40]
fig, ax = plt.subplots(2, 3)
ax = ax.ravel()
for ix, axi in zip(hist_index, ax):
    