In [16]:
from pymdp.agent import Agent
from pymdp import utils, maths, control
from pymdp.envs import SceneConstruction, RandomDotMotion, initialize_scene_construction_GM, initialize_RDM_GM
import numpy as np
from copy import deepcopy

#### Initialize the parameters of the high-level POMDP and create a high level agent using the generative model parameters

T_high = 6
reward = 2.0
punishment = -4.0
urgency = -4.0
params_high, mappings_high, dimensions_high = initialize_scene_construction_GM(T=T_high, reward=reward, punishment=punishment, urgency=urgency)
agent_high = Agent(**params_high)

#### Initialize the high-level environment

scene_name = "UP_RIGHT" # options are "UP_RIGHT", "RIGHT_DOWN", "DOWN_LEFT", and "LEFT_UP"
config = "1_2"          # options are all combinations of (1,2,3,4) in pairs of twos in all orders, but no repeats (e.g. "1_2", "2_1", "1_3", "3_1", "1_4", ...)
env_high_level = SceneConstruction(scene_name = scene_name, config = config)
what_obs_h, where_obs_h = env_high_level.reset()

#### Initialize the parameters of the low-level POMDP and create a low-level agent using the generative model parameters

T_low = 16
p_low = 1.0
urgency = 0.001
params_low, mappings_low, dimensions_low = initialize_RDM_GM(T=T_low, A_precis=p_low, break_reward=urgency)
agent_low = Agent(**params_low)

#### initialize the low-level environment

true_direction = "null"
dot_precision = 1.0
env_low_level = RandomDotMotion(precision=dot_precision, dot_direction=true_direction)
what_obs_l, where_obs_l = env_low_level.reset()

#### Write some "linking" functions that link inputs and outputs from different levels to eachother

def get_prior_from_above(beliefs_high):

    expected_obs = maths.spm_dot(params_high['A'][0], beliefs_high) # map `q_states` through the likelihood mapping to make a prior over the hidden states of the lower-level agent
    empirical_prior = utils.obj_array_zeros(dimensions_low['num_states']) # make an empirical prior over hidden states at the low-level
    empirical_prior[0] = expected_obs[:5]
    empirical_prior[1] = utils.onehot(0, dimensions_low['num_states'][1])

    return empirical_prior

def get_obs_from_below(beliefs_low, where_obs_h):

    obs_h = utils.obj_array_zeros(dimensions_high['num_obs'])
    obs_h[0][:5] = beliefs_low[0]

    where_obs_h_idx = mappings_high['where_obs_names'].index(where_obs_h)
    obs_h[1] = utils.onehot(where_obs_h_idx, dimensions_high['num_obs'][1])

    return obs_h

#### Write some functions that map between semantic (e.g. "UP", "2", "choose_UP_RIGHT") and discrete ordinal (0, 1, 2, ..) labels for observations

# For debugging - add print statements to see what's happening
def label_to_indices(mappings, what_obs, where_obs, dimensions):
    """ Maps from two strings `what_obs` and `where_obs` to their corresponding observation index """
    # Convert labels to indices 
    what_idx = int(mappings['what_obs_names'].index(what_obs))
    where_idx = int(mappings['where_obs_names'].index(where_obs))

    # If this is for the low level (checking by dimension), return tuple of indices
    if len(mappings['what_obs_names']) == 5:  # Low level has 5 what_obs states
        return (what_idx, where_idx)
    
    # Otherwise this is high level, return object array with one-hot vectors
    obs_h = utils.obj_array_zeros(dimensions['num_obs']) 
    obs_h[0] = utils.onehot(what_idx, len(mappings['what_obs_names']))
    obs_h[1] = utils.onehot(where_idx, len(mappings['where_obs_names']))
    return obs_h

def indices_to_label(mappings, what_obs_idx, where_obs_idx):
    """ Maps from two indices `what_obs_idx` and `where_obs_idx` to their corresponding observation label """
    what_obs = mappings['what_obs_names'][what_obs_idx]
    where_obs = mappings['where_obs_names'][where_obs_idx]
    return what_obs, where_obs

#### Set up initial high and low level observations

obs_h = label_to_indices(mappings_high, what_obs_h, where_obs_h, dimensions_high)
qs_high = agent_high.D

#### Create some variables to store the history of choices

scene_beliefs_high = np.zeros((T_high, dimensions_high['num_states'][0]))
direction_beliefs_low = np.zeros((T_high, T_low, dimensions_low['num_states'][0]))
search_choices_high = []
sampling_prob_low = np.zeros((T_high, T_low, 2))

#### Hierarchical active inference loop

for t_h in range(T_high):

    empirical_prior = get_prior_from_above(beliefs_high=qs_high)
    agent_low.reset(init_qs=empirical_prior)

    # reset the low-level process
    sampling_action = 'sample'
    t_s = 0
    if where_obs_h in ['1','2','3','4']:
        what_obs_l, where_obs_l = env_low_level.reset(dot_direction = what_obs_h, sampling_state = sampling_action)
    else:
        what_obs_l, where_obs_l = env_low_level.reset(dot_direction = 'null', sampling_state = sampling_action)

    obs_l = label_to_indices(mappings_low, what_obs_l, where_obs_l, dimensions_low)

    while (t_s < T_low) and (sampling_action == 'sample'):
        qs_low = agent_low.infer_states(obs_l)
        direction_beliefs_low[t_h, t_s, :] = qs_low[0].copy()
        q_pi, _ = agent_low.infer_policies()
        sampling_prob_low[t_h, t_s, :] = q_pi.copy()
        chosen_action_id = agent_low.sample_action()
        sampling_action = mappings_low['action_names'][int(chosen_action_id[1])]
        what_obs_l, where_obs_l = env_low_level.step(sampling_action)
        obs_l = label_to_indices(mappings_low, what_obs_l, where_obs_l, dimensions_low)
        t_s += 1 
    
    if where_obs_h in ['1','2','3','4']:
        obs_h = get_obs_from_below(beliefs_low=qs_low, where_obs_h=where_obs_h)
        use_dist = True
    else:
        obs_h = label_to_indices(mappings_high, what_obs_h, where_obs_h, dimensions_high)
        use_dist = False

    qs_high = agent_high.infer_states(obs_h, distr_obs = use_dist)
    scene_beliefs_high[t_h,:] = qs_high[0].copy()
    agent_high.infer_policies()
    chosen_action_id = agent_high.sample_action()
    movement_id = int(chosen_action_id[1])
    search_action = mappings_high['action_names'][movement_id]
    what_obs_h, where_obs_h = env_high_level.step(search_action)

    # get expected states over the next location, given the action just taken
    qs_high = control.get_expected_states(qs_high, agent_high.B, chosen_action_id.reshape(1,-1))[0]

    # append last saccade action to history of choices
    search_choices_high.append(search_action)
      

#### Now do some plotting of the hierarchical agent's inference over time

from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
def plot_beliefs_hierarchical(beliefs_low, beliefs_high, t_high_to_show):
  fig = plt.figure(tight_layout=True,figsize=(10,12))

  num_sub_t_step = len(t_high_to_show)
  gs = gridspec.GridSpec(num_sub_t_step, num_sub_t_step)

  T_high = beliefs_high.shape[0]
  beliefs_high = beliefs_high.reshape(beliefs_high.shape[0], 4, 12).sum(2) # average across configurations
  ax_top = fig.add_subplot(gs[0, :])
  imdata = ax_top.imshow(beliefs_high.T, clim = (0.0, 1.0))
  ax_top.set_xticks(np.arange(T_high))
  ax_top.set_yticks(np.arange(4))
  ax_top.set_yticklabels(labels = mappings_high['scene_names'], rotation=45)
  fig.colorbar(imdata, ax=ax_top)
  
  for t_h in t_high_to_show:
    ax_top.add_patch(Rectangle((t_h-0.5, -0.5), 1, 5, edgecolor='red', fill=False, lw=3))
    ax_top.text(t_h-0.25, 2.0, f't = {t_h}', fontsize=12.5, color = 'white', rotation = 30)

  for (i,t_h) in enumerate(t_high_to_show):
    ax = fig.add_subplot(gs[1, i])

    ax.plot(beliefs_low[t_h,:,1], label = '$P(Up)$')
    ax.plot(beliefs_low[t_h,:,2], label = '$P(Right)$')
    ax.plot(beliefs_low[t_h,:,3], label = '$P(Down)$')
    ax.plot(beliefs_low[t_h,:,4], label = '$P(Left)$')
    ax.set_xlim(0, beliefs_low.shape[1])
    ax.set_ylim(0, 1.0)
    ax.set_ylabel('Inferred probability of motion')
    ax.legend(fontsize=12)
    ax.set_title('$T_h = $' + f'{t_h}')

plot_beliefs_hierarchical(direction_beliefs_low, scene_beliefs_high, [0, 1, 2])

Starting location is start, Scene is UP_RIGHT, Configuration is 1_2

Re-initialized location to Start location
True motion direction is null, motion coherence is 0.0

Qs:[array([1.e+00, 1.e-16, 1.e-16, 1.e-16, 1.e-16])
 array([1.00000000e+00, 1.00000147e-32])]
G: [-2.30308522 -2.30208522]
q_pi: [0.49600009 0.50399991]


IndexError: arrays used as indices must be of integer (or boolean) type