@@ -289,7 +289,7 @@
},
{
"cell_type" : " code" ,
"execution_count" : 7 ,
"execution_count" : 116 ,
"metadata" : {
"collapsed" : false
},
@@ -300,12 +300,13 @@
" mean_s2 = np.mean(preds[1])\n " ,
" bias_s1 = mean_s1 - s1\n " ,
" bias_s2 = mean_s2 - s2\n " ,
" mse = np.mean((preds[0] - s1)**2 + (preds[1] - s2)**2)\n " ,
" covmat = np.cov(preds)\n " ,
" var_s1 = covmat[0, 0]\n " ,
" var_s2 = covmat[1, 1]\n " ,
" cov = covmat[0, 1]\n " ,
" corr = cov / (np.sqrt(var_s1) * np.sqrt(var_s2))\n " ,
" stats = {'mean_s1': mean_s1, 'mean_s2': mean_s2, 'bias_s1': bias_s1, 'bias_s2': bias_s2, 'var_s1': var_s1, 'var_s2': var_s2, 'cov': cov, 'corr': corr}\n " ,
" stats = {'mean_s1': mean_s1, 'mean_s2': mean_s2, 'bias_s1': bias_s1, 'bias_s2': bias_s2, 'var_s1': var_s1, 'var_s2': var_s2, 'cov': cov, 'corr': corr, 'mse': mse }\n " ,
" return stats"
]
},
@@ -560,175 +561,27 @@
},
{
"cell_type" : " code" ,
"execution_count" : 67 ,
"metadata" : {
"collapsed" : false
},
"outputs" : [
{
"name" : " stdout" ,
"output_type" : " stream" ,
"text" : [
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n " ,
" 10 trials per contrast level\n "
]
}
],
"source" : [
" post_func = posterior_setup(low=1, high=4, discrete_c=3, num_s=60, r_max=1)\n " ,
" resps = [None] * num_deltas\n " ,
" for delta_s in range(num_deltas):\n " ,
" test_data = generate_testset(90, stim_0=s1, stim_1=s1+delta_s, discrete_c=3, low=1, high=4, r_max=1)\n " ,
" r, _, _ = test_data\n " ,
" resps[delta_s] = r"
]
},
{
"cell_type" : " code" ,
"execution_count" : 97 ,
"execution_count" : 120 ,
"metadata" : {
"collapsed" : false
},
"outputs" : [],
"source" : [
" import os\n " ,
" posts = {}\n " ,
" for s_i in range(90):\n " ,
" file_name = 'output_post/post_' + str(i+1) + '.pkl'\n " ,
" testsets = {}\n " ,
" for s_i in range(91):\n " ,
" file_name = 'output_post/post_' + str(s_i) + '.pkl'\n " ,
" if os.path.isfile(file_name):\n " ,
" delta_s = s_i / 3\n " ,
" pkl_file = open(file_name, 'rb')\n " ,
" p, r, c, i = pickle.load(pkl_file)\n " ,
" posts[c, delta_s] = p"
]
},
{
"cell_type" : " code" ,
"execution_count" : 98 ,
"metadata" : {
"collapsed" : false
},
"outputs" : [
{
"name" : " stdout" ,
"output_type" : " stream" ,
"text" : [
" [-28.25227198 -26.0783597 -29.13269856 ..., -24.19848548 -27.67333624\n " ,
" -23.92919243]\n " ,
" 40 trials per contrast level\n " ,
" {'var_s1': array([ 10.02719864, 13.57938305, 10.18881273, 15.35320011,\n " ,
" 9.92542377, 8.73841609, 13.5532021 , 7.4075507 ,\n " ,
" 19.82254523, 6.60182673, 10.91104413, 10.43637532,\n " ,
" 13.00056325, 9.63121981, 13.42775892, 11.25597343,\n " ,
" 15.15762951, 15.39946474, 8.68434282, 8.18976316,\n " ,
" 10.49936786, 9.49821351, 11.00819366, 10.48726616,\n " ,
" 14.11987687, 9.52349402, 13.95857104, 14.13976886,\n " ,
" 13.69102293, 14.79479563, 8.64474225, 11.8561958 ,\n " ,
" 11.73480233, 9.48192256, 12.2676478 , 11.24749057,\n " ,
" 10.54574751, 11.19335677, 12.97359198, 10.53390033]), 'mean_s1': array([-31.4199926 , -29.62917108, -24.09271536, -27.81987493,\n " ,
" -26.19026243, -25.20169875, -26.70929618, -31.22614109,\n " ,
" -31.47175041, -34.31121231, -28.96156587, -31.96678344,\n " ,
" -32.59101277, -26.18998739, -26.67456342, -30.23641675,\n " ,
" -31.70915932, -32.50064772, -21.71534062, -29.43725163,\n " ,
" -32.70246313, -23.69867675, -32.65928863, -24.90044063,\n " ,
" -26.81646823, -28.6679632 , -32.46863126, -33.08027459,\n " ,
" -34.7092647 , -34.36352683, -26.34600899, -28.26650209,\n " ,
" -35.04056069, -34.8551323 , -30.64173745, -32.2281079 ,\n " ,
" -38.69393813, -23.10274835, -29.00187485, -32.95364804]), 'mean_s2': array([-11.15110471, -12.98190942, -16.63833366, -12.75736327,\n " ,
" -18.0554622 , -18.6170213 , -16.12569177, -11.44901575,\n " ,
" -16.32619926, -13.93837799, -15.66003656, -17.22749691,\n " ,
" -18.6664141 , -19.08427004, -16.7467258 , -13.92901638,\n " ,
" -17.28145534, -8.74912276, -16.47964325, -13.93983663,\n " ,
" -6.72011904, -14.24332821, -16.29415371, -16.47573911,\n " ,
" -15.78420959, -20.38280712, -13.74020392, -13.00475248,\n " ,
" -17.40385444, -14.62112069, -20.76851199, -18.68775899,\n " ,
" -11.36603214, -15.90753738, -14.35119576, -14.27560858,\n " ,
" -10.96290111, -16.34325921, -16.04698975, -15.64211883]), 'var_s2': array([ 10.53209263, 11.49920029, 8.70864419, 10.53905932,\n " ,
" 10.5636697 , 9.22604285, 10.62565659, 10.84502265,\n " ,
" 10.9000927 , 7.04790479, 13.26937026, 19.51965676,\n " ,
" 11.08654988, 10.34667834, 11.87091724, 10.64527866,\n " ,
" 13.55874366, 10.21421751, 8.90067403, 9.18431269,\n " ,
" 12.56627485, 10.43969062, 16.02076414, 10.11089715,\n " ,
" 14.66110097, 9.66457882, 6.65123817, 8.96445344,\n " ,
" 11.16987186, 7.36377813, 9.11782738, 11.05166603,\n " ,
" 11.0345938 , 10.4139157 , 14.46865626, 8.44904709,\n " ,
" 38.19891018, 10.31044339, 11.40737675, 7.61506414])}\n "
]
}
],
"source" : [
" print posts[2, 15]['mean_s2']\n " ,
" c = 2\n " ,
" delta_s = 15\n " ,
" s1=-30\n " ,
" post_func = posterior_setup(low=c, high=c, discrete_c=1, num_s=60, r_max=1)\n " ,
" test_data = generate_testset(40, stim_0=s1, stim_1=s1+delta_s, discrete_c=1, low=c, high=c, r_max=1)\n " ,
" r, _, _ = test_data\n " ,
" p = get_posteriors_pool(r, post_func)\n " ,
" print p"
" p, r, c, delta_s = pickle.load(pkl_file)\n " ,
" posts[c, delta_s] = p\n " ,
" testsets[c, delta_s] = r"
]
},
{
"cell_type" : " code" ,
"execution_count" : 80 ,
"execution_count" : 117 ,
"metadata" : {
"collapsed" : false
},
@@ -746,6 +599,7 @@
" 'var_s2': np.zeros(num_deltas), \n " ,
" 'cov': np.zeros(num_deltas), \n " ,
" 'corr': np.zeros(num_deltas),\n " ,
" 'mse': np.zeros(num_deltas),\n " ,
" }\n " ,
" for delta_s in range(num_deltas):\n " ,
" post_means = np.array((posts[c, delta_s]['mean_s1'], posts[c, delta_s]['mean_s2']))\n " ,
@@ -757,12 +611,13 @@
" post_stats['var_s1'][delta_s] = stats['var_s1']\n " ,
" post_stats['var_s2'][delta_s] = stats['var_s2']\n " ,
" post_stats['cov'][delta_s] = stats['cov']\n " ,
" post_stats['corr'][delta_s] = stats['corr']"
" post_stats['corr'][delta_s] = stats['corr']\n " ,
" post_stats['mse'][delta_s] = stats['mse']"
]
},
{
"cell_type" : " code" ,
"execution_count" : 81 ,
"execution_count" : 119 ,
"metadata" : {
"collapsed" : false
},
@@ -771,58 +626,25 @@
"name" : " stdout" ,
"output_type" : " stream" ,
"text" : [
" {'bias_s1': array([-3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n " ,
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n " ,
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n " ,
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n " ,
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133,\n " ,
" -3.26783133, -3.26783133, -3.26783133, -3.26783133, -3.26783133]), 'bias_s2': array([ 3.32216764, 2.32216764, 1.32216764, 0.32216764,\n " ,
" -0.67783236, -1.67783236, -2.67783236, -3.67783236,\n " ,
" -4.67783236, -5.67783236, -6.67783236, -7.67783236,\n " ,
" -8.67783236, -9.67783236, -10.67783236, -11.67783236,\n " ,
" -12.67783236, -13.67783236, -14.67783236, -15.67783236,\n " ,
" -16.67783236, -17.67783236, -18.67783236, -19.67783236,\n " ,
" -20.67783236, -21.67783236, -22.67783236, -23.67783236,\n " ,
" -24.67783236, -25.67783236]), 'corr': array([ 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n " ,
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n " ,
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n " ,
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n " ,
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705,\n " ,
" 0.27977705, 0.27977705, 0.27977705, 0.27977705, 0.27977705]), 'cov': array([ 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n " ,
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n " ,
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n " ,
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n " ,
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376,\n " ,
" 1.35137376, 1.35137376, 1.35137376, 1.35137376, 1.35137376]), 'var_s1': array([ 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n " ,
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n " ,
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n " ,
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n " ,
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954,\n " ,
" 4.82147954, 4.82147954, 4.82147954, 4.82147954, 4.82147954]), 'mean_s1': array([-33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133, -33.26783133, -33.26783133,\n " ,
" -33.26783133, -33.26783133]), 'mean_s2': array([-26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236, -26.67783236, -26.67783236,\n " ,
" -26.67783236, -26.67783236]), 'var_s2': array([ 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n " ,
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n " ,
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n " ,
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n " ,
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802,\n " ,
" 4.83889802, 4.83889802, 4.83889802, 4.83889802, 4.83889802])}\n "
" 20.2534726967\n "
]
}
],
"source" : [
" print post_stats"
" print np.mean(post_stats['mse'])"
]
},
{
"cell_type" : " code" ,
"execution_count" : 121 ,
"metadata" : {
"collapsed" : true
},
"outputs" : [],
"source" : [
" pkl_file = open('posts.pkl', 'wb')\n " ,
" pickle.dump((posts, testsets), pkl_file)\n " ,
" pkl_file.close()"
]
},
{