In [132]:
import numpy as np
import math

In [133]:
def calculate_Gaussian_P(xn, u, var):
    """
    input: 3 scalar
    """
    return (1/np.sqrt(2*math.pi*var))*np.exp(-(xn - u)**2/2/var)

def calculate_gamma(X, u, var, pi):
    """
    input: 
        X: array, (N, 1)
        u, var, pi: array, (K, 1)
    return:
    gamma: (N, K) array
    """
    gamma = np.zeros((len(X), len(u)))
    for k in range(len(u)):
        for n in range(len(X)):
            sum = 0
            for j in range(len(u)):
                sum += pi[j] * calculate_Gaussian_P(X[n], u[j], var[j])
            gamma_nk = pi[k] * calculate_Gaussian_P(X[n], u[k], var[k])/sum
            gamma[n, k] = gamma_nk
    return gamma

def calculate_log_p(X, u, var, pi):
    log_p = 0
    for n in range(len(X)):
        pn = 0
        for k in range(len(u)):
            pn += pi[k] * calculate_Gaussian_P(X[n], u[k], var[k])
        log_p += np.log(pn)
    return log_p


In [134]:
def GMM(X, K):
    """
    input: 
        X: (N, 1) array
        K: scalar, number of klusters
    """
    N = len(X)
    interval = np.round(100/(K+1)).astype(int)
    u_per = [i for i in range(interval, interval*K+1, interval)]
    u = np.percentile(X, u_per)
    var = np.ones(K) * 1.0
    pi = np.ones(K)/K
    log_p_last = calculate_log_p(X, u, var, pi)

    for round in range(10):
        # Estep:
        gamma = calculate_gamma(X, u, var, pi)
        # M-step:
        for k in range(K):
            gamma_k = gamma[:, k]
            Nk = gamma_k.sum()
            u_new = (gamma_k * X).sum() / Nk
            u[k] = u_new
            xn_uk = X - np.ones(N)*u_new
            var[k] =  (gamma_k * xn_uk *xn_uk).sum() / Nk
            pi[k] = Nk/N
        log_p = calculate_log_p(X, u, var, pi)  
        if log_p < log_p_last * 0.995:
            break
        log_p_last = log_p

        print("log_p:", calculate_log_p(X, u, var, pi))
        print("u:", u)
        print("var:", var)
        print("pi:", pi)

In [135]:
X = np.array([1, 2, 8, 9, 20, 21])
K = 3
GMM(X, K)

log_p: -10.946429301363722
u: [ 1.50014869  8.50001127 20.5       ]
var: [0.25096188 0.25000004 0.25      ]
pi: [0.33334095 0.33332572 0.33333333]
