In [1]:
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt

In [158]:
#抽取数据
Z_i=[]
pi_i=[]
Y1_i=[]
Y2_i=[]
np.random.seed(1)
for i in range(200):
    Z=np.random.RandomState(i).binomial(1,0.2,size=200)
    pi=Z.sum()/len(Z)
    Y1=np.random.RandomState(i+1000).normal(0,1,int(pi*200))
    Y2=np.random.RandomState(i+6000).normal(2,2,200-(int(pi*200)))
    Z_i.append(Z)
    pi_i.append(pi)
    Y1_i.append(Y1)
    Y2_i.append(Y2)

In [159]:
#参数平均值
print('pi:%1.3f'%np.array(pi_i).mean())
print('\nmu1:%1.3f'%np.array(mu1_i).mean())
print('\nmu2:%1.3f'%np.array(mu2_i).mean())
print('\nsigma1:%1.3f'%np.array(sigma1_i).mean())
print('\nsigma2:%1.3f'%np.array(sigma2_i).mean())

pi:0.199

mu1:0.195

mu2:2.601

sigma1:1.055

sigma2:1.812


In [160]:
#创建模型
class EM:
    def __init__(self,pi=0.3,mu1=0.2,sigma1=1.1,mu2=0.8,sigma2=2.1,wi=np.ones(np.hstack((Y1,Y2)).shape[0])-0.5):
        self.pi=pi
        self.mu1=mu1
        self.sigma1=sigma1
        self.mu2=mu2
        self.sigma2=sigma2
        self.wi=wi
        
    def Normal_pdf(self,Y,mu,sigma):
        pdf=np.exp(-(Y-mu)**2/(2*sigma**2))/(math.sqrt(2*math.pi)*sigma)
        return pdf
    
    def E_step(self,Y1,Y2):
        Y=np.hstack((Y1,Y2))
        for i in range(len(Y)):
            self.wi[i]=(self.pi*self.Normal_pdf(Y[i],self.mu1,self.sigma1))/(self.pi*self.Normal_pdf(Y[i],self.mu1,self.sigma1)+(1-self.pi)*self.Normal_pdf(Y[i],self.mu2,self.sigma2))
        return self.wi

    def M_step(self,Y1,Y2):
        Y=np.hstack((Y1,Y2))
        self.mu1=sum(self.wi*Y)/sum(self.wi)
        self.mu2=sum((1-self.wi)*Y)/sum(1-self.wi)
        self.sigma1=math.sqrt((sum(self.wi*(Y-self.mu1)**2)/sum(self.wi)))
        self.sigma2=math.sqrt((sum((1-self.wi)*(Y-self.mu2)**2))/sum(1-self.wi))
        self.pi=self.wi.mean()
        return self.mu1,self.mu2,self.sigma1,self.sigma2,self.pi
        
    def fit(self,Y1,Y2):
        while True:
            self.wi=self.E_step(Y1,Y2)
            mu1_old=self.mu1
            mu2_old=self.mu2
            sigma1_old=self.sigma1
            sigma2_old=self.sigma2
            pi_old=self.pi
            M=self.M_step(Y1,Y2)
            self.mu1=M[0]
            self.mu2=M[1]
            self.sigma1=M[2]
            self.sigma2=M[3]
            self.pi=M[4]
            if abs((self.pi-pi_old)+(self.mu1-mu1_old)+(self.mu2-mu2_old)+(self.sigma1-sigma1_old)+(self.sigma2-sigma2_old)) <=0.0001:
                break             
        return self.mu1,self.mu2,self.sigma1,self.sigma2,self.pi

In [161]:
#循环200次计算结果
pi_e=[]
mu1_e=[]
mu2_e=[]
sigma1_e=[]
sigma2_e=[]
for i in range(200):
    em=EM()
    result=em.fit(Y1_i[i],Y2_i[i])
    mu1=result[0]
    mu2=result[1]
    sigma1=result[2]
    sigma2=result[3]
    pi=result[4]
    mu1_e.append(mu1)
    mu2_e.append(mu2)
    sigma1_e.append(sigma1)
    sigma2_e.append(sigma2)
    pi_e.append(pi)

In [167]:
#输出结果
print('estimated_mu1:%1.3f\treal_mu1:%1.3f'%(np.array(mu1_e).mean(),np.array(mu1_i).mean()))
print('\nestimated_mu2:%1.3f\treal_mu2:%1.3f'%(np.array(mu2_e).mean(),np.array(mu1_i).mean()))
print('\nestimated_sigma1:%1.4f\treal_sigma1:%1.3f'%(np.array(sigma1_e).mean(),np.array(sigma1_i).mean()))
print('\nestimated_sigma2:%1.4f\treal_sigma2:%1.3f'%(np.array(sigma2_e).mean(),np.array(sigma2_i).mean()))
print('\nestimated_pi:%1.3f\treal_pi:%1.3f'%(np.array(pi_e).mean(),np.array(pi_i).mean()))

estimated_mu1:0.195	real_mu1:0.195

estimated_mu2:2.601	real_mu2:0.195

estimated_sigma1:1.0553	real_sigma1:1.055

estimated_sigma2:1.8116	real_sigma2:1.812

estimated_pi:0.359	real_pi:0.199


In [169]:
#各个参数的估计值的情况
df=pd.DataFrame((mu1_e,mu2_e,sigma1_e,sigma2_e,pi_e)).T
df.columns=(['mu1','mu2','sigma1','sigma2','pi'])
df.describe()

Unnamed: 0,mu1,mu2,sigma1,sigma2,pi
count,200.0,200.0,200.0,200.0,200.0
mean,0.195285,2.600775,1.0553,1.811582,0.358892
std,0.583863,1.087568,0.414965,0.437609,0.24871
min,-1.242862,1.287023,0.049485,0.332304,0.047621
25%,-0.178072,1.932636,0.730258,1.69739,0.1803
50%,0.126365,2.17516,1.07368,1.928249,0.265072
75%,0.564289,2.86185,1.340437,2.043756,0.507621
max,1.805352,7.186867,1.885471,3.928376,0.981863
