In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats
from scipy.stats import norm
import scipy.integrate as integrate

import gym
from gym import spaces

import random
import itertools as it
from joblib import Parallel, delayed
from toolz import memoize
from contracts import contract
from collections import namedtuple, defaultdict, deque, Counter

import warnings
warnings.filterwarnings("ignore", 
                        message="The objective has been evaluated at this point before.")

from agents import Agent
from oldmouselab import OldMouselabEnv
from policies import FixedPlanPolicy, LiederPolicy
from evaluation import *
from distributions import cmax, smax, sample, expectation, Normal, PointMass, SampleDist, Normal, Categorical

In [3]:
def hd_dist(attributes):
    dist = [1,]*attributes
    dist[0] = np.random.randint(85,97)
    for i in range(1,attributes-1):
        dist[i] += np.random.randint(0,100-np.sum(dist))
    dist[-1] += 100-np.sum(dist)
    dist = np.around(np.array(dist)/100,decimals=2)
    np.random.shuffle(dist)
    return dist

def ld_dist(attributes):
    dist = [np.random.randint(10,40) for _ in range(attributes)]
    dist = np.around(np.array(dist)/sum(dist),decimals=2)
    np.random.shuffle(dist)
    return dist

In [4]:
gambles = 7
attributes = 4
high_stakes = Normal((9.99+0.01)/2, 0.3*(9.99-0.01))
low_stakes = Normal((0.25+0.01)/2, 0.3*(0.25-0.01))
reward = high_stakes
cost=.03

#set to 20 for sanity check
n_train = 20
n_test = 20

train_envs_hd = [OldMouselabEnv(gambles, hd_dist(attributes), reward, cost) for _ in range(n_train)]
train_envs_ld = [OldMouselabEnv(gambles, ld_dist(attributes), reward, cost) for _ in range(n_train)]
train_envs = train_envs_hd+train_envs_ld 

test_envs_hd =  [OldMouselabEnv(gambles, hd_dist(attributes), reward, cost) for _ in range(n_train)]
test_envs_ld = [OldMouselabEnv(gambles, ld_dist(attributes), reward, cost) for _ in range(n_train)]
test_envs = test_envs_hd+test_envs_ld 

term_action = train_envs[0].term_action

In [5]:
bo_pol_theta = np.load('data/high_stakes_3cents.npy')
bo_pol = LiederPolicy(list(bo_pol_theta))

In [6]:
agent = Agent()
def run_env(policy, env):
    agent.register(env)
    agent.register(policy)
    tr = agent.run_episode()
#     print(tr)
    return {'util': tr['return'], 'actions': tr['actions'],
            'observations': len(tr['actions']) - 1, 'ground_truth': env.ground_truth}

def action_coordinate(env, action):
    return (action//env.outcomes,action%env.outcomes)

def p_grid(env, actions):
    grid = np.zeros((env.gambles+1,env.outcomes))
    grid[0,:] = env.dist
    for i in range(len(actions[:-1])):
        gamble, outcome = action_coordinate(env,actions[i]) 
        grid[gamble+1, outcome] = i+1
    return grid

In [7]:
train_envs[21].reset()
trace = run_env(bo_pol, train_envs[21])
trace

{'actions': [14, 18, 13, 28],
 'ground_truth': array([ 11.322,   4.288,   8.354,   0.735,   9.03 ,   6.419,   3.196,   1.25 ,   8.874,   7.126,   2.795,   1.856,   3.748,  11.894,   8.14 ,   6.474,   6.213,   3.935,   6.348,   6.152,   9.956,
          3.178,   6.248,  10.182,   8.404,   8.996,   3.882,   6.267]),
 'observations': 3,
 'util': 8.2528497145643804}

In [8]:
train_envs[21].dist

array([ 0.14,  0.33,  0.34,  0.19])

In [9]:
train_envs[21].grid()

array([[Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99)],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99)],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99)],
       [Norm(5.00, 2.99), 11.894181316154732, 8.140499647745056, Norm(5.00, 2.99)],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), 6.3483969673903609, Norm(5.00, 2.99)],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99)],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99), Norm(5.00, 2.99)]], dtype=object)

In [10]:
p_grid(train_envs[21],trace['actions'])

array([[ 0.14,  0.33,  0.34,  0.19],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  3.  ,  1.  ,  0.  ],
       [ 0.  ,  0.  ,  2.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ]])