Source: https://stats.stackexchange.com/questions/72774/numerical-example-to-understand-expectation-maximization Suppose we have two groups - red and blue. Specifically, each group contains a value drawn from a normal distribution with the following parameters:

In [None]:
import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible random results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

When we can see the colour of each point (i.e. which group it belongs to), it's very easy to estimate the mean and standard deviation for each each group. Just pass the red and blue values to the builtin functions in NumPy for example:

But what if we can't see the colours of the points? That is, instead of red or blue, every point just looks purple to us.

To try and recover the mean and standard deviation parameters for the red and blue groups, we can use Expectation Maximisation.

Our first step (step 1 above) is to guess at the parameter values for each group's mean and standard deviation. We don't have to guess intelligently, we can pick any numbers we like:

In [None]:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

These are bad estimates: both means look far off any kind of "middle" for the groups of points, for instance. We want to improve these estimates.

The next step (step 2) is to compute the likelihood of each data point appearing under the current parameter guesses:

In [None]:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Here, we have simply put each data point into the probability density function with our current guesses at the mean and standard deviation for red and blue. This tells us, for example, that with our current guesses the data point at 1.761 is much more likely to be red (0.189) than blue (0.00003).

For each data point, we can turn these two likelihood values into weights (step 3) so that they sum to 1 as follows:

In [None]:
likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

With our current estimates and our newly-computed weights, we can now compute new estimates for the mean and standard deviation of the red and blue groups (step 4).

We twice compute the mean and standard deviation using all data points, but with the different weightings: once for the red weights and once for the blue weights.

The key bit of intuition is that the greater the weight of a colour on a data point, the more the data point influences the next estimates for that colour's parameters. This has the effect of "pulling" the parameters in the right direction.

In [None]:
def estimate_mean(data, weight):
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

We have new estimates for the parameters. To improve them again, we can jump back to step 2 and repeat the process. We do this until the estimates converge, or after some number of iterations have been performed (step 5).

We see that the means are already converging on some values, and the shapes of the curves (governed by the standard deviation) are also becoming more stable.

The EM process has converged to the following values, which turn out to very close to the actual values (where we can see the colours - no hidden variables):