In [14]:
import numpy as np

from IPython.display import Image
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import seaborn as sns
sns.set_style('ticks')

matplotlib.rcParams.update({'font.size': 16})
matplotlib.rc('axes', titlesize=16)

from infomercial.exp import random_bandit
from infomercial.local_gym import bandit
from infomercial.exp.meta_bandit import load_checkpoint

from infomercial.memory import Memory
from infomercial.memory import ConditionalCount
from infomercial.memory import EfficientConditionalCount
from infomercial.memory import ForgetfulConditionalCount

from scipy.stats import entropy

import gym
from pprint import pprint

def information_value(p_new, p_old, base=None):
    """Calculate information value."""
    if np.isclose(np.sum(p_old), 0.0):
        return 0.0  # Hack

    return entropy(p_old, qk=p_new, base=base)

In [15]:
# environments = [
#     ['BanditOneHot2', 'v0', 1],
#     ['BanditOneHot10', 'v0', 1],
#     ['BanditOneHot121', 'v0', 1],
#     ['BanditOneHot1000', 'v0', 1],
#     ['BanditEvenOdds2', 'v0', 1],
#     ['BanditOneHigh2', 'v0', 1],
#     ['BanditOneHigh10', 'v0', 1],
#     ['BanditOneHigh121', 'v0', 1],
#     ['BanditOneHigh1000', 'v0', 1],
#     ['BanditHardAndSparse2', 'v0', 1],
#     ['BanditHardAndSparse10', 'v0', 1],
#     ['BanditHardAndSparse121', 'v0', 1],
#     ['BanditHardAndSparse1000', 'v0', 1],
#     ['BanditGaussian10', 'v0', 1],
# ]

In [39]:
num_episodes = 100
env_name = "BanditOneHigh10-v0"

# Init gyms
env = gym.make(env_name)
num_actions = env.action_space.n
best_action = env.best

memory = ConditionalCount()

default_info_value = entropy(np.ones(num_actions) / num_actions)
E_t = default_info_value

# Init vars
num_best = 0
total_R = 0.0
total_E = 0.0
scores_R = []
scores_E = []
actions = []
p_bests = []
visited_states = set()

# Run the exps
for n in range(num_episodes):
    state = int(env.reset()[0])
    
    action = np.random.randint(num_actions)
    if action == best_action:
        num_best += 1

    visited_states.add(action)  # Action is state here
    
    # Pull a lever.
    state, reward, _, _ = env.step(action)
    state = int(state[0])
    R_t = reward  # Notation consistency

    # Est E
    cond_sample = [action, action]
    state_sample = [0, 1]
    
    p_old = memory.probs(state_sample, cond_sample)
    memory.update(reward, action)
    p_new = memory.probs(state_sample, cond_sample)

    info = information_value(p_new, p_old)
    if np.isclose(info, 0.0):
        info = default_info_value
        
    E_t = info
    
    # Log data
    actions.append(action)
    total_R += R_t
    total_E += E_t
    scores_E.append(E_t)
    scores_R.append(R_t)
    p_bests.append(num_best / (n + 1))

In [40]:
memory.counts

[OrderedDict([(0, 6), (1, 1)]),
 OrderedDict([(0, 7), (1, 3)]),
 OrderedDict([(0, 8), (1, 1)]),
 OrderedDict([(0, 8), (1, 1)]),
 OrderedDict([(1, 8), (0, 2)]),
 OrderedDict([(0, 9), (1, 2)]),
 OrderedDict([(0, 5), (1, 1)]),
 OrderedDict([(0, 14), (1, 2)]),
 OrderedDict([(0, 11), (1, 1)]),
 OrderedDict([(0, 8), (1, 2)])]

In [41]:
scores_E

[2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 0.6931471805599453,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 0.05889151782819174,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 0.4054651081081644,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 0.01737200037967128,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 2.3025850929940455,
 0.007381996975374061,
 0.056633012265132426,
 2.3025850929940455,
 2.3025850929940455,
 0.28768207245178085,
 2.3025850929940455,
 2.3025850929940455,
 0.22314355131420976,
 2.3025850929940455,
 2.3025850929940455,
 0.6931471805599453,
 0.05889151782819174,
 2.3025850929940455,
 0.043692120681965735,
 2.3025850929940455,
 0.15415067982725836,
 0.0014022384868726222,
 0.2231435513142097