From 1bdeb03223369adccfc4ab8f843a353dc94873ca Mon Sep 17 00:00:00 2001 From: junchenfeng Date: Tue, 8 Aug 2017 17:21:10 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86mongoDAO?= =?UTF-8?q?=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 58 +++++++++++++++-------------------- pyirt/_pyirt.py | 12 ++++++-- pyirt/dao.py | 63 +++++++++++++++++++++++++++++++++++++- setup.py | 3 +- test/test_model_wrapper.py | 17 +++++----- 5 files changed, 106 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 9af865b..816e5f5 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ src_fp = open(file_path,'r') item_param, user_param = irt(src_fp) # (2)Supply bounds -item_param, user-param = irt(src_fp, theta_bnds = [-5,5], alpha_bnds=[0.1,3], beta_bnds = [-3,3]) +item_param, user-param = irt(src_fp, theta_bnds = [-4,4], alpha_bnds=[0.1,3], beta_bnds = [-3,3]) # (3)Supply guess parameter guessParamDict = {1:{'c':0.0}, 2:{'c':0.25}} @@ -44,12 +44,12 @@ The current version supports MMLE algorithm and unidimension two parameter IRT model. There is a backdoor method to specify the guess parameter but there is not active estimation. -The prior distribution of theta is **uniform**. - There is no regularization in alpha and beta estimation. Therefore, the default algorithm uses boundary on the parameter to prevent over-fitting and deal with extreme cases where almost all responses to the item is right. +The unit test code shows that when there are discriminative items to anchor theta, then the model converges reasonably well to the true value. + ## Theta estimation The package offers two methods to estimate theta, given item parameters: Bayesian and MLE.
The estimation procedure is quite primitive. For examples, see the test case. @@ -57,28 +57,28 @@ The estimation procedure is quite primitive. For examples, see the test case. II.Sparse Data Structure ========== -In non-test learning dataset, missing data is the common. Not all students -finish all the items. When the number of students and items are large, the data -can be extremely sparse. +In non-test learning dataset, missing data is the common. Not all students finish all the items. When the number of students and items are large, the data can be extremely sparse. The package deals with the sparse structure in two ways: -- Efficient memory storage. Use collapsed list to index data. The memory usage - is about 3 times of the text data file. If the workstation has 6G free -memory, it can handle 2G data file. Most other IRT package will definitely -break. -- No joint estimation. Under IRT's conditional independence assumption, - estimate each item's parameter is consistent but inefficient. To avoid -reverting a giant jacobian matrix, the item parameters are estimated seprately. +- Efficient memory storage. + +Use collapsed list to index data. The memory usage is about 3 times of the text data file. If the workstation has 6G free memory, it can handle 2G data file. + +If the data are truely big, say billions, a mongo DAO is implemented to stack overflow. + +- No joint estimation. +Under conditional independence assumption, estimate each item's parameter is consistent but inefficient. -III.Default Config + +II.Default Config =========== ## Exposed The theta paramter range from [-4,4] and a step size of 0.8 by default. -Alpha is bounded by [0.25,2] and beta is bounded by [-2,2], which goes with the user ability -specifiation. +Alpha is bounded by [0.25,2] and beta is bounded by [-2,2] + Guess parameter is set to 0 by default, but one could supply a dictionary with eid as key and c as value. @@ -88,37 +88,29 @@ The default solver is L-BFGS-G. The default max iteration is 10. -The stop condition threshold is 1e-3 by default. The algorithm computes the -average likelihood per log at the end of the iteration. If the likelihood -increament is less than the threshold, stops. +The stop condition threshold is 1e-3 by default. The stop condition looks at the geometric mean of the likelihood, which is data size independent. -IV.Data Format + +III.Data Format ========= -The file is expected to be comma delimited. +If the data are digested through file system. The file is expected to be **comma delimited**. The three columns are learner id, item id and result flag. Currently the model only works well with 0/1 flag but will **NOT** raise error for other types. -The three columns are uid, eid, result-flag. +If using the database DAO, see unittest for example -Currently the model only works well with 0/1 flag but will **NOT** raise error for -other types. -V.Note +IV.Note ======= - ## Minimization solver -The scipy minimize is as good as cvxopt.cp and matlab fmincon on item parameter -estimation to the 6th decimal point, which can be viewed as identical for all -practical purposes. +The scipy minimize is as good as cvxopt.cp and matlab fmincon on item parameter estimation to the 6th decimal point, which can be viewed as identical for all practical purposes. -However, the convergence is pretty slow. It requires about 10k obeverations per -item to recover the parameter to the 0.01 precision. +However, the convergence is pretty slow. It requires about 10k obeverations per item to recover the parameter to the 0.01 precision. VII.Acknowledgement ============== -The algorithm is described in details by Bradey Hanson(2000), see in the -literature section. I am grateful to Mr.Hanson's work. +The algorithm is described in details by Bradey Hanson(2000), see in the literature section. I am grateful to Mr.Hanson's work. [Chaoqun Fu](https://github.com/fuchaoqun)'s comment leads to the (much better) API design. diff --git a/pyirt/_pyirt.py b/pyirt/_pyirt.py index 8708643..78b31f9 100644 --- a/pyirt/_pyirt.py +++ b/pyirt/_pyirt.py @@ -3,17 +3,23 @@ from .dao import localDAO def irt(data_src, + dao = 'memory', theta_bnds=[-4, 4], num_theta=11, alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param='default', model_spec='2PL', - max_iter=10, tol=1e-3, nargout=2): + max_iter=10, tol=1e-3, nargout=2, + is_msg=False): # load data - dao_instance = localDAO(data_src) + if dao=='memory': + dao_instance = localDAO(data_src) + else: + dao_instance = data_src + # setup the model if model_spec == '2PL': - mod = model.IRT_MMLE_2PL(dao_instance) + mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg) else: raise Exception('Unknown model specification.') diff --git a/pyirt/dao.py b/pyirt/dao.py index b1ead4b..2f927db 100644 --- a/pyirt/dao.py +++ b/pyirt/dao.py @@ -6,7 +6,68 @@ from .util.dao import loadFromHandle, loadFromTuples, construct_ref_dict -#TODO: bitmap is a function of DAO. Seperate that with database +import pymongo + +class mongoDAO(object): + def __init__(self, connect_config, num_log, group_id=1): + user_name = connect_config['user'] + password = connect_config['password'] + address = connect_config['address'] # IP:PORT + db_name = connect_config['db'] + if 'authsource' not in connect_config: + mongouri = 'mongodb://{un}:{pw}@{addr}'.format(un=user_name, pw=password, addr=address) + else: + authsource = connect_config['authsource'] + mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource) + try: + self.client = pymongo.MongoClient(mongouri, serverSelectionTimeoutMS=10) + except: + raise + + user2item_collection_name = 'irt_user2item' + item2user_collection_name = 'irt_item2user' + + self.user2item = self.client[db_name][user2item_collection_name] + self.item2user = self.client[db_name][item2user_collection_name] + + # TODO:不能做全量扫描 + user_ids = self.user2item.find().distinct('id') + item_ids = self.item2user.find().distinct('id') + + _, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(user_ids) + _, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(item_ids) + + self.stat = {'log':num_log, 'user':len(self.user_idx_ref.keys()), 'item':len(self.item_idx_ref.keys())} + + def get_num(self, name): + if name not in ['user','item','log']: + raise Exception('Unknown stat source %s'%name) + return self.stat[name] + + def get_log(self, user_idx): + user_id = self.translate('user', user_idx) + res = self.user2item.find({'id':user_id}) + log_list = res[0]['data'] + return [(self.item_idx_ref[x[0]], x[1]) for x in log_list] + + def get_right_map(self, item_idx): + item_id = self.translate('item', item_idx) + res = self.item2user.find({'id':item_id}) + return [self.user_idx_ref[x] for x in res[0]['data']['1']] + + def get_wrong_map(self, item_idx): + item_id = self.translate('item', item_idx) + res = self.item2user.find({'id':item_id}) + return [self.user_idx_ref[x] for x in res[0]['data']['0']] + + def translate(self, data_type, idx): + if data_type == 'item': + return self.item_reverse_idx_ref[idx] + elif data_type == 'user': + return self.user_reverse_idx_ref[idx] + + def __del__(self): + self.client.close() class localDAO(object): diff --git a/setup.py b/setup.py index 0db1931..9ded2e6 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,8 @@ install_requires=['numpy', 'scipy', 'cython', - 'six'], + 'six', + 'pymongo'], package_data={'pyirt': ["*.pyx"]}, diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index 746e36a..14bbebe 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -13,7 +13,6 @@ alpha = [0.5, 0.5, 0.5, 1, 1, 1, 2, 2, 2] beta = [0, 1, -1, 0, 1, -1, 0, 1, -1] c = [0.5, 0, 0, 0, 0.5,0, 0, 0, 0.5] -item_ids = ['a', 'b', 'c', 'd', 'e','g','h','i','j'] N = 1000 T = len(alpha) @@ -21,7 +20,7 @@ guess_param = {} for t in range(T): - guess_param[item_ids[t]]={'c':c[t]} + guess_param['q%d'%t]={'c':c[t]} class Test2PLSolver(unittest.TestCase): @classmethod def setUpClass(cls): @@ -32,17 +31,17 @@ def setUpClass(cls): for i in range(N): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t]) - cls.data.append((i, item_ids[t], np.random.binomial(1, prob))) + cls.data.append(('u%d'%i, 'q%d'%t, np.random.binomial(1, prob))) def test_2pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) for t in range(T): - item_id = item_ids[t] + item_id = 'q%d'%t print(item_id, item_param[item_id]) mdl_alpha = item_param[item_id]['alpha'] mdl_beta = item_param[item_id]['beta'] - if item_id != 'h': + if item_id != 'q6': self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) self.assertTrue(abs(mdl_beta - beta[t])<0.16) @@ -56,20 +55,20 @@ def setUpClass(cls): for i in range(N): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t], c[t]) - cls.data.append((i, item_ids[t] ,np.random.binomial(1,prob))) + cls.data.append(('u%d'%i, 'q%d'%t ,np.random.binomial(1,prob))) def test_3pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], in_guess_param=guess_param, tol=1e-5, max_iter=30) for t in range(T): - item_id = item_ids[t] + item_id = 'q%d'%t print(item_id, item_param[item_id]) mdl_alpha = item_param[item_id]['alpha'] mdl_beta = item_param[item_id]['beta'] - if item_id not in ['h','i']: + if item_id not in ['q6','q7']: self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) - if item_id != 'j': + if item_id != 'q8': self.assertTrue(abs(mdl_beta - beta[t])<0.15) if __name__ == '__main__': From cf7988665b771d698af04e7c4cfebde4ff39dcf3 Mon Sep 17 00:00:00 2001 From: Junchen Feng Date: Fri, 11 Aug 2017 17:44:28 +0800 Subject: [PATCH 2/7] add parallel processing for theta update and parameter esitmation --- pyirt/_pyirt.py | 4 +- pyirt/dao.py | 110 +++++++++++++++++-------- pyirt/solver/model.py | 164 ++++++++++++++++++++++++++----------- pyirt/util/tools.py | 10 +++ setup.py | 3 +- test/test_model_wrapper.py | 13 ++- test/test_tools.py | 15 ++++ 7 files changed, 229 insertions(+), 90 deletions(-) diff --git a/pyirt/_pyirt.py b/pyirt/_pyirt.py index 78b31f9..9e3e206 100644 --- a/pyirt/_pyirt.py +++ b/pyirt/_pyirt.py @@ -8,7 +8,7 @@ def irt(data_src, alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param='default', model_spec='2PL', max_iter=10, tol=1e-3, nargout=2, - is_msg=False): + is_msg=False, is_parallel=False): # load data @@ -19,7 +19,7 @@ def irt(data_src, # setup the model if model_spec == '2PL': - mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg) + mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg, is_parallel=is_parallel) else: raise Exception('Unknown model specification.') diff --git a/pyirt/dao.py b/pyirt/dao.py index 2f927db..cf27a92 100644 --- a/pyirt/dao.py +++ b/pyirt/dao.py @@ -8,8 +8,10 @@ import pymongo +from datetime import datetime + class mongoDAO(object): - def __init__(self, connect_config, num_log, group_id=1): + def __init__(self, connect_config, group_id=1, is_msg=False): user_name = connect_config['user'] password = connect_config['password'] address = connect_config['address'] # IP:PORT @@ -20,46 +22,84 @@ def __init__(self, connect_config, num_log, group_id=1): authsource = connect_config['authsource'] mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource) try: - self.client = pymongo.MongoClient(mongouri, serverSelectionTimeoutMS=10) + self.client = pymongo.MongoClient(mongouri, serverSelectionTimeoutMS=10, readPreference='secondaryPreferred') except: raise + user2item_collection_name = 'irt_user2item' item2user_collection_name = 'irt_item2user' self.user2item = self.client[db_name][user2item_collection_name] self.item2user = self.client[db_name][item2user_collection_name] + - # TODO:不能做全量扫描 - user_ids = self.user2item.find().distinct('id') - item_ids = self.item2user.find().distinct('id') + user_ids = list(set([x['id'] for x in self.user2item.find({'gid':group_id},{'id':1})])) + item_ids = list(set([x['id'] for x in self.item2user.find({'gid':group_id},{'id':1})])) _, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(user_ids) _, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(item_ids) - self.stat = {'log':num_log, 'user':len(self.user_idx_ref.keys()), 'item':len(self.item_idx_ref.keys())} + self.stat = {'user':len(self.user_idx_ref.keys()), 'item':len(self.item_idx_ref.keys())} + + print('search idx created.') + self.gid = group_id + self.is_msg = is_msg def get_num(self, name): - if name not in ['user','item','log']: + if name not in ['user','item']: raise Exception('Unknown stat source %s'%name) return self.stat[name] def get_log(self, user_idx): user_id = self.translate('user', user_idx) - res = self.user2item.find({'id':user_id}) - log_list = res[0]['data'] - return [(self.item_idx_ref[x[0]], x[1]) for x in log_list] - - def get_right_map(self, item_idx): - item_id = self.translate('item', item_idx) - res = self.item2user.find({'id':item_id}) - return [self.user_idx_ref[x] for x in res[0]['data']['1']] - - def get_wrong_map(self, item_idx): - item_id = self.translate('item', item_idx) - res = self.item2user.find({'id':item_id}) - return [self.user_idx_ref[x] for x in res[0]['data']['0']] - + # query + if self.is_msg: + stime = datetime.now() + res = self.user2item.find({'id':user_id, 'gid':self.gid}) + etime = datetime.now() + search_time = int((etime-stime).microseconds/1000) + if search_time > 100: + print('s:%d' % search_time) + else: + res = self.user2item.find({'id':user_id, 'gid':self.gid}) + # parse + if res.count() == 0: + return_list = [] + elif res.count() > 1: + raise Exception('duplicate doc for (%s, %d) in user2item' % (user_id, self.gid)) + else: + log_list = res[0]['data'] + return_list = [(self.item_idx_ref[x[0]], x[1]) for x in log_list] + return return_list + + def get_map(self, item_idx, ans_key_list): + item_id = self.translate('item', item_idx) + # query + if self.is_msg: + stime = datetime.now() + res = self.item2user.find({'id':item_id, 'gid':self.gid}) + etime = datetime.now() + search_time = int((etime-stime).microseconds/1000) + if search_time > 100: + print('s:%d' % search_time) + else: + res = self.item2user.find({'id':item_id, 'gid':self.gid}) + # parse + if res.count() == 0: + return_list = [[] for ans_key in ans_key_list] + elif res.count() > 1: + raise Exception('duplicate doc for (%s, %d) in item2user' % (item_id, self.gid)) + else: + doc = res[0]['data'] + return_list = [] + for ans_key in ans_key_list: + if str(ans_key) in doc: + return_list.append([self.user_idx_ref[x] for x in doc[str(ans_key)]] ) + else: + return_list.append([]) + return return_list + def translate(self, data_type, idx): if data_type == 'item': return self.item_reverse_idx_ref[idx] @@ -82,18 +122,17 @@ def __init__(self, src): self.database.setup(user_id_idx_vec, item_id_idx_vec, self.database.ans_tags) def get_num(self, name): - if name not in ['user','item','log']: + if name not in ['user','item']: raise Exception('Unknown stat source %s'%name) return self.database.stat[name] def get_log(self, user_idx): return self.database.user2item[user_idx] - def get_right_map(self, item_idx): - return self.database.right_map[item_idx] + def get_map(self, item_idx, ans_key_list): + # NOTE: return empty list for invalid ans key + return [self.database.item2user_map[str(ans_key)][item_idx] for ans_key in ans_key_list] - def get_wrong_map(self, item_idx): - return self.database.wrong_map[item_idx] def translate(self, data_type, idx): if data_type == 'item': @@ -121,7 +160,7 @@ def setup(self, user_idx_vec, item_idx_vec, ans_tags, msg=False): # initialize some intermediate variables used in the E step start_time = time.time() - self._init_right_wrong_map() + self._init_item2user_map() if msg: print("--- Sparse Mapping: %f secs ---" % np.round((time.time() - start_time))) @@ -136,11 +175,11 @@ def _process_data(self, user_idx_vec, item_idx_vec, ans_tags): self.user2item = defaultdict(list) self.stat = {} - self.stat['log'] = len(user_idx_vec) + num_log = len(user_idx_vec) self.stat['user'] = max(user_idx_vec)+1 # start count from 0 self.stat['item'] = max(item_idx_vec)+1 - for i in range(self.stat['log']): + for i in range(num_log): item_idx = item_idx_vec[i] user_idx = user_idx_vec[i] ans_tag = ans_tags[i] @@ -148,16 +187,15 @@ def _process_data(self, user_idx_vec, item_idx_vec, ans_tags): self.item2user[item_idx].append((user_idx, ans_tag)) self.user2item[user_idx].append((item_idx, ans_tag)) - def _init_right_wrong_map(self): - self.right_map = defaultdict(list) - self.wrong_map = defaultdict(list) + def _init_item2user_map(self, ans_key_list = ['0','1']): + + self.item2user_map = {} + for ans_key in ans_key_list: + self.item2user_map[ans_key] = defaultdict(list) for item_idx, log_result in self.item2user.items(): for log in log_result: ans_tag = log[1] user_idx = log[0] - if ans_tag == 1: - self.right_map[item_idx].append(user_idx) - else: - self.wrong_map[item_idx].append(user_idx) + self.item2user_map[str(ans_tag)][item_idx].append(user_idx) diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index 0e11b74..64b2185 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -13,11 +13,17 @@ import time from copy import deepcopy from six import string_types +from tqdm import tqdm + from ..util import clib, tools from ..solver import optimizer from ..algo import update_theta_distribution +from datetime import datetime +import multiprocessing as mp + + class IRT_MMLE_2PL(object): @@ -27,14 +33,15 @@ class IRT_MMLE_2PL(object): (2) solve (3) get esitmated result ''' - def __init__(self, dao_instance, is_msg=False): + def __init__(self, dao_instance, is_msg=False, is_parallel=False): # interface to data self.dao=dao_instance - self.is_msg = is_msg self.num_iter = 1 self.ell_list = [] self.last_avg_prob = 0 + self.is_msg = is_msg + self.is_parallel = is_parallel def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, tol): # user self.num_theta = num_theta @@ -131,28 +138,65 @@ def _max_step(self): # theta value is universal opt_worker.set_theta(self.theta_prior_val) + num_item = self.dao.get_num('item') + num_chunk = min(6, mp.cpu_count()) + + if num_item self.max_iter): print('EM does not converge within max iteration') - return True - + return True if self.num_iter != 1: - self.last_item_param_dict = self.item_param_dict - + self.last_item_param_dict = self.item_param_dict return False - - def _init_solver_param(self, is_constrained, boundary, - solver_type, max_iter, tol): + def _init_solver_param(self, is_constrained, boundary, solver_type, max_iter, tol): # initialize bounds self.is_constrained = is_constrained self.alpha_bnds = boundary['alpha'] @@ -199,7 +239,6 @@ def _init_solver_param(self, is_constrained, boundary, self.solver_type = solver_type self.max_iter = max_iter self.tol = tol - if solver_type == 'gradient' and not is_constrained: raise Exception('BFGS has to be constrained') @@ -215,7 +254,6 @@ def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): self.theta_prior_val = np.linspace(theta_min, theta_max, num=num_theta) if self.num_theta != len(self.theta_prior_val): raise Exception('wrong number of inintial theta values') - # use a normal approximation if dist == 'uniform': self.theta_density = np.ones(num_theta) / num_theta @@ -231,10 +269,31 @@ def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): def __update_theta_distr(self): # [A] calculate p(data,param|theta) - for user_idx in range(self.dao.get_num('user')): - self.posterior_theta_distr[user_idx, :] = update_theta_distribution(self.dao.get_log(user_idx), - self.num_theta, self.theta_prior_val, self.theta_density, - self.item_param_dict) + num_user = self.dao.get_num('user') + num_chunk = min(6, mp.cpu_count()) + if num_user>num_chunk and self.is_parallel: + def update(d, start_idx, end_idx): + for user_idx in tqdm(xrange(start_idx, end_idx)): + logs = self.dao.get_log(user_idx) + d[user_idx] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_user, num_chunk) + pool = [] + manager = mp.Manager() + pool_repo = manager.dict() + for i in range(num_chunk): + p = mp.Process(target=update, args=(pool_repo, chunk_list[i][0],chunk_list[i][1],)) + pool.append(p) + for p in pool: + p.start() + for p in pool: + p.join() + for user_idx in range(num_user): + self.posterior_theta_distr[user_idx,:] = pool_repo[user_idx] + else: + for user_idx in tqdm(range(num_user)): + logs = self.dao.get_log(user_idx) + self.posterior_theta_distr[user_idx, :] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) # When the loop finish, check if the theta_density adds up to unity for each user check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1) if any(abs(check_user_distr_marginal - 1.0) > 0.0001): @@ -248,25 +307,30 @@ def __check_theta_density(self): raise Exception('theta desnity has wrong shape (%s,%s)'%self.theta_density.shape) def __get_expect_count(self): - self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) - - for item_idx in range(self.dao.get_num('item')): - right_user_idx_vec = self.dao.get_right_map(item_idx) - wrong_user_idx_vec = self.dao.get_wrong_map(item_idx) - self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[right_user_idx_vec, :], axis=0) - self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[wrong_user_idx_vec, :], axis=0) + num_item = self.dao.get_num('item') + for item_idx in tqdm(range(num_item)): + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) # reduce the io time for mongo dao + self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0) + self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) def __calc_data_likelihood(self): # calculate the likelihood for the data set - + # geometric within learner and across learner + + #1/N * sum[i](1/Ni *sum[j] (log pij)) theta_vec = self.__calc_theta() - ell = 0 - for user_idx in range(self.dao.get_num('user')): + user_ell = 0 + user_cnt = 0 + num_user = self.dao.get_num('user') + for user_idx in tqdm(range(num_user)): theta = theta_vec[user_idx] # find all the item_id logs = self.dao.get_log(user_idx) + if len(logs) == 0: + continue + ell = 0 for log in logs: item_idx = log[0] ans_tag = log[1] @@ -274,8 +338,10 @@ def __calc_data_likelihood(self): beta = self.item_param_dict[item_idx]['beta'] c = self.item_param_dict[item_idx]['c'] ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) - return ell + user_ell += ell/len(logs) + user_cnt += 1 + return user_ell / user_cnt def __calc_theta(self): return np.dot(self.posterior_theta_distr, self.theta_prior_val) - + diff --git a/pyirt/util/tools.py b/pyirt/util/tools.py index 5d47d58..e523e97 100644 --- a/pyirt/util/tools.py +++ b/pyirt/util/tools.py @@ -38,3 +38,13 @@ def logsum(logp): w = max(logp) logSump = w + np.log(sum(np.exp(logp - w))) return logSump + + + +def cut_list(list_length, num_chunk): + chunk_bnd = [0] + for i in range(num_chunk): + chunk_bnd.append(int(list_length*(i+1)/num_chunk)) + chunk_bnd.append(list_length) + chunk_list = [(chunk_bnd[i], chunk_bnd[i+1]) for i in range(num_chunk) ] + return chunk_list diff --git a/setup.py b/setup.py index 9ded2e6..02229df 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ 'scipy', 'cython', 'six', - 'pymongo'], + 'pymongo', + 'tqdm'], package_data={'pyirt': ["*.pyx"]}, diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index 14bbebe..a65e486 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -21,6 +21,7 @@ guess_param = {} for t in range(T): guess_param['q%d'%t]={'c':c[t]} + class Test2PLSolver(unittest.TestCase): @classmethod def setUpClass(cls): @@ -33,7 +34,6 @@ def setUpClass(cls): prob = irt_fnc(thetas[i,0], beta[t], alpha[t]) cls.data.append(('u%d'%i, 'q%d'%t, np.random.binomial(1, prob))) - def test_2pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) for t in range(T): @@ -44,6 +44,16 @@ def test_2pl_solver(self): if item_id != 'q6': self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) self.assertTrue(abs(mdl_beta - beta[t])<0.16) + def test_2pl_solver_parallel(self): + item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True) + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id != 'q6': + self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) + self.assertTrue(abs(mdl_beta - beta[t])<0.16) class Test3PLSolver(unittest.TestCase): @classmethod @@ -70,6 +80,5 @@ def test_3pl_solver(self): self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) if item_id != 'q8': self.assertTrue(abs(mdl_beta - beta[t])<0.15) - if __name__ == '__main__': unittest.main() diff --git a/test/test_tools.py b/test/test_tools.py index 66904be..9b434f9 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -142,6 +142,21 @@ def test_log_factor_hessian(self): self.assertTrue(abs(calc_hessian - true_hessian_approx_theta) < 1e-4) +class TestCutList(unittest.TestCase): + def test_no_mod(self): + test_chunks = tools.cut_list(100,4) + true_chunks = [(0,25),(25,50),(50,75),(75,100)] + for i in range(4): + self.assertTrue(test_chunks[i]==true_chunks[i]) + + def test_mod(self): + test_chunks = tools.cut_list(23,4) + true_chunks = [(0,5),(5,11),(11,17),(17,23)] + for i in range(4): + self.assertTrue(test_chunks[i]==true_chunks[i]) + + + if __name__ == '__main__': unittest.main() From 430413af8188a45451943f770bc6b620b451dd72 Mon Sep 17 00:00:00 2001 From: Junchen Feng Date: Fri, 11 Aug 2017 19:48:17 +0800 Subject: [PATCH 3/7] add parallel to all steps. fix the bug --- pyirt/dao.py | 11 +++- pyirt/solver/model.py | 130 ++++++++++++++++++++++++++----------- test/test_model_wrapper.py | 19 +++++- 3 files changed, 119 insertions(+), 41 deletions(-) diff --git a/pyirt/dao.py b/pyirt/dao.py index cf27a92..f87666a 100644 --- a/pyirt/dao.py +++ b/pyirt/dao.py @@ -22,7 +22,7 @@ def __init__(self, connect_config, group_id=1, is_msg=False): authsource = connect_config['authsource'] mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource) try: - self.client = pymongo.MongoClient(mongouri, serverSelectionTimeoutMS=10, readPreference='secondaryPreferred') + self.client = pymongo.MongoClient(mongouri, connect=False, serverSelectionTimeoutMS=10, readPreference='secondaryPreferred') except: raise @@ -45,6 +45,11 @@ def __init__(self, connect_config, group_id=1, is_msg=False): print('search idx created.') self.gid = group_id self.is_msg = is_msg + + self.close_conn() + + def close_conn(self): + self.client.close() def get_num(self, name): if name not in ['user','item']: @@ -132,7 +137,9 @@ def get_log(self, user_idx): def get_map(self, item_idx, ans_key_list): # NOTE: return empty list for invalid ans key return [self.database.item2user_map[str(ans_key)][item_idx] for ans_key in ans_key_list] - + + def close_conn(self): + pass def translate(self, data_type, idx): if data_type == 'item': diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index 64b2185..59a15c5 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python -W ignore ''' The model is an implementation of EM algorithm of IRT @@ -13,7 +14,6 @@ import time from copy import deepcopy from six import string_types -from tqdm import tqdm from ..util import clib, tools @@ -22,7 +22,7 @@ from datetime import datetime import multiprocessing as mp - +from tqdm import tqdm class IRT_MMLE_2PL(object): @@ -70,15 +70,30 @@ def solve_EM(self): self._init_item_param() # main routine - while True: + while True: #----- E step ----- + stime = datetime.now() self._exp_step() + etime = datetime.now() + if self.is_msg: + runtime = (etime-stime).microseconds / 1000 + print('E step runs for %s sec' %runtime) #----- M step ----- + stime = datetime.now() self._max_step() - + etime = datetime.now() + if self.is_msg: + runtime = (etime-stime).microseconds / 1000 + print('M step runs for %s sec' %runtime) + # ---- Stop Condition ---- + stime = datetime.now() is_stop = self._check_stop() + etime = datetime.now() + if self.is_msg: + runtime = (etime-stime).microseconds / 1000 + print('stop condition runs for %s sec' %runtime) if is_stop: break @@ -141,13 +156,13 @@ def _max_step(self): num_item = self.dao.get_num('item') num_chunk = min(6, mp.cpu_count()) - if num_item num_chunk and self.is_parallel: def update(d, start_idx, end_idx): for item_idx in tqdm(xrange(start_idx, end_idx)): - initial_guess_val = (d[itm_idx]['beta'], - d[item_idx]['alpha']) + initial_guess_val = (self.item_param_dict[item_idx]['beta'], + self.item_param_dict[item_idx]['alpha']) opt_worker.set_initial_guess(initial_guess_val) # the value is a mix of 1/0 and current estimate - opt_worker.set_c(d[item_idx]['c']) + opt_worker.set_c(self.item_param_dict[item_idx]['c']) # estimate expected_right_count = self.item_expected_right_by_theta[:, item_idx] @@ -155,10 +170,7 @@ def update(d, start_idx, end_idx): input_data = [expected_right_count, expected_wrong_count] opt_worker.load_res_data(input_data) est_param = opt_worker.solve_param_mix(self.is_constrained) - - # update - d['beta'] = est_param[0] - d['alpha'] = est_param[1] + d[item_idx] = est_param # [A] calculate p(data,param|theta) chunk_list = tools.cut_list(num_item, num_chunk) @@ -166,9 +178,8 @@ def update(d, start_idx, end_idx): pool = [] manager = mp.Manager() pool_repo = manager.dict() - for item_idx in range(num_item): - pool_repo[item_idx] = self.item_param_dict[item_idx] - + + self.dao.close_conn() for i in range(num_chunk): p = mp.Process(target=update, args=(pool_repo, chunk_list[i][0],chunk_list[i][1],)) pool.append(p) @@ -178,7 +189,12 @@ def update(d, start_idx, end_idx): p.join() for item_idx in range(num_item): - self.item_param_dict[item_idx] = pool_repo[item_idx] + self.item_param_dict[item_idx] = { + 'beta':pool_repo[item_idx][0], + 'alpha':pool_repo[item_idx][1], + 'c':self.item_param_dict[item_idx]['c'] + } + else: for item_idx in range(num_item): @@ -281,6 +297,7 @@ def update(d, start_idx, end_idx): pool = [] manager = mp.Manager() pool_repo = manager.dict() + self.dao.close_conn() for i in range(num_chunk): p = mp.Process(target=update, args=(pool_repo, chunk_list[i][0],chunk_list[i][1],)) pool.append(p) @@ -291,7 +308,7 @@ def update(d, start_idx, end_idx): for user_idx in range(num_user): self.posterior_theta_distr[user_idx,:] = pool_repo[user_idx] else: - for user_idx in tqdm(range(num_user)): + for user_idx in range(num_user): logs = self.dao.get_log(user_idx) self.posterior_theta_distr[user_idx, :] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) # When the loop finish, check if the theta_density adds up to unity for each user @@ -310,7 +327,7 @@ def __get_expect_count(self): self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) num_item = self.dao.get_num('item') - for item_idx in tqdm(range(num_item)): + for item_idx in range(num_item): map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) # reduce the io time for mongo dao self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0) self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) @@ -321,26 +338,65 @@ def __calc_data_likelihood(self): #1/N * sum[i](1/Ni *sum[j] (log pij)) theta_vec = self.__calc_theta() - user_ell = 0 - user_cnt = 0 num_user = self.dao.get_num('user') - for user_idx in tqdm(range(num_user)): - theta = theta_vec[user_idx] - # find all the item_id - logs = self.dao.get_log(user_idx) - if len(logs) == 0: - continue - ell = 0 - for log in logs: - item_idx = log[0] - ans_tag = log[1] - alpha = self.item_param_dict[item_idx]['alpha'] - beta = self.item_param_dict[item_idx]['beta'] - c = self.item_param_dict[item_idx]['c'] - ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) - user_ell += ell/len(logs) - user_cnt += 1 - return user_ell / user_cnt + num_chunk = min(6, mp.cpu_count()) + if num_user>num_chunk and self.is_parallel: + def update(tot_llk, cnt, start_idx, end_idx): + for user_idx in tqdm(range(start_idx, end_idx)): + theta = theta_vec[user_idx] + # find all the item_id + logs = self.dao.get_log(user_idx) + if len(logs) == 0: + continue + ell = 0 + for log in logs: + item_idx = log[0] + ans_tag = log[1] + alpha = self.item_param_dict[item_idx]['alpha'] + beta = self.item_param_dict[item_idx]['beta'] + c = self.item_param_dict[item_idx]['c'] + ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) + with tot_llk.get_lock(): + tot_llk.value += ell/len(logs) + with cnt.get_lock(): + cnt.value += 1 + + user_ell = mp.Value('d', 0.0) + user_cnt = mp.Value('i', 0) + + chunk_list = tools.cut_list(num_user, num_chunk) + pool = [] + self.dao.close_conn() + for i in range(num_chunk): + p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) + pool.append(p) + for p in pool: + p.start() + for p in pool: + p.join() + + avg_ell = user_ell.value/user_cnt.value + else: + user_ell = 0 + user_cnt = 0 + for user_idx in range(num_user): + theta = theta_vec[user_idx] + # find all the item_id + logs = self.dao.get_log(user_idx) + if len(logs) == 0: + continue + ell = 0 + for log in logs: + item_idx = log[0] + ans_tag = log[1] + alpha = self.item_param_dict[item_idx]['alpha'] + beta = self.item_param_dict[item_idx]['beta'] + c = self.item_param_dict[item_idx]['c'] + ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) + user_ell += ell/len(logs) + user_cnt += 1 + avg_ell = user_ell / user_cnt + return avg_ell def __calc_theta(self): return np.dot(self.posterior_theta_distr, self.theta_prior_val) diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index a65e486..54824ec 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -33,7 +33,6 @@ def setUpClass(cls): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t]) cls.data.append(('u%d'%i, 'q%d'%t, np.random.binomial(1, prob))) - def test_2pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) for t in range(T): @@ -66,7 +65,6 @@ def setUpClass(cls): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t], c[t]) cls.data.append(('u%d'%i, 'q%d'%t ,np.random.binomial(1,prob))) - def test_3pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], in_guess_param=guess_param, tol=1e-5, max_iter=30) @@ -80,5 +78,22 @@ def test_3pl_solver(self): self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) if item_id != 'q8': self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + def test_3pl_solver_parallel(self): + item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], + in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True) + + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id not in ['q6','q7']: + self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) + if item_id != 'q8': + self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + + if __name__ == '__main__': unittest.main() From 0f2a11558756ba13739bd657d312a9ab0a6425ba Mon Sep 17 00:00:00 2001 From: Junchen Feng Date: Sun, 13 Aug 2017 19:07:25 +0800 Subject: [PATCH 4/7] add production error handling --- pyirt/_pyirt.py | 7 ++-- pyirt/solver/model.py | 65 ++++++++++++++++++++++++-------------- test/test_model_wrapper.py | 33 ++++++++++++++++++- 3 files changed, 77 insertions(+), 28 deletions(-) diff --git a/pyirt/_pyirt.py b/pyirt/_pyirt.py index 9e3e206..c8e05bc 100644 --- a/pyirt/_pyirt.py +++ b/pyirt/_pyirt.py @@ -5,10 +5,11 @@ def irt(data_src, dao = 'memory', theta_bnds=[-4, 4], num_theta=11, - alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param='default', + alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param={}, model_spec='2PL', max_iter=10, tol=1e-3, nargout=2, - is_msg=False, is_parallel=False): + is_msg=False, is_parallel=False, + mode='debug'): # load data @@ -19,7 +20,7 @@ def irt(data_src, # setup the model if model_spec == '2PL': - mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg, is_parallel=is_parallel) + mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg, is_parallel=is_parallel, mode=mode) else: raise Exception('Unknown model specification.') diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index 59a15c5..2f01561 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -33,7 +33,7 @@ class IRT_MMLE_2PL(object): (2) solve (3) get esitmated result ''' - def __init__(self, dao_instance, is_msg=False, is_parallel=False): + def __init__(self, dao_instance, is_msg=False, is_parallel=False, mode='debug'): # interface to data self.dao=dao_instance self.num_iter = 1 @@ -42,6 +42,8 @@ def __init__(self, dao_instance, is_msg=False, is_parallel=False): self.is_msg = is_msg self.is_parallel = is_parallel + self.mode = mode + def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, tol): # user self.num_theta = num_theta @@ -56,14 +58,12 @@ def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, to def set_guess_param(self, in_guess_param): self.guess_param_dict = {} - if isinstance(in_guess_param, string_types): - for item_idx in range(self.dao.get_num('item')): - self.guess_param_dict[item_idx] = {'c': 0.0} # default set to 0 - else: - for item_idx in range(self.dao.get_num('item')): - item_id = self.dao.translate('item', item_idx) - self.guess_param_dict[item_idx] = in_guess_param[item_id] - + for item_idx in xrange(self.dao.get_num('item')): + item_id = self.dao.translate('item', item_idx) + if item_id in in_guess_param: + self.guess_param_dict[item_idx] = {'c': float(in_guess_param[item_id])} + else: + self.guess_param_dict[item_idx] = {'c':0.0} # if null then set as 0 def solve_EM(self): # data dependent initialization @@ -100,7 +100,7 @@ def solve_EM(self): def get_item_param(self): output_item_param = {} - for item_idx in range(self.dao.get_num('item')): + for item_idx in xrange(self.dao.get_num('item')): item_id = self.dao.translate('item', item_idx) output_item_param[item_id] = self.item_param_dict[item_idx] return output_item_param @@ -108,7 +108,7 @@ def get_item_param(self): def get_user_param(self): output_user_param = {} theta_vec = self.__calc_theta() - for user_idx in range(self.dao.get_num('user')): + for user_idx in xrange(self.dao.get_num('user')): user_id = self.dao.translate('user', user_idx) output_user_param[user_id] = theta_vec[user_idx] return output_user_param @@ -169,8 +169,17 @@ def update(d, start_idx, end_idx): expected_wrong_count = self.item_expected_wrong_by_theta[:, item_idx] input_data = [expected_right_count, expected_wrong_count] opt_worker.load_res_data(input_data) - est_param = opt_worker.solve_param_mix(self.is_constrained) - d[item_idx] = est_param + try: + est_param = opt_worker.solve_param_mix(self.is_constrained) + except Exception as e: + if self.mode=='production': + # In production mode, use the previous iteration + print('Item %d does not fit'%item_idx) + d[item_idx] = self.item_param_dict[item_idx] + else: + raise e + finally: + d[item_idx] = est_param # [A] calculate p(data,param|theta) chunk_list = tools.cut_list(num_item, num_chunk) @@ -188,7 +197,7 @@ def update(d, start_idx, end_idx): for p in pool: p.join() - for item_idx in range(num_item): + for item_idx in xrange(num_item): self.item_param_dict[item_idx] = { 'beta':pool_repo[item_idx][0], 'alpha':pool_repo[item_idx][1], @@ -197,7 +206,7 @@ def update(d, start_idx, end_idx): else: - for item_idx in range(num_item): + for item_idx in xrange(num_item): initial_guess_val = (self.item_param_dict[item_idx]['beta'], self.item_param_dict[item_idx]['alpha']) opt_worker.set_initial_guess(initial_guess_val) # the value is a mix of 1/0 and current estimate @@ -210,9 +219,17 @@ def update(d, start_idx, end_idx): opt_worker.load_res_data(input_data) est_param = opt_worker.solve_param_mix(self.is_constrained) - # update - self.item_param_dict[item_idx]['beta'] = est_param[0] - self.item_param_dict[item_idx]['alpha'] = est_param[1] + try: + est_param = opt_worker.solve_param_mix(self.is_constrained) + except Exception as e: + if self.mode=='production': + continue + else: + raise e + finally: + # update + self.item_param_dict[item_idx]['beta'] = est_param[0] + self.item_param_dict[item_idx]['alpha'] = est_param[1] # [B] max for theta density self.theta_density = self.posterior_theta_distr.sum(axis=0)/self.posterior_theta_distr.sum() @@ -260,7 +277,7 @@ def _init_solver_param(self, is_constrained, boundary, solver_type, max_iter, to def _init_item_param(self): self.item_param_dict = {} - for item_idx in range(self.dao.get_num('item')): + for item_idx in xrange(self.dao.get_num('item')): # need to call the old item_id c = self.guess_param_dict[item_idx]['c'] self.item_param_dict[item_idx] = {'alpha': 1.0, 'beta': 0.0, 'c': c} @@ -305,10 +322,10 @@ def update(d, start_idx, end_idx): p.start() for p in pool: p.join() - for user_idx in range(num_user): + for user_idx in tqdm(xrange(num_user)): self.posterior_theta_distr[user_idx,:] = pool_repo[user_idx] else: - for user_idx in range(num_user): + for user_idx in xrange(num_user): logs = self.dao.get_log(user_idx) self.posterior_theta_distr[user_idx, :] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) # When the loop finish, check if the theta_density adds up to unity for each user @@ -327,7 +344,7 @@ def __get_expect_count(self): self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) num_item = self.dao.get_num('item') - for item_idx in range(num_item): + for item_idx in tqdm(xrange(num_item)): map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) # reduce the io time for mongo dao self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0) self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) @@ -342,7 +359,7 @@ def __calc_data_likelihood(self): num_chunk = min(6, mp.cpu_count()) if num_user>num_chunk and self.is_parallel: def update(tot_llk, cnt, start_idx, end_idx): - for user_idx in tqdm(range(start_idx, end_idx)): + for user_idx in tqdm(xrange(start_idx, end_idx)): theta = theta_vec[user_idx] # find all the item_id logs = self.dao.get_log(user_idx) @@ -379,7 +396,7 @@ def update(tot_llk, cnt, start_idx, end_idx): else: user_ell = 0 user_cnt = 0 - for user_idx in range(num_user): + for user_idx in xrange(num_user): theta = theta_vec[user_idx] # find all the item_id logs = self.dao.get_log(user_idx) diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index 54824ec..8ae65b7 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -20,7 +20,8 @@ guess_param = {} for t in range(T): - guess_param['q%d'%t]={'c':c[t]} + guess_param['q%d'%t]=c[t] + class Test2PLSolver(unittest.TestCase): @classmethod @@ -33,6 +34,7 @@ def setUpClass(cls): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t]) cls.data.append(('u%d'%i, 'q%d'%t, np.random.binomial(1, prob))) + def test_2pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) for t in range(T): @@ -43,6 +45,18 @@ def test_2pl_solver(self): if item_id != 'q6': self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) self.assertTrue(abs(mdl_beta - beta[t])<0.16) + + def test_2pl_solver_production(self): + item_param, user_param = irt(self.data, mode='production',theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id != 'q6': + self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) + self.assertTrue(abs(mdl_beta - beta[t])<0.16) + def test_2pl_solver_parallel(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True) for t in range(T): @@ -65,6 +79,7 @@ def setUpClass(cls): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t], c[t]) cls.data.append(('u%d'%i, 'q%d'%t ,np.random.binomial(1,prob))) + def test_3pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], in_guess_param=guess_param, tol=1e-5, max_iter=30) @@ -79,6 +94,22 @@ def test_3pl_solver(self): if item_id != 'q8': self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + def test_3pl_solver_production(self): + item_param, user_param = irt(self.data, mode='production', theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], + in_guess_param=guess_param, tol=1e-5, max_iter=30) + + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id not in ['q6','q7']: + self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) + if item_id != 'q8': + self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + def test_3pl_solver_parallel(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True) From 76a6074fa9e57164b5e383badfa2d2c4c847d65d Mon Sep 17 00:00:00 2001 From: Junchen Feng Date: Mon, 14 Aug 2017 19:00:54 +0800 Subject: [PATCH 5/7] add parallel procs timeout check and adjustable procs num --- pyirt/_pyirt.py | 5 ++- pyirt/solver/model.py | 80 ++++++++++++++++++++++++-------------- test/test_model_wrapper.py | 4 +- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/pyirt/_pyirt.py b/pyirt/_pyirt.py index c8e05bc..0887de5 100644 --- a/pyirt/_pyirt.py +++ b/pyirt/_pyirt.py @@ -8,7 +8,8 @@ def irt(data_src, alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param={}, model_spec='2PL', max_iter=10, tol=1e-3, nargout=2, - is_msg=False, is_parallel=False, + is_msg=False, + is_parallel=False, num_cpu=6, check_interval = 60, mode='debug'): @@ -20,7 +21,7 @@ def irt(data_src, # setup the model if model_spec == '2PL': - mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg, is_parallel=is_parallel, mode=mode) + mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg, is_parallel=is_parallel, num_cpu=num_cpu, check_interval=check_interval, mode=mode) else: raise Exception('Unknown model specification.') diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index 2f01561..4458ff5 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -23,7 +23,29 @@ from datetime import datetime import multiprocessing as mp from tqdm import tqdm +import time + + + +def procs_operator(procs, TIMEOUT, check_interval): + for p in procs: + p.start() + + start = time.time() + while time.time() - start < TIMEOUT: + if any(p.is_alive() for p in procs): + time.sleep(check_interval) + else: + for p in procs: + p.join() + break + else: + for p in procs: + p.terminate() + p.join() + raise Exception('Time out, killing all process') + return procs class IRT_MMLE_2PL(object): @@ -33,7 +55,7 @@ class IRT_MMLE_2PL(object): (2) solve (3) get esitmated result ''' - def __init__(self, dao_instance, is_msg=False, is_parallel=False, mode='debug'): + def __init__(self, dao_instance, is_msg=False, is_parallel=False, num_cpu=6, check_interval=60, mode='debug'): # interface to data self.dao=dao_instance self.num_iter = 1 @@ -42,6 +64,8 @@ def __init__(self, dao_instance, is_msg=False, is_parallel=False, mode='debug'): self.is_msg = is_msg self.is_parallel = is_parallel + self.num_cpu = num_cpu + self.check_interval = check_interval self.mode = mode def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, tol): @@ -154,7 +178,7 @@ def _max_step(self): # theta value is universal opt_worker.set_theta(self.theta_prior_val) num_item = self.dao.get_num('item') - num_chunk = min(6, mp.cpu_count()) + num_chunk = min(self.num_cpu, mp.cpu_count()) if num_item > num_chunk and self.is_parallel: def update(d, start_idx, end_idx): @@ -184,23 +208,22 @@ def update(d, start_idx, end_idx): # [A] calculate p(data,param|theta) chunk_list = tools.cut_list(num_item, num_chunk) - pool = [] + procs = [] manager = mp.Manager() - pool_repo = manager.dict() + procs_repo = manager.dict() self.dao.close_conn() for i in range(num_chunk): - p = mp.Process(target=update, args=(pool_repo, chunk_list[i][0],chunk_list[i][1],)) - pool.append(p) - for p in pool: - p.start() - for p in pool: - p.join() - + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + + procs = procs_operator(procs, 3600, self.check_interval) + for item_idx in xrange(num_item): self.item_param_dict[item_idx] = { - 'beta':pool_repo[item_idx][0], - 'alpha':pool_repo[item_idx][1], + 'beta':procs_repo[item_idx][0], + 'alpha':procs_repo[item_idx][1], 'c':self.item_param_dict[item_idx]['c'] } @@ -303,7 +326,7 @@ def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): def __update_theta_distr(self): # [A] calculate p(data,param|theta) num_user = self.dao.get_num('user') - num_chunk = min(6, mp.cpu_count()) + num_chunk = min(self.num_cpu, mp.cpu_count()) if num_user>num_chunk and self.is_parallel: def update(d, start_idx, end_idx): for user_idx in tqdm(xrange(start_idx, end_idx)): @@ -311,19 +334,18 @@ def update(d, start_idx, end_idx): d[user_idx] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) # [A] calculate p(data,param|theta) chunk_list = tools.cut_list(num_user, num_chunk) - pool = [] + procs = [] manager = mp.Manager() - pool_repo = manager.dict() + procs_repo = manager.dict() self.dao.close_conn() for i in range(num_chunk): - p = mp.Process(target=update, args=(pool_repo, chunk_list[i][0],chunk_list[i][1],)) - pool.append(p) - for p in pool: - p.start() - for p in pool: - p.join() + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + procs = procs_operator(procs, 5400, self.check_interval) + for user_idx in tqdm(xrange(num_user)): - self.posterior_theta_distr[user_idx,:] = pool_repo[user_idx] + self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx] else: for user_idx in xrange(num_user): logs = self.dao.get_log(user_idx) @@ -356,7 +378,7 @@ def __calc_data_likelihood(self): #1/N * sum[i](1/Ni *sum[j] (log pij)) theta_vec = self.__calc_theta() num_user = self.dao.get_num('user') - num_chunk = min(6, mp.cpu_count()) + num_chunk = min(self.num_cpu, mp.cpu_count()) if num_user>num_chunk and self.is_parallel: def update(tot_llk, cnt, start_idx, end_idx): for user_idx in tqdm(xrange(start_idx, end_idx)): @@ -382,15 +404,13 @@ def update(tot_llk, cnt, start_idx, end_idx): user_cnt = mp.Value('i', 0) chunk_list = tools.cut_list(num_user, num_chunk) - pool = [] + procs = [] self.dao.close_conn() for i in range(num_chunk): p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) - pool.append(p) - for p in pool: - p.start() - for p in pool: - p.join() + procs.append(p) + + procs = procs_operator(procs, 3600, self.check_interval) avg_ell = user_ell.value/user_cnt.value else: diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index 8ae65b7..a3e06f0 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -58,7 +58,7 @@ def test_2pl_solver_production(self): self.assertTrue(abs(mdl_beta - beta[t])<0.16) def test_2pl_solver_parallel(self): - item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True) + item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True, check_interval=1) for t in range(T): item_id = 'q%d'%t print(item_id, item_param[item_id]) @@ -112,7 +112,7 @@ def test_3pl_solver_production(self): def test_3pl_solver_parallel(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], - in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True) + in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True, check_interval=1) for t in range(T): item_id = 'q%d'%t From e1a9fa9e5415ddc913830a328d2db1f0bccfd134 Mon Sep 17 00:00:00 2001 From: Junchen Feng Date: Tue, 15 Aug 2017 09:36:21 +0800 Subject: [PATCH 6/7] initiate db connection within each process. factor the code to reduce code repetition. --- pyirt/_pyirt.py | 7 +- pyirt/dao.py | 65 +++++---- pyirt/solver/model.py | 281 +++++++++++++++++++------------------ test/test_model_wrapper.py | 4 +- 4 files changed, 185 insertions(+), 172 deletions(-) diff --git a/pyirt/_pyirt.py b/pyirt/_pyirt.py index 0887de5..a497c60 100644 --- a/pyirt/_pyirt.py +++ b/pyirt/_pyirt.py @@ -3,7 +3,7 @@ from .dao import localDAO def irt(data_src, - dao = 'memory', + dao_type = 'memory', theta_bnds=[-4, 4], num_theta=11, alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param={}, model_spec='2PL', @@ -14,14 +14,15 @@ def irt(data_src, # load data - if dao=='memory': + if dao_type=='memory': dao_instance = localDAO(data_src) else: dao_instance = data_src # setup the model if model_spec == '2PL': - mod = model.IRT_MMLE_2PL(dao_instance, is_msg=is_msg, is_parallel=is_parallel, num_cpu=num_cpu, check_interval=check_interval, mode=mode) + mod = model.IRT_MMLE_2PL(dao_instance, dao_type=dao_type, + is_msg=is_msg, is_parallel=is_parallel, num_cpu=num_cpu, check_interval=check_interval, mode=mode) else: raise Exception('Unknown model specification.') diff --git a/pyirt/dao.py b/pyirt/dao.py index f87666a..bb9a795 100644 --- a/pyirt/dao.py +++ b/pyirt/dao.py @@ -11,31 +11,23 @@ from datetime import datetime class mongoDAO(object): + def __init__(self, connect_config, group_id=1, is_msg=False): - user_name = connect_config['user'] - password = connect_config['password'] - address = connect_config['address'] # IP:PORT - db_name = connect_config['db'] - if 'authsource' not in connect_config: - mongouri = 'mongodb://{un}:{pw}@{addr}'.format(un=user_name, pw=password, addr=address) - else: - authsource = connect_config['authsource'] - mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource) - try: - self.client = pymongo.MongoClient(mongouri, connect=False, serverSelectionTimeoutMS=10, readPreference='secondaryPreferred') - except: - raise + self.connect_config = connect_config + + client = self.open_conn() + self.db_name = connect_config['db'] - user2item_collection_name = 'irt_user2item' - item2user_collection_name = 'irt_item2user' + self.user2item_collection_name = 'irt_user2item' + self.item2user_collection_name = 'irt_item2user' - self.user2item = self.client[db_name][user2item_collection_name] - self.item2user = self.client[db_name][item2user_collection_name] + user2item_conn = client[self.db_name][self.user2item_collection_name] + item2user_conn = client[self.db_name][self.item2user_collection_name] - user_ids = list(set([x['id'] for x in self.user2item.find({'gid':group_id},{'id':1})])) - item_ids = list(set([x['id'] for x in self.item2user.find({'gid':group_id},{'id':1})])) + user_ids = list(set([x['id'] for x in user2item_conn.find({'gid':group_id},{'id':1})])) + item_ids = list(set([x['id'] for x in item2user_conn.find({'gid':group_id},{'id':1})])) _, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(user_ids) _, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(item_ids) @@ -46,28 +38,42 @@ def __init__(self, connect_config, group_id=1, is_msg=False): self.gid = group_id self.is_msg = is_msg - self.close_conn() + client.close() + + def open_conn(self): + + user_name = self.connect_config['user'] + password = self.connect_config['password'] + address = self.connect_config['address'] # IP:PORT + if 'authsource' not in self.connect_config: + mongouri = 'mongodb://{un}:{pw}@{addr}'.format(un=user_name, pw=password, addr=address) + else: + authsource = self.connect_config['authsource'] + mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource) + try: + client = pymongo.MongoClient(mongouri, connect=False, serverSelectionTimeoutMS=10, waitQueueTimeoutMS=100 ,readPreference='secondaryPreferred') + except: + raise + return client - def close_conn(self): - self.client.close() def get_num(self, name): if name not in ['user','item']: raise Exception('Unknown stat source %s'%name) return self.stat[name] - def get_log(self, user_idx): + def get_log(self, user_idx, user2item_conn): user_id = self.translate('user', user_idx) # query if self.is_msg: stime = datetime.now() - res = self.user2item.find({'id':user_id, 'gid':self.gid}) + res = user2item_conn.find({'id':user_id, 'gid':self.gid}) etime = datetime.now() search_time = int((etime-stime).microseconds/1000) if search_time > 100: print('s:%d' % search_time) else: - res = self.user2item.find({'id':user_id, 'gid':self.gid}) + res = user2item_conn.find({'id':user_id, 'gid':self.gid}) # parse if res.count() == 0: return_list = [] @@ -78,18 +84,18 @@ def get_log(self, user_idx): return_list = [(self.item_idx_ref[x[0]], x[1]) for x in log_list] return return_list - def get_map(self, item_idx, ans_key_list): + def get_map(self, item_idx, ans_key_list, item2user_conn): item_id = self.translate('item', item_idx) # query if self.is_msg: stime = datetime.now() - res = self.item2user.find({'id':item_id, 'gid':self.gid}) + res = item2user_conn.find({'id':item_id, 'gid':self.gid}) etime = datetime.now() search_time = int((etime-stime).microseconds/1000) if search_time > 100: print('s:%d' % search_time) else: - res = self.item2user.find({'id':item_id, 'gid':self.gid}) + res = item2user_conn.find({'id':item_id, 'gid':self.gid}) # parse if res.count() == 0: return_list = [[] for ans_key in ans_key_list] @@ -111,9 +117,6 @@ def translate(self, data_type, idx): elif data_type == 'user': return self.user_reverse_idx_ref[idx] - def __del__(self): - self.client.close() - class localDAO(object): def __init__(self, src): diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index 4458ff5..a7d4ccd 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -55,16 +55,21 @@ class IRT_MMLE_2PL(object): (2) solve (3) get esitmated result ''' - def __init__(self, dao_instance, is_msg=False, is_parallel=False, num_cpu=6, check_interval=60, mode='debug'): + def __init__(self, + dao_instance, dao_type='memory', + is_msg=False, + is_parallel=False, num_cpu=6, check_interval=60, + mode='debug'): # interface to data self.dao=dao_instance + self.dao_type = dao_type self.num_iter = 1 self.ell_list = [] self.last_avg_prob = 0 self.is_msg = is_msg self.is_parallel = is_parallel - self.num_cpu = num_cpu + self.num_cpu = min(num_cpu, mp.cpu_count()) self.check_interval = check_interval self.mode = mode @@ -152,7 +157,6 @@ def _exp_step(self): (1-data_[i,j]) *(1-P(Y=1|param_j,theta_[i,k]) ) By similar logic, it is equivalent to sum (1-p) for all done wrong users - ''' # (1) update the posterior distribution of theta @@ -168,20 +172,8 @@ def _max_step(self): Basic Math log likelihood(param_j) = sum_k(log likelihood(param_j, theta_k)) ''' - # [A] max for item parameter - opt_worker = optimizer.irt_2PL_Optimizer() - # the boundary is universal - # the boundary is set regardless of the constrained option because the - # constrained search serves as backup for outlier cases - opt_worker.set_bounds([self.beta_bnds, self.alpha_bnds]) - - # theta value is universal - opt_worker.set_theta(self.theta_prior_val) - num_item = self.dao.get_num('item') - num_chunk = min(self.num_cpu, mp.cpu_count()) - - if num_item > num_chunk and self.is_parallel: - def update(d, start_idx, end_idx): + def update(d, start_idx, end_idx): + try: for item_idx in tqdm(xrange(start_idx, end_idx)): initial_guess_val = (self.item_param_dict[item_idx]['beta'], self.item_param_dict[item_idx]['alpha']) @@ -204,55 +196,49 @@ def update(d, start_idx, end_idx): raise e finally: d[item_idx] = est_param - - # [A] calculate p(data,param|theta) - chunk_list = tools.cut_list(num_item, num_chunk) + except: + raise + # [A] max for item parameter + opt_worker = optimizer.irt_2PL_Optimizer() + # the boundary is universal + # the boundary is set regardless of the constrained option because the + # constrained search serves as backup for outlier cases + opt_worker.set_bounds([self.beta_bnds, self.alpha_bnds]) - procs = [] - manager = mp.Manager() - procs_repo = manager.dict() - - self.dao.close_conn() - for i in range(num_chunk): - p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) - procs.append(p) + # theta value is universal + opt_worker.set_theta(self.theta_prior_val) + num_item = self.dao.get_num('item') - - procs = procs_operator(procs, 3600, self.check_interval) + + if num_item > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 - for item_idx in xrange(num_item): - self.item_param_dict[item_idx] = { - 'beta':procs_repo[item_idx][0], - 'alpha':procs_repo[item_idx][1], - 'c':self.item_param_dict[item_idx]['c'] - } + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_item, num_chunk) + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + if num_chunk > 1: + procs = procs_operator(procs, 3600, self.check_interval) else: - for item_idx in xrange(num_item): - initial_guess_val = (self.item_param_dict[item_idx]['beta'], - self.item_param_dict[item_idx]['alpha']) - opt_worker.set_initial_guess(initial_guess_val) # the value is a mix of 1/0 and current estimate - opt_worker.set_c(self.item_param_dict[item_idx]['c']) - - # estimate - expected_right_count = self.item_expected_right_by_theta[:, item_idx] - expected_wrong_count = self.item_expected_wrong_by_theta[:, item_idx] - input_data = [expected_right_count, expected_wrong_count] - opt_worker.load_res_data(input_data) - est_param = opt_worker.solve_param_mix(self.is_constrained) - - try: - est_param = opt_worker.solve_param_mix(self.is_constrained) - except Exception as e: - if self.mode=='production': - continue - else: - raise e - finally: - # update - self.item_param_dict[item_idx]['beta'] = est_param[0] - self.item_param_dict[item_idx]['alpha'] = est_param[1] + procs = procs_operator(procs, 7200, 0.1) + + for item_idx in xrange(num_item): + self.item_param_dict[item_idx] = { + 'beta':procs_repo[item_idx][0], + 'alpha':procs_repo[item_idx][1], + 'c':self.item_param_dict[item_idx]['c'] + } + + # [B] max for theta density self.theta_density = self.posterior_theta_distr.sum(axis=0)/self.posterior_theta_distr.sum() @@ -262,30 +248,34 @@ def _check_stop(self): ''' preserve user and item parameter from last iteration. This is useful in restoring after a declining llk iteration ''' + if self.is_msg: print('score calculating') avg_prob = np.exp(self.__calc_data_likelihood()) + if self.is_msg: print('score calculated.') + self.ell_list.append(avg_prob) if self.is_msg: print(avg_prob) - if self.last_avg_prob < avg_prob and avg_prob - self.last_avg_prob <= self.tol: + diff = avg_prob - self.last_avg_prob + + if diff >= 0 and diff <= self.tol: print('EM converged at iteration %d.' % self.num_iter) - return True - - # if the algorithm improves, then ell > ell_t0 - if self.last_avg_prob > avg_prob: + return True + elif diff <0: self.item_param_dict = self.last_item_param_dict print('Likelihood descrease, stops at iteration %d.' % self.num_iter) return True - - # update the stop condition - self.last_avg_prob = avg_prob - self.num_iter += 1 - - if (self.num_iter > self.max_iter): - print('EM does not converge within max iteration') - return True - if self.num_iter != 1: - self.last_item_param_dict = self.item_param_dict - return False + else: + # diff larger than tolerance + # update the stop condition + self.last_avg_prob = avg_prob + self.num_iter += 1 + + if (self.num_iter > self.max_iter): + print('EM does not converge within max iteration') + return True + if self.num_iter != 1: + self.last_item_param_dict = self.item_param_dict + return False def _init_solver_param(self, is_constrained, boundary, solver_type, max_iter, tol): # initialize bounds @@ -324,32 +314,44 @@ def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): self.posterior_theta_distr = np.zeros((self.dao.get_num('user'), num_theta)) def __update_theta_distr(self): - # [A] calculate p(data,param|theta) - num_user = self.dao.get_num('user') - num_chunk = min(self.num_cpu, mp.cpu_count()) - if num_user>num_chunk and self.is_parallel: - def update(d, start_idx, end_idx): + def update(d, start_idx, end_idx): + try: + if self.dao_type=='db': + client = self.dao.open_conn() + user2item_conn = client[self.dao.db_name][self.dao.user2item_collection_name] for user_idx in tqdm(xrange(start_idx, end_idx)): - logs = self.dao.get_log(user_idx) + if self.dao_type=='db': + logs = self.dao.get_log(user_idx, user2item_conn) + else: + logs = self.dao.get_log(user_idx) d[user_idx] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) - # [A] calculate p(data,param|theta) - chunk_list = tools.cut_list(num_user, num_chunk) - procs = [] - manager = mp.Manager() - procs_repo = manager.dict() - self.dao.close_conn() - for i in range(num_chunk): - p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) - procs.append(p) - + if self.dao_type=='db': + client.close() + except: + raise + # [A] calculate p(data,param|theta) + num_user = self.dao.get_num('user') + if num_user > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_user, num_chunk) + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + if num_chunk>1: procs = procs_operator(procs, 5400, self.check_interval) - - for user_idx in tqdm(xrange(num_user)): - self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx] else: - for user_idx in xrange(num_user): - logs = self.dao.get_log(user_idx) - self.posterior_theta_distr[user_idx, :] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) + procs = procs_operator(procs, 7200, 0.1) + + for user_idx in tqdm(xrange(num_user)): + self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx] # When the loop finish, check if the theta_density adds up to unity for each user check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1) if any(abs(check_user_distr_marginal - 1.0) > 0.0001): @@ -366,25 +368,36 @@ def __get_expect_count(self): self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) num_item = self.dao.get_num('item') + if self.dao_type == 'db': + client = self.dao.open_conn() + item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name] for item_idx in tqdm(xrange(num_item)): - map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) # reduce the io time for mongo dao + if self.dao_type == 'db': + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn) + else: + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0) self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) + if self.dao_type == 'db': + client.close() def __calc_data_likelihood(self): # calculate the likelihood for the data set # geometric within learner and across learner - #1/N * sum[i](1/Ni *sum[j] (log pij)) - theta_vec = self.__calc_theta() - num_user = self.dao.get_num('user') - num_chunk = min(self.num_cpu, mp.cpu_count()) - if num_user>num_chunk and self.is_parallel: - def update(tot_llk, cnt, start_idx, end_idx): + def update(tot_llk, cnt, start_idx, end_idx): + try: + if self.dao_type == 'db': + client = self.dao.open_conn() + user2item_conn = client[self.dao.db_name][self.dao.user2item_collection_name] + for user_idx in tqdm(xrange(start_idx, end_idx)): theta = theta_vec[user_idx] # find all the item_id - logs = self.dao.get_log(user_idx) + if self.dao_type == 'db': + logs = self.dao.get_log(user_idx, user2item_conn) + else: + logs = self.dao.get_log(user_idx) if len(logs) == 0: continue ell = 0 @@ -399,40 +412,36 @@ def update(tot_llk, cnt, start_idx, end_idx): tot_llk.value += ell/len(logs) with cnt.get_lock(): cnt.value += 1 + if self.dao_type == 'db': + client.close() + except: + raise - user_ell = mp.Value('d', 0.0) - user_cnt = mp.Value('i', 0) + theta_vec = self.__calc_theta() + num_user = self.dao.get_num('user') - chunk_list = tools.cut_list(num_user, num_chunk) - procs = [] - self.dao.close_conn() - for i in range(num_chunk): - p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) - procs.append(p) - - procs = procs_operator(procs, 3600, self.check_interval) - avg_ell = user_ell.value/user_cnt.value + if num_user > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + + user_ell = mp.Value('d', 0.0) + user_cnt = mp.Value('i', 0) + chunk_list = tools.cut_list(num_user, num_chunk) + + procs = [] + for i in range(num_chunk): + p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + if num_chunk>1: + procs = procs_operator(procs, 1200, self.check_interval) else: - user_ell = 0 - user_cnt = 0 - for user_idx in xrange(num_user): - theta = theta_vec[user_idx] - # find all the item_id - logs = self.dao.get_log(user_idx) - if len(logs) == 0: - continue - ell = 0 - for log in logs: - item_idx = log[0] - ans_tag = log[1] - alpha = self.item_param_dict[item_idx]['alpha'] - beta = self.item_param_dict[item_idx]['beta'] - c = self.item_param_dict[item_idx]['c'] - ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) - user_ell += ell/len(logs) - user_cnt += 1 - avg_ell = user_ell / user_cnt + procs = procs_operator(procs, 7200, 0.1) + + avg_ell = user_ell.value/user_cnt.value + return avg_ell def __calc_theta(self): diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index a3e06f0..ee21576 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -58,7 +58,7 @@ def test_2pl_solver_production(self): self.assertTrue(abs(mdl_beta - beta[t])<0.16) def test_2pl_solver_parallel(self): - item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True, check_interval=1) + item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True, check_interval=0.1) for t in range(T): item_id = 'q%d'%t print(item_id, item_param[item_id]) @@ -112,7 +112,7 @@ def test_3pl_solver_production(self): def test_3pl_solver_parallel(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], - in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True, check_interval=1) + in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True, check_interval=0.1) for t in range(T): item_id = 'q%d'%t From 4009f0e95bb7ff54d0a2c66451a83d015b079c96 Mon Sep 17 00:00:00 2001 From: Junchen Feng Date: Tue, 15 Aug 2017 11:18:57 +0800 Subject: [PATCH 7/7] parallel in all expectation steps. --- pyirt/solver/model.py | 73 +++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index a7d4ccd..5192c0f 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -350,8 +350,9 @@ def update(d, start_idx, end_idx): else: procs = procs_operator(procs, 7200, 0.1) - for user_idx in tqdm(xrange(num_user)): + for user_idx in xrange(num_user): self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx] + # When the loop finish, check if the theta_density adds up to unity for each user check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1) if any(abs(check_user_distr_marginal - 1.0) > 0.0001): @@ -365,21 +366,52 @@ def __check_theta_density(self): raise Exception('theta desnity has wrong shape (%s,%s)'%self.theta_density.shape) def __get_expect_count(self): + def update(d, start_idx, end_idx): + try: + if self.dao_type == 'db': + client = self.dao.open_conn() + item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name] + for item_idx in tqdm(xrange(start_idx, end_idx)): + if self.dao_type == 'db': + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn) + else: + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) + d[item_idx] = { + 1: np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0), + 0: np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) + } + if self.dao_type == 'db': + client.close() + except: + raise + + num_item = self.dao.get_num('item') + if num_item > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_item, num_chunk) + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + if num_chunk>1: + procs = procs_operator(procs, 5400, self.check_interval) + else: + procs = procs_operator(procs, 7200, 0.1) + self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) - num_item = self.dao.get_num('item') - if self.dao_type == 'db': - client = self.dao.open_conn() - item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name] - for item_idx in tqdm(xrange(num_item)): - if self.dao_type == 'db': - map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn) - else: - map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) - self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0) - self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) - if self.dao_type == 'db': - client.close() + for item_idx in xrange(num_item): + self.item_expected_right_by_theta[:, item_idx] = procs_repo[item_idx][1] + self.item_expected_wrong_by_theta[:, item_idx] = procs_repo[item_idx][0] + + def __calc_data_likelihood(self): # calculate the likelihood for the data set @@ -412,6 +444,7 @@ def update(tot_llk, cnt, start_idx, end_idx): tot_llk.value += ell/len(logs) with cnt.get_lock(): cnt.value += 1 + if self.dao_type == 'db': client.close() except: @@ -419,27 +452,21 @@ def update(tot_llk, cnt, start_idx, end_idx): theta_vec = self.__calc_theta() num_user = self.dao.get_num('user') - - if num_user > self.num_cpu and self.is_parallel: num_chunk = self.num_cpu else: - num_chunk = 1 - + num_chunk = 1 user_ell = mp.Value('d', 0.0) user_cnt = mp.Value('i', 0) - chunk_list = tools.cut_list(num_user, num_chunk) - + chunk_list = tools.cut_list(num_user, num_chunk) procs = [] for i in range(num_chunk): p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) - procs.append(p) - + procs.append(p) if num_chunk>1: procs = procs_operator(procs, 1200, self.check_interval) else: procs = procs_operator(procs, 7200, 0.1) - avg_ell = user_ell.value/user_cnt.value return avg_ell