diff --git a/.gitignore b/.gitignore index f94dbb2..e30d509 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__* *.ini *.xml *.iml +*.ipynb_checkpoints/ diff --git a/tools/__init__.py b/__init__.py similarity index 100% rename from tools/__init__.py rename to __init__.py diff --git a/blockference/.ipynb_checkpoints/gridference-checkpoint.py b/blockference/.ipynb_checkpoints/gridference-checkpoint.py new file mode 100644 index 0000000..0976e8d --- /dev/null +++ b/blockference/.ipynb_checkpoints/gridference-checkpoint.py @@ -0,0 +1,395 @@ +import enum # not currently used? Is enum used in another script? or can remove. +import sys +from pymdp.control import construct_policies +import pymdp.utils as u +import random as rand +import itertools +import numpy as np + +from matplotlib.pyplot import grid + +# adding tools to the system path +sys.path.insert(0, '../tools/') + + +def actinf_planning_single(agent, env_state, A, B, C, prior): + policies = construct_policies([agent.n_states], + [len(agent.E)], + policy_len=agent.policy_len) + # get obs_idx + obs_idx = grid.index(env_state) + + # infer_states + qs_current = u.infer_states(obs_idx, A, prior) + + # calc efe + G = u.calculate_G_policies(A, B, C, qs_current, policies=policies) + + # calc action posterior + Q_pi = u.softmax(-G) + + # compute the probability of each action + P_u = u.compute_prob_actions(agent.E, policies, Q_pi) + + # sample action + chosen_action = u.sample(P_u) + + # calc next prior + prior = B[:, :, chosen_action].dot(qs_current) + + # update env state + # action_label = params['actions'][chosen_action] + + (Y, X) = env_state + Y_new = Y + X_new = X + + if chosen_action == 0: # UP + + Y_new = Y - 1 if Y > 0 else Y + X_new = X + + elif chosen_action == 1: # DOWN + + Y_new = Y + 1 if Y < agent.border else Y + X_new = X + + elif chosen_action == 2: # LEFT + Y_new = Y + X_new = X - 1 if X > 0 else X + + elif chosen_action == 3: # RIGHT + Y_new = Y + X_new = X + 1 if X < agent.border else X + + elif chosen_action == 4: # STAY + Y_new, X_new = Y, X + + current_state = (Y_new, X_new) # store the new grid location + + return {'update_prior': prior, + 'update_env': current_state, + 'update_action': chosen_action, + 'update_inference': qs_current} + + +def actinf_graph(agent_network): + + # list of all updates to the agents in the network + agent_updates = [] + + for agent in agent_network.nodes: + + policies = construct_policies([agent_network.nodes[agent]['agent'].n_states], [len(agent_network.nodes[agent]['agent'].E)], policy_len=agent_network.nodes[agent]['agent'].policy_len) + # get obs_idx + obs_idx = grid.index(agent_network.nodes[agent]['env_state']) + + # infer_states + qs_current = u.infer_states(obs_idx, agent_network.nodes[agent]['prior_A'], agent_network.nodes[agent]['prior'], noise=1) + + # calc efe + _G = u.calculate_G_policies(agent_network.nodes[agent]['prior_A'], agent_network.nodes[agent]['prior_B'], agent_network.nodes[agent]['prior_C'], qs_current, policies=policies) + + # calc action posterior + Q_pi = u.softmax(-_G) + # compute the probability of each action + P_u = u.compute_prob_actions(agent_network.nodes[agent]['agent'].E, policies, Q_pi) + + # sample action + chosen_action = u.sample(P_u) + + # calc next prior + prior = agent_network.nodes[agent]['prior_B'][:, :, chosen_action].dot(qs_current) + + # update env state + # action_label = params['actions'][chosen_action] + + (Y, X) = agent_network.nodes[agent]['env_state'] + Y_new = Y + X_new = X + # here + + if chosen_action == 0: # UP + + Y_new = Y - 1 if Y > 0 else Y + X_new = X + + elif chosen_action == 1: # DOWN + + Y_new = Y + 1 if Y < agent_network.nodes[agent]['agent'].border else Y + X_new = X + + elif chosen_action == 2: # LEFT + Y_new = Y + X_new = X - 1 if X > 0 else X + + elif chosen_action == 3: # RIGHT + Y_new = Y + X_new = X + 1 if X < agent_network.nodes[agent]['agent'].border else X + + elif chosen_action == 4: # STAY + Y_new, X_new = Y, X + + current_state = (Y_new, X_new) # store the new grid location + agent_update = {'source': agent, + 'update_prior': prior, + 'update_env': current_state, + 'update_action': chosen_action, + 'update_inference': qs_current} + agent_updates.append(agent_update) + + return {'agent_updates': agent_updates} + + +class GridAgent(): + def __init__(self, grid_len, num_agents, grid_dim=2) -> None: + self.grid = self.get_grid(grid_len, grid_dim) + self.grid_dim = grid_dim + self.no_actions = 2 * grid_dim + 1 + self.agents = self.init_agents(num_agents) + + def get_grid(self, grid_len, grid_dim): + g = list(itertools.product(range(grid_len), repeat=grid_dim)) + for i, p in enumerate(g): + g[i] += (0,) + return g + + def move_grid(self, agent, chosen_action): + no_actions = 2 * self.grid_dim + state = list(agent.env_state) + new_state = state.copy() + + # here + + if chosen_action == 0: # STAY + new_state = state + else: + if chosen_action % 2 == 1: + index = (chosen_action+1) / 2 + new_state[index] = state[index] - 1 if state[index] > 0 else state[index] + elif chosen_action % 2 == 0: + index = chosen_action / 2 + new_state[index] = state[index] + 1 if state[index] < agent.border else state[index] + return new_state + + def init_agents(self, no_agents): + # create a dict of agents + agents = {} + + for a in range(no_agents): + # create new agent + agent = ActiveGridference(self.grid) + # generate target state + target = (rand.randint(0, 9), rand.randint(0, 9)) + # add target state + agent.get_C(target + (0,)) + # all agents start in the same position + start = (rand.randint(0, 9), rand.randint(0, 9)) + agent.get_D(start + (1,)) + + agents[a] = agent + + return agents + + def actinf_dict(self, agents_dict, g_agent): + # list of all updates to the agents in the network + agent_updates = [] + + for source, agent in agents_dict.items(): + + policies = construct_policies([agent.n_states], [len(agent.E)], policy_len=agent.policy_len) + # get obs_idx + obs_idx = g_agent.grid.index(agent.env_state) + + # infer_states + qs_current = u.infer_states(obs_idx, agent.A, agent.prior) + + # calc efe + _G = u.calculate_G_policies(agent.A, agent.B, agent.C, qs_current, policies=policies) + + # calc action posterior + Q_pi = u.softmax(-_G) + # compute the probability of each action + P_u = u.compute_prob_actions(agent.E, policies, Q_pi) + + # sample action + chosen_action = u.sample(P_u) + + # calc next prior + prior = agent.B[:, :, chosen_action].dot(qs_current) + + # update env state + # action_label = params['actions'][chosen_action] + + current_state = self.move_2d(agent, chosen_action) # store the new grid location + agent_update = {'source': source, + 'update_prior': prior, + 'update_env': current_state, + 'update_action': chosen_action, + 'update_inference': qs_current} + agent_updates.append(agent_update) + + return {'agent_updates': agent_updates} + + def move_2d(self, agent, chosen_action): + (Y, X) = agent.env_state + Y_new = Y + X_new = X + # here + + if chosen_action == 0: # UP + + Y_new = Y - 1 if Y > 0 else Y + X_new = X + + elif chosen_action == 1: # DOWN + + Y_new = Y + 1 if Y < agent.border else Y + X_new = X + + elif chosen_action == 2: # LEFT + Y_new = Y + X_new = X - 1 if X > 0 else X + + elif chosen_action == 3: # RIGHT + Y_new = Y + X_new = X + 1 if X < agent.border else X + + elif chosen_action == 4: # STAY + Y_new, X_new = Y, X + + return (X_new, Y_new) + + def move_3d(self, agent, chosen_action): + (Y, X, Z) = agent.env_state + Y_new = Y + X_new = X + Z_new = Z + # here + + if chosen_action == 0: # UP + + Y_new = Y - 1 if Y > 0 else Y + X_new = X + Z_new = Z + + elif chosen_action == 1: # DOWN + + Y_new = Y + 1 if Y < agent.border else Y + X_new = X + Z_new = Z + + elif chosen_action == 2: # LEFT + Y_new = Y + X_new = X - 1 if X > 0 else X + Z_new = Z + + elif chosen_action == 3: # RIGHT + Y_new = Y + X_new = X + 1 if X < agent.border else X + Z_new = Z + + elif chosen_action == 4: # IN + X_new = X + Y_new = Y + Z_new = Z + 1 if Z < agent.border else Z + + elif chosen_action == 5: # OUT + X_new = X + Y_new = Y + Z_new = Z - 1 if Z > agent.border else Z + + elif chosen_action == 6: # STAY + Y_new, X_new, Z_new = Y, X, Z + + return (X_new, Y_new, Z_new) + +class ActiveGridference(): + """ + The ActiveInference class is to be used to create a generative model to be used in cadCAD simulations. + The current focus is on discrete spaces. + ------------------------------------------------------ + An actinf generative model consists of the following: + + - (state matrix) A -> the generative model's prior beliefs about how hidden states relate to observations + - (state-transition matrix) B -> the generative model's prior beliefs about controllable transitions between hidden states over time + - (preference matrix) C -> the biased generative model's prior preference for particular observations encoded in terms of probabilities + - (initial state) D -> the generative model's prior belief over hidden states at the first timestep + - (affordances) E -> the generative model's available actions + """ + def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0)) -> None: + super().__init__() + self.A = None + self.B = None + self.C = None + self.D = None + self.E = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"] + self.grid = grid + + self.policy_len = planning_length + + # environment + self.n_states = len(self.grid) + self.n_observations = len(self.grid) + self.border = np.sqrt(self.n_states) - 1 + + # active + self.prior = self.D + self.current_action = '' + self.current_inference = '' + self.env_state = env_state + + if self.grid is not None: + self.get_A() + self.get_B() + + def get_A(self): + """ + State Matrix (identity matrix) + Params: + - n_observations: int: number of possible observations + - n_states: int: number of possible states + """ + self.A = np.eye(self.n_observations, self.n_states) + + def get_B(self): + """State-Transition Matrix""" + self.B = np.zeros((len(self.grid), len(self.grid), len(self.E))) + + for action_id, action_label in enumerate(self.E): + + for curr_state, grid_location in enumerate(self.grid): + + y, x = grid_location + + if action_label == "UP": + next_y = y - 1 if y > 0 else y + next_x = x + elif action_label == "DOWN": + next_y = y + 1 if y < self.border else y + next_x = x + elif action_label == "LEFT": + next_x = x - 1 if x > 0 else x + next_y = y + elif action_label == "RIGHT": + next_x = x + 1 if x < self.border else x + next_y = y + elif action_label == "STAY": + next_x = x + next_y = y + new_location = (next_y, next_x) + next_state = self.grid.index(new_location) + self.B[next_state, curr_state, action_id] = 1.0 + + def get_C(self, preferred_state: tuple): + """Target Location (preferences)""" + self.C = u.onehot(self.grid.index(preferred_state), self.n_observations) + + def get_D(self, initial_state): + """Initial State""" + self.D = u.onehot(self.grid.index(initial_state), self.n_states) + self.prior = self.D + + def get_E(self, actions: list): + self.E = actions + \ No newline at end of file diff --git a/blockference/gridference.py b/blockference/gridference.py index d768dc7..0976e8d 100644 --- a/blockference/gridference.py +++ b/blockference/gridference.py @@ -1,10 +1,10 @@ import enum # not currently used? Is enum used in another script? or can remove. import sys -from tools.model import ActiveGridference -from tools.control import construct_policies -import tools.utils as u +from pymdp.control import construct_policies +import pymdp.utils as u import random as rand import itertools +import numpy as np from matplotlib.pyplot import grid @@ -302,4 +302,94 @@ def move_3d(self, agent, chosen_action): elif chosen_action == 6: # STAY Y_new, X_new, Z_new = Y, X, Z - return (X_new, Y_new, Z_new) \ No newline at end of file + return (X_new, Y_new, Z_new) + +class ActiveGridference(): + """ + The ActiveInference class is to be used to create a generative model to be used in cadCAD simulations. + The current focus is on discrete spaces. + ------------------------------------------------------ + An actinf generative model consists of the following: + + - (state matrix) A -> the generative model's prior beliefs about how hidden states relate to observations + - (state-transition matrix) B -> the generative model's prior beliefs about controllable transitions between hidden states over time + - (preference matrix) C -> the biased generative model's prior preference for particular observations encoded in terms of probabilities + - (initial state) D -> the generative model's prior belief over hidden states at the first timestep + - (affordances) E -> the generative model's available actions + """ + def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0)) -> None: + super().__init__() + self.A = None + self.B = None + self.C = None + self.D = None + self.E = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"] + self.grid = grid + + self.policy_len = planning_length + + # environment + self.n_states = len(self.grid) + self.n_observations = len(self.grid) + self.border = np.sqrt(self.n_states) - 1 + + # active + self.prior = self.D + self.current_action = '' + self.current_inference = '' + self.env_state = env_state + + if self.grid is not None: + self.get_A() + self.get_B() + + def get_A(self): + """ + State Matrix (identity matrix) + Params: + - n_observations: int: number of possible observations + - n_states: int: number of possible states + """ + self.A = np.eye(self.n_observations, self.n_states) + + def get_B(self): + """State-Transition Matrix""" + self.B = np.zeros((len(self.grid), len(self.grid), len(self.E))) + + for action_id, action_label in enumerate(self.E): + + for curr_state, grid_location in enumerate(self.grid): + + y, x = grid_location + + if action_label == "UP": + next_y = y - 1 if y > 0 else y + next_x = x + elif action_label == "DOWN": + next_y = y + 1 if y < self.border else y + next_x = x + elif action_label == "LEFT": + next_x = x - 1 if x > 0 else x + next_y = y + elif action_label == "RIGHT": + next_x = x + 1 if x < self.border else x + next_y = y + elif action_label == "STAY": + next_x = x + next_y = y + new_location = (next_y, next_x) + next_state = self.grid.index(new_location) + self.B[next_state, curr_state, action_id] = 1.0 + + def get_C(self, preferred_state: tuple): + """Target Location (preferences)""" + self.C = u.onehot(self.grid.index(preferred_state), self.n_observations) + + def get_D(self, initial_state): + """Initial State""" + self.D = u.onehot(self.grid.index(initial_state), self.n_states) + self.prior = self.D + + def get_E(self, actions: list): + self.E = actions + \ No newline at end of file diff --git a/blockference/tools/.ipynb_checkpoints/__init__-checkpoint.py b/blockference/tools/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/policy.py b/blockference/tools/.ipynb_checkpoints/policy-checkpoint.py similarity index 100% rename from tools/policy.py rename to blockference/tools/.ipynb_checkpoints/policy-checkpoint.py diff --git a/tools/utils.py b/blockference/tools/.ipynb_checkpoints/utils-checkpoint.py similarity index 100% rename from tools/utils.py rename to blockference/tools/.ipynb_checkpoints/utils-checkpoint.py diff --git a/blockference/tools/__init__.py b/blockference/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/agent.py b/blockference/tools/agent.py similarity index 100% rename from tools/agent.py rename to blockference/tools/agent.py diff --git a/tools/mutual_info.py b/blockference/tools/mutual_info.py similarity index 100% rename from tools/mutual_info.py rename to blockference/tools/mutual_info.py diff --git a/blockference/tools/policy.py b/blockference/tools/policy.py new file mode 100644 index 0000000..2cafc53 --- /dev/null +++ b/blockference/tools/policy.py @@ -0,0 +1,136 @@ +import tools.utils as u +from tools.control import construct_policies + +# Policies for cadCAD actinf simulations + + +# single-agent with planning +def p_actinf(params, substep, state_history, previous_state, act, grid): + + policies = construct_policies([act.n_states], [len(act.E)], policy_len=act.policy_len) + # get obs_idx + obs_idx = grid.index(previous_state['env_state']) + + # infer_states + qs_current = u.infer_states(obs_idx, previous_state['prior_A'], previous_state['prior']) + + # calc efe + G = u.calculate_G_policies(previous_state['prior_A'], previous_state['prior_B'], previous_state['prior_C'], qs_current, policies=policies) + + # calc action posterior + Q_pi = u.softmax(-G) + + # compute the probability of each action + P_u = u.compute_prob_actions(act.E, policies, Q_pi) + + # sample action + chosen_action = u.sample(P_u) + + # calc next prior + prior = previous_state['prior_B'][:, :, chosen_action].dot(qs_current) + + # update env state + # action_label = params['actions'][chosen_action] + + (Y, X) = previous_state['env_state'] + Y_new = Y + X_new = X + + if chosen_action == 0: # UP + + Y_new = Y - 1 if Y > 0 else Y + X_new = X + + elif chosen_action == 1: # DOWN + + Y_new = Y + 1 if Y < act.border else Y + X_new = X + + elif chosen_action == 2: # LEFT + Y_new = Y + X_new = X - 1 if X > 0 else X + + elif chosen_action == 3: # RIGHT + Y_new = Y + X_new = X + 1 if X < act.border else X + + elif chosen_action == 4: # STAY + Y_new, X_new = Y, X + + current_state = (Y_new, X_new) # store the new grid location + + return {'update_prior': prior, + 'update_env': current_state, + 'update_action': chosen_action, + 'update_inference': qs_current} + + +# multi-agent (dict) gridworld +def p_actinf(params, substep, state_history, previous_state, grid): # Is this a useless re-definition from Line 8?? + # State Variables + agents = previous_state['agents'] + + # list of all updates to the agents in the network + agent_updates = [] + + for source, agent in agents.items(): + + policies = construct_policies([agent.n_states], [len(agent.E)], policy_len=agent.policy_len) + # get obs_idx + obs_idx = grid.index(agent.env_state) + + # infer_states + qs_current = u.infer_states(obs_idx, agent.A, agent.prior, 0) + + # calc efe + _G = u.calculate_G_policies(agent.A, agent.B, agent.C, qs_current, policies=policies) + + # calc action posterior + Q_pi = u.softmax(-_G, 0) + # compute the probability of each action + P_u = u.compute_prob_actions(agent.E, policies, Q_pi) + + # sample action + chosen_action = u.sample(P_u) + + # calc next prior + prior = agent.B[:, :, chosen_action].dot(qs_current) + + # update env state + # action_label = params['actions'][chosen_action] + + (Y, X) = agent.env_state + Y_new = Y + X_new = X + # here + + if chosen_action == 0: # UP + + Y_new = Y - 1 if Y > 0 else Y + X_new = X + + elif chosen_action == 1: # DOWN + + Y_new = Y + 1 if Y < agent.border else Y + X_new = X + + elif chosen_action == 2: # LEFT + Y_new = Y + X_new = X - 1 if X > 0 else X + + elif chosen_action == 3: # RIGHT + Y_new = Y + X_new = X + 1 if X < agent.border else X + + elif chosen_action == 4: # STAY + Y_new, X_new = Y, X + + current_state = (Y_new, X_new) # store the new grid location + agent_update = {'source': source, + 'update_prior': prior, + 'update_env': current_state, + 'update_action': chosen_action, + 'update_inference': qs_current} + agent_updates.append(agent_update) + + return {'agent_updates': agent_updates} diff --git a/blockference/tools/utils.py b/blockference/tools/utils.py new file mode 100644 index 0000000..1e545f7 --- /dev/null +++ b/blockference/tools/utils.py @@ -0,0 +1,804 @@ +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +import pandas as pd + +import warnings +import itertools +from pymdp.maths import spm_log_single as log_stable + +EPS_VAL = 1e-16 # global constant for use in norm_dist() + + +def softmax(dist): + """ + Computes the softmax function on a set of values + """ + + output = dist - dist.max(axis=0) + output = np.exp(output) + output = output / np.sum(output, axis=0) + return output + + +def sample(probabilities): + sample_onehot = np.random.multinomial(1, probabilities.squeeze()) + return np.where(sample_onehot == 1)[0][0] + + +def sample_obj_array(arr): + """ + Sample from set of Categorical distributions, stored in the sub-arrays of an object array + """ + + samples = [sample(arr_i) for arr_i in arr] + + return samples + + +def obj_array(num_arr): + """ + Creates a generic object array with the desired number of sub-arrays, given by `num_arr` + """ + return np.empty(num_arr, dtype=object) + + +def obj_array_zeros(shape_list): + """ + Creates a numpy object array whose sub-arrays are 1-D vectors + filled with zeros, with shapes given by shape_list[i] + """ + arr = obj_array(len(shape_list)) + for i, shape in enumerate(shape_list): + arr[i] = np.zeros(shape) + return arr + + +def obj_array_uniform(shape_list): + """ + Creates a numpy object array whose sub-arrays are uniform Categorical + distributions with shapes given by shape_list[i]. The shapes (elements of shape_list) + can either be tuples or lists. + """ + arr = obj_array(len(shape_list)) + for i, shape in enumerate(shape_list): + arr[i] = norm_dist(np.ones(shape)) + return arr + + +def obj_array_ones(shape_list, scale=1.0): + arr = obj_array(len(shape_list)) + for i, shape in enumerate(shape_list): + arr[i] = scale * np.ones(shape) + + return arr + + +def onehot(value, num_values): + arr = np.zeros(num_values) + arr[value] = 1.0 + return arr + + +def random_A_matrix(num_obs, num_states): + if type(num_obs) is int: + num_obs = [num_obs] + if type(num_states) is int: + num_states = [num_states] + num_modalities = len(num_obs) + + A = obj_array(num_modalities) + for modality, modality_obs in enumerate(num_obs): + modality_shape = [modality_obs] + num_states + modality_dist = np.random.rand(*modality_shape) + A[modality] = norm_dist(modality_dist) + return A + + +def random_B_matrix(num_states, num_controls): + if type(num_states) is int: + num_states = [num_states] + if type(num_controls) is int: + num_controls = [num_controls] + num_factors = len(num_states) + assert len(num_controls) == len(num_states) + + B = obj_array(num_factors) + for factor in range(num_factors): + factor_shape = (num_states[factor], num_states[factor], num_controls[factor]) + factor_dist = np.random.rand(*factor_shape) + B[factor] = norm_dist(factor_dist) + return B + + +def random_single_categorical(shape_list): + """ + Creates a random 1-D categorical distribution (or set of 1-D categoricals, e.g. multiple marginals of different factors) and returns them in an object array + """ + + num_sub_arrays = len(shape_list) + + out = obj_array(num_sub_arrays) + + for arr_idx, shape_i in enumerate(shape_list): + out[arr_idx] = norm_dist(np.random.rand(shape_i)) + + return out + + +def construct_controllable_B(num_states, num_controls): + """ + Generates a fully controllable transition likelihood array, where each + action (control state) corresponds to a move to the n-th state from any + other state, for each control factor + """ + + num_factors = len(num_states) + + B = obj_array(num_factors) + for factor, c_dim in enumerate(num_controls): + tmp = np.eye(c_dim)[:, :, np.newaxis] + tmp = np.tile(tmp, (1, 1, c_dim)) + B[factor] = tmp.transpose(1, 2, 0) + + return B + + +def dirichlet_like(template_categorical, scale=1.0): + """ + Helper function to construct a Dirichlet distribution based on an existing Categorical distribution + """ + + if not is_obj_array(template_categorical): + warnings.warn( + "Input array is not an object array...Casting the input to an object array" + ) + template_categorical = to_obj_array(template_categorical) + + n_sub_arrays = len(template_categorical) + + dirichlet_out = obj_array(n_sub_arrays) + + for i, arr in enumerate(template_categorical): + dirichlet_out[i] = scale * arr + + return dirichlet_out + + +def get_model_dimensions(A=None, B=None): + + if A is None and B is None: + raise ValueError( + "Must provide either `A` or `B`" + ) + + if A is not None: + num_obs = [a.shape[0] for a in A] if is_obj_array(A) else [A.shape[0]] + num_modalities = len(num_obs) + else: + num_obs, num_modalities = None, None + + if B is not None: + num_states = [b.shape[0] for b in B] if is_obj_array(B) else [B.shape[0]] + num_factors = len(num_states) + else: + if A is not None: + num_states = list(A[0].shape[1:]) if is_obj_array(A) else list(A.shape[1:]) + num_factors = len(num_states) + else: + num_states, num_factors = None, None + + return num_obs, num_states, num_modalities, num_factors + + +def get_model_dimensions_from_labels(model_labels): + + modalities = model_labels['observations'] + num_modalities = len(modalities.keys()) + num_obs = [len(modalities[modality]) for modality in modalities.keys()] + + factors = model_labels['states'] + num_factors = len(factors.keys()) + num_states = [len(factors[factor]) for factor in factors.keys()] + + if 'actions' in model_labels.keys(): + + controls = model_labels['actions'] + num_control_fac = len(controls.keys()) + num_controls = [len(controls[cfac]) for cfac in controls.keys()] + + return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac + else: + return num_obs, num_modalities, num_states, num_factors + + +def norm_dist(dist): + """ Normalizes a Categorical probability distribution (or set of them) assuming sufficient statistics are stored in leading dimension""" + if dist.ndim == 3: + new_dist = np.zeros_like(dist) + for c in range(dist.shape[2]): + new_dist[:, :, c] = np.divide(dist[:, :, c], dist[:, :, c].sum(axis=0)) + return new_dist + else: + return np.divide(dist, dist.sum(axis=0)) + + +def norm_dist_obj_arr(obj_arr): + + normed_obj_array = obj_array(len(obj_arr)) + for i, arr in enumerate(obj_arr): + normed_obj_array[i] = norm_dist(arr) + + return normed_obj_array + + +def is_normalized(dist): + """ + Utility function for checking whether a single distribution or set of conditional categorical distributions is normalized. + Returns True if all distributions integrate to 1.0 + """ + + if is_obj_array(dist): + normed_arrays = [] + for i, arr in enumerate(dist): + column_sums = arr.sum(axis=0) + normed_arrays.append(np.allclose(column_sums, np.ones_like(column_sums))) + out = all(normed_arrays) + else: + column_sums = dist.sum(axis=0) + out = np.allclose(column_sums, np.ones_like(column_sums)) + + return out + + +def is_obj_array(arr): + return arr.dtype == "object" + + +def to_obj_array(arr): + if is_obj_array(arr): + return arr + obj_array_out = obj_array(1) + obj_array_out[0] = arr.squeeze() + return obj_array_out + + +def obj_array_from_list(list_input): + """ + Takes a list of `numpy.ndarray` and converts them to a `numpy.ndarray` of `dtype = object` + """ + return np.array(list_input, dtype=object) + + +def process_observation_seq(obs_seq, n_modalities, n_observations): + """ + Helper function for formatting observations + Observations can either be `int` (converted to one-hot) + or `tuple` (obs for each modality), or `list` (obs for each modality) + If list, the entries could be object arrays of one-hots, in which + case this function returns `obs_seq` as is. + """ + proc_obs_seq = obj_array(len(obs_seq)) + for t, obs_t in enumerate(obs_seq): + proc_obs_seq[t] = process_observation(obs_t, n_modalities, n_observations) + return proc_obs_seq + + +def process_observation(obs, num_modalities, num_observations): + """ + Helper function for formatting observations + USAGE NOTES: + - If `obs` is a 1D numpy array, it must be a one-hot vector, where one entry (the entry of the observation) is 1.0 + and all other entries are 0. This therefore assumes it's a single modality observation. If these conditions are met, then + this function will return `obs` unchanged. Otherwise, it'll throw an error. + - If `obs` is an int, it assumes this is a single modality observation, whose observation index is given by the value of `obs`. This function will convert + it to be a one hot vector. + - If `obs` is a list, it assumes this is a multiple modality observation, whose len is equal to the number of observation modalities, + and where each entry `obs[m]` is the index of the observation, for that modality. This function will convert it into an object array + of one-hot vectors. + - If `obs` is a tuple, same logic as applies for list (see above). + - if `obs` is a numpy object array (array of arrays), this function will return `obs` unchanged. + """ + + if isinstance(obs, np.ndarray) and not is_obj_array(obs): + assert num_modalities == 1, "If `obs` is a 1D numpy array, `num_modalities` must be equal to 1" + assert len(np.where(obs)[0]) == 1, "If `obs` is a 1D numpy array, it must be a one hot vector (e.g. np.array([0.0, 1.0, 0.0, ....]))" + + if isinstance(obs, (int, np.integer)): + obs = onehot(obs, num_observations[0]) + + if isinstance(obs, tuple) or isinstance(obs, list): + obs_arr_arr = obj_array(num_modalities) + for m in range(num_modalities): + obs_arr_arr[m] = onehot(obs[m], num_observations[m]) + obs = obs_arr_arr + + return obs + + +def convert_observation_array(obs, num_obs): + """ + Converts from SPM-style observation array to infer-actively one-hot object arrays. + + Parameters + ---------- + - 'obs' [numpy 2-D nd.array]: + SPM-style observation arrays are of shape (num_modalities, T), where each row + contains observation indices for a different modality, and columns indicate + different timepoints. Entries store the indices of the discrete observations + within each modality. + - 'num_obs' [list]: + List of the dimensionalities of the observation modalities. `num_modalities` + is calculated as `len(num_obs)` in the function to determine whether we're + dealing with a single- or multi-modality + case. + Returns + ---------- + - `obs_t`[list]: + A list with length equal to T, where each entry of the list is either a) an object + array (in the case of multiple modalities) where each sub-array is a one-hot vector + with the observation for the correspond modality, or b) a 1D numpy array (in the case + of one modality) that is a single one-hot vector encoding the observation for the + single modality. + """ + + T = obs.shape[1] + num_modalities = len(num_obs) + + # Initialise the output + obs_t = [] + # Case of one modality + if num_modalities == 1: + for t in range(T): + obs_t.append(onehot(obs[0, t] - 1, num_obs[0])) + else: + for t in range(T): + obs_AoA = obj_array(num_modalities) + for g in range(num_modalities): + # Subtract obs[g,t] by 1 to account for MATLAB vs. Python indexing + # (MATLAB is 1-indexed) + obs_AoA[g] = onehot(obs[g, t] - 1, num_obs[g]) + obs_t.append(obs_AoA) + + return obs_t + + +def insert_multiple(s, indices, items): + for idx in range(len(items)): + s.insert(indices[idx], items[idx]) + return s + + +def reduce_a_matrix(A): + """ + Utility function for throwing away dimensions (lagging dimensions, hidden state factors) + of a particular A matrix that are independent of the observation. + Parameters: + ========== + - `A` [np.ndarray]: + The A matrix or likelihood array that encodes probabilistic relationship + of the generative model between hidden state factors (lagging dimensions, columns, slices, etc...) + and observations (leading dimension, rows). + Returns: + ========= + - `A_reduced` [np.ndarray]: + The reduced A matrix, missing the lagging dimensions that correspond to hidden state factors + that are statistically independent of observations + - `original_factor_idx` [list]: + List of the indices (in terms of the original dimensionality) of the hidden state factors + that are maintained in the A matrix (and thus have an informative / non-degenerate relationship to observations + """ + + o_dim, num_states = A.shape[0], A.shape[1:] + idx_vec_s = [slice(0, o_dim)] + [slice(ns) for _, ns in enumerate(num_states)] + + original_factor_idx = [] + excluded_factor_idx = [] # the indices of the hidden state factors that are independent of the observation and thus marginalized away + for factor_i, ns in enumerate(num_states): + + level_counter = 0 + break_flag = False + while level_counter < ns and break_flag is False: + idx_vec_i = idx_vec_s.copy() + idx_vec_i[factor_i+1] = slice(level_counter, level_counter+1, None) + if not np.isclose(A.mean(axis=factor_i+1), A[tuple(idx_vec_i)].squeeze()).all(): + break_flag = True # this means they're not independent + original_factor_idx.append(factor_i) + else: + level_counter += 1 + + if break_flag is False: + excluded_factor_idx.append(factor_i+1) + + A_reduced = A.mean(axis=tuple(excluded_factor_idx)).squeeze() + + return A_reduced, original_factor_idx + + +def construct_full_a(A_reduced, original_factor_idx, num_states): + """ + Utility function for reconstruction a full A matrix from a reduced A matrix, using known factor indices + to tile out the reduced A matrix along the 'non-informative' dimensions + Parameters: + ========== + - `A_reduced` [np.ndarray]: + The reduced A matrix or likelihood array that encodes probabilistic relationship + of the generative model between hidden state factors (lagging dimensions, columns, slices, etc...) + and observations (leading dimension, rows). + - `original_factor_idx` [list]: + List of hidden state indices in terms of the full hidden state factor list, that comprise + the lagging dimensions of `A_reduced` + - `num_states` [list]: + The list of all the dimensionalities of hidden state factors in the full generative model. + `A_reduced.shape[1:]` should be equal to `num_states[original_factor_idx]` + Returns: + ========= + - `A` [np.ndarray]: + The full A matrix, containing all the lagging dimensions that correspond to hidden state factors, including + those that are statistically independent of observations + + @ NOTE: This is the "inverse" of the reduce_a_matrix function, + i.e. `reduce_a_matrix(construct_full_a(A_reduced, original_factor_idx, num_states)) == A_reduced, original_factor_idx` + """ + + o_dim = A_reduced.shape[0] # dimensionality of the support of the likelihood distribution (i.e. the number of observation levels) + full_dimensionality = [o_dim] + num_states # full dimensionality of the output (`A`) + fill_indices = [0] + [f+1 for f in original_factor_idx] # these are the indices of the dimensions we need to fill for this modality + fill_dimensions = np.delete(full_dimensionality, fill_indices) + + original_factor_dims = [num_states[f] for f in original_factor_idx] # dimensionalities of the relevant factors + prefilled_slices = [slice(0, o_dim)] + [slice(0, ns) for ns in original_factor_dims] # these are the slices that are filled out by the provided `A_reduced` + + A = np.zeros(full_dimensionality) + + for item in itertools.product(*[list(range(d)) for d in fill_dimensions]): + slice_ = list(item) + A_indices = insert_multiple(slice_, fill_indices, prefilled_slices) # here we insert the correct values for the fill indices for this slice + A[tuple(A_indices)] = A_reduced + + return A + + +def create_A_matrix_stub(model_labels): + + num_obs, _, num_states, _ = get_model_dimensions_from_labels(model_labels) + + obs_labels, state_labels = model_labels['observations'], model_labels['states'] + + state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys())) + num_state_combos = np.prod(num_states) # What is num_state_combos?? + # num_rows = (np.array(num_obs) * num_state_combos).sum() + num_rows = sum(num_obs) + + cell_values = np.zeros((num_rows, len(state_combinations))) + + obs_combinations = [] + for modality in obs_labels.keys(): + levels_to_combine = [[modality]] + [obs_labels[modality]] + # obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine)) + obs_combinations += list(itertools.product(*levels_to_combine)) + + obs_combinations = pd.MultiIndex.from_tuples(obs_combinations, names=["Modality", "Level"]) + + A_matrix = pd.DataFrame(cell_values, index=obs_combinations, columns=state_combinations) + + return A_matrix + + +def create_B_matrix_stubs(model_labels): + + _, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels) + + state_labels = model_labels['states'] + action_labels = model_labels['actions'] + + B_matrices = {} + + for f_idx, factor in enumerate(state_labels.keys()): + + control_fac_name = list(action_labels)[f_idx] + factor_list = [state_labels[factor]] + [action_labels[control_fac_name]] + + prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]]) + + num_state_action_combos = num_states[f_idx] * num_controls[f_idx] + + num_rows = num_states[f_idx] + + cell_values = np.zeros((num_rows, num_state_action_combos)) + + next_state_list = state_labels[factor] + + B_matrix_f = pd.DataFrame(cell_values, index=next_state_list, columns=prev_state_action_combos) + + B_matrices[factor] = B_matrix_f + + return B_matrices + + +def read_A_matrix(path, num_hidden_state_factors): + raw_table = pd.read_excel(path, header=None) + level_counts = { + "index": raw_table.iloc[0, :].dropna().index[0] + 1, + "header": raw_table.iloc[0, :].dropna().index[0] + num_hidden_state_factors - 1, + } + return pd.read_excel( + path, + index_col=list(range(level_counts["index"])), + header=list(range(level_counts["header"])) + ).astype(np.float64) + + +def read_B_matrices(path): + + all_sheets = pd.read_excel(path, sheet_name=None, header=None) + + level_counts = {} + for sheet_name, raw_table in all_sheets.items(): + + level_counts[sheet_name] = { + "index": raw_table.iloc[0, :].dropna().index[0]+1, + "header": raw_table.iloc[0, :].dropna().index[0]+2, + } + + stub_dict = {} + for sheet_name, level_counts_sheet in level_counts.items(): + sheet_f = pd.read_excel( + path, + sheet_name=sheet_name, + index_col=list(range(level_counts_sheet["index"])), + header=list(range(level_counts_sheet["header"])) + ).astype(np.float64) + stub_dict[sheet_name] = sheet_f + + return stub_dict + + +def convert_A_stub_to_ndarray(A_stub, model_labels): + """ + This function converts a multi-index pandas dataframe `A_stub` into an object array of different + A matrices, one per observation modality. + """ + + num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels) + + A = obj_array(num_modalities) + + for g, modality_name in enumerate(model_labels['observations'].keys()): + A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states) + assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n' + + return A + + +def convert_B_stubs_to_ndarray(B_stubs, model_labels): + """ + This function converts a list of multi-index pandas dataframes `B_stubs` into an object array + of different B matrices, one per hidden state factor + """ + + _, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels) + + B = obj_array(num_factors) + + for f, factor_name in enumerate(B_stubs.keys()): + + B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f]) + assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n' + + return B + + +def build_belief_array(qx): + """ + This function constructs array-ified (not nested) versions + of the posterior belief arrays, that are separated + by policy, timepoint, and hidden state factor + """ + + num_policies = len(qx) + num_timesteps = len(qx[0]) + num_factors = len(qx[0][0]) + + if num_factors > 1: + belief_array = obj_array(num_factors) + for factor in range(num_factors): + belief_array[factor] = np.zeros((num_policies, qx[0][0][factor].shape[0], num_timesteps)) + for policy_i in range(num_policies): + for timestep in range(num_timesteps): + for factor in range(num_factors): + belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor] + else: + num_states = qx[0][0][0].shape[0] + belief_array = np.zeros((num_policies, num_states, num_timesteps)) + for policy_i in range(num_policies): + for timestep in range(num_timesteps): + belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0] + + return belief_array + + +def build_xn_vn_array(xn): + """ + This function constructs array-ified (not nested) versions + of the posterior xn (beliefs) or vn (prediction error) arrays, that are separated + by iteration, hidden state factor, timepoint, and policy + """ + + num_policies = len(xn) + num_itr = len(xn[0]) + num_factors = len(xn[0][0]) + + if num_factors > 1: + xn_array = obj_array(num_factors) + for factor in range(num_factors): + num_states, infer_len = xn[0][0].shape + xn_array[factor] = np.zeros((num_itr, num_states, infer_len, num_policies)) + for policy_i in range(num_policies): + for itr in range(num_itr): + for factor in range(num_factors): + xn_array[factor][itr, :, :, policy_i] = xn[policy_i][itr][factor] + else: + num_states, infer_len = xn[0][0][0].shape + xn_array = np.zeros((num_itr, num_states, infer_len, num_policies)) + for policy_i in range(num_policies): + for itr in range(num_itr): + xn_array[itr, :, :, policy_i] = xn[policy_i][itr][0] + + return xn_array + + +# plotting functions +def plot_likelihood(matrix, xlabels=list(range(9)), ylabels=list(range(9)), title_str="Likelihood distribution (A)"): + """ + Plots a 2-D likelihood matrix as a heatmap + """ + + if not np.isclose(matrix.sum(axis=0), 1.0).all(): + raise ValueError("Distribution not column-normalized! Please normalize (ensure matrix.sum(axis=0) == 1.0 for all columns)") + fig = plt.figure(figsize=(6, 6)) # Unclear what "fig" is doing here. + ax=sns.heatmap(matrix, xticklabels=xlabels, yticklabels=ylabels, cmap='gray', cbar=False, vmin=0.0, vmax=1.0) # Unclear what ax is doing + plt.title(title_str) + plt.show() + + +def plot_grid(grid_locations, num_x=3, num_y=3): + """ + Plots the spatial coordinates of GridWorld as a heatmap, with each (X, Y) coordinate + labeled with its linear index (its `state id`) + """ + + grid_heatmap = np.zeros((num_x, num_y)) + for linear_idx, location in enumerate(grid_locations): + y, x = location + grid_heatmap[y, x] = linear_idx + sns.set(font_scale=1.5) + sns.heatmap(grid_heatmap, annot=True, cbar=False, fmt='.0f', cmap='crest') + + +def plot_point_on_grid(state_vector, grid_locations): + """ + Plots the current location of the agent on the grid world + """ + state_index = np.where(state_vector)[0][0] + y, x = grid_locations[state_index] + grid_heatmap = np.zeros((3, 3)) + grid_heatmap[y, x] = 1.0 + sns.heatmap(grid_heatmap, cbar=False, fmt='.0f') + + +def plot_beliefs(belief_dist, title_str=""): + """ + Plot a categorical distribution or belief distribution, stored in the 1-D numpy vector `belief_dist` + """ + + if not np.isclose(belief_dist.sum(), 1.0): + raise ValueError("Distribution not normalized! Please normalize") + + plt.grid(zorder=0) + plt.bar(range(belief_dist.shape[0]), belief_dist, color='r', zorder=3) + plt.xticks(range(belief_dist.shape[0])) + plt.title(title_str) + plt.show() + +# ActInf functions + + +def infer_states(observation_index, A, prior): + + log_likelihood = log_stable(A[observation_index, :]) + + log_prior = log_stable(prior) + + qs = softmax(log_likelihood + log_prior) + + return qs + + +def get_expected_states(B, qs_current, action): + """ Compute the expected states one step into the future, given a particular action """ + qs_u = B[:, :, action].dot(qs_current) + + return qs_u + +def get_expected_observations(A, qs_u): + """ Compute the expected observations one step into the future, given a particular action """ + + qo_u = A.dot(qs_u) + + return qo_u + +def entropy(A): + """ Compute the entropy of a set of conditional distributions, i.e. one entropy value per column """ + + H_A = - (A * log_stable(A)).sum(axis=0) + + return H_A + +def kl_divergence(qo_u, C): + """ Compute the Kullback-Leibler divergence between two 1-D categorical distributions""" + + return (log_stable(qo_u) - log_stable(C)).dot(qo_u) + +def calculate_G(A, B, C, qs_current, actions): + + G = np.zeros(len(actions)) # vector of expected free energies, one per action + + H_A = entropy(A) # entropy of the observation model, P(o|s) + + for action_i in range(len(actions)): + + qs_u = get_expected_states(B, qs_current, action_i) # expected states, under the action we're currently looping over + qo_u = get_expected_observations(A, qs_u) # expected observations, under the action we're currently looping over + + pred_uncertainty = H_A.dot(qs_u) # predicted uncertainty, i.e. expected entropy of the A matrix + pred_div = kl_divergence(qo_u, C) # predicted divergence + + G[action_i] = pred_uncertainty + pred_div # sum them together to get expected free energy + + return G + +def calculate_G_policies(A, B, C, qs_current, policies): + + G = np.zeros(len(policies)) # initialize the vector of expected free energies, one per policy + H_A = entropy(A) # can calculate the entropy of the A matrix beforehand, since it'll be the same for all policies + + for policy_id, policy in enumerate(policies): # loop over policies - policy_id will be the linear index of the policy (0, 1, 2, ...) and `policy` will be a column vector where `policy[t,0]` indexes the action entailed by that policy at time `t` + + t_horizon = policy.shape[0] # temporal depth of the policy + + G_pi = 0.0 # initialize expected free energy for this policy + + for t in range(t_horizon): # loop over temporal depth of the policy + + action = policy[t,0] # action entailed by this particular policy, at time `t` + + # get the past predictive posterior - which is either your current posterior at the current time (not the policy time) or the predictive posterior entailed by this policy, one timstep ago (in policy time) + if t == 0: + qs_prev = qs_current + else: + qs_prev = qs_pi_t # What is "qs_pi_t"? + + qs_pi_t = get_expected_states(B, qs_prev, action) # expected states, under the action entailed by the policy at this particular time + qo_pi_t = get_expected_observations(A, qs_pi_t) # expected observations, under the action entailed by the policy at this particular time + + kld = kl_divergence(qo_pi_t, C) # Kullback-Leibler divergence between expected observations and the prior preferences C + + G_pi_t = H_A.dot(qs_pi_t) + kld # predicted uncertainty + predicted divergence, for this policy & timepoint + + G_pi += G_pi_t # accumulate the expected free energy for each timepoint into the overall EFE for the policy + + G[policy_id] += G_pi + + return G + +def compute_prob_actions(actions, policies, Q_pi): + P_u = np.zeros(len(actions)) # initialize the vector of probabilities of each action + + for policy_id, policy in enumerate(policies): + P_u[int(policy[0,0])] += Q_pi[policy_id] # get the marginal probability for the given action, entailed by this policy at the first timestep + + P_u = norm_dist(P_u) # normalize the action probabilities + + return P_u diff --git a/notebooks/01_actinf_planning.ipynb b/notebooks/01_actinf_planning.ipynb new file mode 100644 index 0000000..84b00f3 --- /dev/null +++ b/notebooks/01_actinf_planning.ipynb @@ -0,0 +1,847 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Active Inference cadCAD model\n", + "\n", + "This notebook explores active inference agent modeling in arbitrarily large grid worlds." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Developing Active Inference Agents in cadCAD\n", + "\n", + "An active inference agent consists of the following matrices:\n", + "- $A$ -> $P(o|s)$ the generative model's prior beliefs about how hidden states relate to observations\n", + "- $B$ -> $๐(๐ _๐กโฃ๐ _{๐กโ1},๐ข_{๐กโ1})$ the generative model's prior beliefs about controllable transitions between hidden states over time\n", + "- $C$ -> $P(o)$ the biased generative model's prior preference for particular observations encoded in terms of probabilities\n", + "- $D$ -> $P(s)$ the generative model's prior belief over hidden states at the first timestep\n", + "- $E$ -> agent's affordances (in this notebook referred to as 'actions')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## cadCAD Standard Notebook Layout" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 0. Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: inferactively-pymdp in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (0.0.5)\n", + "Requirement already satisfied: autograd>=1.3 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.4)\n", + "Requirement already satisfied: toml>=0.10.2 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (0.10.2)\n", + "Requirement already satisfied: py>=1.10.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.11.0)\n", + "Requirement already satisfied: matplotlib>=3.1.3 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (3.5.2)\n", + "Requirement already satisfied: pytz>=2020.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (2022.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (2.8.2)\n", + "Requirement already satisfied: attrs>=20.3.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (21.4.0)\n", + "Requirement already satisfied: pandas>=1.2.4 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.4.2)\n", + "Requirement already satisfied: seaborn>=0.11.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (0.11.2)\n", + "Requirement already satisfied: packaging>=20.8 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (21.3)\n", + "Requirement already satisfied: six>=1.15.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.16.0)\n", + "Requirement already satisfied: iniconfig>=1.1.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.1.1)\n", + "Requirement already satisfied: nose>=1.3.7 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.3.7)\n", + "Requirement already satisfied: pyparsing>=2.4.7 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (3.0.9)\n", + "Requirement already satisfied: numpy>=1.19.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.22.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (3.10.0.2)\n", + "Requirement already satisfied: Pillow>=8.2.0pluggy>=0.13.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (9.1.1)\n", + "Requirement already satisfied: sphinx-rtd-theme>=0.4 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.0.0)\n", + "Requirement already satisfied: xlsxwriter>=1.4.3 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (3.0.3)\n", + "Requirement already satisfied: openpyxl>=3.0.7 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (3.0.10)\n", + "Requirement already satisfied: pytest>=6.2.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (6.2.5)\n", + "Requirement already satisfied: cycler>=0.10.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (0.11.0)\n", + "Requirement already satisfied: myst-nb>=0.13.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (0.15.0)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.4.2)\n", + "Requirement already satisfied: scipy>=1.6.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from inferactively-pymdp) (1.8.1)\n", + "Requirement already satisfied: future>=0.15.2 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from autograd>=1.3->inferactively-pymdp) (0.18.2)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from matplotlib>=3.1.3->inferactively-pymdp) (4.33.3)\n", + "Requirement already satisfied: sphinx<5,>=3.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (4.5.0)\n", + "Requirement already satisfied: nbformat~=5.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (5.4.0)\n", + "Requirement already satisfied: pyyaml in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (5.4.1)\n", + "Requirement already satisfied: ipython in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (8.3.0)\n", + "Requirement already satisfied: myst-parser~=0.17.2 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (0.17.2)\n", + "Requirement already satisfied: ipykernel in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (6.13.0)\n", + "Requirement already satisfied: importlib_metadata in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (4.11.4)\n", + "Requirement already satisfied: sphinx-togglebutton~=0.3.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (0.3.1)\n", + "Requirement already satisfied: docutils<0.18,>=0.15 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (0.17.1)\n", + "Requirement already satisfied: nbclient in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (0.5.13)\n", + "Requirement already satisfied: jupyter-cache~=0.5.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-nb>=0.13.1->inferactively-pymdp) (0.5.0)\n", + "Requirement already satisfied: tabulate in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jupyter-cache~=0.5.0->myst-nb>=0.13.1->inferactively-pymdp) (0.8.9)\n", + "Requirement already satisfied: click in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jupyter-cache~=0.5.0->myst-nb>=0.13.1->inferactively-pymdp) (8.1.3)\n", + "Requirement already satisfied: sqlalchemy<1.5,>=1.3.12 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jupyter-cache~=0.5.0->myst-nb>=0.13.1->inferactively-pymdp) (1.4.37)\n", + "Requirement already satisfied: jinja2 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-parser~=0.17.2->myst-nb>=0.13.1->inferactively-pymdp) (3.0.3)\n", + "Requirement already satisfied: markdown-it-py<3.0.0,>=1.0.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-parser~=0.17.2->myst-nb>=0.13.1->inferactively-pymdp) (2.1.0)\n", + "Requirement already satisfied: mdit-py-plugins~=0.3.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from myst-parser~=0.17.2->myst-nb>=0.13.1->inferactively-pymdp) (0.3.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from markdown-it-py<3.0.0,>=1.0.0->myst-parser~=0.17.2->myst-nb>=0.13.1->inferactively-pymdp) (0.1.1)\n", + "Requirement already satisfied: traitlets>=5.0.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from nbclient->myst-nb>=0.13.1->inferactively-pymdp) (5.2.1.post0)\n", + "Requirement already satisfied: nest-asyncio in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from nbclient->myst-nb>=0.13.1->inferactively-pymdp) (1.5.4)\n", + "Requirement already satisfied: jupyter-client>=6.1.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from nbclient->myst-nb>=0.13.1->inferactively-pymdp) (7.3.1)\n", + "Requirement already satisfied: jupyter-core>=4.9.2 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jupyter-client>=6.1.5->nbclient->myst-nb>=0.13.1->inferactively-pymdp) (4.10.0)\n", + "Requirement already satisfied: tornado>=6.0 in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from jupyter-client>=6.1.5->nbclient->myst-nb>=0.13.1->inferactively-pymdp) (6.1)\n", + "Requirement already satisfied: pyzmq>=22.3 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jupyter-client>=6.1.5->nbclient->myst-nb>=0.13.1->inferactively-pymdp) (23.0.0)\n", + "Requirement already satisfied: entrypoints in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from jupyter-client>=6.1.5->nbclient->myst-nb>=0.13.1->inferactively-pymdp) (0.4)\n", + "Requirement already satisfied: fastjsonschema in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from nbformat~=5.0->myst-nb>=0.13.1->inferactively-pymdp) (2.15.3)\n", + "Requirement already satisfied: jsonschema>=2.6 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from nbformat~=5.0->myst-nb>=0.13.1->inferactively-pymdp) (3.2.0)\n", + "Requirement already satisfied: setuptools in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jsonschema>=2.6->nbformat~=5.0->myst-nb>=0.13.1->inferactively-pymdp) (61.2.0)\n", + "Requirement already satisfied: pyrsistent>=0.14.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jsonschema>=2.6->nbformat~=5.0->myst-nb>=0.13.1->inferactively-pymdp) (0.18.1)\n", + "Requirement already satisfied: et-xmlfile in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from openpyxl>=3.0.7->inferactively-pymdp) (1.1.0)\n", + "Requirement already satisfied: pluggy<2.0,>=0.12 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from pytest>=6.2.1->inferactively-pymdp) (1.0.0)\n", + "Requirement already satisfied: sphinxcontrib-jsmath in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.0.1)\n", + "Requirement already satisfied: snowballstemmer>=1.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2.2.0)\n", + "Requirement already satisfied: requests>=2.5.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2.27.1)\n", + "Requirement already satisfied: babel>=1.3 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2.10.1)\n", + "Requirement already satisfied: sphinxcontrib-applehelp in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.0.2)\n", + "Requirement already satisfied: Pygments>=2.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2.12.0)\n", + "Requirement already satisfied: sphinxcontrib-htmlhelp>=2.0.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2.0.0)\n", + "Requirement already satisfied: alabaster<0.8,>=0.7 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (0.7.12)\n", + "Requirement already satisfied: imagesize in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.3.0)\n", + "Requirement already satisfied: sphinxcontrib-serializinghtml>=1.1.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.1.5)\n", + "Requirement already satisfied: sphinxcontrib-devhelp in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.0.2)\n", + "Requirement already satisfied: sphinxcontrib-qthelp in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.0.3)\n", + "Requirement already satisfied: zipp>=0.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from importlib_metadata->myst-nb>=0.13.1->inferactively-pymdp) (3.8.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from jinja2->myst-parser~=0.17.2->myst-nb>=0.13.1->inferactively-pymdp) (2.1.1)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from requests>=2.5.0->sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (3.3)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from requests>=2.5.0->sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2.0.12)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from requests>=2.5.0->sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (1.26.9)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from requests>=2.5.0->sphinx<5,>=3.5->myst-nb>=0.13.1->inferactively-pymdp) (2022.5.18.1)\n", + "Requirement already satisfied: wheel in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sphinx-togglebutton~=0.3.0->myst-nb>=0.13.1->inferactively-pymdp) (0.37.1)\n", + "Requirement already satisfied: greenlet!=0.4.17 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from sqlalchemy<1.5,>=1.3.12->jupyter-cache~=0.5.0->myst-nb>=0.13.1->inferactively-pymdp) (1.1.2)\n", + "Requirement already satisfied: debugpy>=1.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from ipykernel->myst-nb>=0.13.1->inferactively-pymdp) (1.6.0)\n", + "Requirement already satisfied: psutil in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from ipykernel->myst-nb>=0.13.1->inferactively-pymdp) (5.9.1)\n", + "Requirement already satisfied: matplotlib-inline>=0.1 in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipykernel->myst-nb>=0.13.1->inferactively-pymdp) (0.1.3)\n", + "Requirement already satisfied: appnope in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from ipykernel->myst-nb>=0.13.1->inferactively-pymdp) (0.1.3)\n", + "Requirement already satisfied: pexpect>4.3 in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (4.8.0)\n", + "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /Users/jakubsmekal/miniconda3/envs/dev/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (3.0.29)\n", + "Requirement already satisfied: backcall in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.2.0)\n", + "Requirement already satisfied: pickleshare in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.7.5)\n", + "Requirement already satisfied: decorator in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (5.1.1)\n", + "Requirement already satisfied: jedi>=0.16 in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.18.1)\n", + "Requirement already satisfied: stack-data in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.1.4)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from jedi>=0.16->ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.8.3)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from pexpect>4.3->ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.7.0)\n", + "Requirement already satisfied: wcwidth in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.2.5)\n", + "Requirement already satisfied: asttokens in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from stack-data->ipython->myst-nb>=0.13.1->inferactively-pymdp) (2.0.5)\n", + "Requirement already satisfied: executing in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from stack-data->ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.8.2)\n", + "Requirement already satisfied: pure-eval in /Users/jakubsmekal/.local/lib/python3.8/site-packages (from stack-data->ipython->myst-nb>=0.13.1->inferactively-pymdp) (0.2.2)\n" + ] + } + ], + "source": [ + "!pip install inferactively-pymdp" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from random import normalvariate, random\n", + "import plotly.express as px\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import sys\n", + "\n", + "sys.path.insert(0, '../')\n", + "\n", + "from radcad import Model, Simulation, Experiment\n", + "\n", + "\n", + "# Additional dependencies\n", + "from pymdp.control import construct_policies\n", + "\n", + "# For analytics\n", + "import itertools\n", + "\n", + "# For visualization\n", + "import plotly.express as px\n", + "\n", + "# local utils\n", + "import blockference.tools.utils as u\n", + "from blockference.gridference import ActiveGridference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 0. Useful functions" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (7, 0), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 0), (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (9, 0), (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (9, 6), (9, 7), (9, 8), (9, 9)]\n", + "9.0\n", + "100\n" + ] + } + ], + "source": [ + "grid = list(itertools.product(range(10), repeat=2))\n", + "print(grid)\n", + "act = ActiveGridference(grid)\n", + "act.get_C((9, 3))\n", + "act.get_D((0, 0))\n", + "print(act.border)\n", + "print(len(grid))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.02672363 0.00983107 0.00983107 ... 0.00983107 0.00983107 0.00983107]\n", + " [0.00983107 0.02672363 0.00983107 ... 0.00983107 0.00983107 0.00983107]\n", + " [0.00983107 0.00983107 0.02672363 ... 0.00983107 0.00983107 0.00983107]\n", + " ...\n", + " [0.00983107 0.00983107 0.00983107 ... 0.02672363 0.00983107 0.00983107]\n", + " [0.00983107 0.00983107 0.00983107 ... 0.00983107 0.02672363 0.00983107]\n", + " [0.00983107 0.00983107 0.00983107 ... 0.00983107 0.00983107 0.02672363]]\n" + ] + } + ], + "source": [ + "from pymdp.maths import softmax\n", + "\n", + "print(softmax(act.A))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. State Variables" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "initial_state = {\n", + " 'prior_A': softmax(act.A),\n", + " 'prior_B': act.B,\n", + " 'prior_C': act.C,\n", + " 'prior': softmax(act.D),\n", + " 'env_state': grid[0],\n", + " 'action': '',\n", + " 'current_inference': ''\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. System Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "params = {\n", + " 'prior_A': softmax(act.A),\n", + " 'prior_B': act.B,\n", + " 'prior_C': act.C,\n", + " 'prior': softmax(act.D),\n", + " 'env_state': grid,\n", + " 'actions': act.E\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Policy Functions\n", + "\n", + "- `get_observation`\n", + "- `infer_states`\n", + "- `calc_efe`\n", + "- `calc_action_posterior`\n", + "- `sample_action`\n", + "- `calc_next_prior`\n", + "- `update_env_state`" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def p_actinf(params, substep, state_history, previous_state):\n", + " policies = construct_policies([act.n_states], [len(act.E)], policy_len = act.policy_len)\n", + " # get obs_idx\n", + " obs_idx = grid.index(previous_state['env_state'])\n", + "\n", + " # infer_states\n", + " qs_current = u.infer_states(obs_idx, previous_state['prior_A'], previous_state['prior'])\n", + "\n", + " # calc efe\n", + " G = u.calculate_G_policies(previous_state['prior_A'], previous_state['prior_B'], previous_state['prior_C'], qs_current, policies=policies)\n", + "\n", + " # calc action posterior\n", + " Q_pi = u.softmax(-G)\n", + "\n", + " # compute the probability of each action\n", + " P_u = u.compute_prob_actions(act.E, policies, Q_pi)\n", + "\n", + " # sample action\n", + " chosen_action = u.sample(P_u)\n", + "\n", + " # calc next prior\n", + " prior = previous_state['prior_B'][:,:,chosen_action].dot(qs_current) \n", + "\n", + " # update env state\n", + " # action_label = params['actions'][chosen_action]\n", + "\n", + " (Y, X) = previous_state['env_state']\n", + " Y_new = Y\n", + " X_new = X\n", + "\n", + " if chosen_action == 0: # UP\n", + " \n", + " Y_new = Y - 1 if Y > 0 else Y\n", + " X_new = X\n", + "\n", + " elif chosen_action == 1: # DOWN\n", + "\n", + " Y_new = Y + 1 if Y < act.border else Y\n", + " X_new = X\n", + "\n", + " elif chosen_action == 2: # LEFT\n", + " Y_new = Y\n", + " X_new = X - 1 if X > 0 else X\n", + "\n", + " elif chosen_action == 3: # RIGHT\n", + " Y_new = Y\n", + " X_new = X +1 if X < act.border else X\n", + "\n", + " elif chosen_action == 4: # STAY\n", + " Y_new, X_new = Y, X \n", + " \n", + " current_state = (Y_new, X_new) # store the new grid location\n", + "\n", + " return {'update_prior': prior,\n", + " 'update_env': current_state,\n", + " 'update_action': chosen_action,\n", + " 'update_inference': qs_current}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. State Update Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def s_prior(params, substep, state_history, previous_state, policy_input):\n", + " updated_prior = policy_input['update_prior']\n", + " return 'prior', updated_prior\n", + "\n", + "def s_env(params, substep, state_history, previous_state, policy_input):\n", + " updated_env_state = policy_input['update_env']\n", + " return 'env_state', updated_env_state\n", + "\n", + "def s_action(params, substep, state_history, previous_state, policy_input):\n", + " return 'action', policy_input['update_action']\n", + "\n", + "def s_qs(params, substep, state_history, previous_state, policy_input):\n", + " return 'current_inference', policy_input['update_inference']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Partial State Update Blocks" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "state_update_blocks = [\n", + " {\n", + " 'policies': {\n", + " 'p_actinf': p_actinf\n", + " },\n", + " 'variables': {\n", + " 'prior': s_prior,\n", + " 'env_state': s_env,\n", + " 'action': s_action,\n", + " 'current_inference': s_qs\n", + " }\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "model = Model(\n", + " # Model initial state\n", + " initial_state=initial_state,\n", + " # Model Partial State Update Blocks\n", + " state_update_blocks=state_update_blocks,\n", + " # System Parameters\n", + " params=params\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7. Execution" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "simulation = Simulation(\n", + " model=model,\n", + " timesteps=100, # Number of timesteps\n", + " runs=1 # Number of Monte Carlo Runs\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "result = simulation.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8. Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
| \n", + " | prior_A | \n", + "prior_B | \n", + "prior_C | \n", + "prior | \n", + "env_state | \n", + "action | \n", + "current_inference | \n", + "simulation | \n", + "subset | \n", + "run | \n", + "substep | \n", + "timestep | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.02672363098939523, 0.009831074434450556, 0.... | \n", + "(0, 0) | \n", + "\n", + " | \n", + " | 0 | \n", + "0 | \n", + "1 | \n", + "0 | \n", + "0 | \n", + "
| 1 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.0694531596563796, 0.00939946303377394, 0.00... | \n", + "(0, 0) | \n", + "4 | \n", + "[0.0694531596563796, 0.00939946303377394, 0.00... | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "1 | \n", + "
| 2 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "(1, 0) | \n", + "1 | \n", + "[0.16866478870681606, 0.008397325366597817, 0.... | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "2 | \n", + "
| 3 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.35546098713663643, 0.00651049507942799, 0.0... | \n", + "(0, 0) | \n", + "0 | \n", + "[7.753058021694268e-17, 7.753058021694268e-17,... | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "3 | \n", + "
| 4 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.603901424016928, 0.00808364440781786, 0.008... | \n", + "(0, 0) | \n", + "0 | \n", + "[0.5998596018130191, 0.004041822203908954, 0.0... | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "4 | \n", + "
| ... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
| 10095 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[8.710941655795045e-17, 8.710941655795045e-17,... | \n", + "(6, 4) | \n", + "0 | \n", + "[3.678794411714443e-17, 3.678794411714443e-17,... | \n", + "0 | \n", + "99 | \n", + "1 | \n", + "1 | \n", + "96 | \n", + "
| 10096 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[1.2596545076772043e-16, 1.2713499590296219e-1... | \n", + "(5, 4) | \n", + "0 | \n", + "[6.883370760125423e-17, 6.883370760125423e-17,... | \n", + "0 | \n", + "99 | \n", + "1 | \n", + "1 | \n", + "97 | \n", + "
| 10097 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.0, 8.312804375248272e-17, 8.355829536326019... | \n", + "(5, 5) | \n", + "3 | \n", + "[8.312804375248272e-17, 8.355829536326019e-17,... | \n", + "0 | \n", + "99 | \n", + "1 | \n", + "1 | \n", + "98 | \n", + "
| 10098 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[0.0, 3.678794411714443e-17, 6.736904239848323... | \n", + "(5, 6) | \n", + "3 | \n", + "[3.678794411714443e-17, 6.736904239848323e-17,... | \n", + "0 | \n", + "99 | \n", + "1 | \n", + "1 | \n", + "99 | \n", + "
| 10099 | \n", + "[[0.02672363098939523, 0.009831074434450556, 0... | \n", + "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", + "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", + "[3.678794411714443e-17, 5.0321472440806026e-17... | \n", + "(5, 6) | \n", + "4 | \n", + "[3.678794411714443e-17, 5.0321472440806026e-17... | \n", + "0 | \n", + "99 | \n", + "1 | \n", + "1 | \n", + "100 | \n", + "
10100 rows ร 12 columns
\n", + "| \n", - " | prior_A | \n", - "prior_B | \n", - "prior_C | \n", - "prior | \n", - "env_state | \n", - "action | \n", - "current_inference | \n", - "simulation | \n", - "subset | \n", - "run | \n", - "substep | \n", - "timestep | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] | \n", - "(0, 0) | \n", - "\n", - " | 0 | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "0 | \n", - "0 | \n", - "
| 1 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 0.0, 0.0, 0.4999999999999999, 4.99999999... | \n", - "(1, 0) | \n", - "1 | \n", - "[0.4999999999999999, 4.9999999999999814e-17, 4... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "1 | \n", - "
| 2 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[1.0, 4.999999999999962e-32, 4.999999999999962... | \n", - "(0, 0) | \n", - "0 | \n", - "[1.9999999999999897e-32, 1.9999999999999897e-3... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "2 | \n", - "
| 3 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[1.0, 1.9999999999999862e-32, 2.99999999999998... | \n", - "(0, 0) | \n", - "0 | \n", - "[1.0, 9.999999999999931e-33, 9.999999999999931... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "3 | \n", - "
| 4 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[1.0, 1.9999999999999862e-32, 1.99999999999998... | \n", - "(0, 0) | \n", - "0 | \n", - "[1.0, 9.999999999999931e-33, 9.999999999999931... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "4 | \n", - "
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 904 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 9.999999999999931e-33, 1.999999999999986... | \n", - "(2, 2) | \n", - "3 | \n", - "[9.999999999999931e-33, 9.999999999999931e-33,... | \n", - "0 | \n", - "8 | \n", - "1 | \n", - "1 | \n", - "96 | \n", - "
| 905 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 0.0, 0.0, 9.999999999999931e-33, 9.99999... | \n", - "(2, 2) | \n", - "1 | \n", - "[9.999999999999931e-33, 9.999999999999931e-33,... | \n", - "0 | \n", - "8 | \n", - "1 | \n", - "1 | \n", - "97 | \n", - "
| 906 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 0.0, 0.0, 9.999999999999931e-33, 9.99999... | \n", - "(2, 2) | \n", - "1 | \n", - "[9.999999999999931e-33, 9.999999999999931e-33,... | \n", - "0 | \n", - "8 | \n", - "1 | \n", - "1 | \n", - "98 | \n", - "
| 907 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 9.999999999999931e-33, 1.999999999999986... | \n", - "(2, 2) | \n", - "3 | \n", - "[9.999999999999931e-33, 9.999999999999931e-33,... | \n", - "0 | \n", - "8 | \n", - "1 | \n", - "1 | \n", - "99 | \n", - "
| 908 | \n", - "[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] | \n", - "[0.0, 9.999999999999931e-33, 1.999999999999986... | \n", - "(2, 2) | \n", - "3 | \n", - "[9.999999999999931e-33, 9.999999999999931e-33,... | \n", - "0 | \n", - "8 | \n", - "1 | \n", - "1 | \n", - "100 | \n", - "
909 rows ร 12 columns
\n", - "| \n", - " | prior_A | \n", - "prior_B | \n", - "prior_C | \n", - "prior | \n", - "env_state | \n", - "action | \n", - "current_inference | \n", - "simulation | \n", - "subset | \n", - "run | \n", - "substep | \n", - "timestep | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.02672363098939523, 0.009831074434450556, 0.... | \n", - "(0, 0) | \n", - "\n", - " | \n", - " | 0 | \n", - "0 | \n", - "1 | \n", - "0 | \n", - "0 | \n", - "
| 1 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.07885262269015354, 0.01879892606754788, 0.0... | \n", - "(0, 0) | \n", - "0 | \n", - "[0.0694531596563796, 0.00939946303377394, 0.00... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "1 | \n", - "
| 2 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(1, 0) | \n", - "1 | \n", - "[0.18876736668835925, 0.01655576802676808, 0.0... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "2 | \n", - "
| 3 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(2, 0) | \n", - "1 | \n", - "[7.550842449984315e-17, 7.550842449984315e-17,... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "3 | \n", - "
| 4 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(3, 0) | \n", - "1 | \n", - "[6.003301649048549e-17, 6.003301649048549e-17,... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "4 | \n", - "
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 50095 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 6.883370760125422e-17, 7.066527149012727... | \n", - "(8, 7) | \n", - "3 | \n", - "[6.883370760125422e-17, 7.066527149012727e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "496 | \n", - "
| 50096 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[7.357588823428884e-17, 1.1924219316971896e-16... | \n", - "(7, 7) | \n", - "0 | \n", - "[3.678794411714442e-17, 6.211045000325277e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "497 | \n", - "
| 50097 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[1.1417647320527348e-16, 1.3846023337085264e-1... | \n", - "(6, 7) | \n", - "0 | \n", - "[6.385500076446746e-17, 8.065469550447812e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "498 | \n", - "
| 50098 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[1.3409130055241973e-16, 1.457780304781232e-16... | \n", - "(5, 7) | \n", - "0 | \n", - "[7.879112127482724e-17, 8.772461739408198e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "499 | \n", - "
| 50099 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(6, 7) | \n", - "1 | \n", - "[8.611737683032094e-17, 9.041668450451017e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "500 | \n", - "
50100 rows ร 12 columns
\n", - "| \n", - " | prior_A | \n", - "prior_B | \n", - "prior_C | \n", - "prior | \n", - "env_state | \n", - "action | \n", - "current_inference | \n", - "simulation | \n", - "subset | \n", - "run | \n", - "substep | \n", - "timestep | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.02672363098939523, 0.009831074434450556, 0.... | \n", - "(0, 0) | \n", - "\n", - " | \n", - " | 0 | \n", - "0 | \n", - "1 | \n", - "0 | \n", - "0 | \n", - "
| 1 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(1, 0) | \n", - "1 | \n", - "[0.0694531596563796, 0.00939946303377394, 0.00... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "1 | \n", - "
| 2 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(2, 0) | \n", - "1 | \n", - "[8.933835195079347e-17, 8.933835195079347e-17,... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "2 | \n", - "
| 3 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[2.2432570305759068e-16, 2.2432570305759068e-1... | \n", - "(1, 0) | \n", - "0 | \n", - "[7.753058021694265e-17, 7.753058021694265e-17,... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "3 | \n", - "
| 4 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(2, 0) | \n", - "1 | \n", - "[2.0134671970778882e-16, 2.0134671970778882e-1... | \n", - "0 | \n", - "0 | \n", - "1 | \n", - "1 | \n", - "4 | \n", - "
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 10095 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 5.032147244080602e-17, 5.032147244080602... | \n", - "(8, 2) | \n", - "3 | \n", - "[5.032147244080602e-17, 5.032147244080602e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "96 | \n", - "
| 10096 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[9.208812339473694e-17, 5.53001792775925e-17, ... | \n", - "(8, 1) | \n", - "2 | \n", - "[3.678794411714443e-17, 5.53001792775925e-17, ... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "97 | \n", - "
| 10097 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[7.066527149012729e-17, 5.71317431664662e-17, ... | \n", - "(8, 1) | \n", - "4 | \n", - "[7.066527149012729e-17, 5.71317431664662e-17, ... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "98 | \n", - "
| 10098 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "(9, 1) | \n", - "1 | \n", - "[6.278424470316125e-17, 5.780553786637451e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "99 | \n", - "
| 10099 | \n", - "[[0.02672363098939523, 0.009831074434450556, 0... | \n", - "[[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... | \n", - "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | \n", - "[0.0, 3.678794411714442e-17, 3.678794411714442... | \n", - "(9, 2) | \n", - "3 | \n", - "[3.678794411714442e-17, 3.678794411714442e-17,... | \n", - "0 | \n", - "99 | \n", - "1 | \n", - "1 | \n", - "100 | \n", - "
10100 rows ร 12 columns
\n", - "