In [1]:
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

#http://nipunbatra.github.io/2014/04/em/

observations = np.array([[1,0,0,0,1,1,0,1,0,1],
                         [1,1,1,1,0,1,1,1,1,1],
                         [1,0,1,1,1,1,1,0,1,1],
                         [1,0,1,0,0,0,1,1,0,0],
                         [0,1,1,1,0,1,1,1,0,1]])

coins_id = np.array([False,True,True,False,True])

def em_single(priors, observations):
    """
    Performs a single EM step
    Arguments
    ---------
    priors : [theta_A, theta_B]
    observations : [m X n matrix]
    
    Returns
    --------
    new_priors: [new_theta_A, new_theta_B]
    """
    counts = {'A':{'H':0,'T':0}, 'B':{'H':0,'T':0}}
    theta_A = priors[0]
    theta_B = priors[1]
    # E step
    for observation in observations: 
        len_observation = len(observation)
        num_heads = observation.sum()
        num_tails = len_observation - num_heads
        contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
        contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)
        weight_A = contribution_A/(contribution_A+contribution_B)
        weight_B = contribution_B/(contribution_A+contribution_B)
        # Incrementing counts
        counts['A']['H']+= weight_A*num_heads
        counts['A']['T']+= weight_A*num_tails
        counts['B']['H']+= weight_B*num_heads
        counts['B']['T']+= weight_B*num_tails
    # M step
    new_theta_A = counts['A']['H']/(counts['A']['H']+counts['A']['T'])
    new_theta_B = counts['B']['H']/(counts['B']['H']+counts['B']['T'])
    return [new_theta_A, new_theta_B]

def em(observations, prior, tol=1e-6, iterations=10000):
    import math
    iteration = 0
    while iteration<iterations:
        new_prior = em_single(prior, observations)
        delta_change = np.abs(prior[0]-new_prior[0])
        if delta_change<tol:
            break
        else:
            prior = new_prior
            iteration+=1
    return [new_prior, iteration]

print ( em(observations, [0.6,0.5]) )


[[0.79678875938310978, 0.51958393567528027], 14]
