In [1]:
from mouselab.mouselab import MouselabEnv
import numpy as np

In [4]:
def get_myopic_action(env: MouselabEnv) -> int:
    """ Computes the action with the highest myopic VOC based on the given environments state. 

    Args:
        env (MouselabEnv): Mouselab environment

    Returns:
        int: Action with highest myopic VOC value
    """
    available_actions = list(env.actions(env._state))
    myopic_vocs = []
    for action in available_actions:
        print(env._state)
        myopic_voc = env.myopic_voc(action, env._state)
        # Term action has the same cost associated for some reason
        if action != env.term_action:
            myopic_voc += env.cost(action)
        myopic_vocs.append(myopic_voc)
    # Will choose first action if multiple have the same VOC
    return available_actions[np.argmax(myopic_vocs)]

def run_episode(env_name="high_increasing", cost=1) -> float:
    """ Evaluates the meta-greedy strategy over 1 episode

    Args:
        env_name (str, optional): Environment (registered) to be evaluated. Defaults to "high_increasing".
        cost (int, optional): Click cost. Defaults to 1.

    Returns:
        float: Total reward of the episode.
    """
    env = MouselabEnv.new_symmetric_registered(env_name, cost=cost)
    done = False
    episode_reward = 0.
    while not done:
        action = get_myopic_action(env)
        _, reward, done, _ = env.step(action)
        episode_reward += reward
    return episode_reward

In [5]:
rewards = []
for i in range(500):
    episode_reward = run_episode()
    rewards.append(episode_reward)
print(f"Mean reward: {np.mean(rewards)}")

(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, 48, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, 48, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat, Cat)
(0, Cat, Cat, 48, Cat, Cat, Cat, Cat, Cat,