Skip to content

Commit

Permalink
Merge 4009f0e into c69faa1
Browse files Browse the repository at this point in the history
  • Loading branch information
junchenfeng committed Aug 15, 2017
2 parents c69faa1 + 4009f0e commit be57a96
Show file tree
Hide file tree
Showing 8 changed files with 533 additions and 147 deletions.
58 changes: 25 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ src_fp = open(file_path,'r')
item_param, user_param = irt(src_fp)

# (2)Supply bounds
item_param, user-param = irt(src_fp, theta_bnds = [-5,5], alpha_bnds=[0.1,3], beta_bnds = [-3,3])
item_param, user-param = irt(src_fp, theta_bnds = [-4,4], alpha_bnds=[0.1,3], beta_bnds = [-3,3])

# (3)Supply guess parameter
guessParamDict = {1:{'c':0.0}, 2:{'c':0.25}}
Expand All @@ -44,41 +44,41 @@ The current version supports MMLE algorithm and unidimension two parameter
IRT model. There is a backdoor method to specify the guess parameter but there
is not active estimation.

The prior distribution of theta is **uniform**.

There is no regularization in alpha and beta estimation. Therefore, the default
algorithm uses boundary on the parameter to prevent over-fitting and deal with
extreme cases where almost all responses to the item is right.

The unit test code shows that when there are discriminative items to anchor theta, then the model converges reasonably well to the true value.

## Theta estimation
The package offers two methods to estimate theta, given item parameters: Bayesian and MLE. <br>
The estimation procedure is quite primitive. For examples, see the test case.

II.Sparse Data Structure
==========

In non-test learning dataset, missing data is the common. Not all students
finish all the items. When the number of students and items are large, the data
can be extremely sparse.
In non-test learning dataset, missing data is the common. Not all students finish all the items. When the number of students and items are large, the data can be extremely sparse.

The package deals with the sparse structure in two ways:
- Efficient memory storage. Use collapsed list to index data. The memory usage
is about 3 times of the text data file. If the workstation has 6G free
memory, it can handle 2G data file. Most other IRT package will definitely
break.

- No joint estimation. Under IRT's conditional independence assumption,
estimate each item's parameter is consistent but inefficient. To avoid
reverting a giant jacobian matrix, the item parameters are estimated seprately.
- Efficient memory storage.

Use collapsed list to index data. The memory usage is about 3 times of the text data file. If the workstation has 6G free memory, it can handle 2G data file.

If the data are truely big, say billions, a mongo DAO is implemented to stack overflow.

- No joint estimation.

Under conditional independence assumption, estimate each item's parameter is consistent but inefficient.

III.Default Config

II.Default Config
===========
## Exposed
The theta paramter range from [-4,4] and a step size of 0.8 by default.

Alpha is bounded by [0.25,2] and beta is bounded by [-2,2], which goes with the user ability
specifiation.
Alpha is bounded by [0.25,2] and beta is bounded by [-2,2]


Guess parameter is set to 0 by default, but one could supply a dictionary with eid as key and c as value.

Expand All @@ -88,37 +88,29 @@ The default solver is L-BFGS-G.

The default max iteration is 10.

The stop condition threshold is 1e-3 by default. The algorithm computes the
average likelihood per log at the end of the iteration. If the likelihood
increament is less than the threshold, stops.
The stop condition threshold is 1e-3 by default. The stop condition looks at the geometric mean of the likelihood, which is data size independent.

IV.Data Format

III.Data Format
=========
The file is expected to be comma delimited.
If the data are digested through file system. The file is expected to be **comma delimited**. The three columns are learner id, item id and result flag. Currently the model only works well with 0/1 flag but will **NOT** raise error for other types.

The three columns are uid, eid, result-flag.
If using the database DAO, see unittest for example

Currently the model only works well with 0/1 flag but will **NOT** raise error for
other types.


V.Note
IV.Note
=======


## Minimization solver
The scipy minimize is as good as cvxopt.cp and matlab fmincon on item parameter
estimation to the 6th decimal point, which can be viewed as identical for all
practical purposes.
The scipy minimize is as good as cvxopt.cp and matlab fmincon on item parameter estimation to the 6th decimal point, which can be viewed as identical for all practical purposes.

However, the convergence is pretty slow. It requires about 10k obeverations per
item to recover the parameter to the 0.01 precision.
However, the convergence is pretty slow. It requires about 10k obeverations per item to recover the parameter to the 0.01 precision.


VII.Acknowledgement
==============
The algorithm is described in details by Bradey Hanson(2000), see in the
literature section. I am grateful to Mr.Hanson's work.
The algorithm is described in details by Bradey Hanson(2000), see in the literature section. I am grateful to Mr.Hanson's work.

[Chaoqun Fu](https://github.com/fuchaoqun)'s comment leads to the (much better) API design.

Expand Down
17 changes: 13 additions & 4 deletions pyirt/_pyirt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
from .dao import localDAO

def irt(data_src,
dao_type = 'memory',
theta_bnds=[-4, 4], num_theta=11,
alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param='default',
alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param={},
model_spec='2PL',
max_iter=10, tol=1e-3, nargout=2):
max_iter=10, tol=1e-3, nargout=2,
is_msg=False,
is_parallel=False, num_cpu=6, check_interval = 60,
mode='debug'):


# load data
dao_instance = localDAO(data_src)
if dao_type=='memory':
dao_instance = localDAO(data_src)
else:
dao_instance = data_src

# setup the model
if model_spec == '2PL':
mod = model.IRT_MMLE_2PL(dao_instance)
mod = model.IRT_MMLE_2PL(dao_instance, dao_type=dao_type,
is_msg=is_msg, is_parallel=is_parallel, num_cpu=num_cpu, check_interval=check_interval, mode=mode)
else:
raise Exception('Unknown model specification.')

Expand Down
143 changes: 126 additions & 17 deletions pyirt/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,117 @@

from .util.dao import loadFromHandle, loadFromTuples, construct_ref_dict

#TODO: bitmap is a function of DAO. Seperate that with database
import pymongo

from datetime import datetime

class mongoDAO(object):

def __init__(self, connect_config, group_id=1, is_msg=False):

self.connect_config = connect_config

client = self.open_conn()
self.db_name = connect_config['db']

self.user2item_collection_name = 'irt_user2item'
self.item2user_collection_name = 'irt_item2user'

user2item_conn = client[self.db_name][self.user2item_collection_name]
item2user_conn = client[self.db_name][self.item2user_collection_name]


user_ids = list(set([x['id'] for x in user2item_conn.find({'gid':group_id},{'id':1})]))
item_ids = list(set([x['id'] for x in item2user_conn.find({'gid':group_id},{'id':1})]))

_, self.user_idx_ref, self.user_reverse_idx_ref = construct_ref_dict(user_ids)
_, self.item_idx_ref, self.item_reverse_idx_ref = construct_ref_dict(item_ids)

self.stat = {'user':len(self.user_idx_ref.keys()), 'item':len(self.item_idx_ref.keys())}

print('search idx created.')
self.gid = group_id
self.is_msg = is_msg

client.close()

def open_conn(self):

user_name = self.connect_config['user']
password = self.connect_config['password']
address = self.connect_config['address'] # IP:PORT
if 'authsource' not in self.connect_config:
mongouri = 'mongodb://{un}:{pw}@{addr}'.format(un=user_name, pw=password, addr=address)
else:
authsource = self.connect_config['authsource']
mongouri = 'mongodb://{un}:{pw}@{addr}/?authsource={auth_src}'.format(un=user_name, pw=password, addr=address, auth_src=authsource)
try:
client = pymongo.MongoClient(mongouri, connect=False, serverSelectionTimeoutMS=10, waitQueueTimeoutMS=100 ,readPreference='secondaryPreferred')
except:
raise
return client


def get_num(self, name):
if name not in ['user','item']:
raise Exception('Unknown stat source %s'%name)
return self.stat[name]

def get_log(self, user_idx, user2item_conn):
user_id = self.translate('user', user_idx)
# query
if self.is_msg:
stime = datetime.now()
res = user2item_conn.find({'id':user_id, 'gid':self.gid})
etime = datetime.now()
search_time = int((etime-stime).microseconds/1000)
if search_time > 100:
print('s:%d' % search_time)
else:
res = user2item_conn.find({'id':user_id, 'gid':self.gid})
# parse
if res.count() == 0:
return_list = []
elif res.count() > 1:
raise Exception('duplicate doc for (%s, %d) in user2item' % (user_id, self.gid))
else:
log_list = res[0]['data']
return_list = [(self.item_idx_ref[x[0]], x[1]) for x in log_list]
return return_list

def get_map(self, item_idx, ans_key_list, item2user_conn):
item_id = self.translate('item', item_idx)
# query
if self.is_msg:
stime = datetime.now()
res = item2user_conn.find({'id':item_id, 'gid':self.gid})
etime = datetime.now()
search_time = int((etime-stime).microseconds/1000)
if search_time > 100:
print('s:%d' % search_time)
else:
res = item2user_conn.find({'id':item_id, 'gid':self.gid})
# parse
if res.count() == 0:
return_list = [[] for ans_key in ans_key_list]
elif res.count() > 1:
raise Exception('duplicate doc for (%s, %d) in item2user' % (item_id, self.gid))
else:
doc = res[0]['data']
return_list = []
for ans_key in ans_key_list:
if str(ans_key) in doc:
return_list.append([self.user_idx_ref[x] for x in doc[str(ans_key)]] )
else:
return_list.append([])
return return_list

def translate(self, data_type, idx):
if data_type == 'item':
return self.item_reverse_idx_ref[idx]
elif data_type == 'user':
return self.user_reverse_idx_ref[idx]

class localDAO(object):

def __init__(self, src):
Expand All @@ -21,18 +130,19 @@ def __init__(self, src):
self.database.setup(user_id_idx_vec, item_id_idx_vec, self.database.ans_tags)

def get_num(self, name):
if name not in ['user','item','log']:
if name not in ['user','item']:
raise Exception('Unknown stat source %s'%name)
return self.database.stat[name]

def get_log(self, user_idx):
return self.database.user2item[user_idx]

def get_right_map(self, item_idx):
return self.database.right_map[item_idx]

def get_wrong_map(self, item_idx):
return self.database.wrong_map[item_idx]
def get_map(self, item_idx, ans_key_list):
# NOTE: return empty list for invalid ans key
return [self.database.item2user_map[str(ans_key)][item_idx] for ans_key in ans_key_list]

def close_conn(self):
pass

def translate(self, data_type, idx):
if data_type == 'item':
Expand Down Expand Up @@ -60,7 +170,7 @@ def setup(self, user_idx_vec, item_idx_vec, ans_tags, msg=False):

# initialize some intermediate variables used in the E step
start_time = time.time()
self._init_right_wrong_map()
self._init_item2user_map()
if msg:
print("--- Sparse Mapping: %f secs ---" % np.round((time.time() - start_time)))

Expand All @@ -75,28 +185,27 @@ def _process_data(self, user_idx_vec, item_idx_vec, ans_tags):
self.user2item = defaultdict(list)

self.stat = {}
self.stat['log'] = len(user_idx_vec)
num_log = len(user_idx_vec)
self.stat['user'] = max(user_idx_vec)+1 # start count from 0
self.stat['item'] = max(item_idx_vec)+1

for i in range(self.stat['log']):
for i in range(num_log):
item_idx = item_idx_vec[i]
user_idx = user_idx_vec[i]
ans_tag = ans_tags[i]
# add to the data dictionary
self.item2user[item_idx].append((user_idx, ans_tag))
self.user2item[user_idx].append((item_idx, ans_tag))

def _init_right_wrong_map(self):
self.right_map = defaultdict(list)
self.wrong_map = defaultdict(list)
def _init_item2user_map(self, ans_key_list = ['0','1']):

self.item2user_map = {}
for ans_key in ans_key_list:
self.item2user_map[ans_key] = defaultdict(list)

for item_idx, log_result in self.item2user.items():
for log in log_result:
ans_tag = log[1]
user_idx = log[0]
if ans_tag == 1:
self.right_map[item_idx].append(user_idx)
else:
self.wrong_map[item_idx].append(user_idx)
self.item2user_map[str(ans_tag)][item_idx].append(user_idx)

0 comments on commit be57a96

Please sign in to comment.