In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os, shutil

### Initialise the agent

In [None]:
def policy(q_values, temp=None, policy_type='softmax'):

    '''
    ----
    Agent's policy

    q_values -- q values at the current state
    temp     -- inverse temperature
    type     -- softmax / greeedy
    ----
    '''

    if np.all(q_values == 0):
        return np.array([0.5, 0.5])

    if temp:
        t = temp
    else:
        t = 1
        
    if policy_type == 'softmax':
        return np.exp(q_values*t)/np.sum(np.exp(q_values*t))
    elif policy_type == 'greedy':
        return np.array(q_values >= q_values.max()).astype(int)
    else:
        raise KeyError('Unknown policy type')

In [None]:
from belief_tree import Tree

In [None]:
M = np.array([
    [3, 1],
    [1, 1]
])

gamma = 0.9
xi    = 0.0

Q = np.array([0.0, 0.0], dtype=np.float32)

tree = Tree(M, Q, 1, 'softmax')

### Full Bayesian updates

In [None]:
horizon = 5
tree.build_tree(horizon)
tree.full_updates(gamma)

In [None]:
qval_tree = tree.qval_tree
qval_tree[0][(0, 0, 0)]

### Replay

In [None]:
horizon = 5
tree.build_tree(horizon)
qval_history, need_history, replays = tree.replay_updates(gamma, xi)

In [None]:
qval_tree = tree.qval_tree
qval_tree[0][(0, 0, 0)]

### Generate replay tree

In [None]:
from tex_tree import generate_big_tex_tree

save_folder = '/home/georgy/Documents/Dayan_lab/PhD/bandits/Data/example_tree/seq/'
if os.path.exists(save_folder):
    shutil.rmtree(save_folder)
    os.mkdir(save_folder)
else:
    os.makedirs(save_folder)

for idx, rep in enumerate(replays):
    these_replays = replays[:idx+1]
    save_path = os.path.join(save_folder, 'tex_tree_%u.tex'%idx)
    generate_big_tex_tree(horizon, these_replays, qval_history[idx], need_history[idx], save_path)

### Distribution updates asymmetry 

In [None]:
from scipy.stats import beta
from scipy.special import kl_div

In [None]:
x = np.linspace(0.001, 1, 100)

a1b, b1b = np.array([2, 1])
rv1b = beta(a1b, b1b)  

a1a, b1a = np.array([3, 1])
rv1a = beta(a1a, b1a)

a1aa, b1aa = np.array([4, 1])
rv1aa = beta(a1aa, b1aa)

In [None]:
plt.figure(figsize=(10, 3))
plt.subplot(131)
plt.plot(x, rv1b.pdf(x))
plt.title(r'$\alpha=2, \beta=1$')

plt.subplot(132)
plt.plot(x, rv1a.pdf(x))
plt.title(r'$\alpha=3, \beta=1$')

print(np.sum(kl_div(rv1b.pdf(x), rv1a.pdf(x))))

plt.subplot(133)
plt.plot(x, rv1aa.pdf(x))
plt.title(r'$\alpha=4, \beta=1$')

print(np.sum(kl_div(rv1a.pdf(x), rv1aa.pdf(x))))

plt.tight_layout()

# plt.savefig('/home/georgy/Documents/Dayan_lab/PhD/bandits/Data/betas.png')

### Convergence of replay values

In [None]:
save_folder = '/home/georgy/Documents/Dayan_lab/PhD/bandits/data/convergence/root_greedy'

horizons = np.arange(2, 7)

for alpha_0 in range(1, 13):

    vals     = np.zeros((len(horizons), 2))

    M = np.array([
        [alpha_0, 1],
        [1, 1]
    ])

    gamma = 0.9
    xi    = 0.0

    Q = np.array([0.0, 0.0], dtype=np.float32)

    tree = Tree(M, Q, 1, 'softmax')

    for horizon in horizons:
        # build the tree
        tree.build_tree(horizon)
        # do full bayesian updates
        tree.root_q_values = np.array([0.0, 0.0], dtype=np.float32)
        tree.full_updates(gamma)
        qval_tree = tree.qval_tree
        qvals     = qval_tree[0][(0, 0, 0)]
        print('Horizon %u'%horizon)
        print('Full bayesian ', qvals)
        v_full    = np.max(qvals)
        # do replay
        tree.root_q_values = np.array([0.0, 0.0], dtype=np.float32)
        qval_history, need_history, replays = tree.replay_updates(gamma, xi)
        qval_tree = tree.qval_tree
        qvals     = qval_tree[0][(0, 0, 0)]
        print('Replay ', qvals, '\n')
        v_replay  = np.max(qvals)
        # append
        vals[horizon-horizons[0], 0] = v_full
        vals[horizon-horizons[0], 1] = v_replay

    file_name   = '%u.png'%alpha_0

    fig = plt.figure(figsize=(8, 5), dpi=100, constrained_layout=True)

    x   = np.arange(horizons[0], horizons[-1]+1)
    plt.scatter(x, vals[:, 0], label='Bayes-optimal')
    plt.scatter(x, vals[:, 1], label='Replay with ' + r'$\xi = %.2f$'%xi)
    plt.title(r'$\alpha_0=%u , \beta_0=%u, \alpha_1=%u, \beta_1=%u$'%(M[0, 0], M[0, 1], M[1, 0], M[1, 1]), fontsize=14)
    plt.xlabel('Horizon', fontsize=13)
    plt.ylabel('Root value', fontsize=13)
    plt.legend(prop={'size':13})

    plt.savefig(os.path.join(save_folder, file_name))


In [None]:
import matplotlib.image as mpimg

In [None]:
plt.figure(figsize=(18, 15), constrained_layout=True, dpi=100)
for i in range(1, 13):
    plt.subplot(4, 3, i)
    plt.imshow(mpimg.imread(os.path.join(save_folder, '%u.png'%i)))
    plt.axis('off')

plt.savefig(os.path.join(save_folder, 'all.png'))