@@ -289,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 116,
"metadata": {
"collapsed": false
},
@@ -300,12 +300,13 @@
" mean_s2 = np.mean(preds[1])\n",
" bias_s1 = mean_s1 - s1\n",
" bias_s2 = mean_s2 - s2\n",
" mse = np.mean((preds[0] - s1)**2 + (preds[1] - s2)**2)\n",
" covmat = np.cov(preds)\n",
" var_s1 = covmat[0, 0]\n",
" var_s2 = covmat[1, 1]\n",
" cov = covmat[0, 1]\n",
" corr = cov / (np.sqrt(var_s1) * np.sqrt(var_s2))\n",
" stats = {'mean_s1': mean_s1, 'mean_s2': mean_s2, 'bias_s1': bias_s1, 'bias_s2': bias_s2, 'var_s1': var_s1, 'var_s2': var_s2, 'cov': cov, 'corr': corr}\n",
" stats = {'mean_s1': mean_s1, 'mean_s2': mean_s2, 'bias_s1': bias_s1, 'bias_s2': bias_s2, 'var_s1': var_s1, 'var_s2': var_s2, 'cov': cov, 'corr': corr, 'mse': mse}\n",
" return stats"
]
},
@@ -560,175 +561,27 @@
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n",
"10 trials per contrast level\n"
]
}
],
"source": [
"post_func = posterior_setup(low=1, high=4, discrete_c=3, num_s=60, r_max=1)\n",
"resps = [None] * num_deltas\n",
"for delta_s in range(num_deltas):\n",
" test_data = generate_testset(90, stim_0=s1, stim_1=s1+delta_s, discrete_c=3, low=1, high=4, r_max=1)\n",
" r, _, _ = test_data\n",
" resps[delta_s] = r"
]
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 120,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\n",
"posts = {}\n",
"for s_i in range(90):\n",
" file_name = 'output_post/post_' + str(i+1) + '.pkl'\n",
"testsets = {}\n",
"for s_i in range(91):\n",
" file_name = 'output_post/post_' + str(s_i) + '.pkl'\n",
" if os.path.isfile(file_name):\n",
" delta_s = s_i / 3\n",
" pkl_file = open(file_name, 'rb')\n",
" p, r, c, i = pickle.load(pkl_file)\n",
" posts[c, delta_s] = p"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-28.25227198 -26.0783597 -29.13269856 ..., -24.19848548 -27.67333624\n",
" -23.92919243]\n",
"40 trials per contrast level\n",
"{'var_s1': array([ 10.02719864, 13.57938305, 10.18881273, 15.35320011,\n",
" 9.92542377, 8.73841609, 13.5532021 , 7.4075507 ,\n",
" 19.82254523, 6.60182673, 10.91104413, 10.43637532,\n",
" 13.00056325, 9.63121981, 13.42775892, 11.25597343,\n",
" 15.15762951, 15.39946474, 8.68434282, 8.18976316,\n",
" 10.49936786, 9.49821351, 11.00819366, 10.48726616,\n",
" 14.11987687, 9.52349402, 13.95857104, 14.13976886,\n",
" 13.69102293, 14.79479563, 8.64474225, 11.8561958 ,\n",
" 11.73480233, 9.48192256, 12.2676478 , 11.24749057,\n",
" 10.54574751, 11.19335677, 12.97359198, 10.53390033]), 'mean_s1': array([-31.4199926 , -29.62917108, -24.09271536, -27.81987493,\n",
" -26.19026243, -25.20169875, -26.70929618, -31.22614109,\n",
" -31.47175041, -34.31121231, -28.96156587, -31.96678344,\n",
" -32.59101277, -26.18998739, -26.67456342, -30.23641675,\n",
" -31.70915932, -32.50064772, -21.71534062, -29.43725163,\n",
" -32.70246313, -23.69867675, -32.65928863, -24.90044063,\n",
" -26.81646823, -28.6679632 , -32.46863126, -33.08027459,\n",
" -34.7092647 , -34.36352683, -26.34600899, -28.26650209,\n",
" -35.04056069, -34.8551323 , -30.64173745, -32.2281079 ,\n",
" -38.69393813, -23.10274835, -29.00187485, -32.95364804]), 'mean_s2': array([-11.15110471, -12.98190942, -16.63833366, -12.75736327,\n",
" -18.0554622 , -18.6170213 , -16.12569177, -11.44901575,\n",
" -16.32619926, -13.93837799, -15.66003656, -17.22749691,\n",
" -18.6664141 , -19.08427004, -16.7467258 , -13.92901638,\n",
" -17.28145534, -8.74912276, -16.47964325, -13.93983663,\n",
" -6.72011904, -14.24332821, -16.29415371, -16.47573911,\n",
" -15.78420959, -20.38280712, -13.74020392, -13.00475248,\n",
" -17.40385444, -14.62112069, -20.76851199, -18.68775899,\n",
" -11.36603214, -15.90753738, -14.35119576, -14.27560858,\n",
" -10.96290111, -16.34325921, -16.04698975, -15.64211883]), 'var_s2': array([ 10.53209263, 11.49920029, 8.70864419, 10.53905932,\n",
" 10.5636697 , 9.22604285, 10.62565659, 10.84502265,\n",
" 10.9000927 , 7.04790479, 13.26937026, 19.51965676,\n",
" 11.08654988, 10.34667834, 11.87091724, 10.64527866,\n",
" 13.55874366, 10.21421751, 8.90067403, 9.18431269,\n",
" 12.56627485, 10.43969062, 16.02076414, 10.11089715,\n",
" 14.66110097, 9.66457882, 6.65123817, 8.96445344,\n",
" 11.16987186, 7.36377813, 9.11782738, 11.05166603,\n",
" 11.0345938 , 10.4139157 , 14.46865626, 8.44904709,\n",
" 38.19891018, 10.31044339, 11.40737675, 7.61506414])}\n"
]
}
],
"source": [
"print posts[2, 15]['mean_s2']\n",
"c = 2\n",
"delta_s = 15\n",
"s1=-30\n",
"post_func = posterior_setup(low=c, high=c, discrete_c=1, num_s=60, r_max=1)\n",
"test_data = generate_testset(40, stim_0=s1, stim_1=s1+delta_s, discrete_c=1, low=c, high=c, r_max=1)\n",
"r, _, _ = test_data\n",
"p = get_posteriors_pool(r, post_func)\n",
"print p"
" p, r, c, delta_s = pickle.load(pkl_file)\n",
" posts[c, delta_s] = p\n",
" testsets[c, delta_s] = r"
]
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 117,
"metadata": {
"collapsed": false
},
@@ -746,6 +599,7 @@
" 'var_s2': np.zeros(num_deltas), \n",
" 'cov': np.zeros(num_deltas), \n",
" 'corr': np.zeros(num_deltas),\n",
" 'mse': np.zeros(num_deltas),\n",
" }\n",
"for delta_s in range(num_deltas):\n",
" post_means = np.array((posts[c, delta_s]['mean_s1'], posts[c, delta_s]['mean_s2']))\n",
@@ -757,12 +611,13 @@
" post_stats['var_s1'][delta_s] = stats['var_s1']\n",
" post_stats['var_s2'][delta_s] = stats['var_s2']\n",
" post_stats['cov'][delta_s] = stats['cov']\n",
" post_stats['corr'][delta_s] = stats['corr']"
" post_stats['corr'][delta_s] = stats['corr']\n",
" post_stats['mse'][delta_s] = stats['mse']"
]
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 119,
"metadata": {
"collapsed": false
},
@@ -771,58 +626,25 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'bias_s1': array([-3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n",
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n",
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n",
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n",
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n",
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133]), 'bias_s2': array([ 3.32216764, 2.32216764, 1.32216764, 0.32216764,\n",
" -0.67783236, -1.67783236, -2.67783236, -3.67783236,\n",
" -4.67783236, -5.67783236, -6.67783236, -7.67783236,\n",
" -8.67783236, -9.67783236, -10.67783236, -11.67783236,\n",
" -12.67783236, -13.67783236, -14.67783236, -15.67783236,\n",
" -16.67783236, -17.67783236, -18.67783236, -19.67783236,\n",
" -20.67783236, -21.67783236, -22.67783236, -23.67783236,\n",
" -24.67783236, -25.67783236]), 'corr': array([ 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n",
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n",
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n",
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n",
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n",
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705]), 'cov': array([ 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n",
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n",
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n",
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n",
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n",
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376]), 'var_s1': array([ 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n",
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n",
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n",
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n",
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n",
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954]), 'mean_s1': array([-33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n",
" -33.26783133, -33.26783133]), 'mean_s2': array([-26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n",
" -26.67783236, -26.67783236]), 'var_s2': array([ 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n",
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n",
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n",
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n",
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n",
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802])}\n"
"20.2534726967\n"
]
}
],
"source": [
"print post_stats"
"print np.mean(post_stats['mse'])"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"pkl_file = open('posts.pkl', 'wb')\n",
"pickle.dump((posts, testsets), pkl_file)\n",
"pkl_file.close()"
]
},
{

Large diffs are not rendered by default.