-
Notifications
You must be signed in to change notification settings - Fork 5
/
kmeans.py
87 lines (66 loc) · 2.32 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Mixture model based on kmeans"""
from typing import Tuple
import numpy as np
from common import GaussianMixture
def estep(X: np.ndarray, mixture: GaussianMixture) -> np.ndarray:
"""E-step: Assigns each datapoint to the gaussian component with the
closest mean
Args:
X: (n, d) array holding the data
mixture: the current gaussian mixture
Returns:
np.ndarray: (n, K) array holding the soft counts
for all components for all examples
"""
n, _ = X.shape
K, _ = mixture.mu.shape
post = np.zeros((n, K))
for i in range(n):
tiled_vector = np.tile(X[i, :], (K, 1))
sse = ((tiled_vector - mixture.mu)**2).sum(axis=1)
j = np.argmin(sse)
post[i, j] = 1
return post
def mstep(X: np.ndarray, post: np.ndarray) -> Tuple[GaussianMixture, float]:
"""M-step: Updates the gaussian mixture. Each cluster
yields a component mean and variance.
Args: X: (n, d) array holding the data
post: (n, K) array holding the soft counts
for all components for all examples
Returns:
GaussianMixture: the new gaussian mixture
float: the distortion cost for the current assignment
"""
n, d = X.shape
_, K = post.shape
n_hat = post.sum(axis=0)
p = n_hat / n
cost = 0
mu = np.zeros((K, d))
var = np.zeros(K)
for j in range(K):
mu[j, :] = post[:, j] @ X / n_hat[j]
sse = ((mu[j] - X)**2).sum(axis=1) @ post[:, j]
cost += sse
var[j] = sse / (d * n_hat[j])
return GaussianMixture(mu, var, p), cost
def run(X: np.ndarray, mixture: GaussianMixture,
post: np.ndarray) -> Tuple[GaussianMixture, np.ndarray, float]:
"""Runs the mixture model
Args:
X: (n, d) array holding the data
post: (n, K) array holding the soft counts
for all components for all examples
Returns:
GaussianMixture: the new gaussian mixture
np.ndarray: (n, K) array holding the soft counts
for all components for all examples
float: distortion cost of the current assignment
"""
prev_cost = None
cost = None
while (prev_cost is None or prev_cost - cost > 1e-4):
prev_cost = cost
post = estep(X, mixture)
mixture, cost = mstep(X, post)
return mixture, post, cost