diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index a7d4ccd..5192c0f 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -350,8 +350,9 @@ def update(d, start_idx, end_idx): else: procs = procs_operator(procs, 7200, 0.1) - for user_idx in tqdm(xrange(num_user)): + for user_idx in xrange(num_user): self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx] + # When the loop finish, check if the theta_density adds up to unity for each user check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1) if any(abs(check_user_distr_marginal - 1.0) > 0.0001): @@ -365,21 +366,52 @@ def __check_theta_density(self): raise Exception('theta desnity has wrong shape (%s,%s)'%self.theta_density.shape) def __get_expect_count(self): + def update(d, start_idx, end_idx): + try: + if self.dao_type == 'db': + client = self.dao.open_conn() + item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name] + for item_idx in tqdm(xrange(start_idx, end_idx)): + if self.dao_type == 'db': + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn) + else: + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) + d[item_idx] = { + 1: np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0), + 0: np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) + } + if self.dao_type == 'db': + client.close() + except: + raise + + num_item = self.dao.get_num('item') + if num_item > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_item, num_chunk) + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + if num_chunk>1: + procs = procs_operator(procs, 5400, self.check_interval) + else: + procs = procs_operator(procs, 7200, 0.1) + self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) - num_item = self.dao.get_num('item') - if self.dao_type == 'db': - client = self.dao.open_conn() - item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name] - for item_idx in tqdm(xrange(num_item)): - if self.dao_type == 'db': - map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn) - else: - map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) - self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0) - self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) - if self.dao_type == 'db': - client.close() + for item_idx in xrange(num_item): + self.item_expected_right_by_theta[:, item_idx] = procs_repo[item_idx][1] + self.item_expected_wrong_by_theta[:, item_idx] = procs_repo[item_idx][0] + + def __calc_data_likelihood(self): # calculate the likelihood for the data set @@ -412,6 +444,7 @@ def update(tot_llk, cnt, start_idx, end_idx): tot_llk.value += ell/len(logs) with cnt.get_lock(): cnt.value += 1 + if self.dao_type == 'db': client.close() except: @@ -419,27 +452,21 @@ def update(tot_llk, cnt, start_idx, end_idx): theta_vec = self.__calc_theta() num_user = self.dao.get_num('user') - - if num_user > self.num_cpu and self.is_parallel: num_chunk = self.num_cpu else: - num_chunk = 1 - + num_chunk = 1 user_ell = mp.Value('d', 0.0) user_cnt = mp.Value('i', 0) - chunk_list = tools.cut_list(num_user, num_chunk) - + chunk_list = tools.cut_list(num_user, num_chunk) procs = [] for i in range(num_chunk): p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) - procs.append(p) - + procs.append(p) if num_chunk>1: procs = procs_operator(procs, 1200, self.check_interval) else: procs = procs_operator(procs, 7200, 0.1) - avg_ell = user_ell.value/user_cnt.value return avg_ell