In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
plt.rcParams.update({
    'figure.figsize': (4.8, 2.7), 'font.size': 15, 'lines.linewidth': 2,
    'xtick.labelsize': 'small', 'ytick.labelsize': 'small',
    'axes.spines.top': False, 'axes.spines.right': False,
    'savefig.dpi': 100,
})

import numpy as np
from pathlib import Path
import os, torch

FIG_PATH = Path('figures')
os.makedirs(FIG_PATH, exist_ok=True)
STORE_PATH = Path('store')
rng = np.random.default_rng()

In [2]:
from hexarena.utils import get_valid_blocks, load_monkey_data, align_monkey_data

filename = STORE_PATH/'foraging_masterTable/data_viktor.mat'
meta = get_valid_blocks(filename, min_pos_ratio=0.5, min_gaze_ratio=0.1)
num_blocks = sum([len(meta[s_id]) for s_id in meta])
print('{} blocks across {} sessions found'.format(num_blocks, len(meta)))

36 blocks across 7 sessions found


In [3]:
block_ids = []
for session_id in meta:
    for block_idx in meta[session_id]:
        block_data = load_monkey_data(filename, session_id, block_idx)
        if set(block_data['taus'])==set([7., 14., 21.]):
            block_ids.append((session_id, block_idx))
block_ids = sorted(block_ids)
print(f'{len(block_ids)} blocks of collected')

36 blocks of collected


In [4]:
from hexarena.env import SimilarBoxForagingEnv

env = SimilarBoxForagingEnv(
    box={
        '_target_': 'hexarena.box.GammaLinearBox', 'num_patches': 1, 'num_levels': 40,
    },
    boxes=[{'tau': tau} for tau in [21, 14, 7]],
)

In [5]:
from irc.model import SamplingBeliefModel

phi = {
    'embedder._target_': 'hexarena.box.LinearBoxStateEmbedder',
    'mlp_features': [6],
}

model = SamplingBeliefModel(
    env, p_s={'phis': [phi]*3},
)

def observation_weight(obs_drawn, obs_actual, env):
    if obs_drawn[-1]!=obs_actual[-1]:
        return 0.
    weight = 1.
    dpos = np.array(env.arena.anchors[obs_drawn[0]])-np.array(env.arena.anchors[obs_actual[0]])
    weight *= np.exp(-(dpos**2).sum()/(0.5**2))
    dgaze = np.array(env.arena.anchors[obs_drawn[1]])-np.array(env.arena.anchors[obs_actual[1]])
    weight *= np.exp(-(dgaze**2).sum()/(0.5**2))
    for i in range(env.num_boxes):
        if obs_drawn[i+2]==env.boxes[i].num_grades or obs_actual[i+2]==env.boxes[i].num_grades:
            weight *= float(obs_drawn[i+2]==obs_actual[i+2])
        else:
            weight *= np.exp(-np.abs(obs_drawn[i+2]-obs_actual[i+2])/(0.8*env.boxes[i].num_grades))
    return weight
model.get_weight = lambda obs_drawn, obs_actual: observation_weight(obs_drawn, obs_actual, env)

model.use_sample = True
model.num_samples = 2000
model.estimate_kw.sga_kw.pbar_kw.disable = True

In [6]:
from jarvis.utils import tqdm

rng.shuffle(block_ids)
for session_id, block_idx in tqdm(block_ids, unit='block'):
    savename = 'store/beliefs/viktor/{}_{}.pt'.format(session_id, block_idx)
    if os.path.exists(savename):
        continue
    
    block_data = load_monkey_data(filename, session_id, block_idx)
    block_data = align_monkey_data(block_data)
    env_data = env.convert_experiment_data(block_data)
    observations, actions, _ = env.extract_observation_action_reward(env_data)

    knowns, beliefs, infos = model.compute_beliefs(observations, actions, pbar_kw={'leave': False})
    torch.save({
        'knowns': knowns, 'beliefs': beliefs, 'infos': infos,
        **env_data,
    }, savename)

  0%|                                                                                                         …

Compute beliefs:   0%|                                                                                        …

In [None]:
knowns, beliefs, infos = model.compute_beliefs(observations[:7], actions[:6], pbar_kw={'leave': True})

In [None]:
logps = []
for belief in beliefs:
# for belief in [belief]:
    _logps, _ = model.p_s.s_dists[0].loglikelihoods(
        model.p_s.s_dists[0].all_xs,
        model.p_s.set_param_vec(0, belief),
    )
    logps.append(_logps)
logps = torch.stack(logps)

In [None]:
num_levels = env.boxes[0].num_levels
ps_box0 = np.zeros((len(logps), num_levels, num_levels+1))
for t in range(len(logps)):
    for level in range(1, num_levels+1):
        for timer in range(level+1):
            ps_box0[t, level-1, timer] = logps[t, env.boxes[0]._sub2idx(level, timer)].exp().item()

In [None]:
vmin, vmax = 0, ps_box0.max()
fig, ax = plt.subplots(figsize=(5, 5))
h = ax.imshow(ps_box0[0], vmin=0, vmax=vmax, cmap='Reds')
ax.set_xlabel('Timer')
ax.set_ylabel('Interval')
h_title = ax.set_title('')

def update(t):
    h.set_data(ps_box0[t])
    h_title.set_text(r'$t$='+'{}'.format(t))
    return h, h_title

ani = FuncAnimation(fig, update, frames=range(len(ps_box0)), blit=True)

HTML(ani.to_jshtml())