Skip to content

Commit

Permalink
parallel in all expectation steps.
Browse files Browse the repository at this point in the history
  • Loading branch information
Junchen Feng committed Aug 15, 2017
1 parent e1a9fa9 commit 4009f0e
Showing 1 changed file with 50 additions and 23 deletions.
73 changes: 50 additions & 23 deletions pyirt/solver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,9 @@ def update(d, start_idx, end_idx):
else:
procs = procs_operator(procs, 7200, 0.1)

for user_idx in tqdm(xrange(num_user)):
for user_idx in xrange(num_user):
self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx]

# When the loop finish, check if the theta_density adds up to unity for each user
check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1)
if any(abs(check_user_distr_marginal - 1.0) > 0.0001):
Expand All @@ -365,21 +366,52 @@ def __check_theta_density(self):
raise Exception('theta desnity has wrong shape (%s,%s)'%self.theta_density.shape)

def __get_expect_count(self):
def update(d, start_idx, end_idx):
try:
if self.dao_type == 'db':
client = self.dao.open_conn()
item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name]
for item_idx in tqdm(xrange(start_idx, end_idx)):
if self.dao_type == 'db':
map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn)
else:
map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'])
d[item_idx] = {
1: np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0),
0: np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0)
}
if self.dao_type == 'db':
client.close()
except:
raise

num_item = self.dao.get_num('item')
if num_item > self.num_cpu and self.is_parallel:
num_chunk = self.num_cpu
else:
num_chunk = 1

# [A] calculate p(data,param|theta)
chunk_list = tools.cut_list(num_item, num_chunk)
procs = []
manager = mp.Manager()
procs_repo = manager.dict()
for i in range(num_chunk):
p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],))
procs.append(p)

if num_chunk>1:
procs = procs_operator(procs, 5400, self.check_interval)
else:
procs = procs_operator(procs, 7200, 0.1)

self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item')))
self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item')))
num_item = self.dao.get_num('item')
if self.dao_type == 'db':
client = self.dao.open_conn()
item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name]
for item_idx in tqdm(xrange(num_item)):
if self.dao_type == 'db':
map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn)
else:
map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'])
self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0)
self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0)
if self.dao_type == 'db':
client.close()
for item_idx in xrange(num_item):
self.item_expected_right_by_theta[:, item_idx] = procs_repo[item_idx][1]
self.item_expected_wrong_by_theta[:, item_idx] = procs_repo[item_idx][0]



def __calc_data_likelihood(self):
# calculate the likelihood for the data set
Expand Down Expand Up @@ -412,34 +444,29 @@ def update(tot_llk, cnt, start_idx, end_idx):
tot_llk.value += ell/len(logs)
with cnt.get_lock():
cnt.value += 1

if self.dao_type == 'db':
client.close()
except:
raise

theta_vec = self.__calc_theta()
num_user = self.dao.get_num('user')


if num_user > self.num_cpu and self.is_parallel:
num_chunk = self.num_cpu
else:
num_chunk = 1

num_chunk = 1
user_ell = mp.Value('d', 0.0)
user_cnt = mp.Value('i', 0)
chunk_list = tools.cut_list(num_user, num_chunk)

chunk_list = tools.cut_list(num_user, num_chunk)
procs = []
for i in range(num_chunk):
p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],))
procs.append(p)

procs.append(p)
if num_chunk>1:
procs = procs_operator(procs, 1200, self.check_interval)
else:
procs = procs_operator(procs, 7200, 0.1)

avg_ell = user_ell.value/user_cnt.value

return avg_ell
Expand Down

0 comments on commit 4009f0e

Please sign in to comment.