-
Notifications
You must be signed in to change notification settings - Fork 5
/
common.py
104 lines (87 loc) · 3.14 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Mixture model for collaborative filtering"""
from typing import NamedTuple, Tuple
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle, Arc
class GaussianMixture(NamedTuple):
"""Tuple holding a gaussian mixture"""
mu: np.ndarray # (K, d) array - each row corresponds to a gaussian component mean
var: np.ndarray # (K, ) array - each row corresponds to the variance of a component
p: np.ndarray # (K, ) array = each row corresponds to the weight of a component
def init(X: np.ndarray, K: int,
seed: int = 0) -> Tuple[GaussianMixture, np.ndarray]:
"""Initializes the mixture model with random points as initial
means and uniform assingments
Args:
X: (n, d) array holding the data
K: number of components
seed: random seed
Returns:
mixture: the initialized gaussian mixture
post: (n, K) array holding the soft counts
for all components for all examples
"""
np.random.seed(seed)
n, _ = X.shape
p = np.ones(K) / K
# select K random points as initial means
mu = X[np.random.choice(n, K, replace=False)]
var = np.zeros(K)
# Compute variance
for j in range(K):
var[j] = ((X - mu[j])**2).mean()
mixture = GaussianMixture(mu, var, p)
post = np.ones((n, K)) / K
return mixture, post
def plot(X: np.ndarray, mixture: GaussianMixture, post: np.ndarray,
title: str):
"""Plots the mixture model for 2D data"""
_, K = post.shape
percent = post / post.sum(axis=1).reshape(-1, 1)
_, ax = plt.subplots(figsize=[9,6])
ax.title.set_text(title)
ax.set_xlim((-20, 20))
ax.set_ylim((-20, 20))
r = 0.25
color = ["r", "b", "k", "y", "m", "c"]
for i, point in enumerate(X):
theta = 0
for j in range(K):
offset = percent[i, j] * 360
arc = Arc(point,
r,
r,
0,
theta,
theta + offset,
edgecolor=color[j])
ax.add_patch(arc)
theta += offset
for j in range(K):
mu = mixture.mu[j]
sigma = np.sqrt(mixture.var[j])
circle = Circle(mu, sigma, color=color[j], fill=False)
ax.add_patch(circle)
legend = "mu = ({:0.2f}, {:0.2f})\n stdv = {:0.2f}".format(
mu[0], mu[1], sigma)
ax.legend([circle], "legend")
plt.plot([mu[0]],[mu[1]], '+', label=legend, color=color[j])
plt.axis('equal')
plt.legend()
plt.savefig("./plots/" + title + ".png")
plt.close()
def rmse(X, Y):
return np.sqrt(np.mean((X - Y)**2))
def bic(X: np.ndarray, mixture: GaussianMixture,
log_likelihood: float) -> float:
"""Computes the Bayesian Information Criterion for a
mixture of gaussians
Args:
X: (n, d) array holding the data
mixture: a mixture of spherical gaussian
log_likelihood: the log-likelihood of the data
Returns:
float: the BIC for this mixture
"""
p = mixture.mu.size + mixture.var.size + mixture.p.size
return log_likelihood - 0.5*p*np.log(X.shape[0])