Skip to content

Commit

Permalink
exhaustively delete all components
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Mar 31, 2017
1 parent 486cb43 commit 40f487f
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion snob/mixture_ka.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,74 @@ def _split_component(y, mu, cov, weight, responsibility, index,
return (mu, cov, weight, responsibility, meta, dl)


def _delete_component(y, mu, cov, weight, responsibility, index,
covariance_type="free", regularization=0, N_covpars=None, threshold=1e-5,
**kwargs):
"""
Delete a component from the mixture.
"""

# Initialize.
M = weight.size
N, D = y.shape
N_covpars = N_covpars or _component_covariance_parameters(D, covariance_type)
max_iterations = kwargs.get("max_sub_iterations", 10000)

# Create new component weights.
parent_weight = weight[index]
parent_responsibility = responsibility[index]

# Eq. 54-55
perturbed_weight = np.delete(weight, index) / (1 - parent_weight)
perturbed_responsibility = np.delete(responsibility, index, axis=0) \
/ (1 - parent_responsibility)

perturbed_mu = np.delete(mu, index, axis=0)
perturbed_cov = np.delete(cov, index, axis=0)

# Calculate the current log-likelihood.
_, ll, dl = _expectation(
y, perturbed_mu, perturbed_cov, perturbed_weight, N_covpars)

iterations = 1
ll_dl = [(ll, dl)]

while True:

# Perform the maximization step.
perturbed_mu, perturbed_cov, perturbed_weight = _maximization(
y, perturbed_mu, perturbed_cov, perturbed_weight,
perturbed_responsibility, covariance_type, regularization,
N_covpars)

# Run the expectation step.
perturbed_responsibility, ll, dl = _expectation(
y, perturbed_mu, perturbed_cov, perturbed_weight, N_covpars)

# Check for convergence.
prev_ll, prev_dl = ll_dl[-1]
relative_delta_ll = np.abs((ll - prev_ll)/prev_ll)

ll_dl.append([ll, dl])
iterations += 1

if relative_delta_ll <= threshold \
or iterations >= max_iterations:
break

meta = dict(warnflag=iterations >=max_iterations)
if meta["warnflag"]:
logger.warn("Maximum number of E-M iterations reached ({}) "\
"when deleting component index {}".format(
max_iterations, index))


meta["log-likelihood"] = ll

return (perturbed_mu, perturbed_cov, perturbed_weight,
perturbed_responsibility, meta, dl)


class GaussianMixtureEstimator(estimator.Estimator):

r"""
Expand Down Expand Up @@ -826,7 +894,7 @@ def optimize(self):
# Keep best deleted component.
if p_dl < best_perturbations["delete"][-1]:
best_perturbations["delete"] = [m] + list(r)

# Exhaustively merge all components.
for m in range(M):
r = (p_mu, p_cov, p_weight, p_responsibility, p_meta, p_dl) \
Expand Down

0 comments on commit 40f487f

Please sign in to comment.