In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats
from scipy.stats import norm
import scipy.integrate as integrate
import datetime

import gym
from gym import spaces

import random
import itertools as it
from joblib import Parallel, delayed
from toolz import memoize
from contracts import contract
from collections import namedtuple, defaultdict, deque, Counter

import warnings
warnings.filterwarnings("ignore", 
                        message="The objective has been evaluated at this point before.")

from agents import Agent
from oldmouselab import OldMouselabEnv
from policies import FixedPlanPolicy, LiederPolicy
from evaluation import *
from distributions import cmax, smax, sample, expectation, Normal, PointMass, SampleDist, TruncatedNormal, Categorical

In [2]:
gambles = 4
attributes = [0.24,0.74,0.01,0.01]
reward = Normal(2,1)
cost=0.03
env = OldMouselabEnv(gambles, attributes, reward)

In [3]:
attributes = 4
env = OldMouselabEnv(gambles, attributes, reward, cost, randomness=1)

In [4]:
env.dist

array([ 0.073,  0.338,  0.065,  0.524])

In [5]:
env.action_features(4)

array([-0.03 ,  0.029,  0.252,  0.658,  2.   ])

In [6]:
a = datetime.datetime.now()
means = []
for i in range(100):
    means.append(env.vpi())
t = datetime.datetime.now() - a
print(np.mean(means))

0.649369306803


In [7]:
a = datetime.datetime.now()
means = []
for i in range(100):
    means.append(env.vpi2())
t = datetime.datetime.now() - a
print(np.mean(means))

0.64912782615


In [8]:
a = datetime.datetime.now()
for i in range(100):
    env.vpi_action(2)
t = datetime.datetime.now() - a
print(t)

0:00:01.962793


In [9]:
a = datetime.datetime.now()
for i in range(100):
    env.myopic_voi(2)
t = datetime.datetime.now() - a
print(t)

0:00:01.535858


In [10]:
env._state
[env.step(i) for i in range(4)]

[((4.4420317043077002,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((4.4420317043077002,
   3.149066956978896,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((4.4420317043077002,
   3.149066956978896,
   2.5861928462151309,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 

In [11]:
env.step(16)

('__term_state__', 2.6111064974273672, True, {})

In [12]:
env.ground_truth

array([ 4.442,  3.149,  2.586,  2.013,  1.608,  1.293,  3.802,  3.492,  2.225,  1.875,  1.944,  2.205,  3.222,  1.173,  1.648,  2.634])

In [13]:
env.reset()

(Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00))

In [14]:
env.vars

array([ 0.398,  0.398,  0.398,  0.398])

In [15]:
expectation(env.term_reward())

2.0

# GT, SR

In [47]:
env = OldMouselabEnv(gambles, attributes, reward, cost, randomness=1, ground_truth = None, sample_term_reward = True)

In [48]:
env._state

(Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00))

In [49]:
[env.step(i) for i in range(6)]

[((2.6943879453559783,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((2.6943879453559783,
   3.7339911676854687,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((2.6943879453559783,
   3.7339911676854687,
   1.1083413597965912,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00

In [50]:
env.grid()

array([[2.6943879453559783, 3.7339911676854687, 1.1083413597965912, 1.8396801419061977],
       [2.0442896986083343, 1.8209548899478829, Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)]], dtype=object)

In [51]:
env.step(16)

('__term_state__', 1.6116856571274651, True, {})

In [52]:
env.ground_truth

array([ 2.694,  3.734,  1.108,  1.84 ,  2.044,  1.821,  2.324,  1.611,  1.469,  1.389,  2.044,  1.169,  1.103,  3.1  ,  1.072,  2.426])

In [53]:
gt_grid = env.ground_truth.reshape(gambles,attributes)

In [56]:
env.dist.dot(gt_grid[np.argmax(env.mus)])

1.6116856571274651

In [57]:
env.mus

[1.96518404061367, 1.9904433567878979, 2.0, 2.0]

# $\neg$GT, SR

In [71]:
env = OldMouselabEnv(gambles, attributes, reward, cost, randomness=1, ground_truth = False, sample_term_reward = True)

In [72]:
env._state

(Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00))

In [73]:
[env.step(i) for i in range(6)]

[((3.378434054458368,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((3.378434054458368,
   1.8906258685471413,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((3.378434054458368,
   1.8906258685471413,
   1.4907697796379693,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1

In [74]:
env.grid()

array([[3.378434054458368, 1.8906258685471413, 1.4907697796379693, 1.4307838097081216],
       [3.1019516881516203, 1.9588118677227733, Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)]], dtype=object)

In [75]:
env.step(16)

('__term_state__', 2.4613271588973094, True, {})

In [76]:
env.ground_truth

False

# $\neg$GT, $\neg$SR

In [109]:
env = OldMouselabEnv(gambles, attributes, reward, cost, randomness=1, ground_truth = False)

In [110]:
env._state

(Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00))

In [117]:
env.reset()

(Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00))

In [118]:
[env.step(i) for i in range(6)]

[((3.0665573493603078,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((3.0665573493603078,
   3.753863666154312,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((3.0665573493603078,
   3.753863666154312,
   1.5725900117679668,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 

In [119]:
env.grid()

array([[3.0665573493603078, 3.753863666154312, 1.5725900117679668, 1.4552015089803207],
       [0.7929337358155275, 3.3725707020205182, Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)]], dtype=object)

In [120]:
env.step(16)

('__term_state__', 2.0, True, {})

In [121]:
env.ground_truth

False

In [122]:
env.mus

[1.6030738630292483, 1.9537916839850449, 2.0, 2.0]

In [123]:
env.dist

array([ 0.041,  0.002,  0.654,  0.303])

# GT, $\neg$SR

In [164]:
env = OldMouselabEnv(gambles, attributes, reward, cost, randomness=1, ground_truth = None)

In [165]:
env._state

(Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00),
 Norm(2.00, 1.00))

In [166]:
[env.step(i) for i in range(6)]

[((2.5393907523692065,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((2.5393907523692065,
   1.5629486208841425,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00)),
  -0.03,
  False,
  {}),
 ((2.5393907523692065,
   1.5629486208841425,
   2.9862812006671833,
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00, 1.00),
   Norm(2.00

In [167]:
env.grid()

array([[2.5393907523692065, 1.5629486208841425, 2.9862812006671833, 0.8459585859269918],
       [0.81750436846740149, 1.5596043670470483, Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)],
       [Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00), Norm(2.00, 1.00)]], dtype=object)

In [168]:
env.step(16)

('__term_state__', 2.0, True, {})

In [169]:
env.ground_truth

array([ 2.539,  1.563,  2.986,  0.846,  0.818,  1.56 ,  1.654,  1.759,  2.059,  1.907,  1.908,  2.345,  2.228,  1.825,  3.274,  2.694])

In [170]:
env.mus

[1.855641065211922, 1.5332240082439264, 2.0, 2.0]

In [171]:
env.dist

array([ 0.35 ,  0.121,  0.155,  0.375])