In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats
from scipy.stats import norm
import scipy.integrate as integrate

import gym
from gym import spaces

import random
import itertools as it
from joblib import Parallel, delayed
from toolz import memoize
from contracts import contract
from collections import namedtuple, defaultdict, deque, Counter

import warnings
warnings.filterwarnings("ignore", 
                        message="The objective has been evaluated at this point before.")

from agents import Agent
from oldmouselab import OldMouselabEnv
from policies import FixedPlanPolicy, LiederPolicy
from evaluation import *
from omdc_util import *
from distributions import cmax, smax, sample, expectation, Normal, PointMass, SampleDist, Normal, Categorical

In [2]:
def hd_dist(attributes):
    dist = [1,]*attributes
    dist[0] = np.random.randint(85,97)
    for i in range(1,attributes-1):
        dist[i] += np.random.randint(0,100-np.sum(dist))
    dist[-1] += 100-np.sum(dist)
    dist = np.around(np.array(dist)/100,decimals=2)
    np.random.shuffle(dist)
    return dist

def ld_dist(attributes):
    constrain = True
    while constrain:
        dist = [np.random.randint(10,50) for _ in range(attributes)]
        dist = np.around(np.array(dist)/sum(dist),decimals=2)
        constrain = np.min(dist) <= 0.10 or np.max(dist) >= 0.40
    np.random.shuffle(dist)
    return dist

In [3]:
gambles = 7
attributes = 4
high_stakes = Normal((9.99+0.01)/2, 0.3*(9.99-0.01))
low_stakes = Normal((0.25+0.01)/2, 0.3*(0.25-0.01))
reward = high_stakes
cost=.03

#set to 20 for sanity check
n_train = 20
n_test = 20

train_envs_hd = [OldMouselabEnv(gambles, hd_dist(attributes), reward, cost) for _ in range(n_train)]
train_envs_ld = [OldMouselabEnv(gambles, ld_dist(attributes), reward, cost) for _ in range(n_train)]
train_envs = train_envs_hd+train_envs_ld 

test_envs_hd =  [OldMouselabEnv(gambles, hd_dist(attributes), reward, cost) for _ in range(n_train)]
test_envs_ld = [OldMouselabEnv(gambles, ld_dist(attributes), reward, cost) for _ in range(n_train)]
test_envs = test_envs_hd+test_envs_ld 

term_action = train_envs[0].term_action

In [4]:
bo_pol_theta = np.load('data/high_stakes_3cents.npy')
bo_pol = LiederPolicy(list(bo_pol_theta))

In [5]:
agent = Agent()
def run_env(policy, env):
    agent.register(env)
    agent.register(policy)
    tr = agent.run_episode()
#     print(tr)
    return {'util': tr['return'], 'actions': tr['actions'],
            'observations': len(tr['actions']) - 1, 'ground_truth': env.ground_truth}

def action_coordinate(env, action):
    return (action//env.outcomes,action%env.outcomes)

def p_grid(env, actions):
    grid = np.zeros((env.gambles+1,env.outcomes))
    grid[0,:] = env.dist
    for i in range(len(actions[:-1])):
        gamble, outcome = action_coordinate(env,actions[i]) 
        grid[gamble+1, outcome] = i+1
    return grid

# BMPS Run

In [19]:
train_envs[21].reset()
trace = run_env(bo_pol, train_envs[21])
trace

{'actions': [2,
  6,
  10,
  14,
  18,
  22,
  26,
  9,
  25,
  11,
  27,
  8,
  13,
  21,
  23,
  20,
  15,
  24,
  1,
  3,
  5,
  28],
 'ground_truth': array([ -0.54 ,   6.997,   3.468,   3.171,  10.512,   3.075,   3.225,   5.417,   2.533,   8.301,   5.649,   2.666,   4.335,   4.55 ,   4.328,   5.184,  12.329,   3.819,   1.904,   5.086,   2.257,
          6.948,   4.081,   4.633,   6.33 ,   3.996,   8.754,   3.026]),
 'observations': 21,
 'util': 4.9380335869757719}

In [20]:
train_envs[21].dist

array([ 0.17,  0.27,  0.3 ,  0.26])

In [21]:
train_envs[21].grid()

array([[Norm(5.00, 2.99), 6.9968832316956426, 3.4679869287369995, 3.1705854226419756],
       [Norm(5.00, 2.99), 3.0754120800467213, 3.2250433259077607, Norm(5.00, 2.99)],
       [2.5333603734218157, 8.3006676078030939, 5.6486372548577819, 2.6657422797617527],
       [Norm(5.00, 2.99), 4.5495495057475406, 4.3284366648253858, 5.1844002318269613],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), 1.9042385445329972, Norm(5.00, 2.99)],
       [2.2570549631896046, 6.947908273977033, 4.080768922234733, 4.63324843565879],
       [6.3300067548706025, 3.9960887648433654, 8.7538455729102722, 3.026287693334536]], dtype=object)

In [22]:
p_grid(train_envs[21],trace['actions'])

array([[  0.17,   0.27,   0.3 ,   0.26],
       [  0.  ,  19.  ,   1.  ,  20.  ],
       [  0.  ,  21.  ,   2.  ,   0.  ],
       [ 12.  ,   8.  ,   3.  ,  10.  ],
       [  0.  ,  13.  ,   4.  ,  17.  ],
       [  0.  ,   0.  ,   5.  ,   0.  ],
       [ 16.  ,  14.  ,   6.  ,  15.  ],
       [ 18.  ,   9.  ,   7.  ,  11.  ]])

# DC Run

In [23]:
train_envs[21].reset()
trace = run_dc(train_envs[21])
trace

{'actions': [26,
  25,
  27,
  6,
  18,
  22,
  10,
  9,
  11,
  8,
  14,
  2,
  24,
  13,
  15,
  21,
  23,
  20,
  1,
  3,
  5,
  7,
  28],
 'ground_truth': array([ -0.54 ,   6.997,   3.468,   3.171,  10.512,   3.075,   3.225,   5.417,   2.533,   8.301,   5.649,   2.666,   4.335,   4.55 ,   4.328,   5.184,  12.329,   3.819,   1.904,   5.086,   2.257,
          6.948,   4.081,   4.633,   6.33 ,   3.996,   8.754,   3.026]),
 'observations': 22,
 'options': [(6, 1),
  (6, 1),
  (6, 1),
  (1, 1),
  (4, 1),
  (5, 1),
  (2, 1),
  (2, 1),
  (2, 1),
  (2, 1),
  (3, 1),
  (0, 1),
  (6, 1),
  (3, 2),
  (5, 2),
  (5, 1),
  (0, 2),
  (1, 2),
  (-99, 1)],
 'util': 4.9080335869757716}

In [24]:
train_envs[21].dist

array([ 0.17,  0.27,  0.3 ,  0.26])

In [25]:
train_envs[21].grid()

array([[Norm(5.00, 2.99), 6.9968832316956426, 3.4679869287369995, 3.1705854226419756],
       [Norm(5.00, 2.99), 3.0754120800467213, 3.2250433259077607, 5.4173103451602973],
       [2.5333603734218157, 8.3006676078030939, 5.6486372548577819, 2.6657422797617527],
       [Norm(5.00, 2.99), 4.5495495057475406, 4.3284366648253858, 5.1844002318269613],
       [Norm(5.00, 2.99), Norm(5.00, 2.99), 1.9042385445329972, Norm(5.00, 2.99)],
       [2.2570549631896046, 6.947908273977033, 4.080768922234733, 4.63324843565879],
       [6.3300067548706025, 3.9960887648433654, 8.7538455729102722, 3.026287693334536]], dtype=object)

In [26]:
p_grid(train_envs[21],trace['actions'])

array([[  0.17,   0.27,   0.3 ,   0.26],
       [  0.  ,  19.  ,  12.  ,  20.  ],
       [  0.  ,  21.  ,   4.  ,  22.  ],
       [ 10.  ,   8.  ,   7.  ,   9.  ],
       [  0.  ,  14.  ,  11.  ,  15.  ],
       [  0.  ,   0.  ,   5.  ,   0.  ],
       [ 18.  ,  16.  ,   6.  ,  17.  ],
       [ 13.  ,   2.  ,   1.  ,   3.  ]])

# Parsing

In [32]:
gambles = 7
attributes = 4
def make_hs_env(gambles=7, cost=.03, ground_truth=False, dist=hd_dist(4)):
    hs = Normal((9.99+0.01)/2, 0.3*(9.99-0.01))
    return OldMouselabEnv(gambles, dist, hs, cost, ground_truth= ground_truth) 

In [34]:
env2 = make_hs_env(ground_truth=env.ground_truth,dist=env.dist)

In [40]:
env.ground_truth

array([ -0.54 ,   6.997,   3.468,   3.171,  10.512,   3.075,   3.225,   5.417,   2.533,   8.301,   5.649,   2.666,   4.335,   4.55 ,   4.328,   5.184,  12.329,   3.819,   1.904,   5.086,   2.257,
         6.948,   4.081,   4.633,   6.33 ,   3.996,   8.754,   3.026])

In [41]:
env2.ground_truth

array([ -0.54 ,   6.997,   3.468,   3.171,  10.512,   3.075,   3.225,   5.417,   2.533,   8.301,   5.649,   2.666,   4.335,   4.55 ,   4.328,   5.184,  12.329,   3.819,   1.904,   5.086,   2.257,
         6.948,   4.081,   4.633,   6.33 ,   3.996,   8.754,   3.026])

In [42]:
trace = run_env(bo_pol, env)
trace

{'actions': [2,
  6,
  10,
  14,
  18,
  22,
  26,
  9,
  25,
  11,
  27,
  8,
  13,
  21,
  23,
  20,
  15,
  24,
  1,
  3,
  5,
  28],
 'ground_truth': array([ -0.54 ,   6.997,   3.468,   3.171,  10.512,   3.075,   3.225,   5.417,   2.533,   8.301,   5.649,   2.666,   4.335,   4.55 ,   4.328,   5.184,  12.329,   3.819,   1.904,   5.086,   2.257,
          6.948,   4.081,   4.633,   6.33 ,   3.996,   8.754,   3.026]),
 'observations': 21,
 'util': 4.9380335869757719}

In [43]:
trace = run_env(bo_pol, env2)
trace

{'actions': [2,
  6,
  10,
  14,
  18,
  22,
  26,
  9,
  25,
  11,
  27,
  24,
  8,
  13,
  21,
  23,
  20,
  1,
  3,
  15,
  5,
  28],
 'ground_truth': array([ -0.54 ,   6.997,   3.468,   3.171,  10.512,   3.075,   3.225,   5.417,   2.533,   8.301,   5.649,   2.666,   4.335,   4.55 ,   4.328,   5.184,  12.329,   3.819,   1.904,   5.086,   2.257,
          6.948,   4.081,   4.633,   6.33 ,   3.996,   8.754,   3.026]),
 'observations': 21,
 'util': 4.9380335869757719}

In [27]:
def wrap_po(env,click_sequence,t=1,p_rand=0):
    memo = dict()
    def parse_options_clean(init_state,dist,pre_acts,click_sequence,t=1,p_err=0.001):
        if click_sequence == []: 
            return True, [[]], [1]
        if (tuple(pre_acts),tuple(click_sequence),tuple(dist),t,p_err) in memo:
            return memo[(tuple(pre_acts),tuple(click_sequence),t,p_err)]
        
        envc = make_hs_env(ground_truth=init_state, dist=dist)
        envc.reset()
        for a in pre_acts:
            envc._step(a)

        option_seqs = []
        likelihoods = []
        done = False
        options, option_insts, option_utils,n_available_clicks = get_all_options(envc)

        for i,j in product(range(1,min(len(envc.paths[0]),len(click_sequence))+1),range(len(options))):  
            option = options[j]
#             n_insts = 1 if option == (-1,1) else len(option_insts[option])
            n_insts = len(option_insts[option])
            for inst in option_insts[option]:      
                if np.array_equal(click_sequence[:i],inst): 
                    will_done, remaining, rem_likelihoods = (parse_options_clean
                                  (init_state,pre_acts+click_sequence[:i],click_sequence[i:],t,p_rand))
                    done = done or will_done  
                    if done:
                        for k in range(len(remaining)): 
                            option_seqs.append([option]+remaining[k]) 
#                             l_opt_seq = ((1-p_rand)*np.exp(1/t*option_utils[j])/np.sum(np.exp(1/t*option_utils))
#                                         + p_rand*np.prod([1/(n_available_clicks-k) for k in range(option[1])]))
                            alpha = 1 if option == (-1,1) else 0 
                            l_opt_seq = ((1-p_rand)*np.exp(1/t*option_utils[j])/np.sum(np.exp(1/t*option_utils))
                                    + p_rand*alpha)
                            likelihoods.append(l_opt_seq*rem_likelihoods[k]/n_insts)
                            
        memo[(tuple(pre_acts),tuple(click_sequence),t,p_err)] = done, option_seqs, likelihoods
        return done, option_seqs, likelihoods
    return parse_options_clean(env.ground_truth,[],click_sequence+[env.term_action],t,p_rand)

In [30]:
env = train_envs[21]

In [28]:
trace['actions']

[26,
 25,
 27,
 6,
 18,
 22,
 10,
 9,
 11,
 8,
 14,
 2,
 24,
 13,
 15,
 21,
 23,
 20,
 1,
 3,
 5,
 7,
 28]

In [None]:
a,b,c = wrap_po(env,trial,t=2,p_rand=0.01,branching=[3,1,2])