diff --git a/README.md b/README.md index 9af865b..816e5f5 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ src_fp = open(file_path,'r') item_param, user_param = irt(src_fp) # (2)Supply bounds -item_param, user-param = irt(src_fp, theta_bnds = [-5,5], alpha_bnds=[0.1,3], beta_bnds = [-3,3]) +item_param, user-param = irt(src_fp, theta_bnds = [-4,4], alpha_bnds=[0.1,3], beta_bnds = [-3,3]) # (3)Supply guess parameter guessParamDict = {1:{'c':0.0}, 2:{'c':0.25}} @@ -44,12 +44,12 @@ The current version supports MMLE algorithm and unidimension two parameter IRT model. There is a backdoor method to specify the guess parameter but there is not active estimation. -The prior distribution of theta is **uniform**. - There is no regularization in alpha and beta estimation. Therefore, the default algorithm uses boundary on the parameter to prevent over-fitting and deal with extreme cases where almost all responses to the item is right. +The unit test code shows that when there are discriminative items to anchor theta, then the model converges reasonably well to the true value. + ## Theta estimation The package offers two methods to estimate theta, given item parameters: Bayesian and MLE.
The estimation procedure is quite primitive. For examples, see the test case. @@ -57,28 +57,28 @@ The estimation procedure is quite primitive. For examples, see the test case. II.Sparse Data Structure ========== -In non-test learning dataset, missing data is the common. Not all students -finish all the items. When the number of students and items are large, the data -can be extremely sparse. +In non-test learning dataset, missing data is the common. Not all students finish all the items. When the number of students and items are large, the data can be extremely sparse. The package deals with the sparse structure in two ways: -- Efficient memory storage. Use collapsed list to index data. The memory usage - is about 3 times of the text data file. If the workstation has 6G free -memory, it can handle 2G data file. Most other IRT package will definitely -break. -- No joint estimation. Under IRT's conditional independence assumption, - estimate each item's parameter is consistent but inefficient. To avoid -reverting a giant jacobian matrix, the item parameters are estimated seprately. +- Efficient memory storage. + +Use collapsed list to index data. The memory usage is about 3 times of the text data file. If the workstation has 6G free memory, it can handle 2G data file. + +If the data are truely big, say billions, a mongo DAO is implemented to stack overflow. + +- No joint estimation. +Under conditional independence assumption, estimate each item's parameter is consistent but inefficient. -III.Default Config + +II.Default Config =========== ## Exposed The theta paramter range from [-4,4] and a step size of 0.8 by default. -Alpha is bounded by [0.25,2] and beta is bounded by [-2,2], which goes with the user ability -specifiation. +Alpha is bounded by [0.25,2] and beta is bounded by [-2,2] + Guess parameter is set to 0 by default, but one could supply a dictionary with eid as key and c as value. @@ -88,37 +88,29 @@ The default solver is L-BFGS-G. The default max iteration is 10. -The stop condition threshold is 1e-3 by default. The algorithm computes the -average likelihood per log at the end of the iteration. If the likelihood -increament is less than the threshold, stops. +The stop condition threshold is 1e-3 by default. The stop condition looks at the geometric mean of the likelihood, which is data size independent. -IV.Data Format + +III.Data Format ========= -The file is expected to be comma delimited. +If the data are digested through file system. The file is expected to be **comma delimited**. The three columns are learner id, item id and result flag. Currently the model only works well with 0/1 flag but will **NOT** raise error for other types. -The three columns are uid, eid, result-flag. +If using the database DAO, see unittest for example -Currently the model only works well with 0/1 flag but will **NOT** raise error for -other types. -V.Note +IV.Note ======= - ## Minimization solver -The scipy minimize is as good as cvxopt.cp and matlab fmincon on item parameter -estimation to the 6th decimal point, which can be viewed as identical for all -practical purposes. +The scipy minimize is as good as cvxopt.cp and matlab fmincon on item parameter estimation to the 6th decimal point, which can be viewed as identical for all practical purposes. -However, the convergence is pretty slow. It requires about 10k obeverations per -item to recover the parameter to the 0.01 precision. +However, the convergence is pretty slow. It requires about 10k obeverations per item to recover the parameter to the 0.01 precision. VII.Acknowledgement ============== -The algorithm is described in details by Bradey Hanson(2000), see in the -literature section. I am grateful to Mr.Hanson's work. +The algorithm is described in details by Bradey Hanson(2000), see in the literature section. I am grateful to Mr.Hanson's work. [Chaoqun Fu](https://github.com/fuchaoqun)'s comment leads to the (much better) API design. diff --git a/pyirt/_pyirt.py b/pyirt/_pyirt.py index 8708643..a497c60 100644 --- a/pyirt/_pyirt.py +++ b/pyirt/_pyirt.py @@ -3,17 +3,26 @@ from .dao import localDAO def irt(data_src, + dao_type = 'memory', theta_bnds=[-4, 4], num_theta=11, - alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param='default', + alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param={}, model_spec='2PL', - max_iter=10, tol=1e-3, nargout=2): + max_iter=10, tol=1e-3, nargout=2, + is_msg=False, + is_parallel=False, num_cpu=6, check_interval = 60, + mode='debug'): # load data - dao_instance = localDAO(data_src) + if dao_type=='memory': + dao_instance = localDAO(data_src) + else: + dao_instance = data_src + # setup the model if model_spec == '2PL': - mod = model.IRT_MMLE_2PL(dao_instance) + mod = model.IRT_MMLE_2PL(dao_instance, dao_type=dao_type, + is_msg=is_msg, is_parallel=is_parallel, num_cpu=num_cpu, check_interval=check_interval, mode=mode) else: raise Exception('Unknown model specification.') diff --git a/pyirt/dao.py b/pyirt/dao.py index b1ead4b..bb9a795 100644 --- a/pyirt/dao.py +++ b/pyirt/dao.py @@ -6,8 +6,117 @@ from .util.dao import loadFromHandle, loadFromTuples, construct_ref_dict -#TODO: bitmap is a function of DAO. Seperate that with database +import pymongo +from datetime import datetime + +class mongoDAO(object): + + def __init__(self, connect_config, group_id=1, is_msg=False): + + self.connect_config = connect_config + + client = self.open_conn() + self.db_name = connect_config['db'] + + self.user2item_collection_name = 'irt_user2item' + self.item2user_collection_name = 'irt_item2user' + + user2item_conn = client[self.db_name][self.user2item_collection_name] + item2user_conn = client[self.db_name][self.item2user_collection_name] + + + user_ids = list(set([x['id'] for x in user2item_conn.find({'gid':group_id},{'id':1})])) + item_ids = list(set([x['id'] for x in item2user_conn.find({'gid':group_id},{'id':1})])) + + _, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(user_ids) + _, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(item_ids) + + self.stat = {'user':len(self.user_idx_ref.keys()), 'item':len(self.item_idx_ref.keys())} + + print('search idx created.') + self.gid = group_id + self.is_msg = is_msg + + client.close() + + def open_conn(self): + + user_name = self.connect_config['user'] + password = self.connect_config['password'] + address = self.connect_config['address'] # IP:PORT + if 'authsource' not in self.connect_config: + mongouri = 'mongodb://{un}:{pw}@{addr}'.format(un=user_name, pw=password, addr=address) + else: + authsource = self.connect_config['authsource'] + mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource) + try: + client = pymongo.MongoClient(mongouri, connect=False, serverSelectionTimeoutMS=10, waitQueueTimeoutMS=100 ,readPreference='secondaryPreferred') + except: + raise + return client + + + def get_num(self, name): + if name not in ['user','item']: + raise Exception('Unknown stat source %s'%name) + return self.stat[name] + + def get_log(self, user_idx, user2item_conn): + user_id = self.translate('user', user_idx) + # query + if self.is_msg: + stime = datetime.now() + res = user2item_conn.find({'id':user_id, 'gid':self.gid}) + etime = datetime.now() + search_time = int((etime-stime).microseconds/1000) + if search_time > 100: + print('s:%d' % search_time) + else: + res = user2item_conn.find({'id':user_id, 'gid':self.gid}) + # parse + if res.count() == 0: + return_list = [] + elif res.count() > 1: + raise Exception('duplicate doc for (%s, %d) in user2item' % (user_id, self.gid)) + else: + log_list = res[0]['data'] + return_list = [(self.item_idx_ref[x[0]], x[1]) for x in log_list] + return return_list + + def get_map(self, item_idx, ans_key_list, item2user_conn): + item_id = self.translate('item', item_idx) + # query + if self.is_msg: + stime = datetime.now() + res = item2user_conn.find({'id':item_id, 'gid':self.gid}) + etime = datetime.now() + search_time = int((etime-stime).microseconds/1000) + if search_time > 100: + print('s:%d' % search_time) + else: + res = item2user_conn.find({'id':item_id, 'gid':self.gid}) + # parse + if res.count() == 0: + return_list = [[] for ans_key in ans_key_list] + elif res.count() > 1: + raise Exception('duplicate doc for (%s, %d) in item2user' % (item_id, self.gid)) + else: + doc = res[0]['data'] + return_list = [] + for ans_key in ans_key_list: + if str(ans_key) in doc: + return_list.append([self.user_idx_ref[x] for x in doc[str(ans_key)]] ) + else: + return_list.append([]) + return return_list + + def translate(self, data_type, idx): + if data_type == 'item': + return self.item_reverse_idx_ref[idx] + elif data_type == 'user': + return self.user_reverse_idx_ref[idx] + class localDAO(object): def __init__(self, src): @@ -21,18 +130,19 @@ def __init__(self, src): self.database.setup(user_id_idx_vec, item_id_idx_vec, self.database.ans_tags) def get_num(self, name): - if name not in ['user','item','log']: + if name not in ['user','item']: raise Exception('Unknown stat source %s'%name) return self.database.stat[name] def get_log(self, user_idx): return self.database.user2item[user_idx] - def get_right_map(self, item_idx): - return self.database.right_map[item_idx] - - def get_wrong_map(self, item_idx): - return self.database.wrong_map[item_idx] + def get_map(self, item_idx, ans_key_list): + # NOTE: return empty list for invalid ans key + return [self.database.item2user_map[str(ans_key)][item_idx] for ans_key in ans_key_list] + + def close_conn(self): + pass def translate(self, data_type, idx): if data_type == 'item': @@ -60,7 +170,7 @@ def setup(self, user_idx_vec, item_idx_vec, ans_tags, msg=False): # initialize some intermediate variables used in the E step start_time = time.time() - self._init_right_wrong_map() + self._init_item2user_map() if msg: print("--- Sparse Mapping: %f secs ---" % np.round((time.time() - start_time))) @@ -75,11 +185,11 @@ def _process_data(self, user_idx_vec, item_idx_vec, ans_tags): self.user2item = defaultdict(list) self.stat = {} - self.stat['log'] = len(user_idx_vec) + num_log = len(user_idx_vec) self.stat['user'] = max(user_idx_vec)+1 # start count from 0 self.stat['item'] = max(item_idx_vec)+1 - for i in range(self.stat['log']): + for i in range(num_log): item_idx = item_idx_vec[i] user_idx = user_idx_vec[i] ans_tag = ans_tags[i] @@ -87,16 +197,15 @@ def _process_data(self, user_idx_vec, item_idx_vec, ans_tags): self.item2user[item_idx].append((user_idx, ans_tag)) self.user2item[user_idx].append((item_idx, ans_tag)) - def _init_right_wrong_map(self): - self.right_map = defaultdict(list) - self.wrong_map = defaultdict(list) + def _init_item2user_map(self, ans_key_list = ['0','1']): + + self.item2user_map = {} + for ans_key in ans_key_list: + self.item2user_map[ans_key] = defaultdict(list) for item_idx, log_result in self.item2user.items(): for log in log_result: ans_tag = log[1] user_idx = log[0] - if ans_tag == 1: - self.right_map[item_idx].append(user_idx) - else: - self.wrong_map[item_idx].append(user_idx) + self.item2user_map[str(ans_tag)][item_idx].append(user_idx) diff --git a/pyirt/solver/model.py b/pyirt/solver/model.py index 0e11b74..5192c0f 100644 --- a/pyirt/solver/model.py +++ b/pyirt/solver/model.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python -W ignore ''' The model is an implementation of EM algorithm of IRT @@ -14,10 +15,37 @@ from copy import deepcopy from six import string_types + from ..util import clib, tools from ..solver import optimizer from ..algo import update_theta_distribution +from datetime import datetime +import multiprocessing as mp +from tqdm import tqdm +import time + + + +def procs_operator(procs, TIMEOUT, check_interval): + for p in procs: + p.start() + + start = time.time() + while time.time() - start < TIMEOUT: + if any(p.is_alive() for p in procs): + time.sleep(check_interval) + else: + for p in procs: + p.join() + break + else: + for p in procs: + p.terminate() + p.join() + + raise Exception('Time out, killing all process') + return procs class IRT_MMLE_2PL(object): @@ -27,14 +55,24 @@ class IRT_MMLE_2PL(object): (2) solve (3) get esitmated result ''' - def __init__(self, dao_instance, is_msg=False): + def __init__(self, + dao_instance, dao_type='memory', + is_msg=False, + is_parallel=False, num_cpu=6, check_interval=60, + mode='debug'): # interface to data self.dao=dao_instance - self.is_msg = is_msg + self.dao_type = dao_type self.num_iter = 1 self.ell_list = [] self.last_avg_prob = 0 + self.is_msg = is_msg + self.is_parallel = is_parallel + self.num_cpu = min(num_cpu, mp.cpu_count()) + self.check_interval = check_interval + self.mode = mode + def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, tol): # user self.num_theta = num_theta @@ -49,36 +87,49 @@ def set_options(self, theta_bnds, num_theta, alpha_bnds, beta_bnds, max_iter, to def set_guess_param(self, in_guess_param): self.guess_param_dict = {} - if isinstance(in_guess_param, string_types): - for item_idx in range(self.dao.get_num('item')): - self.guess_param_dict[item_idx] = {'c': 0.0} # default set to 0 - else: - for item_idx in range(self.dao.get_num('item')): - item_id = self.dao.translate('item', item_idx) - self.guess_param_dict[item_idx] = in_guess_param[item_id] - + for item_idx in xrange(self.dao.get_num('item')): + item_id = self.dao.translate('item', item_idx) + if item_id in in_guess_param: + self.guess_param_dict[item_idx] = {'c': float(in_guess_param[item_id])} + else: + self.guess_param_dict[item_idx] = {'c':0.0} # if null then set as 0 def solve_EM(self): # data dependent initialization self._init_item_param() # main routine - while True: + while True: #----- E step ----- + stime = datetime.now() self._exp_step() + etime = datetime.now() + if self.is_msg: + runtime = (etime-stime).microseconds / 1000 + print('E step runs for %s sec' %runtime) #----- M step ----- + stime = datetime.now() self._max_step() - + etime = datetime.now() + if self.is_msg: + runtime = (etime-stime).microseconds / 1000 + print('M step runs for %s sec' %runtime) + # ---- Stop Condition ---- + stime = datetime.now() is_stop = self._check_stop() + etime = datetime.now() + if self.is_msg: + runtime = (etime-stime).microseconds / 1000 + print('stop condition runs for %s sec' %runtime) if is_stop: break def get_item_param(self): output_item_param = {} - for item_idx in range(self.dao.get_num('item')): + for item_idx in xrange(self.dao.get_num('item')): item_id = self.dao.translate('item', item_idx) output_item_param[item_id] = self.item_param_dict[item_idx] return output_item_param @@ -86,7 +137,7 @@ def get_item_param(self): def get_user_param(self): output_user_param = {} theta_vec = self.__calc_theta() - for user_idx in range(self.dao.get_num('user')): + for user_idx in xrange(self.dao.get_num('user')): user_id = self.dao.translate('user', user_idx) output_user_param[user_id] = theta_vec[user_idx] return output_user_param @@ -106,7 +157,6 @@ def _exp_step(self): (1-data_[i,j]) *(1-P(Y=1|param_j,theta_[i,k]) ) By similar logic, it is equivalent to sum (1-p) for all done wrong users - ''' # (1) update the posterior distribution of theta @@ -122,6 +172,32 @@ def _max_step(self): Basic Math log likelihood(param_j) = sum_k(log likelihood(param_j, theta_k)) ''' + def update(d, start_idx, end_idx): + try: + for item_idx in tqdm(xrange(start_idx, end_idx)): + initial_guess_val = (self.item_param_dict[item_idx]['beta'], + self.item_param_dict[item_idx]['alpha']) + opt_worker.set_initial_guess(initial_guess_val) # the value is a mix of 1/0 and current estimate + opt_worker.set_c(self.item_param_dict[item_idx]['c']) + + # estimate + expected_right_count = self.item_expected_right_by_theta[:, item_idx] + expected_wrong_count = self.item_expected_wrong_by_theta[:, item_idx] + input_data = [expected_right_count, expected_wrong_count] + opt_worker.load_res_data(input_data) + try: + est_param = opt_worker.solve_param_mix(self.is_constrained) + except Exception as e: + if self.mode=='production': + # In production mode, use the previous iteration + print('Item %d does not fit'%item_idx) + d[item_idx] = self.item_param_dict[item_idx] + else: + raise e + finally: + d[item_idx] = est_param + except: + raise # [A] max for item parameter opt_worker = optimizer.irt_2PL_Optimizer() # the boundary is universal @@ -131,28 +207,38 @@ def _max_step(self): # theta value is universal opt_worker.set_theta(self.theta_prior_val) + num_item = self.dao.get_num('item') + + + if num_item > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 - for item_idx in range(self.dao.get_num('item')): - # set the initial guess as a mixture of current value and a new - # start to avoid trap in local maximum - initial_guess_val = (self.item_param_dict[item_idx]['beta'], - self.item_param_dict[item_idx]['alpha']) + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_item, num_chunk) + + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) - opt_worker.set_initial_guess(initial_guess_val) - opt_worker.set_c(self.item_param_dict[item_idx]['c']) + if num_chunk > 1: + procs = procs_operator(procs, 3600, self.check_interval) + else: + procs = procs_operator(procs, 7200, 0.1) - # assemble the expected data - expected_right_count = self.item_expected_right_by_theta[:, item_idx] - expected_wrong_count = self.item_expected_wrong_by_theta[:, item_idx] - input_data = [expected_right_count, expected_wrong_count] - opt_worker.load_res_data(input_data) - # if one wishes to inspect the model input, print the input data + for item_idx in xrange(num_item): + self.item_param_dict[item_idx] = { + 'beta':procs_repo[item_idx][0], + 'alpha':procs_repo[item_idx][1], + 'c':self.item_param_dict[item_idx]['c'] + } - est_param = opt_worker.solve_param_mix(self.is_constrained) - # update - self.item_param_dict[item_idx]['beta'] = est_param[0] - self.item_param_dict[item_idx]['alpha'] = est_param[1] # [B] max for theta density self.theta_density = self.posterior_theta_distr.sum(axis=0)/self.posterior_theta_distr.sum() @@ -162,36 +248,36 @@ def _check_stop(self): ''' preserve user and item parameter from last iteration. This is useful in restoring after a declining llk iteration ''' - avg_prob = np.exp(self.__calc_data_likelihood() / self.dao.get_num('log')) + if self.is_msg: print('score calculating') + avg_prob = np.exp(self.__calc_data_likelihood()) + if self.is_msg: print('score calculated.') + self.ell_list.append(avg_prob) if self.is_msg: print(avg_prob) - if self.last_avg_prob < avg_prob and avg_prob - self.last_avg_prob <= self.tol: + diff = avg_prob - self.last_avg_prob + + if diff >= 0 and diff <= self.tol: print('EM converged at iteration %d.' % self.num_iter) - return True - - # if the algorithm improves, then ell > ell_t0 - if self.last_avg_prob > avg_prob: + return True + elif diff <0: self.item_param_dict = self.last_item_param_dict print('Likelihood descrease, stops at iteration %d.' % self.num_iter) return True - - # update the stop condition - self.last_avg_prob = avg_prob - self.num_iter += 1 - - if (self.num_iter > self.max_iter): - print('EM does not converge within max iteration') - return True - - if self.num_iter != 1: - self.last_item_param_dict = self.item_param_dict - - return False - - - def _init_solver_param(self, is_constrained, boundary, - solver_type, max_iter, tol): + else: + # diff larger than tolerance + # update the stop condition + self.last_avg_prob = avg_prob + self.num_iter += 1 + + if (self.num_iter > self.max_iter): + print('EM does not converge within max iteration') + return True + if self.num_iter != 1: + self.last_item_param_dict = self.item_param_dict + return False + + def _init_solver_param(self, is_constrained, boundary, solver_type, max_iter, tol): # initialize bounds self.is_constrained = is_constrained self.alpha_bnds = boundary['alpha'] @@ -199,13 +285,12 @@ def _init_solver_param(self, is_constrained, boundary, self.solver_type = solver_type self.max_iter = max_iter self.tol = tol - if solver_type == 'gradient' and not is_constrained: raise Exception('BFGS has to be constrained') def _init_item_param(self): self.item_param_dict = {} - for item_idx in range(self.dao.get_num('item')): + for item_idx in xrange(self.dao.get_num('item')): # need to call the old item_id c = self.guess_param_dict[item_idx]['c'] self.item_param_dict[item_idx] = {'alpha': 1.0, 'beta': 0.0, 'c': c} @@ -215,7 +300,6 @@ def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): self.theta_prior_val = np.linspace(theta_min, theta_max, num=num_theta) if self.num_theta != len(self.theta_prior_val): raise Exception('wrong number of inintial theta values') - # use a normal approximation if dist == 'uniform': self.theta_density = np.ones(num_theta) / num_theta @@ -230,11 +314,45 @@ def _init_user_param(self, theta_min, theta_max, num_theta, dist='normal'): self.posterior_theta_distr = np.zeros((self.dao.get_num('user'), num_theta)) def __update_theta_distr(self): + def update(d, start_idx, end_idx): + try: + if self.dao_type=='db': + client = self.dao.open_conn() + user2item_conn = client[self.dao.db_name][self.dao.user2item_collection_name] + for user_idx in tqdm(xrange(start_idx, end_idx)): + if self.dao_type=='db': + logs = self.dao.get_log(user_idx, user2item_conn) + else: + logs = self.dao.get_log(user_idx) + d[user_idx] = update_theta_distribution(logs, self.num_theta, self.theta_prior_val, self.theta_density, self.item_param_dict) + if self.dao_type=='db': + client.close() + except: + raise + # [A] calculate p(data,param|theta) + num_user = self.dao.get_num('user') + if num_user > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + # [A] calculate p(data,param|theta) - for user_idx in range(self.dao.get_num('user')): - self.posterior_theta_distr[user_idx, :] = update_theta_distribution(self.dao.get_log(user_idx), - self.num_theta, self.theta_prior_val, self.theta_density, - self.item_param_dict) + chunk_list = tools.cut_list(num_user, num_chunk) + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + if num_chunk>1: + procs = procs_operator(procs, 5400, self.check_interval) + else: + procs = procs_operator(procs, 7200, 0.1) + + for user_idx in xrange(num_user): + self.posterior_theta_distr[user_idx,:] = procs_repo[user_idx] + # When the loop finish, check if the theta_density adds up to unity for each user check_user_distr_marginal = np.sum(self.posterior_theta_distr, axis=1) if any(abs(check_user_distr_marginal - 1.0) > 0.0001): @@ -248,34 +366,111 @@ def __check_theta_density(self): raise Exception('theta desnity has wrong shape (%s,%s)'%self.theta_density.shape) def __get_expect_count(self): + def update(d, start_idx, end_idx): + try: + if self.dao_type == 'db': + client = self.dao.open_conn() + item2user_conn = client[self.dao.db_name][self.dao.item2user_collection_name] + for item_idx in tqdm(xrange(start_idx, end_idx)): + if self.dao_type == 'db': + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0'], item2user_conn) + else: + map_user_idx_vec = self.dao.get_map(item_idx, ['1','0']) + d[item_idx] = { + 1: np.sum(self.posterior_theta_distr[map_user_idx_vec[0], :], axis=0), + 0: np.sum(self.posterior_theta_distr[map_user_idx_vec[1], :], axis=0) + } + if self.dao_type == 'db': + client.close() + except: + raise + + num_item = self.dao.get_num('item') + if num_item > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + # [A] calculate p(data,param|theta) + chunk_list = tools.cut_list(num_item, num_chunk) + procs = [] + manager = mp.Manager() + procs_repo = manager.dict() + for i in range(num_chunk): + p = mp.Process(target=update, args=(procs_repo, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + + if num_chunk>1: + procs = procs_operator(procs, 5400, self.check_interval) + else: + procs = procs_operator(procs, 7200, 0.1) + self.item_expected_right_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) self.item_expected_wrong_by_theta = np.zeros((self.num_theta, self.dao.get_num('item'))) + for item_idx in xrange(num_item): + self.item_expected_right_by_theta[:, item_idx] = procs_repo[item_idx][1] + self.item_expected_wrong_by_theta[:, item_idx] = procs_repo[item_idx][0] + - for item_idx in range(self.dao.get_num('item')): - right_user_idx_vec = self.dao.get_right_map(item_idx) - wrong_user_idx_vec = self.dao.get_wrong_map(item_idx) - self.item_expected_right_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[right_user_idx_vec, :], axis=0) - self.item_expected_wrong_by_theta[:, item_idx] = np.sum(self.posterior_theta_distr[wrong_user_idx_vec, :], axis=0) def __calc_data_likelihood(self): # calculate the likelihood for the data set - + # geometric within learner and across learner + #1/N * sum[i](1/Ni *sum[j] (log pij)) + def update(tot_llk, cnt, start_idx, end_idx): + try: + if self.dao_type == 'db': + client = self.dao.open_conn() + user2item_conn = client[self.dao.db_name][self.dao.user2item_collection_name] + + for user_idx in tqdm(xrange(start_idx, end_idx)): + theta = theta_vec[user_idx] + # find all the item_id + if self.dao_type == 'db': + logs = self.dao.get_log(user_idx, user2item_conn) + else: + logs = self.dao.get_log(user_idx) + if len(logs) == 0: + continue + ell = 0 + for log in logs: + item_idx = log[0] + ans_tag = log[1] + alpha = self.item_param_dict[item_idx]['alpha'] + beta = self.item_param_dict[item_idx]['beta'] + c = self.item_param_dict[item_idx]['c'] + ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) + with tot_llk.get_lock(): + tot_llk.value += ell/len(logs) + with cnt.get_lock(): + cnt.value += 1 + + if self.dao_type == 'db': + client.close() + except: + raise + theta_vec = self.__calc_theta() - ell = 0 - for user_idx in range(self.dao.get_num('user')): - theta = theta_vec[user_idx] - # find all the item_id - logs = self.dao.get_log(user_idx) - for log in logs: - item_idx = log[0] - ans_tag = log[1] - alpha = self.item_param_dict[item_idx]['alpha'] - beta = self.item_param_dict[item_idx]['beta'] - c = self.item_param_dict[item_idx]['c'] - ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0-ans_tag, theta, alpha, beta, c) - return ell + num_user = self.dao.get_num('user') + if num_user > self.num_cpu and self.is_parallel: + num_chunk = self.num_cpu + else: + num_chunk = 1 + user_ell = mp.Value('d', 0.0) + user_cnt = mp.Value('i', 0) + chunk_list = tools.cut_list(num_user, num_chunk) + procs = [] + for i in range(num_chunk): + p = mp.Process(target=update, args=(user_ell, user_cnt, chunk_list[i][0],chunk_list[i][1],)) + procs.append(p) + if num_chunk>1: + procs = procs_operator(procs, 1200, self.check_interval) + else: + procs = procs_operator(procs, 7200, 0.1) + avg_ell = user_ell.value/user_cnt.value + + return avg_ell def __calc_theta(self): return np.dot(self.posterior_theta_distr, self.theta_prior_val) - + diff --git a/pyirt/util/tools.py b/pyirt/util/tools.py index 5d47d58..e523e97 100644 --- a/pyirt/util/tools.py +++ b/pyirt/util/tools.py @@ -38,3 +38,13 @@ def logsum(logp): w = max(logp) logSump = w + np.log(sum(np.exp(logp - w))) return logSump + + + +def cut_list(list_length, num_chunk): + chunk_bnd = [0] + for i in range(num_chunk): + chunk_bnd.append(int(list_length*(i+1)/num_chunk)) + chunk_bnd.append(list_length) + chunk_list = [(chunk_bnd[i], chunk_bnd[i+1]) for i in range(num_chunk) ] + return chunk_list diff --git a/setup.py b/setup.py index 0db1931..02229df 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,9 @@ install_requires=['numpy', 'scipy', 'cython', - 'six'], + 'six', + 'pymongo', + 'tqdm'], package_data={'pyirt': ["*.pyx"]}, diff --git a/test/test_model_wrapper.py b/test/test_model_wrapper.py index 746e36a..ee21576 100644 --- a/test/test_model_wrapper.py +++ b/test/test_model_wrapper.py @@ -13,7 +13,6 @@ alpha = [0.5, 0.5, 0.5, 1, 1, 1, 2, 2, 2] beta = [0, 1, -1, 0, 1, -1, 0, 1, -1] c = [0.5, 0, 0, 0, 0.5,0, 0, 0, 0.5] -item_ids = ['a', 'b', 'c', 'd', 'e','g','h','i','j'] N = 1000 T = len(alpha) @@ -21,7 +20,9 @@ guess_param = {} for t in range(T): - guess_param[item_ids[t]]={'c':c[t]} + guess_param['q%d'%t]=c[t] + + class Test2PLSolver(unittest.TestCase): @classmethod def setUpClass(cls): @@ -32,17 +33,38 @@ def setUpClass(cls): for i in range(N): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t]) - cls.data.append((i, item_ids[t], np.random.binomial(1, prob))) - + cls.data.append(('u%d'%i, 'q%d'%t, np.random.binomial(1, prob))) def test_2pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) for t in range(T): - item_id = item_ids[t] + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id != 'q6': + self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) + self.assertTrue(abs(mdl_beta - beta[t])<0.16) + + def test_2pl_solver_production(self): + item_param, user_param = irt(self.data, mode='production',theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30) + for t in range(T): + item_id = 'q%d'%t print(item_id, item_param[item_id]) mdl_alpha = item_param[item_id]['alpha'] mdl_beta = item_param[item_id]['beta'] - if item_id != 'h': + if item_id != 'q6': + self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) + self.assertTrue(abs(mdl_beta - beta[t])<0.16) + + def test_2pl_solver_parallel(self): + item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], tol=1e-5, max_iter=30, is_parallel=True, check_interval=0.1) + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id != 'q6': self.assertTrue(abs(mdl_alpha - alpha[t])<0.37) self.assertTrue(abs(mdl_beta - beta[t])<0.16) @@ -56,21 +78,53 @@ def setUpClass(cls): for i in range(N): for t in range(T): prob = irt_fnc(thetas[i,0], beta[t], alpha[t], c[t]) - cls.data.append((i, item_ids[t] ,np.random.binomial(1,prob))) + cls.data.append(('u%d'%i, 'q%d'%t ,np.random.binomial(1,prob))) def test_3pl_solver(self): item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], in_guess_param=guess_param, tol=1e-5, max_iter=30) for t in range(T): - item_id = item_ids[t] + item_id = 'q%d'%t print(item_id, item_param[item_id]) mdl_alpha = item_param[item_id]['alpha'] mdl_beta = item_param[item_id]['beta'] - if item_id not in ['h','i']: + if item_id not in ['q6','q7']: self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) - if item_id != 'j': + if item_id != 'q8': self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + def test_3pl_solver_production(self): + item_param, user_param = irt(self.data, mode='production', theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], + in_guess_param=guess_param, tol=1e-5, max_iter=30) + + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id not in ['q6','q7']: + self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) + if item_id != 'q8': + self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + + def test_3pl_solver_parallel(self): + item_param, user_param = irt(self.data, theta_bnds=[-theta_range/2,theta_range/2], num_theta=11, alpha_bnds=[0.25,3], beta_bnds=[-3,3], + in_guess_param=guess_param, tol=1e-5, max_iter=30, is_parallel=True, check_interval=0.1) + + for t in range(T): + item_id = 'q%d'%t + print(item_id, item_param[item_id]) + mdl_alpha = item_param[item_id]['alpha'] + mdl_beta = item_param[item_id]['beta'] + if item_id not in ['q6','q7']: + self.assertTrue(abs(mdl_alpha - alpha[t])<0.25) + if item_id != 'q8': + self.assertTrue(abs(mdl_beta - beta[t])<0.15) + + + if __name__ == '__main__': unittest.main() diff --git a/test/test_tools.py b/test/test_tools.py index 66904be..9b434f9 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -142,6 +142,21 @@ def test_log_factor_hessian(self): self.assertTrue(abs(calc_hessian - true_hessian_approx_theta) < 1e-4) +class TestCutList(unittest.TestCase): + def test_no_mod(self): + test_chunks = tools.cut_list(100,4) + true_chunks = [(0,25),(25,50),(50,75),(75,100)] + for i in range(4): + self.assertTrue(test_chunks[i]==true_chunks[i]) + + def test_mod(self): + test_chunks = tools.cut_list(23,4) + true_chunks = [(0,5),(5,11),(11,17),(17,23)] + for i in range(4): + self.assertTrue(test_chunks[i]==true_chunks[i]) + + + if __name__ == '__main__': unittest.main()