In [9]:
from scipy.stats import multivariate_normal
import numpy as np
def pdf(points, mean, cov, prior):
    points, mean, cov = np.asarray(points), np.asarray(mean), np.asarray(cov)
    prior = np.asarray(prior)
    n, d = points.shape
    k, d_1 = mean.shape
    k_2, d_2, d_3 = cov.shape
    k_3, = prior.shape
    assert d == d_1 == d_2 == d_3
    assert k == k_2 == k_3, "%s %s %s should be equal" % (k, k_2, k_3)

    # Compute probabilities
    prob = []
    for i in range(k):
        if prior[i] < 1 / k ** 3:
            prob.append(np.zeros(n))
        else:
            prob.append(
                prior[i] *
                multivariate_normal.pdf(
                    mean=mean[i], cov=cov[i], x=points))
    prob = np.transpose(prob)  # n x k
    # Normalize cluster probabilities of each point
    prob = prob / np.sum(prob, axis=1, keepdims=True)  # n x k

    assert prob.shape == (n, k)
    assert np.allclose(prob.sum(axis=1), 1)
    return prob

In [10]:
def most_likely(points, mean, cov, prior):
    prob = pdf(points, mean, cov, prior)
    return np.argmax(prob, axis=1)

In [11]:
testPoints = np.random.rand(150, 4)

In [5]:
Q = np.zeros(4)
Q[:] = np.INF
print(Q)
print(Q**2)
print(np.sqrt(Q**2))

AttributeError: module 'numpy' has no attribute 'INF'

In [13]:
def em(points, k, epsilon, mean=None):
    points = np.asarray(points)
    n, d = points.shape
    # Initialize and validate mean
    if mean is None:
        # Randomly pick k points
        mean = points[np.random.choice(range(n), size=k, replace=False)]
        #mean = points[list(range(k))] #For testing
    # Validate input
    mean = np.asarray(mean)
    k_, d_ = mean.shape
    assert k == k_
    assert d == d_
    # Initialize cov, prior
    cov = np.asarray([np.identity(d)]*k)
    prior = np.ones(shape=(k))/k
    tired = False
    old_mean = np.zeros_like(mean)
    while not tired:
        old_mean[:] = mean

        # Expectation step
        exp = pdf(points, mean, cov, prior)
        # Maximization step
        prior = exp.sum(axis=0)/n
        expSum = exp.sum(axis=0)
        mean =  np.nan_to_num((np.dot(points.T, exp)/expSum).T)
        num = 0
        for i in range(k):
            for j in range(n):
                sub = (points[j, :]-mean[i, :])[:, np.newaxis]
                num += exp[j, i] * (sub * sub.T)
            cov[i] = np.nan_to_num(num/expSum[i])
        # Finish condition
        print("BOOOOM")
        print(mean)
        print(old_mean)
        dist = np.sqrt(((mean - old_mean) ** 2).sum(axis=1))
        
        print(dist)
        tired = np.all(dist < epsilon)
        print("Number of points below threshold: " + str((dist < epsilon).sum()))

    # Validate output
    assert mean.shape == (k, d)
    assert cov.shape == (k, d, d)
    assert prior.shape == (k,)
    return mean, cov, prior


em(testPoints, 3, 0.005)

BOOOOM
[[ 0.47501648  0.47230456  0.47142533  0.48654338]
 [ 0.50312347  0.45625076  0.47943027  0.49063578]
 [ 0.45321229  0.44108806  0.44290565  0.47586246]]
[[ 0.34347158  0.58796616  0.55614386  0.13607501]
 [ 0.73799226  0.41735911  0.64284056  0.14508976]
 [ 0.01707551  0.17112457  0.17112536  0.07968726]]
[ 0.4008578   0.45031198  0.70279079]
Number of points below threshold: 0
BOOOOM
[[ 0.47918842  0.46735608  0.45679605  0.47279508]
 [ 0.48620641  0.44914408  0.47716256  0.49558651]
 [ 0.45931731  0.44310853  0.47074886  0.50079822]]
[[ 0.47501648  0.47230456  0.47142533  0.48654338]
 [ 0.50312347  0.45625076  0.47943027  0.49063578]
 [ 0.45321229  0.44108806  0.44290565  0.47586246]]
[ 0.02109322  0.01914012  0.03792612]
Number of points below threshold: 0
BOOOOM
[[ 0.48235767  0.46361147  0.45189657  0.46803306]
 [ 0.47115182  0.44361664  0.49566691  0.52069052]
 [ 0.4568272   0.44505062  0.50302134  0.5320063 ]]
[[ 0.47918842  0.46735608  0.45679605  0.47279508]
 [ 0.48620



(array([[ 0.47833365,  0.45746353,  0.46575617,  0.48480109],
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[[  6.88525370e-002,  -4.38743458e-003,   2.63604111e-003,
            5.75296329e-003],
         [ -4.38743458e-003,   8.14750130e-002,  -2.82132931e-003,
            1.04886393e-004],
         [  2.63604111e-003,  -2.82132931e-003,   7.49488897e-002,
            1.16027102e-002],
         [  5.75296329e-003,   1.04886393e-004,   1.16027102e-002,
            8.25538787e-002]],
 
        [[  1.79769313e+308,  -1.79769313e+308,   1.79769313e+308,
            1.79769313e+308],
         [ -1.79769313e+308,   1.79769313e+308,  -1.79769313e+308,
            1.79769313e+308],
         [  1.79769313e+308,  -1.79769313e+308,   1.79769313e+308,
            1.79769313e+308],
         [  1.79769313e+308,   1.79769313e+308,   1.79769313e+308,
            1.79769313e+308]],
 
        [[  1.79769313e+308,  