-
Notifications
You must be signed in to change notification settings - Fork 38
/
collapsed_mixture.py
206 lines (171 loc) · 7.81 KB
/
collapsed_mixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Copyright (c) 2012, 2013, 2014 James Hensman
# Licensed under the GPL v3 (see LICENSE.txt)
import numpy as np
try:
from .utilities import ln_dirichlet_C, softmax_weave
except ImportError:
from .np_utilities import ln_dirichlet_C
from .np_utilities import softmax_numpy as softmax_weave
from scipy.special import gammaln, digamma
from .collapsed_vb import CollapsedVB
class CollapsedMixture(CollapsedVB):
"""
A base class for collapsed mixture models based on the CollapsedVB class
We inherrit from this to build mixtures of Gaussians, mixures of GPs etc.
This handles the mixing proportion part of the model,
as well as providing generic functions for a merge-split approach
"""
def __init__(self, N, K, prior_Z='symmetric', alpha=1.0, name='col_mix'):
"""
Arguments
=========
N: the number of data
K: the (initial) number of cluster (or truncation)
prior_Z - either 'symmetric' or 'DP', specifies whether to use a symmetric Dirichlet prior for the clusters, or a (truncated) Dirichlet Process.
alpha: parameter of the Dirichelt (process)
"""
CollapsedVB.__init__(self, name)
self.N, self.K = N, K
assert prior_Z in ['symmetric','DP']
self.prior_Z = prior_Z
self.alpha = alpha
#random initial conditions for the vb parameters
self.phi_ = np.random.randn(self.N, self.K)
self.phi, logphi, self.H = softmax_weave(self.phi_)
self.phi_hat = self.phi.sum(0)
self.Hgrad = -logphi
if self.prior_Z == 'DP':
self.phi_tilde_plus_hat = self.phi_hat[::-1].cumsum()[::-1]
self.phi_tilde = self.phi_tilde_plus_hat - self.phi_hat
def set_vb_param(self,phi_):
"""
Accept a vector representing the variatinoal parameters, and reshape it into self.phi
"""
self.phi_ = phi_.reshape(self.N, self.K)
self.phi, logphi, self.H = softmax_weave(self.phi_)
self.phi_hat = self.phi.sum(0)
self.Hgrad = -logphi
if self.prior_Z == 'DP':
self.phi_tilde_plus_hat = self.phi_hat[::-1].cumsum()[::-1]
self.phi_tilde = self.phi_tilde_plus_hat - self.phi_hat
self.do_computations()
def get_vb_param(self):
return self.phi_.flatten()
def mixing_prop_bound(self):
"""
The portion of the bound which is provided by the mixing proportions
"""
if self.prior_Z=='symmetric':
return ln_dirichlet_C(np.ones(self.K)*self.alpha) -ln_dirichlet_C(self.alpha + self.phi_hat)
elif self.prior_Z=='DP':
A = gammaln(1. + self.phi_hat)
B = gammaln(self.alpha + self.phi_tilde)
C = gammaln(self.alpha + 1. + self.phi_tilde_plus_hat)
D = self.K*(gammaln(1.+self.alpha) - gammaln(self.alpha))
return A.sum() + B.sum() - C.sum() + D
else:
raise NotImplementedError("invalid mixing proportion prior type: %s" % self.prior_Z)
def mixing_prop_bound_grad(self):
"""
The gradient of the portion of the bound which arises from the mixing
proportions, with respect to self.phi
"""
if self.prior_Z=='symmetric':
return digamma(self.alpha + self.phi_hat)
elif self.prior_Z=='DP':
A = digamma(self.phi_hat + 1.)
B = np.hstack((0, digamma(self.phi_tilde + self.alpha)[:-1].cumsum()))
C = digamma(self.phi_tilde_plus_hat + self.alpha + 1.).cumsum()
return A + B - C
else:
raise NotImplementedError("invalid mixing proportion prior type: %s"%self.prior_Z)
def reorder(self):
"""
Re-order the clusters so that the biggest one is first. This increases
the bound if the prior type is a DP.
"""
if self.prior_Z=='DP':
i = np.argsort(self.phi_hat)[::-1]
self.set_vb_param(self.phi_[:,i])
def remove_empty_clusters(self,threshold=1e-6):
"""Remove any cluster which has no data assigned to it"""
i = self.phi_hat>threshold
phi_ = self.phi_[:,i]
self.K = i.sum()
self.set_vb_param(phi_)
def try_split(self, indexK=None, threshold=0.9, verbose=True, maxiter=100, optimize_params=None):
"""
Re-initialize one of the clusters as two clusters, optimize, and keep
the solution if the bound is increased. Kernel parameters stay constant.
Arguments
---------
indexK: (int) the index of the cluster to split
threshold: float [0,1], to assign data to the splitting cluster
verbose: whether to print status
Returns
-------
Success: (bool)
"""
if indexK is None:
indexK = np.random.multinomial(1,self.phi_hat/self.N).argmax()
if indexK > (self.K-1):
return False #index exceed no. clusters
elif self.phi_hat[indexK]<1:
return False # no data to split
#ensure there's something to split
if np.sum(self.phi[:,indexK]>threshold) <2:
return False
if verbose:print("\nattempting to split cluster ", indexK)
bound_old = self.bound()
phi_old = self.get_vb_param().copy()
self._optimizer_copy_transformed = False # Redo transform in case parameters have been unlinked
param_old = self.optimizer_array.copy()
old_K = self.K
#re-initalize
self.K += 1
self.phi_ = np.hstack((self.phi_,self.phi_.min(1)[:,None]))
indexN = np.nonzero(self.phi[:,indexK] > threshold)[0]
#this procedure randomly assigns data to the new and old clusters
#rand_indexN = np.random.permutation(indexN)
#n = np.floor(indexN.size/2.)
#i1 = rand_indexN[:n]
#self.phi_[i1,indexK], self.phi_[i1,-1] = self.phi_[i1,-1], self.phi_[i1,indexK]
#self.set_vb_param(self.get_vb_param())
#this procedure equally assigns data to the new and old clusters, aside from one random point, which is in the new cluster
special = np.random.permutation(indexN)[0]
self.phi_[indexN,-1] = self.phi_[indexN,indexK].copy()
self.phi_[special,-1] = np.max(self.phi_[special])+10
self.set_vb_param(self.get_vb_param())
self.optimize(maxiter=maxiter, verbose=verbose)
self.remove_empty_clusters()
bound_new = self.bound()
bound_increase = bound_new-bound_old
if (bound_increase < 1e-3):
self.K = old_K
self.set_vb_param(phi_old)
self.optimizer_array = param_old
if verbose:print("split failed, bound changed by: ",bound_increase, '(K=%s)' % self.K)
return False
else:
if verbose:print("split suceeded, bound changed by: ", bound_increase, ',', self.K-old_K,' new clusters', '(K=%s)' % self.K)
if verbose:print("optimizing new split to convergence:")
if optimize_params:
self.optimize(**optimize_params)
else:
self.optimize(maxiter=5000, verbose=verbose)
return True
def systematic_splits(self, verbose=True):
"""
perform recursive splits on each of the existing clusters
"""
for kk in range(self.K):
self.recursive_splits(kk, verbose=verbose)
def recursive_splits(self,k=0, verbose=True, optimize_params=None):
"""
A recursive function which attempts to split a cluster (indexed by k), and if sucessful attempts to split the resulting clusters
"""
success = self.try_split(k, verbose=verbose, optimize_params=optimize_params)
if success:
if not k==(self.K-1):
self.recursive_splits(self.K-1, verbose=verbose, optimize_params=optimize_params)
self.recursive_splits(k, verbose=verbose, optimize_params=optimize_params)