In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os, shutil

### Initialise the agent

In [2]:
from belief_tree import Tree

In [8]:
M = np.array([
    [5, 1],
    [1, 1]
])

gamma = 0.9
xi    = 0.0

Q = np.array([0.2, 0.2], dtype=np.float32)

tree = Tree(M, Q, 1, 'softmax')

### Full Bayesian updates

In [9]:
horizon = 5
tree.build_tree(horizon)
tree.full_updates(gamma)

In [10]:
qval_tree = tree.qval_tree
qval_tree[0][(0, 0, 0)]

array([3.4125834, 3.081275 ], dtype=float32)

### Replay

In [11]:
horizon = 5
tree.build_tree(horizon)
qval_history, need_history, replays = tree.replay_updates(gamma, xi)

In [12]:
qval_tree = tree.qval_tree
qval_tree[0][(0, 0, 0)]

array([2.9937387, 2.6748548], dtype=float32)

### Generate replay tree

In [None]:
from tex_tree import generate_big_tex_tree

save_folder = '/home/georgy/Documents/Dayan_lab/PhD/bandits/Data/example_tree/seq/'
if os.path.exists(save_folder):
    shutil.rmtree(save_folder)
    os.mkdir(save_folder)
else:
    os.makedirs(save_folder)

for idx, rep in enumerate(replays):
    these_replays = replays[:idx+1]
    save_path = os.path.join(save_folder, 'tex_tree_%u.tex'%idx)
    generate_big_tex_tree(horizon, these_replays, qval_history[idx], need_history[idx], save_path)

### Distribution updates asymmetry 

In [None]:
from scipy.stats import beta
from scipy.special import kl_div

In [None]:
x = np.linspace(0.001, 1, 100)

a1b, b1b = np.array([2, 1])
rv1b = beta(a1b, b1b)  

a1a, b1a = np.array([3, 1])
rv1a = beta(a1a, b1a)

a1aa, b1aa = np.array([4, 1])
rv1aa = beta(a1aa, b1aa)

In [None]:
plt.figure(figsize=(10, 3))
plt.subplot(131)
plt.plot(x, rv1b.pdf(x))
plt.title(r'$\alpha=2, \beta=1$')

plt.subplot(132)
plt.plot(x, rv1a.pdf(x))
plt.title(r'$\alpha=3, \beta=1$')

print(np.sum(kl_div(rv1b.pdf(x), rv1a.pdf(x))))

plt.subplot(133)
plt.plot(x, rv1aa.pdf(x))
plt.title(r'$\alpha=4, \beta=1$')

print(np.sum(kl_div(rv1a.pdf(x), rv1aa.pdf(x))))

plt.tight_layout()

# plt.savefig('/home/georgy/Documents/Dayan_lab/PhD/bandits/Data/betas.png')

### Convergence of replay values

In [None]:
for horizon in horizons: