In [1]:
import numpy as np

In [2]:
# hyper_parameters
max_iter = 1000
N = 100000

In [3]:
class Environment:
    def __init__( self, N, s12, pqr ):
        self.N = N
        self.s1, self.s2 = s12
        self.pqr = pqr

    def __start_experiment( self ):
        self.coins = np.random.rand( self.N )
        for i in range( self.N ):
            if( self.coins[ i ] < self.s1 ):
                self.coins[i] = 0
            elif( self.coins[i] < self.s1 + self.s2 ):
                self.coins[i] = 1
            else:
                
                self.coins[i] = 2
        self.coins = self.coins.astype(np.int64)
        self.seq = np.random.rand( self.N )
        self.seq = [ ( self.seq[i] < self.pqr[ self.coins[i] ] ).astype(int)  for i in range( self.N )]
    
    def observe( self ):
        self.__start_experiment()
        return np.array(self.seq)

In [4]:
np.random.seed(42)

In [5]:
s123 = np.random.rand(3)
s123 = s123 / s123.sum(0)
pqr = np.random.rand(3)
env = Environment( N, s123[:2], pqr )

In [6]:
# random initial parameters
s123_iter = np.random.rand(3)
s123_iter = s123_iter / s123_iter.sum(0)
pqr_iter = np.random.rand(3)

print("before EM")
print(s123_iter)
print(s123)
print(s123_iter - s123)
print(pqr_iter)
print(pqr)
print(pqr_iter - pqr)

# optimized solution
for j in range(max_iter):
    bits_seq = env.observe()
    N = len( bits_seq )
#     print(pqr_iter)
    sum_ones = bits_seq.sum(0)
    one_increment = pqr_iter * s123_iter 
#     print(one_increment)
    one_increment = one_increment / one_increment.sum(0)
    zero_increment = ( 1 - pqr_iter ) * s123_iter 
#     print(zero_increment)
    zero_increment = zero_increment / zero_increment.sum(0)
    coins_k = np.stack([ zero_increment * ( N - sum_ones ), one_increment * sum_ones ])
#     print(coins_k)
    pqr_iter = coins_k[1, :]/coins_k.sum(0)
    if(np.min(pqr_iter)<0):
        pqr_iter -= np.min(pqr_iter)
    if(np.max(pqr_iter)>1.0):
        pqr_iter = pqr_iter / ( np.max(pqr_iter) - np.min(pqr_iter) )
#     print(coins_k)
    s123_iter = coins_k.sum(0)
    if(np.min(s123_iter)<0):
        s123_iter -= np.min(s123_iter)
    s123_iter = s123_iter / s123_iter.sum(0)
#     print(s123_iter)
#     print(s123)
#     print(s123_iter - s123)
#     print(pqr_iter)
#     print(pqr)
#     print(pqr_iter - pqr)

print("after EM")
print(s123_iter)
print(s123)
print(s123_iter - s123)
print(pqr_iter)
print(pqr)
print(pqr_iter - pqr)

before EM
[0.03807826 0.56784481 0.39407693]
[0.18205878 0.46212909 0.35581214]
[-0.14398052  0.10571572  0.03826479]
[0.70807258 0.02058449 0.96990985]
[0.59865848 0.15601864 0.15599452]
[ 0.10941409 -0.13543415  0.81391533]
after EM
[0.02984444 0.73897847 0.23117709]
[0.18205878 0.46212909 0.35581214]
[-0.15221433  0.27684939 -0.12463505]
[0.50950784 0.00892067 0.9324533 ]
[0.59865848 0.15601864 0.15599452]
[-0.08915064 -0.14709797  0.77645878]


In [7]:
# initial parameters near groundtruth
s123_sigma = 1e-1
s123_iter = np.random.normal(s123, s123_sigma)
s123_iter = np.abs(s123_iter)
s123_iter = s123_iter / s123_iter.sum(0)
pqr_sigma = 1e-1
pqr_iter = np.random.normal(pqr, pqr_sigma)
print("before EM")
print(s123_iter)
print(s123)
print(s123_iter - s123)
print(pqr_iter)
print(pqr)
print(pqr_iter - pqr)

# optimized solution
for j in range(max_iter):
    bits_seq = env.observe()
    N = len( bits_seq )
#     print(pqr_iter)
    sum_ones = bits_seq.sum(0)
    one_increment = pqr_iter * s123_iter 
#     print(one_increment)
    one_increment = one_increment / one_increment.sum(0)
    zero_increment = ( 1 - pqr_iter ) * s123_iter 
#     print(zero_increment)
    zero_increment = zero_increment / zero_increment.sum(0)
    coins_k = np.stack([ zero_increment * ( N - sum_ones ), one_increment * sum_ones ])
#     print(coins_k)
    pqr_iter = coins_k[1, :]/coins_k.sum(0)
    if(np.min(pqr_iter)<0):
        pqr_iter -= np.min(pqr_iter)
    if(np.max(pqr_iter)>1):
        pqr_iter = pqr_iter / ( np.max(pqr_iter) - np.min(pqr_iter) )
#     print(coins_k)
    s123_iter = coins_k.sum(0)
    if(np.min(s123_iter)<0):
        s123_iter -= np.min(s123_iter)
    s123_iter = s123_iter / s123_iter.sum(0)
#     print(s123_iter)
#     print(s123)
#     print(s123_iter - s123)
#     print(pqr_iter)
#     print(pqr)
#     print(pqr_iter - pqr)

print("after EM")
print(s123_iter)
print(s123)
print(s123_iter - s123)
print(pqr_iter)
print(pqr)
print(pqr_iter - pqr)

before EM
[0.159291   0.41403962 0.42666938]
[0.18205878 0.46212909 0.35581214]
[-0.02276777 -0.04808947  0.07085724]
[0.39676982 0.08162467 0.12740918]
[0.59865848 0.15601864 0.15599452]
[-0.20188867 -0.07439397 -0.02858534]
after EM
[0.18516575 0.39492868 0.41990558]
[0.18205878 0.46212909 0.35581214]
[ 0.00310697 -0.06720041  0.06409344]
[0.53305352 0.13364302 0.20218184]
[0.59865848 0.15601864 0.15599452]
[-0.06560496 -0.02237562  0.04618732]


In [8]:
print(s123_iter)
print(s123)
print(s123_iter - s123)
print(pqr_iter)
print(pqr)
print(pqr_iter - pqr)

[0.18516575 0.39492868 0.41990558]
[0.18205878 0.46212909 0.35581214]
[ 0.00310697 -0.06720041  0.06409344]
[0.53305352 0.13364302 0.20218184]
[0.59865848 0.15601864 0.15599452]
[-0.06560496 -0.02237562  0.04618732]


In [9]:
# # original solution
# # too slow
# s123_iter = np.random.rand(3)
# s123_iter = s123_iter / s123_iter.sum(0)
# pqr_iter = np.random.rand(3)
# N = len( bits_seq )
# for j in range(max_iter):
#     coins_k = np.zeros( ( 3, 2 ) )
#     for i in range( N ):
#         if( bits_seq[i] ):
#             likelihood_k = pqr_iter
#             posteriori_k = likelihood_k * s123_iter
#             posteriori_k = posteriori_k / posteriori_k.sum(0)
#             coins_k[:, 1] = coins_k[:, 1] + posteriori_k
#         else:
#             likelihood_k = 1 - pqr_iter
#             posteriori_k = likelihood_k * s123_iter
#             posteriori_k = posteriori_k / posteriori_k.sum(0)
#             coins_k[:, 0] = coins_k[:, 0] + posteriori_k
#     pqr_iter = coins_k[:, 1]/coins_k.sum(1)
#     s123_iter = coins_k.sum(1)
#     s123_iter = s123_iter / s123_iter.sum(0)
#     print(pqr_iter)
#     print(s123_iter)