In [11]:
#Importing library
import numpy as np; #For mathematical operations
import matplotlib.pyplot as plt;
from tqdm import tqdm;

def generate_data(n, pi_1, mu_1, mu_2, v):
    data = np.zeros(n)
    for i in range(n):
        if np.random.rand() < pi_1:
            data[i] = np.random.normal(mu_1, np.sqrt(v))
        else:
            data[i] = np.random.normal(mu_2, np.sqrt(v))
    return data

pi_1 = 0.6
mu_1 = 1.0
mu_2 = 4.0
v = 0.5

data = generate_data(1000, pi_1, mu_1, mu_2, v)

n = len(data)
pi_1_hat = 0.5
mu_1_hat = np.mean(data)
mu_2_hat = np.mean(data) + 2
v_hat = np.var(data)

for i in range(100):
    # E step
    z = pi_1_hat * np.exp(-0.5 * (data - mu_1_hat)**2 / v_hat) / np.sqrt(2*np.pi*v_hat) \
    / ((1-pi_1_hat) * np.exp(-0.5 * (data - mu_2_hat)**2 / v_hat) / np.sqrt(2*np.pi*v_hat) \
        + pi_1_hat * np.exp(-0.5 * (data - mu_1_hat)**2 / v_hat) / np.sqrt(2*np.pi*v_hat))
    # M step
    pi_1_hat = np.sum(z) / n
    mu_1_hat = np.sum(z * data) / np.sum(z)
    mu_2_hat = np.sum((1 - z) * data) / np.sum(1 - z)
    v_hat = np.sum(z * (data - mu_1_hat)**2) / np.sum(z)

print("True Parameters: pi_1 = {}, mu_1 = {}, mu_2 = {}, v = {}".format(pi_1, mu_1, mu_2, v))
print("Estimated Parameters: pi_1 = {}, mu_1 = {}, mu_2 = {}, v = {}".format(pi_1_hat, mu_1_hat, mu_2_hat, v_hat))

True Parameters: pi_1 = 0.6, mu_1 = 1.0, mu_2 = 4.0, v = 0.5
Estimated Parameters: pi_1 = 0.6061072899699816, mu_1 = 1.022634004365792, mu_2 = 4.028892698078109, v = 0.49008621143917147


In [3]:
np.random.rand(100)

array([0.96302087, 0.04592029, 0.65748938, 0.97779399, 0.41489594,
       0.35116969, 0.23629073, 0.44888073, 0.26012716, 0.54543054,
       0.9827764 , 0.02945781, 0.06451938, 0.36528926, 0.97751362,
       0.80301335, 0.92131893, 0.20435805, 0.56322531, 0.20194406,
       0.90533459, 0.47244613, 0.27958007, 0.62811665, 0.151857  ,
       0.30483086, 0.89788558, 0.13625412, 0.8221293 , 0.55357815,
       0.55923884, 0.2131301 , 0.46129033, 0.95792499, 0.85500533,
       0.32203027, 0.18375572, 0.64742483, 0.05741616, 0.78958228,
       0.81916948, 0.02997211, 0.08599509, 0.57823371, 0.64189886,
       0.27687124, 0.7923104 , 0.80332621, 0.20871039, 0.2416921 ,
       0.34051123, 0.03522392, 0.5249886 , 0.46889612, 0.38459808,
       0.33943862, 0.72469184, 0.47104832, 0.87935132, 0.46531067,
       0.5967415 , 0.25805207, 0.29551403, 0.38676356, 0.08270498,
       0.31117048, 0.53093058, 0.59523236, 0.12325987, 0.1382707 ,
       0.2318588 , 0.26715672, 0.1325954 , 0.61606829, 0.48304