-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from junchenfeng/master
Fix convergence bug. Add DAO module. Add CI.
- Loading branch information
Showing
25 changed files
with
675 additions
and
1,127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,2 @@ | ||
__all__ = ["irt", "model", "solver", "utl"] | ||
|
||
from ._pyirt import irt, model | ||
|
||
import solver | ||
import utl | ||
__all__ = ["_pyirt", "solver", "util"] | ||
from ._pyirt import irt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,36 @@ | ||
# -*-coding:utf-8-*- | ||
|
||
from .solver import model | ||
from .dao import localDAO | ||
|
||
|
||
def irt(src, theta_bnds=[-4, 4], | ||
def irt(data_src, | ||
theta_bnds=[-4, 4], num_theta=11, | ||
alpha_bnds=[0.25, 2], beta_bnds=[-2, 2], in_guess_param='default', | ||
model_spec='2PL', | ||
mode='memory', is_mount=False, user_name=None): | ||
max_iter=10, tol=1e-3, nargout=2): | ||
|
||
|
||
# load data | ||
dao_instance = localDAO(data_src) | ||
|
||
if model_spec == '2PL': | ||
mod = model.IRT_MMLE_2PL() | ||
mod = model.IRT_MMLE_2PL(dao_instance) | ||
else: | ||
raise Exception('Unknown model specification.') | ||
|
||
# load | ||
mod.load_data(src, is_mount, user_name) | ||
mod.load_param(theta_bnds, alpha_bnds, beta_bnds) | ||
mod.load_guess_param(in_guess_param) | ||
# specify the irt parameters | ||
mod.set_options(theta_bnds, num_theta, alpha_bnds, beta_bnds,max_iter, tol) | ||
mod.set_guess_param(in_guess_param) | ||
|
||
# solve | ||
mod.solve_EM() | ||
|
||
# post | ||
item_param_dict = mod.get_item_param() | ||
user_param_dict = mod.get_user_param() | ||
|
||
return item_param_dict, user_param_dict | ||
if nargout ==1: | ||
return item_param_dict | ||
elif nargout ==2: | ||
user_param_dict = mod.get_user_param() | ||
return item_param_dict, user_param_dict | ||
else: | ||
raise Exception('Invalid number of argument') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -*- coding:utf-8 -*- | ||
from .util import clib, tools | ||
import numpy as np | ||
|
||
def update_theta_distribution(data, num_theta, theta_prior_val, theta_density, item_param_dict): | ||
''' | ||
data = [(item_idx int, ans_tag binary)] | ||
''' | ||
|
||
''' | ||
Basic Math. | ||
P_t(theta, data |q_param) = p(data|q_param, theta)*p_[t-1](theta) | ||
p_t(data|q_param) = sum(p_t(theta,data|q_param)) over theta | ||
p_t(theta|data, q_param) = P_t(theta, data|q_param)/p_t(data|q_param) | ||
''' | ||
likelihood_vec = np.zeros(num_theta) | ||
|
||
for k in range(num_theta): | ||
theta = theta_prior_val[k] | ||
ell = 0.0 | ||
for log in data: | ||
item_idx = log[0] | ||
ans_tag = log[1] | ||
alpha = item_param_dict[item_idx]['alpha'] | ||
beta = item_param_dict[item_idx]['beta'] | ||
c = item_param_dict[item_idx]['c'] | ||
ell += clib.log_likelihood_2PL(0.0+ans_tag, 1.0 - ans_tag, theta, alpha, beta, c) | ||
likelihood_vec[k] = ell | ||
|
||
# posterior | ||
joint_llk_vec = likelihood_vec + np.log(theta_density) | ||
marginal = tools.logsum(joint_llk_vec) | ||
posterior = np.exp(joint_llk_vec - marginal) | ||
|
||
return posterior | ||
|
Oops, something went wrong.