In [190]:
from collections import namedtuple, defaultdict, deque, Counter
import numpy as np
import gym
from gym import spaces
import itertools as it
from distributions import cmax, smax, expectation, Normal, PointMass, SampleDist
from toolz import memoize
import random
from contracts import contract

In [191]:
ZERO = PointMass(0)

class OldMouselabEnv(gym.Env):
    """MetaMDP for the Mouselab task."""
    
    term_state = '__term_state__'
    def __init__(self, gambles=4, attributes=5, reward=None, cost=0,
                 ground_truth=None, initial_states=None, randomness=1):
        
        self.gambles = gambles # no of gambles
        
        # distribution and number of attributes
        if hasattr(attributes, '__len__'):
            self.outcomes = len(attributes)
            self.dist = np.array(attributes)/np.sum(attributes)
        else:
            self.outcomes = attributes
            self.dist = np.random.dirichlet(np.ones(attributes)*randomness,size=1)

        # reward for the payoffs
        self.reward = reward if reward is not None else Normal(1, 1)
        if hasattr(reward, 'sample'):
            self.iid_rewards = True
        else:
            self.iid_rewards = False
            
        self.cost = - abs(cost)
        self.ground_truth = np.array(ground_truth) if ground_truth is not None else None
        self.grid = np.arange(self.gambles*self.outcomes).reshape((self.gambles, self.outcomes))
        self.initial_states = initial_states
        self.exact = hasattr(reward, 'vals')
        if self.exact:
            assert self.iid_rewards
            self.max = cmax
            self.init = np.array([self.reward,] * (self.gambles*self.outcomes))
        else:
            # Distributions represented as samples.
            self.max = smax
            self.init = np.array([self.reward.to_sampledist(),] * (self.gambles*self.outcomes))
        self.sample_term_reward = False
        self.term_action = self.gambles*self.outcomes + 1
        self.reset()

    def _reset(self):
        if self.initial_states:
            self.init = random.choice(self.initial_states)
        self._state = self.init
        return self.features(self._state)

    def _step(self, action):
        if self._state is self.term_state:
            assert 0, 'state is terminal'
            # return None, 0, True, {}
        if action == self.term_action:
            # self._state = self.term_state
            if self.sample_term_reward:
                if self.ground_truth is not None:
                    gamble = self.best_gamble()
                    reward = self.ground_truth[gamble].sum()
                else:
                    reward = self.term_reward().sample()
            else:
                reward = self.term_reward().expectation()
            done = True
        elif not hasattr(self._state[action], 'sample'):  # already observed
            assert 0, self._state[action]
            reward = 0
            done = False
        else:  # observe a new node
            self._state = self._observe(action)
            reward = self.cost
            done = False
        return self.features(self._state), reward, done, {}

    def _observe(self, action):
        if self.ground_truth is not None:
            result = self.ground_truth[action]
        else:
            result = self._state[action].sample()
        s = list(self._state)
        s[action] = result
        return tuple(s)

    def actions(self, state):
        """Yields actions that can be taken in the given state.

        Actions include observing the value of each unobserved node and terminating.
        """
        if state is self.term_state:
            return
        for i, v in enumerate(state):
            if hasattr(v, 'sample'):
                yield i
        yield self.term_action

    def results(self, state, action):
        """Returns a list of possible results of taking action in state.

        Each outcome is (probability, next_state, reward).
        """
        if action == self.term_action:
            # R = self.term_reward()
            # S1 = Categorical([self.term_state])
            # return cross(S1, R)
            yield (1, self.term_state, self.expected_term_reward(state))
        else:
            for r, p in state[action]:
                s1 = list(state)
                s1[action] = r
                yield (p, tuple(s1), self.cost)

    def features(self, state=None):
        state = state if state is not None else self._state
        return state


    def action_features(self, action, state=None):
        state = state if state is not None else self._state
        assert state is not None


        if action == self.term_action:
            return np.array([
                0,
                0,
                0,
                0,
                self.expected_term_reward(state)
            ])

        return np.array([
            self.cost,
            self.myopic_voc(action, state),
            self.vpi_action(action, state),
            self.vpi(state),
            self.expected_term_reward(state)
        ])


    def term_reward(self, state=None):
        state = state if state is not None else self._state
        assert state is not None
        return self.state_value(0, state)
    

    def state_value(self, state=None):
        """A distribution over total rewards after the given node."""
        state = state if state is not None else self._state
        grid = np.array(state).reshape(self.gambles,self.outcomes)
        best_gamble = max((grid[g] for g in range(self.gambles)), default=ZERO, key=lambda x: sum(map(expectation,x)))
        return np.sum(best_gamble)
    
    def expected_term_reward(self, state):
        return self.term_reward(state).expectation()

In [192]:
gambles = 4
attributes = [0.25,0.15,0.16,0.19]
env = OldMouselabEnv(gambles, attributes)
env.reset()

array([SD(200), SD(200), SD(200), SD(200), SD(200), SD(200), SD(200),
       SD(200), SD(200), SD(200), SD(200), SD(200), SD(200), SD(200),
       SD(200), SD(200)], dtype=object)

In [196]:
state = env._state
grid = np.array(state).reshape(env.gambles,env.outcomes)
best_gamble = max((grid[g] for g in range(env.gambles)), default=ZERO, key=lambda x: sum(map(expectation,x)))

In [197]:
best_gamble

array([SD(200), SD(200), SD(200), SD(200)], dtype=object)

In [211]:
np.sum(best_gamble).sample()

5.5330709939410223