diff --git a/.gitignore b/.gitignore index e30d509..1160bd4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ __pycache__* *.xml *.iml *.ipynb_checkpoints/ +./blockference/.ipynb_checkpoints +.DS_Store +**checkpoint.py diff --git a/blockference/.ipynb_checkpoints/gridference-checkpoint.py b/blockference/.ipynb_checkpoints/gridference-checkpoint.py index 0976e8d..ec11aee 100644 --- a/blockference/.ipynb_checkpoints/gridference-checkpoint.py +++ b/blockference/.ipynb_checkpoints/gridference-checkpoint.py @@ -345,7 +345,7 @@ def __init__(self, grid, planning_length: int = 2, env_state: tuple = (0, 0)) -> def get_A(self): """ - State Matrix (identity matrix) + State Matrix (identity matrix for the single agent gridworld) Params: - n_observations: int: number of possible observations - n_states: int: number of possible states diff --git a/blockference/agent.py b/blockference/agent.py index 85b5374..88f72af 100644 --- a/blockference/agent.py +++ b/blockference/agent.py @@ -48,6 +48,33 @@ def __init__( lr_pD=1.0, use_BMA=True, policy_sep_prior=False, - save_belief_hist=False + save_belief_hist=False, ): - super().__init__() \ No newline at end of file + super().__init__(A, + B, + C=None, + D=None, + E=None, + pA=None, + pB=None, + pD=None, + num_controls=None, + policy_len=1, + inference_horizon=1, + control_fac_idx=None, + policies=None, + gamma=16.0, + use_utility=True, + use_states_info_gain=True, + use_param_info_gain=False, + action_selection="deterministic", + inference_algo="VANILLA", + inference_params=None, + modalities_to_learn="all", + lr_pA=1.0, + factors_to_learn="all", + lr_pB=1.0, + lr_pD=1.0, + use_BMA=True, + policy_sep_prior=False, + save_belief_hist=False,) \ No newline at end of file diff --git a/blockference/envs/grid_env.py b/blockference/envs/grid_env.py index d5fbad7..b519b2e 100644 --- a/blockference/envs/grid_env.py +++ b/blockference/envs/grid_env.py @@ -6,12 +6,13 @@ def __init__(self, grid_len, num_agents, grid_dim=2) -> None: self.grid = self.get_grid(grid_len, grid_dim) self.grid_dim = grid_dim self.no_actions = 2 * grid_dim + 1 - self.agents = self.init_agents(num_agents) + self.n_observations = grid_len ** 2 + self.n_states = grid_len ** 2 + self.border = np.sqrt(self.n_states) - 1 + # self.agents = self.init_agents(num_agents) def get_grid(self, grid_len, grid_dim): g = list(itertools.product(range(grid_len), repeat=grid_dim)) - for i, p in enumerate(g): - g[i] += (0,) return g def move_grid(self, agent, chosen_action): @@ -29,7 +30,7 @@ def move_grid(self, agent, chosen_action): new_state[index] = state[index] - 1 if state[index] > 0 else state[index] elif chosen_action % 2 == 0: index = chosen_action / 2 - new_state[index] = state[index] + 1 if state[index] < agent.border else state[index] + new_state[index] = state[index] + 1 if state[index] < self.border else state[index] return new_state def init_agents(self, no_agents): diff --git a/blockference/tools/.ipynb_checkpoints/__init__-checkpoint.py b/blockference/tools/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index e69de29..0000000 diff --git a/blockference/tools/.ipynb_checkpoints/policy-checkpoint.py b/blockference/tools/.ipynb_checkpoints/policy-checkpoint.py deleted file mode 100644 index 2cafc53..0000000 --- a/blockference/tools/.ipynb_checkpoints/policy-checkpoint.py +++ /dev/null @@ -1,136 +0,0 @@ -import tools.utils as u -from tools.control import construct_policies - -# Policies for cadCAD actinf simulations - - -# single-agent with planning -def p_actinf(params, substep, state_history, previous_state, act, grid): - - policies = construct_policies([act.n_states], [len(act.E)], policy_len=act.policy_len) - # get obs_idx - obs_idx = grid.index(previous_state['env_state']) - - # infer_states - qs_current = u.infer_states(obs_idx, previous_state['prior_A'], previous_state['prior']) - - # calc efe - G = u.calculate_G_policies(previous_state['prior_A'], previous_state['prior_B'], previous_state['prior_C'], qs_current, policies=policies) - - # calc action posterior - Q_pi = u.softmax(-G) - - # compute the probability of each action - P_u = u.compute_prob_actions(act.E, policies, Q_pi) - - # sample action - chosen_action = u.sample(P_u) - - # calc next prior - prior = previous_state['prior_B'][:, :, chosen_action].dot(qs_current) - - # update env state - # action_label = params['actions'][chosen_action] - - (Y, X) = previous_state['env_state'] - Y_new = Y - X_new = X - - if chosen_action == 0: # UP - - Y_new = Y - 1 if Y > 0 else Y - X_new = X - - elif chosen_action == 1: # DOWN - - Y_new = Y + 1 if Y < act.border else Y - X_new = X - - elif chosen_action == 2: # LEFT - Y_new = Y - X_new = X - 1 if X > 0 else X - - elif chosen_action == 3: # RIGHT - Y_new = Y - X_new = X + 1 if X < act.border else X - - elif chosen_action == 4: # STAY - Y_new, X_new = Y, X - - current_state = (Y_new, X_new) # store the new grid location - - return {'update_prior': prior, - 'update_env': current_state, - 'update_action': chosen_action, - 'update_inference': qs_current} - - -# multi-agent (dict) gridworld -def p_actinf(params, substep, state_history, previous_state, grid): # Is this a useless re-definition from Line 8?? - # State Variables - agents = previous_state['agents'] - - # list of all updates to the agents in the network - agent_updates = [] - - for source, agent in agents.items(): - - policies = construct_policies([agent.n_states], [len(agent.E)], policy_len=agent.policy_len) - # get obs_idx - obs_idx = grid.index(agent.env_state) - - # infer_states - qs_current = u.infer_states(obs_idx, agent.A, agent.prior, 0) - - # calc efe - _G = u.calculate_G_policies(agent.A, agent.B, agent.C, qs_current, policies=policies) - - # calc action posterior - Q_pi = u.softmax(-_G, 0) - # compute the probability of each action - P_u = u.compute_prob_actions(agent.E, policies, Q_pi) - - # sample action - chosen_action = u.sample(P_u) - - # calc next prior - prior = agent.B[:, :, chosen_action].dot(qs_current) - - # update env state - # action_label = params['actions'][chosen_action] - - (Y, X) = agent.env_state - Y_new = Y - X_new = X - # here - - if chosen_action == 0: # UP - - Y_new = Y - 1 if Y > 0 else Y - X_new = X - - elif chosen_action == 1: # DOWN - - Y_new = Y + 1 if Y < agent.border else Y - X_new = X - - elif chosen_action == 2: # LEFT - Y_new = Y - X_new = X - 1 if X > 0 else X - - elif chosen_action == 3: # RIGHT - Y_new = Y - X_new = X + 1 if X < agent.border else X - - elif chosen_action == 4: # STAY - Y_new, X_new = Y, X - - current_state = (Y_new, X_new) # store the new grid location - agent_update = {'source': source, - 'update_prior': prior, - 'update_env': current_state, - 'update_action': chosen_action, - 'update_inference': qs_current} - agent_updates.append(agent_update) - - return {'agent_updates': agent_updates} diff --git a/blockference/tools/.ipynb_checkpoints/utils-checkpoint.py b/blockference/tools/.ipynb_checkpoints/utils-checkpoint.py deleted file mode 100644 index 1e545f7..0000000 --- a/blockference/tools/.ipynb_checkpoints/utils-checkpoint.py +++ /dev/null @@ -1,804 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns - -import pandas as pd - -import warnings -import itertools -from pymdp.maths import spm_log_single as log_stable - -EPS_VAL = 1e-16 # global constant for use in norm_dist() - - -def softmax(dist): - """ - Computes the softmax function on a set of values - """ - - output = dist - dist.max(axis=0) - output = np.exp(output) - output = output / np.sum(output, axis=0) - return output - - -def sample(probabilities): - sample_onehot = np.random.multinomial(1, probabilities.squeeze()) - return np.where(sample_onehot == 1)[0][0] - - -def sample_obj_array(arr): - """ - Sample from set of Categorical distributions, stored in the sub-arrays of an object array - """ - - samples = [sample(arr_i) for arr_i in arr] - - return samples - - -def obj_array(num_arr): - """ - Creates a generic object array with the desired number of sub-arrays, given by `num_arr` - """ - return np.empty(num_arr, dtype=object) - - -def obj_array_zeros(shape_list): - """ - Creates a numpy object array whose sub-arrays are 1-D vectors - filled with zeros, with shapes given by shape_list[i] - """ - arr = obj_array(len(shape_list)) - for i, shape in enumerate(shape_list): - arr[i] = np.zeros(shape) - return arr - - -def obj_array_uniform(shape_list): - """ - Creates a numpy object array whose sub-arrays are uniform Categorical - distributions with shapes given by shape_list[i]. The shapes (elements of shape_list) - can either be tuples or lists. - """ - arr = obj_array(len(shape_list)) - for i, shape in enumerate(shape_list): - arr[i] = norm_dist(np.ones(shape)) - return arr - - -def obj_array_ones(shape_list, scale=1.0): - arr = obj_array(len(shape_list)) - for i, shape in enumerate(shape_list): - arr[i] = scale * np.ones(shape) - - return arr - - -def onehot(value, num_values): - arr = np.zeros(num_values) - arr[value] = 1.0 - return arr - - -def random_A_matrix(num_obs, num_states): - if type(num_obs) is int: - num_obs = [num_obs] - if type(num_states) is int: - num_states = [num_states] - num_modalities = len(num_obs) - - A = obj_array(num_modalities) - for modality, modality_obs in enumerate(num_obs): - modality_shape = [modality_obs] + num_states - modality_dist = np.random.rand(*modality_shape) - A[modality] = norm_dist(modality_dist) - return A - - -def random_B_matrix(num_states, num_controls): - if type(num_states) is int: - num_states = [num_states] - if type(num_controls) is int: - num_controls = [num_controls] - num_factors = len(num_states) - assert len(num_controls) == len(num_states) - - B = obj_array(num_factors) - for factor in range(num_factors): - factor_shape = (num_states[factor], num_states[factor], num_controls[factor]) - factor_dist = np.random.rand(*factor_shape) - B[factor] = norm_dist(factor_dist) - return B - - -def random_single_categorical(shape_list): - """ - Creates a random 1-D categorical distribution (or set of 1-D categoricals, e.g. multiple marginals of different factors) and returns them in an object array - """ - - num_sub_arrays = len(shape_list) - - out = obj_array(num_sub_arrays) - - for arr_idx, shape_i in enumerate(shape_list): - out[arr_idx] = norm_dist(np.random.rand(shape_i)) - - return out - - -def construct_controllable_B(num_states, num_controls): - """ - Generates a fully controllable transition likelihood array, where each - action (control state) corresponds to a move to the n-th state from any - other state, for each control factor - """ - - num_factors = len(num_states) - - B = obj_array(num_factors) - for factor, c_dim in enumerate(num_controls): - tmp = np.eye(c_dim)[:, :, np.newaxis] - tmp = np.tile(tmp, (1, 1, c_dim)) - B[factor] = tmp.transpose(1, 2, 0) - - return B - - -def dirichlet_like(template_categorical, scale=1.0): - """ - Helper function to construct a Dirichlet distribution based on an existing Categorical distribution - """ - - if not is_obj_array(template_categorical): - warnings.warn( - "Input array is not an object array...Casting the input to an object array" - ) - template_categorical = to_obj_array(template_categorical) - - n_sub_arrays = len(template_categorical) - - dirichlet_out = obj_array(n_sub_arrays) - - for i, arr in enumerate(template_categorical): - dirichlet_out[i] = scale * arr - - return dirichlet_out - - -def get_model_dimensions(A=None, B=None): - - if A is None and B is None: - raise ValueError( - "Must provide either `A` or `B`" - ) - - if A is not None: - num_obs = [a.shape[0] for a in A] if is_obj_array(A) else [A.shape[0]] - num_modalities = len(num_obs) - else: - num_obs, num_modalities = None, None - - if B is not None: - num_states = [b.shape[0] for b in B] if is_obj_array(B) else [B.shape[0]] - num_factors = len(num_states) - else: - if A is not None: - num_states = list(A[0].shape[1:]) if is_obj_array(A) else list(A.shape[1:]) - num_factors = len(num_states) - else: - num_states, num_factors = None, None - - return num_obs, num_states, num_modalities, num_factors - - -def get_model_dimensions_from_labels(model_labels): - - modalities = model_labels['observations'] - num_modalities = len(modalities.keys()) - num_obs = [len(modalities[modality]) for modality in modalities.keys()] - - factors = model_labels['states'] - num_factors = len(factors.keys()) - num_states = [len(factors[factor]) for factor in factors.keys()] - - if 'actions' in model_labels.keys(): - - controls = model_labels['actions'] - num_control_fac = len(controls.keys()) - num_controls = [len(controls[cfac]) for cfac in controls.keys()] - - return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac - else: - return num_obs, num_modalities, num_states, num_factors - - -def norm_dist(dist): - """ Normalizes a Categorical probability distribution (or set of them) assuming sufficient statistics are stored in leading dimension""" - if dist.ndim == 3: - new_dist = np.zeros_like(dist) - for c in range(dist.shape[2]): - new_dist[:, :, c] = np.divide(dist[:, :, c], dist[:, :, c].sum(axis=0)) - return new_dist - else: - return np.divide(dist, dist.sum(axis=0)) - - -def norm_dist_obj_arr(obj_arr): - - normed_obj_array = obj_array(len(obj_arr)) - for i, arr in enumerate(obj_arr): - normed_obj_array[i] = norm_dist(arr) - - return normed_obj_array - - -def is_normalized(dist): - """ - Utility function for checking whether a single distribution or set of conditional categorical distributions is normalized. - Returns True if all distributions integrate to 1.0 - """ - - if is_obj_array(dist): - normed_arrays = [] - for i, arr in enumerate(dist): - column_sums = arr.sum(axis=0) - normed_arrays.append(np.allclose(column_sums, np.ones_like(column_sums))) - out = all(normed_arrays) - else: - column_sums = dist.sum(axis=0) - out = np.allclose(column_sums, np.ones_like(column_sums)) - - return out - - -def is_obj_array(arr): - return arr.dtype == "object" - - -def to_obj_array(arr): - if is_obj_array(arr): - return arr - obj_array_out = obj_array(1) - obj_array_out[0] = arr.squeeze() - return obj_array_out - - -def obj_array_from_list(list_input): - """ - Takes a list of `numpy.ndarray` and converts them to a `numpy.ndarray` of `dtype = object` - """ - return np.array(list_input, dtype=object) - - -def process_observation_seq(obs_seq, n_modalities, n_observations): - """ - Helper function for formatting observations - Observations can either be `int` (converted to one-hot) - or `tuple` (obs for each modality), or `list` (obs for each modality) - If list, the entries could be object arrays of one-hots, in which - case this function returns `obs_seq` as is. - """ - proc_obs_seq = obj_array(len(obs_seq)) - for t, obs_t in enumerate(obs_seq): - proc_obs_seq[t] = process_observation(obs_t, n_modalities, n_observations) - return proc_obs_seq - - -def process_observation(obs, num_modalities, num_observations): - """ - Helper function for formatting observations - USAGE NOTES: - - If `obs` is a 1D numpy array, it must be a one-hot vector, where one entry (the entry of the observation) is 1.0 - and all other entries are 0. This therefore assumes it's a single modality observation. If these conditions are met, then - this function will return `obs` unchanged. Otherwise, it'll throw an error. - - If `obs` is an int, it assumes this is a single modality observation, whose observation index is given by the value of `obs`. This function will convert - it to be a one hot vector. - - If `obs` is a list, it assumes this is a multiple modality observation, whose len is equal to the number of observation modalities, - and where each entry `obs[m]` is the index of the observation, for that modality. This function will convert it into an object array - of one-hot vectors. - - If `obs` is a tuple, same logic as applies for list (see above). - - if `obs` is a numpy object array (array of arrays), this function will return `obs` unchanged. - """ - - if isinstance(obs, np.ndarray) and not is_obj_array(obs): - assert num_modalities == 1, "If `obs` is a 1D numpy array, `num_modalities` must be equal to 1" - assert len(np.where(obs)[0]) == 1, "If `obs` is a 1D numpy array, it must be a one hot vector (e.g. np.array([0.0, 1.0, 0.0, ....]))" - - if isinstance(obs, (int, np.integer)): - obs = onehot(obs, num_observations[0]) - - if isinstance(obs, tuple) or isinstance(obs, list): - obs_arr_arr = obj_array(num_modalities) - for m in range(num_modalities): - obs_arr_arr[m] = onehot(obs[m], num_observations[m]) - obs = obs_arr_arr - - return obs - - -def convert_observation_array(obs, num_obs): - """ - Converts from SPM-style observation array to infer-actively one-hot object arrays. - - Parameters - ---------- - - 'obs' [numpy 2-D nd.array]: - SPM-style observation arrays are of shape (num_modalities, T), where each row - contains observation indices for a different modality, and columns indicate - different timepoints. Entries store the indices of the discrete observations - within each modality. - - 'num_obs' [list]: - List of the dimensionalities of the observation modalities. `num_modalities` - is calculated as `len(num_obs)` in the function to determine whether we're - dealing with a single- or multi-modality - case. - Returns - ---------- - - `obs_t`[list]: - A list with length equal to T, where each entry of the list is either a) an object - array (in the case of multiple modalities) where each sub-array is a one-hot vector - with the observation for the correspond modality, or b) a 1D numpy array (in the case - of one modality) that is a single one-hot vector encoding the observation for the - single modality. - """ - - T = obs.shape[1] - num_modalities = len(num_obs) - - # Initialise the output - obs_t = [] - # Case of one modality - if num_modalities == 1: - for t in range(T): - obs_t.append(onehot(obs[0, t] - 1, num_obs[0])) - else: - for t in range(T): - obs_AoA = obj_array(num_modalities) - for g in range(num_modalities): - # Subtract obs[g,t] by 1 to account for MATLAB vs. Python indexing - # (MATLAB is 1-indexed) - obs_AoA[g] = onehot(obs[g, t] - 1, num_obs[g]) - obs_t.append(obs_AoA) - - return obs_t - - -def insert_multiple(s, indices, items): - for idx in range(len(items)): - s.insert(indices[idx], items[idx]) - return s - - -def reduce_a_matrix(A): - """ - Utility function for throwing away dimensions (lagging dimensions, hidden state factors) - of a particular A matrix that are independent of the observation. - Parameters: - ========== - - `A` [np.ndarray]: - The A matrix or likelihood array that encodes probabilistic relationship - of the generative model between hidden state factors (lagging dimensions, columns, slices, etc...) - and observations (leading dimension, rows). - Returns: - ========= - - `A_reduced` [np.ndarray]: - The reduced A matrix, missing the lagging dimensions that correspond to hidden state factors - that are statistically independent of observations - - `original_factor_idx` [list]: - List of the indices (in terms of the original dimensionality) of the hidden state factors - that are maintained in the A matrix (and thus have an informative / non-degenerate relationship to observations - """ - - o_dim, num_states = A.shape[0], A.shape[1:] - idx_vec_s = [slice(0, o_dim)] + [slice(ns) for _, ns in enumerate(num_states)] - - original_factor_idx = [] - excluded_factor_idx = [] # the indices of the hidden state factors that are independent of the observation and thus marginalized away - for factor_i, ns in enumerate(num_states): - - level_counter = 0 - break_flag = False - while level_counter < ns and break_flag is False: - idx_vec_i = idx_vec_s.copy() - idx_vec_i[factor_i+1] = slice(level_counter, level_counter+1, None) - if not np.isclose(A.mean(axis=factor_i+1), A[tuple(idx_vec_i)].squeeze()).all(): - break_flag = True # this means they're not independent - original_factor_idx.append(factor_i) - else: - level_counter += 1 - - if break_flag is False: - excluded_factor_idx.append(factor_i+1) - - A_reduced = A.mean(axis=tuple(excluded_factor_idx)).squeeze() - - return A_reduced, original_factor_idx - - -def construct_full_a(A_reduced, original_factor_idx, num_states): - """ - Utility function for reconstruction a full A matrix from a reduced A matrix, using known factor indices - to tile out the reduced A matrix along the 'non-informative' dimensions - Parameters: - ========== - - `A_reduced` [np.ndarray]: - The reduced A matrix or likelihood array that encodes probabilistic relationship - of the generative model between hidden state factors (lagging dimensions, columns, slices, etc...) - and observations (leading dimension, rows). - - `original_factor_idx` [list]: - List of hidden state indices in terms of the full hidden state factor list, that comprise - the lagging dimensions of `A_reduced` - - `num_states` [list]: - The list of all the dimensionalities of hidden state factors in the full generative model. - `A_reduced.shape[1:]` should be equal to `num_states[original_factor_idx]` - Returns: - ========= - - `A` [np.ndarray]: - The full A matrix, containing all the lagging dimensions that correspond to hidden state factors, including - those that are statistically independent of observations - - @ NOTE: This is the "inverse" of the reduce_a_matrix function, - i.e. `reduce_a_matrix(construct_full_a(A_reduced, original_factor_idx, num_states)) == A_reduced, original_factor_idx` - """ - - o_dim = A_reduced.shape[0] # dimensionality of the support of the likelihood distribution (i.e. the number of observation levels) - full_dimensionality = [o_dim] + num_states # full dimensionality of the output (`A`) - fill_indices = [0] + [f+1 for f in original_factor_idx] # these are the indices of the dimensions we need to fill for this modality - fill_dimensions = np.delete(full_dimensionality, fill_indices) - - original_factor_dims = [num_states[f] for f in original_factor_idx] # dimensionalities of the relevant factors - prefilled_slices = [slice(0, o_dim)] + [slice(0, ns) for ns in original_factor_dims] # these are the slices that are filled out by the provided `A_reduced` - - A = np.zeros(full_dimensionality) - - for item in itertools.product(*[list(range(d)) for d in fill_dimensions]): - slice_ = list(item) - A_indices = insert_multiple(slice_, fill_indices, prefilled_slices) # here we insert the correct values for the fill indices for this slice - A[tuple(A_indices)] = A_reduced - - return A - - -def create_A_matrix_stub(model_labels): - - num_obs, _, num_states, _ = get_model_dimensions_from_labels(model_labels) - - obs_labels, state_labels = model_labels['observations'], model_labels['states'] - - state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys())) - num_state_combos = np.prod(num_states) # What is num_state_combos?? - # num_rows = (np.array(num_obs) * num_state_combos).sum() - num_rows = sum(num_obs) - - cell_values = np.zeros((num_rows, len(state_combinations))) - - obs_combinations = [] - for modality in obs_labels.keys(): - levels_to_combine = [[modality]] + [obs_labels[modality]] - # obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine)) - obs_combinations += list(itertools.product(*levels_to_combine)) - - obs_combinations = pd.MultiIndex.from_tuples(obs_combinations, names=["Modality", "Level"]) - - A_matrix = pd.DataFrame(cell_values, index=obs_combinations, columns=state_combinations) - - return A_matrix - - -def create_B_matrix_stubs(model_labels): - - _, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels) - - state_labels = model_labels['states'] - action_labels = model_labels['actions'] - - B_matrices = {} - - for f_idx, factor in enumerate(state_labels.keys()): - - control_fac_name = list(action_labels)[f_idx] - factor_list = [state_labels[factor]] + [action_labels[control_fac_name]] - - prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]]) - - num_state_action_combos = num_states[f_idx] * num_controls[f_idx] - - num_rows = num_states[f_idx] - - cell_values = np.zeros((num_rows, num_state_action_combos)) - - next_state_list = state_labels[factor] - - B_matrix_f = pd.DataFrame(cell_values, index=next_state_list, columns=prev_state_action_combos) - - B_matrices[factor] = B_matrix_f - - return B_matrices - - -def read_A_matrix(path, num_hidden_state_factors): - raw_table = pd.read_excel(path, header=None) - level_counts = { - "index": raw_table.iloc[0, :].dropna().index[0] + 1, - "header": raw_table.iloc[0, :].dropna().index[0] + num_hidden_state_factors - 1, - } - return pd.read_excel( - path, - index_col=list(range(level_counts["index"])), - header=list(range(level_counts["header"])) - ).astype(np.float64) - - -def read_B_matrices(path): - - all_sheets = pd.read_excel(path, sheet_name=None, header=None) - - level_counts = {} - for sheet_name, raw_table in all_sheets.items(): - - level_counts[sheet_name] = { - "index": raw_table.iloc[0, :].dropna().index[0]+1, - "header": raw_table.iloc[0, :].dropna().index[0]+2, - } - - stub_dict = {} - for sheet_name, level_counts_sheet in level_counts.items(): - sheet_f = pd.read_excel( - path, - sheet_name=sheet_name, - index_col=list(range(level_counts_sheet["index"])), - header=list(range(level_counts_sheet["header"])) - ).astype(np.float64) - stub_dict[sheet_name] = sheet_f - - return stub_dict - - -def convert_A_stub_to_ndarray(A_stub, model_labels): - """ - This function converts a multi-index pandas dataframe `A_stub` into an object array of different - A matrices, one per observation modality. - """ - - num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels) - - A = obj_array(num_modalities) - - for g, modality_name in enumerate(model_labels['observations'].keys()): - A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states) - assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n' - - return A - - -def convert_B_stubs_to_ndarray(B_stubs, model_labels): - """ - This function converts a list of multi-index pandas dataframes `B_stubs` into an object array - of different B matrices, one per hidden state factor - """ - - _, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels) - - B = obj_array(num_factors) - - for f, factor_name in enumerate(B_stubs.keys()): - - B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f]) - assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n' - - return B - - -def build_belief_array(qx): - """ - This function constructs array-ified (not nested) versions - of the posterior belief arrays, that are separated - by policy, timepoint, and hidden state factor - """ - - num_policies = len(qx) - num_timesteps = len(qx[0]) - num_factors = len(qx[0][0]) - - if num_factors > 1: - belief_array = obj_array(num_factors) - for factor in range(num_factors): - belief_array[factor] = np.zeros((num_policies, qx[0][0][factor].shape[0], num_timesteps)) - for policy_i in range(num_policies): - for timestep in range(num_timesteps): - for factor in range(num_factors): - belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor] - else: - num_states = qx[0][0][0].shape[0] - belief_array = np.zeros((num_policies, num_states, num_timesteps)) - for policy_i in range(num_policies): - for timestep in range(num_timesteps): - belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0] - - return belief_array - - -def build_xn_vn_array(xn): - """ - This function constructs array-ified (not nested) versions - of the posterior xn (beliefs) or vn (prediction error) arrays, that are separated - by iteration, hidden state factor, timepoint, and policy - """ - - num_policies = len(xn) - num_itr = len(xn[0]) - num_factors = len(xn[0][0]) - - if num_factors > 1: - xn_array = obj_array(num_factors) - for factor in range(num_factors): - num_states, infer_len = xn[0][0].shape - xn_array[factor] = np.zeros((num_itr, num_states, infer_len, num_policies)) - for policy_i in range(num_policies): - for itr in range(num_itr): - for factor in range(num_factors): - xn_array[factor][itr, :, :, policy_i] = xn[policy_i][itr][factor] - else: - num_states, infer_len = xn[0][0][0].shape - xn_array = np.zeros((num_itr, num_states, infer_len, num_policies)) - for policy_i in range(num_policies): - for itr in range(num_itr): - xn_array[itr, :, :, policy_i] = xn[policy_i][itr][0] - - return xn_array - - -# plotting functions -def plot_likelihood(matrix, xlabels=list(range(9)), ylabels=list(range(9)), title_str="Likelihood distribution (A)"): - """ - Plots a 2-D likelihood matrix as a heatmap - """ - - if not np.isclose(matrix.sum(axis=0), 1.0).all(): - raise ValueError("Distribution not column-normalized! Please normalize (ensure matrix.sum(axis=0) == 1.0 for all columns)") - fig = plt.figure(figsize=(6, 6)) # Unclear what "fig" is doing here. - ax=sns.heatmap(matrix, xticklabels=xlabels, yticklabels=ylabels, cmap='gray', cbar=False, vmin=0.0, vmax=1.0) # Unclear what ax is doing - plt.title(title_str) - plt.show() - - -def plot_grid(grid_locations, num_x=3, num_y=3): - """ - Plots the spatial coordinates of GridWorld as a heatmap, with each (X, Y) coordinate - labeled with its linear index (its `state id`) - """ - - grid_heatmap = np.zeros((num_x, num_y)) - for linear_idx, location in enumerate(grid_locations): - y, x = location - grid_heatmap[y, x] = linear_idx - sns.set(font_scale=1.5) - sns.heatmap(grid_heatmap, annot=True, cbar=False, fmt='.0f', cmap='crest') - - -def plot_point_on_grid(state_vector, grid_locations): - """ - Plots the current location of the agent on the grid world - """ - state_index = np.where(state_vector)[0][0] - y, x = grid_locations[state_index] - grid_heatmap = np.zeros((3, 3)) - grid_heatmap[y, x] = 1.0 - sns.heatmap(grid_heatmap, cbar=False, fmt='.0f') - - -def plot_beliefs(belief_dist, title_str=""): - """ - Plot a categorical distribution or belief distribution, stored in the 1-D numpy vector `belief_dist` - """ - - if not np.isclose(belief_dist.sum(), 1.0): - raise ValueError("Distribution not normalized! Please normalize") - - plt.grid(zorder=0) - plt.bar(range(belief_dist.shape[0]), belief_dist, color='r', zorder=3) - plt.xticks(range(belief_dist.shape[0])) - plt.title(title_str) - plt.show() - -# ActInf functions - - -def infer_states(observation_index, A, prior): - - log_likelihood = log_stable(A[observation_index, :]) - - log_prior = log_stable(prior) - - qs = softmax(log_likelihood + log_prior) - - return qs - - -def get_expected_states(B, qs_current, action): - """ Compute the expected states one step into the future, given a particular action """ - qs_u = B[:, :, action].dot(qs_current) - - return qs_u - -def get_expected_observations(A, qs_u): - """ Compute the expected observations one step into the future, given a particular action """ - - qo_u = A.dot(qs_u) - - return qo_u - -def entropy(A): - """ Compute the entropy of a set of conditional distributions, i.e. one entropy value per column """ - - H_A = - (A * log_stable(A)).sum(axis=0) - - return H_A - -def kl_divergence(qo_u, C): - """ Compute the Kullback-Leibler divergence between two 1-D categorical distributions""" - - return (log_stable(qo_u) - log_stable(C)).dot(qo_u) - -def calculate_G(A, B, C, qs_current, actions): - - G = np.zeros(len(actions)) # vector of expected free energies, one per action - - H_A = entropy(A) # entropy of the observation model, P(o|s) - - for action_i in range(len(actions)): - - qs_u = get_expected_states(B, qs_current, action_i) # expected states, under the action we're currently looping over - qo_u = get_expected_observations(A, qs_u) # expected observations, under the action we're currently looping over - - pred_uncertainty = H_A.dot(qs_u) # predicted uncertainty, i.e. expected entropy of the A matrix - pred_div = kl_divergence(qo_u, C) # predicted divergence - - G[action_i] = pred_uncertainty + pred_div # sum them together to get expected free energy - - return G - -def calculate_G_policies(A, B, C, qs_current, policies): - - G = np.zeros(len(policies)) # initialize the vector of expected free energies, one per policy - H_A = entropy(A) # can calculate the entropy of the A matrix beforehand, since it'll be the same for all policies - - for policy_id, policy in enumerate(policies): # loop over policies - policy_id will be the linear index of the policy (0, 1, 2, ...) and `policy` will be a column vector where `policy[t,0]` indexes the action entailed by that policy at time `t` - - t_horizon = policy.shape[0] # temporal depth of the policy - - G_pi = 0.0 # initialize expected free energy for this policy - - for t in range(t_horizon): # loop over temporal depth of the policy - - action = policy[t,0] # action entailed by this particular policy, at time `t` - - # get the past predictive posterior - which is either your current posterior at the current time (not the policy time) or the predictive posterior entailed by this policy, one timstep ago (in policy time) - if t == 0: - qs_prev = qs_current - else: - qs_prev = qs_pi_t # What is "qs_pi_t"? - - qs_pi_t = get_expected_states(B, qs_prev, action) # expected states, under the action entailed by the policy at this particular time - qo_pi_t = get_expected_observations(A, qs_pi_t) # expected observations, under the action entailed by the policy at this particular time - - kld = kl_divergence(qo_pi_t, C) # Kullback-Leibler divergence between expected observations and the prior preferences C - - G_pi_t = H_A.dot(qs_pi_t) + kld # predicted uncertainty + predicted divergence, for this policy & timepoint - - G_pi += G_pi_t # accumulate the expected free energy for each timepoint into the overall EFE for the policy - - G[policy_id] += G_pi - - return G - -def compute_prob_actions(actions, policies, Q_pi): - P_u = np.zeros(len(actions)) # initialize the vector of probabilities of each action - - for policy_id, policy in enumerate(policies): - P_u[int(policy[0,0])] += Q_pi[policy_id] # get the marginal probability for the given action, entailed by this policy at the first timestep - - P_u = norm_dist(P_u) # normalize the action probabilities - - return P_u diff --git a/blockference/tools/utils.py b/blockference/tools/utils.py deleted file mode 100644 index 1e545f7..0000000 --- a/blockference/tools/utils.py +++ /dev/null @@ -1,804 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns - -import pandas as pd - -import warnings -import itertools -from pymdp.maths import spm_log_single as log_stable - -EPS_VAL = 1e-16 # global constant for use in norm_dist() - - -def softmax(dist): - """ - Computes the softmax function on a set of values - """ - - output = dist - dist.max(axis=0) - output = np.exp(output) - output = output / np.sum(output, axis=0) - return output - - -def sample(probabilities): - sample_onehot = np.random.multinomial(1, probabilities.squeeze()) - return np.where(sample_onehot == 1)[0][0] - - -def sample_obj_array(arr): - """ - Sample from set of Categorical distributions, stored in the sub-arrays of an object array - """ - - samples = [sample(arr_i) for arr_i in arr] - - return samples - - -def obj_array(num_arr): - """ - Creates a generic object array with the desired number of sub-arrays, given by `num_arr` - """ - return np.empty(num_arr, dtype=object) - - -def obj_array_zeros(shape_list): - """ - Creates a numpy object array whose sub-arrays are 1-D vectors - filled with zeros, with shapes given by shape_list[i] - """ - arr = obj_array(len(shape_list)) - for i, shape in enumerate(shape_list): - arr[i] = np.zeros(shape) - return arr - - -def obj_array_uniform(shape_list): - """ - Creates a numpy object array whose sub-arrays are uniform Categorical - distributions with shapes given by shape_list[i]. The shapes (elements of shape_list) - can either be tuples or lists. - """ - arr = obj_array(len(shape_list)) - for i, shape in enumerate(shape_list): - arr[i] = norm_dist(np.ones(shape)) - return arr - - -def obj_array_ones(shape_list, scale=1.0): - arr = obj_array(len(shape_list)) - for i, shape in enumerate(shape_list): - arr[i] = scale * np.ones(shape) - - return arr - - -def onehot(value, num_values): - arr = np.zeros(num_values) - arr[value] = 1.0 - return arr - - -def random_A_matrix(num_obs, num_states): - if type(num_obs) is int: - num_obs = [num_obs] - if type(num_states) is int: - num_states = [num_states] - num_modalities = len(num_obs) - - A = obj_array(num_modalities) - for modality, modality_obs in enumerate(num_obs): - modality_shape = [modality_obs] + num_states - modality_dist = np.random.rand(*modality_shape) - A[modality] = norm_dist(modality_dist) - return A - - -def random_B_matrix(num_states, num_controls): - if type(num_states) is int: - num_states = [num_states] - if type(num_controls) is int: - num_controls = [num_controls] - num_factors = len(num_states) - assert len(num_controls) == len(num_states) - - B = obj_array(num_factors) - for factor in range(num_factors): - factor_shape = (num_states[factor], num_states[factor], num_controls[factor]) - factor_dist = np.random.rand(*factor_shape) - B[factor] = norm_dist(factor_dist) - return B - - -def random_single_categorical(shape_list): - """ - Creates a random 1-D categorical distribution (or set of 1-D categoricals, e.g. multiple marginals of different factors) and returns them in an object array - """ - - num_sub_arrays = len(shape_list) - - out = obj_array(num_sub_arrays) - - for arr_idx, shape_i in enumerate(shape_list): - out[arr_idx] = norm_dist(np.random.rand(shape_i)) - - return out - - -def construct_controllable_B(num_states, num_controls): - """ - Generates a fully controllable transition likelihood array, where each - action (control state) corresponds to a move to the n-th state from any - other state, for each control factor - """ - - num_factors = len(num_states) - - B = obj_array(num_factors) - for factor, c_dim in enumerate(num_controls): - tmp = np.eye(c_dim)[:, :, np.newaxis] - tmp = np.tile(tmp, (1, 1, c_dim)) - B[factor] = tmp.transpose(1, 2, 0) - - return B - - -def dirichlet_like(template_categorical, scale=1.0): - """ - Helper function to construct a Dirichlet distribution based on an existing Categorical distribution - """ - - if not is_obj_array(template_categorical): - warnings.warn( - "Input array is not an object array...Casting the input to an object array" - ) - template_categorical = to_obj_array(template_categorical) - - n_sub_arrays = len(template_categorical) - - dirichlet_out = obj_array(n_sub_arrays) - - for i, arr in enumerate(template_categorical): - dirichlet_out[i] = scale * arr - - return dirichlet_out - - -def get_model_dimensions(A=None, B=None): - - if A is None and B is None: - raise ValueError( - "Must provide either `A` or `B`" - ) - - if A is not None: - num_obs = [a.shape[0] for a in A] if is_obj_array(A) else [A.shape[0]] - num_modalities = len(num_obs) - else: - num_obs, num_modalities = None, None - - if B is not None: - num_states = [b.shape[0] for b in B] if is_obj_array(B) else [B.shape[0]] - num_factors = len(num_states) - else: - if A is not None: - num_states = list(A[0].shape[1:]) if is_obj_array(A) else list(A.shape[1:]) - num_factors = len(num_states) - else: - num_states, num_factors = None, None - - return num_obs, num_states, num_modalities, num_factors - - -def get_model_dimensions_from_labels(model_labels): - - modalities = model_labels['observations'] - num_modalities = len(modalities.keys()) - num_obs = [len(modalities[modality]) for modality in modalities.keys()] - - factors = model_labels['states'] - num_factors = len(factors.keys()) - num_states = [len(factors[factor]) for factor in factors.keys()] - - if 'actions' in model_labels.keys(): - - controls = model_labels['actions'] - num_control_fac = len(controls.keys()) - num_controls = [len(controls[cfac]) for cfac in controls.keys()] - - return num_obs, num_modalities, num_states, num_factors, num_controls, num_control_fac - else: - return num_obs, num_modalities, num_states, num_factors - - -def norm_dist(dist): - """ Normalizes a Categorical probability distribution (or set of them) assuming sufficient statistics are stored in leading dimension""" - if dist.ndim == 3: - new_dist = np.zeros_like(dist) - for c in range(dist.shape[2]): - new_dist[:, :, c] = np.divide(dist[:, :, c], dist[:, :, c].sum(axis=0)) - return new_dist - else: - return np.divide(dist, dist.sum(axis=0)) - - -def norm_dist_obj_arr(obj_arr): - - normed_obj_array = obj_array(len(obj_arr)) - for i, arr in enumerate(obj_arr): - normed_obj_array[i] = norm_dist(arr) - - return normed_obj_array - - -def is_normalized(dist): - """ - Utility function for checking whether a single distribution or set of conditional categorical distributions is normalized. - Returns True if all distributions integrate to 1.0 - """ - - if is_obj_array(dist): - normed_arrays = [] - for i, arr in enumerate(dist): - column_sums = arr.sum(axis=0) - normed_arrays.append(np.allclose(column_sums, np.ones_like(column_sums))) - out = all(normed_arrays) - else: - column_sums = dist.sum(axis=0) - out = np.allclose(column_sums, np.ones_like(column_sums)) - - return out - - -def is_obj_array(arr): - return arr.dtype == "object" - - -def to_obj_array(arr): - if is_obj_array(arr): - return arr - obj_array_out = obj_array(1) - obj_array_out[0] = arr.squeeze() - return obj_array_out - - -def obj_array_from_list(list_input): - """ - Takes a list of `numpy.ndarray` and converts them to a `numpy.ndarray` of `dtype = object` - """ - return np.array(list_input, dtype=object) - - -def process_observation_seq(obs_seq, n_modalities, n_observations): - """ - Helper function for formatting observations - Observations can either be `int` (converted to one-hot) - or `tuple` (obs for each modality), or `list` (obs for each modality) - If list, the entries could be object arrays of one-hots, in which - case this function returns `obs_seq` as is. - """ - proc_obs_seq = obj_array(len(obs_seq)) - for t, obs_t in enumerate(obs_seq): - proc_obs_seq[t] = process_observation(obs_t, n_modalities, n_observations) - return proc_obs_seq - - -def process_observation(obs, num_modalities, num_observations): - """ - Helper function for formatting observations - USAGE NOTES: - - If `obs` is a 1D numpy array, it must be a one-hot vector, where one entry (the entry of the observation) is 1.0 - and all other entries are 0. This therefore assumes it's a single modality observation. If these conditions are met, then - this function will return `obs` unchanged. Otherwise, it'll throw an error. - - If `obs` is an int, it assumes this is a single modality observation, whose observation index is given by the value of `obs`. This function will convert - it to be a one hot vector. - - If `obs` is a list, it assumes this is a multiple modality observation, whose len is equal to the number of observation modalities, - and where each entry `obs[m]` is the index of the observation, for that modality. This function will convert it into an object array - of one-hot vectors. - - If `obs` is a tuple, same logic as applies for list (see above). - - if `obs` is a numpy object array (array of arrays), this function will return `obs` unchanged. - """ - - if isinstance(obs, np.ndarray) and not is_obj_array(obs): - assert num_modalities == 1, "If `obs` is a 1D numpy array, `num_modalities` must be equal to 1" - assert len(np.where(obs)[0]) == 1, "If `obs` is a 1D numpy array, it must be a one hot vector (e.g. np.array([0.0, 1.0, 0.0, ....]))" - - if isinstance(obs, (int, np.integer)): - obs = onehot(obs, num_observations[0]) - - if isinstance(obs, tuple) or isinstance(obs, list): - obs_arr_arr = obj_array(num_modalities) - for m in range(num_modalities): - obs_arr_arr[m] = onehot(obs[m], num_observations[m]) - obs = obs_arr_arr - - return obs - - -def convert_observation_array(obs, num_obs): - """ - Converts from SPM-style observation array to infer-actively one-hot object arrays. - - Parameters - ---------- - - 'obs' [numpy 2-D nd.array]: - SPM-style observation arrays are of shape (num_modalities, T), where each row - contains observation indices for a different modality, and columns indicate - different timepoints. Entries store the indices of the discrete observations - within each modality. - - 'num_obs' [list]: - List of the dimensionalities of the observation modalities. `num_modalities` - is calculated as `len(num_obs)` in the function to determine whether we're - dealing with a single- or multi-modality - case. - Returns - ---------- - - `obs_t`[list]: - A list with length equal to T, where each entry of the list is either a) an object - array (in the case of multiple modalities) where each sub-array is a one-hot vector - with the observation for the correspond modality, or b) a 1D numpy array (in the case - of one modality) that is a single one-hot vector encoding the observation for the - single modality. - """ - - T = obs.shape[1] - num_modalities = len(num_obs) - - # Initialise the output - obs_t = [] - # Case of one modality - if num_modalities == 1: - for t in range(T): - obs_t.append(onehot(obs[0, t] - 1, num_obs[0])) - else: - for t in range(T): - obs_AoA = obj_array(num_modalities) - for g in range(num_modalities): - # Subtract obs[g,t] by 1 to account for MATLAB vs. Python indexing - # (MATLAB is 1-indexed) - obs_AoA[g] = onehot(obs[g, t] - 1, num_obs[g]) - obs_t.append(obs_AoA) - - return obs_t - - -def insert_multiple(s, indices, items): - for idx in range(len(items)): - s.insert(indices[idx], items[idx]) - return s - - -def reduce_a_matrix(A): - """ - Utility function for throwing away dimensions (lagging dimensions, hidden state factors) - of a particular A matrix that are independent of the observation. - Parameters: - ========== - - `A` [np.ndarray]: - The A matrix or likelihood array that encodes probabilistic relationship - of the generative model between hidden state factors (lagging dimensions, columns, slices, etc...) - and observations (leading dimension, rows). - Returns: - ========= - - `A_reduced` [np.ndarray]: - The reduced A matrix, missing the lagging dimensions that correspond to hidden state factors - that are statistically independent of observations - - `original_factor_idx` [list]: - List of the indices (in terms of the original dimensionality) of the hidden state factors - that are maintained in the A matrix (and thus have an informative / non-degenerate relationship to observations - """ - - o_dim, num_states = A.shape[0], A.shape[1:] - idx_vec_s = [slice(0, o_dim)] + [slice(ns) for _, ns in enumerate(num_states)] - - original_factor_idx = [] - excluded_factor_idx = [] # the indices of the hidden state factors that are independent of the observation and thus marginalized away - for factor_i, ns in enumerate(num_states): - - level_counter = 0 - break_flag = False - while level_counter < ns and break_flag is False: - idx_vec_i = idx_vec_s.copy() - idx_vec_i[factor_i+1] = slice(level_counter, level_counter+1, None) - if not np.isclose(A.mean(axis=factor_i+1), A[tuple(idx_vec_i)].squeeze()).all(): - break_flag = True # this means they're not independent - original_factor_idx.append(factor_i) - else: - level_counter += 1 - - if break_flag is False: - excluded_factor_idx.append(factor_i+1) - - A_reduced = A.mean(axis=tuple(excluded_factor_idx)).squeeze() - - return A_reduced, original_factor_idx - - -def construct_full_a(A_reduced, original_factor_idx, num_states): - """ - Utility function for reconstruction a full A matrix from a reduced A matrix, using known factor indices - to tile out the reduced A matrix along the 'non-informative' dimensions - Parameters: - ========== - - `A_reduced` [np.ndarray]: - The reduced A matrix or likelihood array that encodes probabilistic relationship - of the generative model between hidden state factors (lagging dimensions, columns, slices, etc...) - and observations (leading dimension, rows). - - `original_factor_idx` [list]: - List of hidden state indices in terms of the full hidden state factor list, that comprise - the lagging dimensions of `A_reduced` - - `num_states` [list]: - The list of all the dimensionalities of hidden state factors in the full generative model. - `A_reduced.shape[1:]` should be equal to `num_states[original_factor_idx]` - Returns: - ========= - - `A` [np.ndarray]: - The full A matrix, containing all the lagging dimensions that correspond to hidden state factors, including - those that are statistically independent of observations - - @ NOTE: This is the "inverse" of the reduce_a_matrix function, - i.e. `reduce_a_matrix(construct_full_a(A_reduced, original_factor_idx, num_states)) == A_reduced, original_factor_idx` - """ - - o_dim = A_reduced.shape[0] # dimensionality of the support of the likelihood distribution (i.e. the number of observation levels) - full_dimensionality = [o_dim] + num_states # full dimensionality of the output (`A`) - fill_indices = [0] + [f+1 for f in original_factor_idx] # these are the indices of the dimensions we need to fill for this modality - fill_dimensions = np.delete(full_dimensionality, fill_indices) - - original_factor_dims = [num_states[f] for f in original_factor_idx] # dimensionalities of the relevant factors - prefilled_slices = [slice(0, o_dim)] + [slice(0, ns) for ns in original_factor_dims] # these are the slices that are filled out by the provided `A_reduced` - - A = np.zeros(full_dimensionality) - - for item in itertools.product(*[list(range(d)) for d in fill_dimensions]): - slice_ = list(item) - A_indices = insert_multiple(slice_, fill_indices, prefilled_slices) # here we insert the correct values for the fill indices for this slice - A[tuple(A_indices)] = A_reduced - - return A - - -def create_A_matrix_stub(model_labels): - - num_obs, _, num_states, _ = get_model_dimensions_from_labels(model_labels) - - obs_labels, state_labels = model_labels['observations'], model_labels['states'] - - state_combinations = pd.MultiIndex.from_product(list(state_labels.values()), names=list(state_labels.keys())) - num_state_combos = np.prod(num_states) # What is num_state_combos?? - # num_rows = (np.array(num_obs) * num_state_combos).sum() - num_rows = sum(num_obs) - - cell_values = np.zeros((num_rows, len(state_combinations))) - - obs_combinations = [] - for modality in obs_labels.keys(): - levels_to_combine = [[modality]] + [obs_labels[modality]] - # obs_combinations += num_state_combos * list(itertools.product(*levels_to_combine)) - obs_combinations += list(itertools.product(*levels_to_combine)) - - obs_combinations = pd.MultiIndex.from_tuples(obs_combinations, names=["Modality", "Level"]) - - A_matrix = pd.DataFrame(cell_values, index=obs_combinations, columns=state_combinations) - - return A_matrix - - -def create_B_matrix_stubs(model_labels): - - _, _, num_states, _, num_controls, _ = get_model_dimensions_from_labels(model_labels) - - state_labels = model_labels['states'] - action_labels = model_labels['actions'] - - B_matrices = {} - - for f_idx, factor in enumerate(state_labels.keys()): - - control_fac_name = list(action_labels)[f_idx] - factor_list = [state_labels[factor]] + [action_labels[control_fac_name]] - - prev_state_action_combos = pd.MultiIndex.from_product(factor_list, names=[factor, list(action_labels.keys())[f_idx]]) - - num_state_action_combos = num_states[f_idx] * num_controls[f_idx] - - num_rows = num_states[f_idx] - - cell_values = np.zeros((num_rows, num_state_action_combos)) - - next_state_list = state_labels[factor] - - B_matrix_f = pd.DataFrame(cell_values, index=next_state_list, columns=prev_state_action_combos) - - B_matrices[factor] = B_matrix_f - - return B_matrices - - -def read_A_matrix(path, num_hidden_state_factors): - raw_table = pd.read_excel(path, header=None) - level_counts = { - "index": raw_table.iloc[0, :].dropna().index[0] + 1, - "header": raw_table.iloc[0, :].dropna().index[0] + num_hidden_state_factors - 1, - } - return pd.read_excel( - path, - index_col=list(range(level_counts["index"])), - header=list(range(level_counts["header"])) - ).astype(np.float64) - - -def read_B_matrices(path): - - all_sheets = pd.read_excel(path, sheet_name=None, header=None) - - level_counts = {} - for sheet_name, raw_table in all_sheets.items(): - - level_counts[sheet_name] = { - "index": raw_table.iloc[0, :].dropna().index[0]+1, - "header": raw_table.iloc[0, :].dropna().index[0]+2, - } - - stub_dict = {} - for sheet_name, level_counts_sheet in level_counts.items(): - sheet_f = pd.read_excel( - path, - sheet_name=sheet_name, - index_col=list(range(level_counts_sheet["index"])), - header=list(range(level_counts_sheet["header"])) - ).astype(np.float64) - stub_dict[sheet_name] = sheet_f - - return stub_dict - - -def convert_A_stub_to_ndarray(A_stub, model_labels): - """ - This function converts a multi-index pandas dataframe `A_stub` into an object array of different - A matrices, one per observation modality. - """ - - num_obs, num_modalities, num_states, num_factors = get_model_dimensions_from_labels(model_labels) - - A = obj_array(num_modalities) - - for g, modality_name in enumerate(model_labels['observations'].keys()): - A[g] = A_stub.loc[modality_name].to_numpy().reshape(num_obs[g], *num_states) - assert (A[g].sum(axis=0) == 1.0).all(), 'A matrix not normalized! Check your initialization....\n' - - return A - - -def convert_B_stubs_to_ndarray(B_stubs, model_labels): - """ - This function converts a list of multi-index pandas dataframes `B_stubs` into an object array - of different B matrices, one per hidden state factor - """ - - _, _, num_states, num_factors, num_controls, num_control_fac = get_model_dimensions_from_labels(model_labels) - - B = obj_array(num_factors) - - for f, factor_name in enumerate(B_stubs.keys()): - - B[f] = B_stubs[factor_name].to_numpy().reshape(num_states[f], num_states[f], num_controls[f]) - assert (B[f].sum(axis=0) == 1.0).all(), 'B matrix not normalized! Check your initialization....\n' - - return B - - -def build_belief_array(qx): - """ - This function constructs array-ified (not nested) versions - of the posterior belief arrays, that are separated - by policy, timepoint, and hidden state factor - """ - - num_policies = len(qx) - num_timesteps = len(qx[0]) - num_factors = len(qx[0][0]) - - if num_factors > 1: - belief_array = obj_array(num_factors) - for factor in range(num_factors): - belief_array[factor] = np.zeros((num_policies, qx[0][0][factor].shape[0], num_timesteps)) - for policy_i in range(num_policies): - for timestep in range(num_timesteps): - for factor in range(num_factors): - belief_array[factor][policy_i, :, timestep] = qx[policy_i][timestep][factor] - else: - num_states = qx[0][0][0].shape[0] - belief_array = np.zeros((num_policies, num_states, num_timesteps)) - for policy_i in range(num_policies): - for timestep in range(num_timesteps): - belief_array[policy_i, :, timestep] = qx[policy_i][timestep][0] - - return belief_array - - -def build_xn_vn_array(xn): - """ - This function constructs array-ified (not nested) versions - of the posterior xn (beliefs) or vn (prediction error) arrays, that are separated - by iteration, hidden state factor, timepoint, and policy - """ - - num_policies = len(xn) - num_itr = len(xn[0]) - num_factors = len(xn[0][0]) - - if num_factors > 1: - xn_array = obj_array(num_factors) - for factor in range(num_factors): - num_states, infer_len = xn[0][0].shape - xn_array[factor] = np.zeros((num_itr, num_states, infer_len, num_policies)) - for policy_i in range(num_policies): - for itr in range(num_itr): - for factor in range(num_factors): - xn_array[factor][itr, :, :, policy_i] = xn[policy_i][itr][factor] - else: - num_states, infer_len = xn[0][0][0].shape - xn_array = np.zeros((num_itr, num_states, infer_len, num_policies)) - for policy_i in range(num_policies): - for itr in range(num_itr): - xn_array[itr, :, :, policy_i] = xn[policy_i][itr][0] - - return xn_array - - -# plotting functions -def plot_likelihood(matrix, xlabels=list(range(9)), ylabels=list(range(9)), title_str="Likelihood distribution (A)"): - """ - Plots a 2-D likelihood matrix as a heatmap - """ - - if not np.isclose(matrix.sum(axis=0), 1.0).all(): - raise ValueError("Distribution not column-normalized! Please normalize (ensure matrix.sum(axis=0) == 1.0 for all columns)") - fig = plt.figure(figsize=(6, 6)) # Unclear what "fig" is doing here. - ax=sns.heatmap(matrix, xticklabels=xlabels, yticklabels=ylabels, cmap='gray', cbar=False, vmin=0.0, vmax=1.0) # Unclear what ax is doing - plt.title(title_str) - plt.show() - - -def plot_grid(grid_locations, num_x=3, num_y=3): - """ - Plots the spatial coordinates of GridWorld as a heatmap, with each (X, Y) coordinate - labeled with its linear index (its `state id`) - """ - - grid_heatmap = np.zeros((num_x, num_y)) - for linear_idx, location in enumerate(grid_locations): - y, x = location - grid_heatmap[y, x] = linear_idx - sns.set(font_scale=1.5) - sns.heatmap(grid_heatmap, annot=True, cbar=False, fmt='.0f', cmap='crest') - - -def plot_point_on_grid(state_vector, grid_locations): - """ - Plots the current location of the agent on the grid world - """ - state_index = np.where(state_vector)[0][0] - y, x = grid_locations[state_index] - grid_heatmap = np.zeros((3, 3)) - grid_heatmap[y, x] = 1.0 - sns.heatmap(grid_heatmap, cbar=False, fmt='.0f') - - -def plot_beliefs(belief_dist, title_str=""): - """ - Plot a categorical distribution or belief distribution, stored in the 1-D numpy vector `belief_dist` - """ - - if not np.isclose(belief_dist.sum(), 1.0): - raise ValueError("Distribution not normalized! Please normalize") - - plt.grid(zorder=0) - plt.bar(range(belief_dist.shape[0]), belief_dist, color='r', zorder=3) - plt.xticks(range(belief_dist.shape[0])) - plt.title(title_str) - plt.show() - -# ActInf functions - - -def infer_states(observation_index, A, prior): - - log_likelihood = log_stable(A[observation_index, :]) - - log_prior = log_stable(prior) - - qs = softmax(log_likelihood + log_prior) - - return qs - - -def get_expected_states(B, qs_current, action): - """ Compute the expected states one step into the future, given a particular action """ - qs_u = B[:, :, action].dot(qs_current) - - return qs_u - -def get_expected_observations(A, qs_u): - """ Compute the expected observations one step into the future, given a particular action """ - - qo_u = A.dot(qs_u) - - return qo_u - -def entropy(A): - """ Compute the entropy of a set of conditional distributions, i.e. one entropy value per column """ - - H_A = - (A * log_stable(A)).sum(axis=0) - - return H_A - -def kl_divergence(qo_u, C): - """ Compute the Kullback-Leibler divergence between two 1-D categorical distributions""" - - return (log_stable(qo_u) - log_stable(C)).dot(qo_u) - -def calculate_G(A, B, C, qs_current, actions): - - G = np.zeros(len(actions)) # vector of expected free energies, one per action - - H_A = entropy(A) # entropy of the observation model, P(o|s) - - for action_i in range(len(actions)): - - qs_u = get_expected_states(B, qs_current, action_i) # expected states, under the action we're currently looping over - qo_u = get_expected_observations(A, qs_u) # expected observations, under the action we're currently looping over - - pred_uncertainty = H_A.dot(qs_u) # predicted uncertainty, i.e. expected entropy of the A matrix - pred_div = kl_divergence(qo_u, C) # predicted divergence - - G[action_i] = pred_uncertainty + pred_div # sum them together to get expected free energy - - return G - -def calculate_G_policies(A, B, C, qs_current, policies): - - G = np.zeros(len(policies)) # initialize the vector of expected free energies, one per policy - H_A = entropy(A) # can calculate the entropy of the A matrix beforehand, since it'll be the same for all policies - - for policy_id, policy in enumerate(policies): # loop over policies - policy_id will be the linear index of the policy (0, 1, 2, ...) and `policy` will be a column vector where `policy[t,0]` indexes the action entailed by that policy at time `t` - - t_horizon = policy.shape[0] # temporal depth of the policy - - G_pi = 0.0 # initialize expected free energy for this policy - - for t in range(t_horizon): # loop over temporal depth of the policy - - action = policy[t,0] # action entailed by this particular policy, at time `t` - - # get the past predictive posterior - which is either your current posterior at the current time (not the policy time) or the predictive posterior entailed by this policy, one timstep ago (in policy time) - if t == 0: - qs_prev = qs_current - else: - qs_prev = qs_pi_t # What is "qs_pi_t"? - - qs_pi_t = get_expected_states(B, qs_prev, action) # expected states, under the action entailed by the policy at this particular time - qo_pi_t = get_expected_observations(A, qs_pi_t) # expected observations, under the action entailed by the policy at this particular time - - kld = kl_divergence(qo_pi_t, C) # Kullback-Leibler divergence between expected observations and the prior preferences C - - G_pi_t = H_A.dot(qs_pi_t) + kld # predicted uncertainty + predicted divergence, for this policy & timepoint - - G_pi += G_pi_t # accumulate the expected free energy for each timepoint into the overall EFE for the policy - - G[policy_id] += G_pi - - return G - -def compute_prob_actions(actions, policies, Q_pi): - P_u = np.zeros(len(actions)) # initialize the vector of probabilities of each action - - for policy_id, policy in enumerate(policies): - P_u[int(policy[0,0])] += Q_pi[policy_id] # get the marginal probability for the given action, entailed by this policy at the first timestep - - P_u = norm_dist(P_u) # normalize the action probabilities - - return P_u diff --git a/blockference/tools/__init__.py b/blockference/utils/__init__.py similarity index 100% rename from blockference/tools/__init__.py rename to blockference/utils/__init__.py diff --git a/blockference/tools/mutual_info.py b/blockference/utils/mutual_info.py similarity index 100% rename from blockference/tools/mutual_info.py rename to blockference/utils/mutual_info.py diff --git a/blockference/tools/policy.py b/blockference/utils/policy.py similarity index 100% rename from blockference/tools/policy.py rename to blockference/utils/policy.py diff --git a/blockference/utils/utils.py b/blockference/utils/utils.py new file mode 100644 index 0000000..1d53a5b --- /dev/null +++ b/blockference/utils/utils.py @@ -0,0 +1,161 @@ +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +import pandas as pd + +import warnings +import itertools +from pymdp.maths import spm_log_single as log_stable +from pymdp.maths import softmax +from pymdp.utils import norm_dist, sample + + +def plot_likelihood(matrix, xlabels = list(range(9)), ylabels = list(range(9)), title_str = "Likelihood distribution (A)"): + """ + Plots a 2-D likelihood matrix as a heatmap + """ + + if not np.isclose(matrix.sum(axis=0), 1.0).all(): + raise ValueError("Distribution not column-normalized! Please normalize (ensure matrix.sum(axis=0) == 1.0 for all columns)") + + fig = plt.figure(figsize = (6,6)) + ax = sns.heatmap(matrix, xticklabels = xlabels, yticklabels = ylabels, cmap = 'gray', cbar = False, vmin = 0.0, vmax = 1.0) + plt.title(title_str) + plt.show() + +def plot_grid(grid_locations, num_x = 3, num_y = 3 ): + """ + Plots the spatial coordinates of GridWorld as a heatmap, with each (X, Y) coordinate + labeled with its linear index (its `state id`) + """ + + grid_heatmap = np.zeros((num_x, num_y)) + for linear_idx, location in enumerate(grid_locations): + y, x = location + grid_heatmap[y, x] = linear_idx + sns.set(font_scale=1.5) + sns.heatmap(grid_heatmap, annot=True, cbar = False, fmt='.0f', cmap='crest') + +def plot_point_on_grid(state_vector, grid_locations): + """ + Plots the current location of the agent on the grid world + """ + state_index = np.where(state_vector)[0][0] + y, x = grid_locations[state_index] + grid_heatmap = np.zeros((3,3)) + grid_heatmap[y,x] = 1.0 + sns.heatmap(grid_heatmap, cbar = False, fmt='.0f') + +def plot_beliefs(belief_dist, title_str=""): + """ + Plot a categorical distribution or belief distribution, stored in the 1-D numpy vector `belief_dist` + """ + + if not np.isclose(belief_dist.sum(), 1.0): + raise ValueError("Distribution not normalized! Please normalize") + + plt.grid(zorder=0) + plt.bar(range(belief_dist.shape[0]), belief_dist, color='r', zorder=3) + plt.xticks(range(belief_dist.shape[0])) + plt.title(title_str) + plt.show() + +def get_expected_states(B, qs_current, action): + """ Compute the expected states one step into the future, given a particular action """ + qs_u = B[:,:,action].dot(qs_current) + + return qs_u + +def get_expected_observations(A, qs_u): + """ Compute the expected observations one step into the future, given a particular action """ + + qo_u = A.dot(qs_u) + + return qo_u + +def entropy(A): + """ Compute the entropy of a set of conditional distributions, i.e. one entropy value per column """ + + H_A = - (A * log_stable(A)).sum(axis=0) + + return H_A + +def kl_divergence(qo_u, C): + """ Compute the Kullback-Leibler divergence between two 1-D categorical distributions""" + + return (log_stable(qo_u) - log_stable(C)).dot(qo_u) + + + +def infer_states(observation_index, A, prior): + + log_likelihood = log_stable(A[observation_index, :]) + + log_prior = log_stable(prior) + + qs = softmax(log_likelihood + log_prior) + + return qs + +def calculate_G(A, B, C, qs_current, actions): + + G = np.zeros(len(actions)) # vector of expected free energies, one per action + + H_A = entropy(A) # entropy of the observation model, P(o|s) + + for action_i in range(len(actions)): + + qs_u = get_expected_states(B, qs_current, action_i) # expected states, under the action we're currently looping over + qo_u = get_expected_observations(A, qs_u) # expected observations, under the action we're currently looping over + + pred_uncertainty = H_A.dot(qs_u) # predicted uncertainty, i.e. expected entropy of the A matrix + pred_div = kl_divergence(qo_u, C) # predicted divergence + + G[action_i] = pred_uncertainty + pred_div # sum them together to get expected free energy + + return G + +def calculate_G_policies(A, B, C, qs_current, policies): + + G = np.zeros(len(policies)) # initialize the vector of expected free energies, one per policy + H_A = entropy(A) # can calculate the entropy of the A matrix beforehand, since it'll be the same for all policies + + for policy_id, policy in enumerate(policies): # loop over policies - policy_id will be the linear index of the policy (0, 1, 2, ...) and `policy` will be a column vector where `policy[t,0]` indexes the action entailed by that policy at time `t` + + t_horizon = policy.shape[0] # temporal depth of the policy + + G_pi = 0.0 # initialize expected free energy for this policy + + for t in range(t_horizon): # loop over temporal depth of the policy + + action = policy[t,0] # action entailed by this particular policy, at time `t` + + # get the past predictive posterior - which is either your current posterior at the current time (not the policy time) or the predictive posterior entailed by this policy, one timstep ago (in policy time) + if t == 0: + qs_prev = qs_current + else: + qs_prev = qs_pi_t # What is "qs_pi_t"? + + qs_pi_t = get_expected_states(B, qs_prev, action) # expected states, under the action entailed by the policy at this particular time + qo_pi_t = get_expected_observations(A, qs_pi_t) # expected observations, under the action entailed by the policy at this particular time + + kld = kl_divergence(qo_pi_t, C) # Kullback-Leibler divergence between expected observations and the prior preferences C + + G_pi_t = H_A.dot(qs_pi_t) + kld # predicted uncertainty + predicted divergence, for this policy & timepoint + + G_pi += G_pi_t # accumulate the expected free energy for each timepoint into the overall EFE for the policy + + G[policy_id] += G_pi + + return G + +def compute_prob_actions(actions, policies, Q_pi): + P_u = np.zeros(len(actions)) # initialize the vector of probabilities of each action + + for policy_id, policy in enumerate(policies): + P_u[int(policy[0,0])] += Q_pi[policy_id] # get the marginal probability for the given action, entailed by this policy at the first timestep + + P_u = norm_dist(P_u) # normalize the action probabilities + + return P_u diff --git a/notebooks/ants/blockferants.ipynb b/notebooks/ants/blockferants.ipynb new file mode 100644 index 0000000..dcb31ed --- /dev/null +++ b/notebooks/ants/blockferants.ipynb @@ -0,0 +1,441 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "734ff500-1fdb-4d6a-91db-8db84714e033", + "metadata": {}, + "source": [ + "## Active BlockferAnts" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "80406ce4-654a-4fe4-befd-029498f57dab", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "ADD_ANT_EVERY = 50\n", + "INIT_X = 20\n", + "INIT_Y = 30\n", + "\n", + "NEST_FACTOR = 0.1\n", + "\n", + "GRID = [40, 40]\n", + "GRID_SIZE = np.prod(GRID)\n", + "\n", + "FOCAL_AREA = [3, 3]\n", + "FOCAL_SIZE = np.prod(FOCAL_AREA)\n", + "ACTION_MAP = [(-1, -1), (0, -1), (1, -1), (-1, 0), (0, 0), (1, 0), (-1, 1), (0, 1), (1, 1)]\n", + "OPPOSITE_ACTIONS = list(reversed(range(len(ACTION_MAP))))\n", + "\n", + "FOOD_LOCATION = [40, 5]\n", + "FOOD_SIZE = [10, 10]\n", + "\n", + "WALL_LEFT = 15\n", + "WALL_RIGHT = 25\n", + "WALL_TOP = 10\n", + "\n", + "NUM_PHEROMONE_LEVELS = 10\n", + "DECAY_FACTOR = 0.01\n", + "\n", + "NUM_OBSERVATIONS = NUM_PHEROMONE_LEVELS\n", + "NUM_STATES = FOCAL_SIZE\n", + "NUM_ACTIONS = FOCAL_SIZE\n", + "\n", + "MAX_LEN = 500" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "abd0cd85-534e-426f-8005-126c9858f8b0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import imageio\n", + "\n", + "matplotlib.use(\"Agg\")\n", + "\n", + "\n", + "class Ant(object):\n", + " def __init__(self, mdp, init_x, init_y):\n", + " self.mdp = mdp\n", + " self.x_pos = init_x\n", + " self.y_pos = init_y\n", + " self.traj = [(init_x, init_y)]\n", + " self.distance = []\n", + " self.backward_step = 0\n", + " self.is_returning = False\n", + "\n", + " def update_forward(self, x_pos, y_pos):\n", + " self.x_pos = x_pos\n", + " self.y_pos = y_pos\n", + " self.traj.append((x_pos, y_pos))\n", + " self.distance.append(dis(x_pos, y_pos, INIT_X, INIT_Y))\n", + "\n", + " def update_backward(self, x_pos, y_pos):\n", + " self.x_pos = x_pos\n", + " self.y_pos = y_pos\n", + " self.distance.append(dis(x_pos, y_pos, INIT_X, INIT_Y))\n", + "\n", + "\n", + "class Env(object):\n", + " def __init__(self):\n", + " self.visit_matrix = np.zeros((GRID[0], GRID[1]))\n", + " self.obs_matrix = np.zeros((cf.NUM_OBSERVATIONS, GRID[0], GRID[1]))\n", + " self.obs_matrix[0, :, :] = 1.0\n", + "\n", + " def get_A(self, ant):\n", + " A = np.zeros((NUM_OBSERVATIONS, NUM_STATES))\n", + " for s in range(NUM_STATES):\n", + " delta = ACTION_MAP[s]\n", + " A[:, s] = self.obs_matrix[:, ant.x_pos + delta[0], ant.y_pos + delta[1]]\n", + " return A\n", + "\n", + " def get_obs(self, ant):\n", + " obs_vec = self.obs_matrix[:, ant.x_pos, ant.y_pos]\n", + " return np.argmax(obs_vec)\n", + "\n", + " def check_food(self, x_pos, y_pos):\n", + " is_food = False\n", + " if (x_pos > (FOOD_LOCATION[0] - FOOD_SIZE[0])) and (\n", + " x_pos < (FOOD_LOCATION[0] + FOOD_SIZE[0])\n", + " ):\n", + " if (y_pos > (FOOD_LOCATION[1] - FOOD_SIZE[1])) and (\n", + " y_pos < (FOOD_LOCATION[1] + FOOD_SIZE[1])\n", + " ):\n", + " is_food = True\n", + " return is_food\n", + "\n", + " def check_walls(self, orig_x, orig_y, x_pos, y_pos):\n", + " valid = True\n", + " if orig_y > WALL_TOP:\n", + " if orig_x >= WALL_LEFT and x_pos <= WALL_LEFT:\n", + " valid = False\n", + " if orig_x <= WALL_RIGHT and x_pos >= WALL_RIGHT:\n", + " valid = False\n", + " if orig_y <= WALL_TOP:\n", + " if y_pos > WALL_TOP and ((x_pos < WALL_LEFT) or (x_pos > WALL_RIGHT)):\n", + " valid = False\n", + " return valid\n", + "\n", + " def step_forward(self, ant, action):\n", + " delta = ACTION_MAP[action]\n", + " x_pos = np.clip(ant.x_pos + delta[0], 1, cf.GRID[0] - 2)\n", + " y_pos = np.clip(ant.y_pos + delta[1], 1, cf.GRID[1] - 2)\n", + "\n", + " if self.check_food(x_pos, y_pos) and np.random.rand() < cf.NEST_FACTOR:\n", + " ant.is_returning = True\n", + " ant.backward_step = 0\n", + "\n", + " if self.check_walls(ant.x_pos, ant.y_pos, x_pos, y_pos):\n", + " ant.update_forward(x_pos, y_pos)\n", + "\n", + " \"\"\"\n", + " if len(ant.traj) > cf.MAX_LEN:\n", + " pos = ant.traj[0]\n", + " ant.update_backward(pos[0], pos[1])\n", + " ant.traj = []\n", + " \"\"\"\n", + "\n", + " def step_backward(self, ant):\n", + " path_len = len(ant.traj)\n", + " next_step = path_len - (ant.backward_step + 1)\n", + " pos = ant.traj[next_step]\n", + " ant.update_backward(pos[0], pos[1])\n", + "\n", + " self.visit_matrix[pos[0], pos[1]] += 1\n", + " curr_obs = np.argmax(self.obs_matrix[:, pos[0], pos[1]])\n", + " curr_obs = min(curr_obs + 1, cf.NUM_OBSERVATIONS - 1)\n", + "\n", + " self.obs_matrix[:, pos[0], pos[1]] = 0.0\n", + " self.obs_matrix[curr_obs, pos[0], pos[1]] = 1.0\n", + "\n", + " ant.backward_step += 1\n", + " if ant.backward_step >= path_len - 1:\n", + " ant.is_returning = False\n", + " traj = ant.traj\n", + " ant.traj = []\n", + " return True, traj\n", + " else:\n", + " return False, None\n", + "\n", + " def decay(self):\n", + " for x in range(cf.GRID[0]):\n", + " for y in range(cf.GRID[1]):\n", + " curr_obs = np.argmax(self.obs_matrix[:, x, y])\n", + " if (curr_obs > 0) and (np.random.rand() < DECAY_FACTOR):\n", + " curr_obs = curr_obs - 1\n", + " self.obs_matrix[:, x, y] = 0.0\n", + " self.obs_matrix[curr_obs, x, y] = 1.0\n", + "\n", + " def plot(self, ants, savefig=False, name=\"\", ant_only_gif=False):\n", + " x_pos_forward, y_pos_forward = [], []\n", + " x_pos_backward, y_pos_backward = [], []\n", + " for ant in ants:\n", + " if ant.is_returning:\n", + " x_pos_backward.append(ant.x_pos)\n", + " y_pos_backward.append(ant.y_pos)\n", + " else:\n", + " x_pos_forward.append(ant.x_pos)\n", + " y_pos_forward.append(ant.y_pos)\n", + "\n", + " img = np.ones((GRID[0], GRID[1]))\n", + " fig, ax = plt.subplots()\n", + " ax.imshow(img.T, cmap=\"gray\")\n", + " plot_matrix = np.zeros((GRID[0], GRID[1]))\n", + "\n", + " for x in range(cf.GRID[0]):\n", + " for y in range(cf.GRID[1]):\n", + " curr_obs = np.argmax(self.obs_matrix[:, x, y])\n", + " plot_matrix[x, y] = curr_obs\n", + "\n", + " if ant_only_gif == False:\n", + " ax.imshow(plot_matrix.T, alpha=0.7, vmin=0)\n", + "\n", + " if not savefig:\n", + " ax.scatter(x_pos_forward, y_pos_forward, color=\"red\", s=5)\n", + " ax.scatter(x_pos_backward, y_pos_backward, color=\"blue\", s=5)\n", + "\n", + " if not savefig:\n", + " fig.canvas.draw()\n", + " img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=\"uint8\")\n", + " img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n", + " plt.close(\"all\")\n", + " return img\n", + " else:\n", + " plt.savefig(name)\n", + " plt.close(\"all\")\n", + "\n", + "\n", + "class MDP(object):\n", + " def __init__(self, A, B, C):\n", + " self.A = A\n", + " self.B = B\n", + " self.C = C\n", + " self.p0 = np.exp(-16)\n", + "\n", + " self.num_states = self.A.shape[1]\n", + " self.num_obs = self.A.shape[0]\n", + " self.num_actions = self.B.shape[0]\n", + "\n", + " self.A = self.A + self.p0\n", + " self.A = self.normdist(self.A)\n", + " self.lnA = np.log(self.A)\n", + "\n", + " self.B = self.B + self.p0\n", + " for a in range(self.num_actions):\n", + " self.B[a] = self.normdist(self.B[a])\n", + "\n", + " self.C = self.C + self.p0\n", + " self.C = self.normdist(self.C)\n", + "\n", + " self.sQ = np.zeros([self.num_states, 1])\n", + " self.uQ = np.zeros([self.num_actions, 1])\n", + " self.prev_action = None\n", + " self.t = 0\n", + "\n", + " def set_A(self, A):\n", + " self.A = A + self.p0\n", + " self.A = self.normdist(self.A)\n", + " self.lnA = np.log(self.A)\n", + "\n", + " def reset(self, obs):\n", + " self.t = 0\n", + " self.curr_obs = obs\n", + " likelihood = self.lnA[obs, :]\n", + " likelihood = likelihood[:, np.newaxis]\n", + " self.sQ = self.softmax(likelihood)\n", + " self.prev_action = self.random_action()\n", + "\n", + " def step(self, obs):\n", + " \"\"\" state inference \"\"\"\n", + " likelihood = self.lnA[obs, :]\n", + " likelihood = likelihood[:, np.newaxis]\n", + " prior = np.dot(self.B[self.prev_action], self.sQ)\n", + " prior = np.log(prior)\n", + " self.sQ = self.softmax(prior)\n", + "\n", + " \"\"\" action inference \"\"\"\n", + " SCALE = 10\n", + " neg_efe = np.zeros([self.num_actions, 1])\n", + " for a in range(self.num_actions):\n", + " fs = np.dot(self.B[a], self.sQ)\n", + " fo = np.dot(self.A, fs)\n", + " fo = self.normdist(fo + self.p0)\n", + " utility = np.sum(fo * np.log(fo / self.C), axis=0)\n", + " utility = utility[0]\n", + " neg_efe[a] -= utility / SCALE\n", + "\n", + " # priors\n", + " neg_efe[4] -= 20.0\n", + " neg_efe[cf.OPPOSITE_ACTIONS[self.prev_action]] -= 20.0 # type: ignore\n", + "\n", + " # action selection\n", + " self.uQ = self.softmax(neg_efe)\n", + " action = np.argmax(np.random.multinomial(1, self.uQ.squeeze()))\n", + " self.prev_action = action\n", + " return action\n", + "\n", + " def random_action(self):\n", + " return int(np.random.choice(range(self.num_actions)))\n", + "\n", + " @staticmethod\n", + " def softmax(x):\n", + " x = x - x.max()\n", + " x = np.exp(x)\n", + " x = x / np.sum(x)\n", + " return x\n", + "\n", + " @staticmethod\n", + " def normdist(x):\n", + " return np.dot(x, np.diag(1 / np.sum(x, 0)))\n", + "\n", + "\n", + "def create_ant(init_x, init_y, C):\n", + " A = np.zeros((NUM_OBSERVATIONS, NUM_STATES))\n", + " B = np.zeros((NUM_ACTIONS, NUM_STATES, NUM_STATES))\n", + " for a in range(NUM_ACTIONS):\n", + " B[a, a, :] = 1.0\n", + " mdp = MDP(A, B, C)\n", + " ant = Ant(mdp, init_x, init_y)\n", + " return ant\n", + "\n", + "\n", + "def dis(x1, y1, x2, y2):\n", + " return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2))\n", + "\n", + "\n", + "def plot_path(path, save_name):\n", + " path = np.array(path)\n", + " _, ax = plt.subplots(1, 1)\n", + " ax.set_xlim(GRID[0])\n", + " ax.set_ylim(GRID[1])\n", + " ax.plot(path[:, 0], path[:, 1], \"-o\", color=\"red\", alpha=0.4)\n", + " plt.savefig(save_name)\n", + " plt.close(\"all\")\n", + "\n", + "\n", + "def save_gif(imgs, path, fps=32):\n", + " imageio.mimsave(path, imgs, fps=fps)\n", + "\n", + "\n", + "def main(num_steps, init_ants, max_ants, C, save=True, switch=False, name=\"\", ant_only_gif=False):\n", + " env = Env()\n", + " ants = []\n", + " paths = []\n", + " for _ in range(init_ants):\n", + " ant = create_ant(INIT_X, INIT_Y, C)\n", + " obs = env.get_obs(ant)\n", + " A = env.get_A(ant)\n", + " ant.mdp.set_A(A)\n", + " ant.mdp.reset(obs)\n", + " ants.append(ant)\n", + "\n", + " imgs = []\n", + " completed_trips = 0\n", + " distance = 0\n", + " ant_locations = []\n", + " round_trips_over_time = []\n", + " for t in range(num_steps):\n", + " t_dis = 0\n", + "\n", + " for ant in ants:\n", + " for ant_2 in ants:\n", + " t_dis += dis(ant.x_pos, ant.y_pos, ant_2.x_pos, ant_2.y_pos)\n", + " distance += t_dis / len(ants)\n", + "\n", + " if t % (num_steps // 100) == 0:\n", + " print(f\"{t}/{num_steps}\")\n", + "\n", + " if t % ADD_ANT_EVERY == 0 and len(ants) < max_ants:\n", + " ant = create_ant(INIT_X, INIT_Y, C)\n", + " obs = env.get_obs(ant)\n", + " A = env.get_A(ant)\n", + " ant.mdp.set_A(A)\n", + " ant.mdp.reset(obs)\n", + " ants.append(ant)\n", + "\n", + " if switch and t % (num_steps // 2) == 0:\n", + " # switch\n", + " FOOD_LOCATION[0] = GRID[0] - FOOD_LOCATION[0]\n", + "\n", + " for ant in ants:\n", + " if not ant.is_returning:\n", + " obs = env.get_obs(ant)\n", + " A = env.get_A(ant)\n", + " ant.mdp.set_A(A)\n", + " action = ant.mdp.step(obs)\n", + " env.step_forward(ant, action)\n", + " else:\n", + " is_complete, traj = env.step_backward(ant)\n", + " completed_trips += int(is_complete)\n", + "\n", + " if is_complete:\n", + " paths.append(traj)\n", + " env.decay()\n", + "\n", + " if save:\n", + " if t in np.arange(0, num_steps, num_steps // 20):\n", + " env.plot(ants, savefig=True, name=f\"imgs/{name}_{t}.png\")\n", + " else:\n", + " img = env.plot(ants, ant_only_gif=ant_only_gif)\n", + " imgs.append(img)\n", + "\n", + " round_trips_over_time.append(completed_trips / max_ants)\n", + " ant_locations.append([[ant.x_pos, ant.y_pos] for ant in ants])\n", + "\n", + " \"\"\"\n", + " dis_coeff = 0\n", + " for ant in ants:\n", + " dis_coeff += sum(ant.distance)\n", + " \"\"\"\n", + "\n", + " if save:\n", + " save_gif(imgs, f\"imgs/{name}.gif\")\n", + "\n", + " ant_locations = np.array(ant_locations)\n", + " round_trips_over_time = np.array(round_trips_over_time)\n", + " np.save(f\"imgs/{name}_locations\", ant_locations)\n", + " np.save(f\"imgs/{name}_round_trips\", round_trips_over_time)\n", + "\n", + " return completed_trips, np.array(paths), distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a80d2fb-1d36-486e-a7d9-4f397d246b1e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "block", + "language": "python", + "name": "block" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/02_actinf_with_agent.ipynb b/notebooks/simple_gridworld/02_actinf_with_agent.ipynb similarity index 100% rename from notebooks/02_actinf_with_agent.ipynb rename to notebooks/simple_gridworld/02_actinf_with_agent.ipynb diff --git a/notebooks/04_actinf_graphs.ipynb b/notebooks/simple_gridworld/04_actinf_graphs.ipynb similarity index 100% rename from notebooks/04_actinf_graphs.ipynb rename to notebooks/simple_gridworld/04_actinf_graphs.ipynb diff --git a/notebooks/05_actinf_multi_agent.ipynb b/notebooks/simple_gridworld/05_actinf_multi_agent.ipynb similarity index 100% rename from notebooks/05_actinf_multi_agent.ipynb rename to notebooks/simple_gridworld/05_actinf_multi_agent.ipynb diff --git a/notebooks/01_actinf_planning.ipynb b/notebooks/simple_gridworld/gridference_single.ipynb similarity index 65% rename from notebooks/01_actinf_planning.ipynb rename to notebooks/simple_gridworld/gridference_single.ipynb index 84b00f3..cbd4f22 100644 --- a/notebooks/01_actinf_planning.ipynb +++ b/notebooks/simple_gridworld/gridference_single.ipynb @@ -157,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -169,13 +169,14 @@ "import seaborn as sns\n", "import sys\n", "\n", - "sys.path.insert(0, '../')\n", + "sys.path.insert(0, '../../')\n", "\n", "from radcad import Model, Simulation, Experiment\n", "\n", "\n", "# Additional dependencies\n", "from pymdp.control import construct_policies\n", + "from pymdp.maths import softmax\n", "\n", "# For analytics\n", "import itertools\n", @@ -184,7 +185,7 @@ "import plotly.express as px\n", "\n", "# local utils\n", - "import blockference.tools.utils as u\n", + "import blockference.utils.utils as u\n", "from blockference.gridference import ActiveGridference" ] }, @@ -197,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -214,7 +215,7 @@ "grid = list(itertools.product(range(10), repeat=2))\n", "print(grid)\n", "act = ActiveGridference(grid)\n", - "act.get_C((9, 3))\n", + "act.get_C((2, 3))\n", "act.get_D((0, 0))\n", "print(act.border)\n", "print(len(grid))" @@ -222,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -240,7 +241,6 @@ } ], "source": [ - "from pymdp.maths import softmax\n", "\n", "print(softmax(act.A))" ] @@ -254,7 +254,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -278,17 +278,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "params = {\n", - " 'prior_A': softmax(act.A),\n", - " 'prior_B': act.B,\n", - " 'prior_C': act.C,\n", - " 'prior': softmax(act.D),\n", - " 'env_state': grid,\n", - " 'actions': act.E\n", "}" ] }, @@ -309,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -325,7 +319,7 @@ " G = u.calculate_G_policies(previous_state['prior_A'], previous_state['prior_B'], previous_state['prior_C'], qs_current, policies=policies)\n", "\n", " # calc action posterior\n", - " Q_pi = u.softmax(-G)\n", + " Q_pi = softmax(-G)\n", "\n", " # compute the probability of each action\n", " P_u = u.compute_prob_actions(act.E, policies, Q_pi)\n", @@ -381,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -409,7 +403,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -437,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -460,20 +454,20 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "simulation = Simulation(\n", " model=model,\n", - " timesteps=100, # Number of timesteps\n", + " timesteps=20, # Number of timesteps\n", " runs=1 # Number of Monte Carlo Runs\n", ")" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -489,7 +483,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -548,9 +542,9 @@ " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [0.0694531596563796, 0.00939946303377394, 0.00...\n", + " [0.07885262269015354, 0.00939946303377394, 0.0...\n", " (0, 0)\n", - " 4\n", + " 2\n", " [0.0694531596563796, 0.00939946303377394, 0.00...\n", " 0\n", " 0\n", @@ -563,10 +557,10 @@ " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " (1, 0)\n", - " 1\n", - " [0.16866478870681606, 0.008397325366597817, 0....\n", + " [0.0, 0.18876736668835936, 0.00827788401338408...\n", + " (0, 1)\n", + " 3\n", + " [0.18876736668835936, 0.008277884013384086, 0....\n", " 0\n", " 0\n", " 1\n", @@ -578,10 +572,10 @@ " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [0.35546098713663643, 0.00651049507942799, 0.0...\n", - " (0, 0)\n", + " [1.5101684899968635e-16, 0.3999520191586417, 0...\n", + " (0, 1)\n", " 0\n", - " [7.753058021694268e-17, 7.753058021694268e-17,...\n", + " [7.550842449984317e-17, 0.3874510195577802, 0....\n", " 0\n", " 0\n", " 1\n", @@ -593,194 +587,401 @@ " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [0.603901424016928, 0.00808364440781786, 0.008...\n", - " (0, 0)\n", + " [1.4877450397086345e-16, 0.6443591692583799, 0...\n", + " (0, 1)\n", + " 4\n", + " [1.4877450397086345e-16, 0.6443591692583799, 0...\n", " 0\n", - " [0.5998596018130191, 0.004041822203908954, 0.0...\n", + " 0\n", + " 1\n", + " 1\n", + " 4\n", + " \n", + " \n", + " 5\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " (1, 1)\n", + " 1\n", + " [1.180597986182839e-16, 0.8312251288310746, 0....\n", " 0\n", " 0\n", " 1\n", " 1\n", + " 5\n", + " \n", + " \n", + " 6\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [4.1181428736319827e-17, 4.1181428736319827e-1...\n", + " (1, 1)\n", " 4\n", + " [4.1181428736319827e-17, 4.1181428736319827e-1...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 6\n", + " \n", + " \n", + " 7\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [0.0, 5.432448290832731e-17, 5.432448290832731...\n", + " (1, 2)\n", + " 3\n", + " [5.432448290832731e-17, 5.432448290832731e-17,...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 7\n", " \n", " \n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", + " 8\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [7.484111208766037e-17, 1.2249870199770002e-16...\n", + " (0, 2)\n", + " 0\n", + " [3.7420556043830184e-17, 5.7749079616062e-17, ...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 8\n", " \n", " \n", - " 10095\n", + " 9\n", " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [8.710941655795045e-17, 8.710941655795045e-17,...\n", - " (6, 4)\n", + " [1.156059852219525e-16, 1.4290141914068662e-16...\n", + " (0, 2)\n", + " 0\n", + " [6.472993992872797e-17, 8.23738046650617e-17, ...\n", " 0\n", - " [3.678794411714443e-17, 3.678794411714443e-17,...\n", " 0\n", - " 99\n", " 1\n", " 1\n", - " 96\n", + " 9\n", " \n", " \n", - " 10096\n", + " 10\n", " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [1.2596545076772043e-16, 1.2713499590296219e-1...\n", - " (5, 4)\n", + " [7.949932295795853e-17, 8.956383259645139e-17,...\n", + " (0, 2)\n", + " 4\n", + " [7.949932295795853e-17, 8.956383259645139e-17,...\n", " 0\n", - " [6.883370760125423e-17, 6.883370760125423e-17,...\n", " 0\n", - " 99\n", " 1\n", " 1\n", - " 97\n", + " 10\n", " \n", " \n", - " 10097\n", + " 11\n", " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [0.0, 8.312804375248272e-17, 8.355829536326019...\n", - " (5, 5)\n", + " [0.0, 6.60899954562223e-17, 6.979565509502117e...\n", + " (0, 3)\n", " 3\n", - " [8.312804375248272e-17, 8.355829536326019e-17,...\n", + " [6.60899954562223e-17, 6.979565509502117e-17, ...\n", + " 0\n", " 0\n", - " 99\n", " 1\n", " 1\n", - " 98\n", + " 11\n", " \n", " \n", - " 10098\n", + " 12\n", " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [0.0, 3.678794411714443e-17, 6.736904239848323...\n", - " (5, 6)\n", - " 3\n", - " [3.678794411714443e-17, 6.736904239848323e-17,...\n", + " [9.791952483089828e-17, 6.248378424280491e-17,...\n", + " (0, 2)\n", + " 2\n", + " [3.679940114359091e-17, 6.112012368730737e-17,...\n", + " 0\n", " 0\n", - " 99\n", " 1\n", " 1\n", - " 99\n", + " 12\n", " \n", " \n", - " 10099\n", + " 13\n", " [[0.02672363098939523, 0.009831074434450556, 0...\n", " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", - " [3.678794411714443e-17, 5.0321472440806026e-17...\n", - " (5, 6)\n", + " [1.3260015997376943e-16, 0.9999333151346724, 2...\n", + " (0, 1)\n", + " 2\n", + " [7.281886706856661e-17, 5.978129290520284e-17,...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 13\n", + " \n", + " \n", + " 14\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " (1, 1)\n", + " 1\n", + " [8.557242399117368e-17, 0.9999754669748677, 7....\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 14\n", + " \n", + " \n", + " 15\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [1.0505785297112441e-16, 0.9999909746644572, 2...\n", + " (0, 1)\n", + " 0\n", + " [3.678851462715861e-17, 3.678851462715861e-17,...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 15\n", + " \n", + " \n", + " 16\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " (1, 1)\n", + " 1\n", + " [7.543699873446601e-17, 0.9999966797456583, 1....\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 16\n", + " \n", + " \n", + " 17\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [3.678802132788118e-17, 3.678802132788118e-17,...\n", + " (1, 1)\n", " 4\n", - " [3.678794411714443e-17, 5.0321472440806026e-17...\n", + " [3.678802132788118e-17, 3.678802132788118e-17,...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 17\n", + " \n", + " \n", + " 18\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [1.1085239629760632e-16, 0.999999550651144, 1....\n", + " (0, 1)\n", + " 0\n", + " [5.032153969868005e-17, 5.032153969868005e-17,...\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 18\n", + " \n", + " \n", + " 19\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [1.3594378840824958e-16, 0.9999998382108979, 7...\n", + " (0, 1)\n", + " 0\n", + " [7.756828375232375e-17, 0.9999998346937435, 5....\n", + " 0\n", + " 0\n", + " 1\n", + " 1\n", + " 19\n", + " \n", + " \n", + " 20\n", + " [[0.02672363098939523, 0.009831074434450556, 0...\n", + " [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0...\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n", + " [1.4506197601108719e-16, 0.9999999417749954, 3...\n", + " (0, 1)\n", + " 0\n", + " [8.67988779044398e-17, 0.9999999404811064, 2.5...\n", + " 0\n", " 0\n", - " 99\n", " 1\n", " 1\n", - " 100\n", + " 20\n", " \n", " \n", "\n", - "

10100 rows × 12 columns

\n", "" ], "text/plain": [ - " prior_A \\\n", - "0 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "1 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "2 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "3 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "4 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "... ... \n", - "10095 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "10096 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "10097 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "10098 [[0.02672363098939523, 0.009831074434450556, 0... \n", - "10099 [[0.02672363098939523, 0.009831074434450556, 0... \n", + " prior_A \\\n", + "0 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "1 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "2 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "3 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "4 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "5 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "6 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "7 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "8 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "9 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "10 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "11 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "12 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "13 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "14 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "15 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "16 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "17 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "18 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "19 [[0.02672363098939523, 0.009831074434450556, 0... \n", + "20 [[0.02672363098939523, 0.009831074434450556, 0... \n", "\n", - " prior_B \\\n", - "0 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "1 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "2 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "3 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "4 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "... ... \n", - "10095 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "10096 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "10097 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "10098 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", - "10099 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + " prior_B \\\n", + "0 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "1 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "2 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "3 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "4 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "5 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "6 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "7 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "8 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "9 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "10 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "11 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "12 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "13 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "14 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "15 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "16 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "17 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "18 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "19 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", + "20 [[[1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0... \n", "\n", - " prior_C \\\n", - "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "... ... \n", - "10095 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "10096 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "10097 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "10098 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", - "10099 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + " prior_C \\\n", + "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "5 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "6 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "7 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "8 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "9 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "11 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "12 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "13 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "14 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "15 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "16 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "17 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "18 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "19 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "20 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", "\n", - " prior env_state action \\\n", - "0 [0.02672363098939523, 0.009831074434450556, 0.... (0, 0) \n", - "1 [0.0694531596563796, 0.00939946303377394, 0.00... (0, 0) 4 \n", - "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... (1, 0) 1 \n", - "3 [0.35546098713663643, 0.00651049507942799, 0.0... (0, 0) 0 \n", - "4 [0.603901424016928, 0.00808364440781786, 0.008... (0, 0) 0 \n", - "... ... ... ... \n", - "10095 [8.710941655795045e-17, 8.710941655795045e-17,... (6, 4) 0 \n", - "10096 [1.2596545076772043e-16, 1.2713499590296219e-1... (5, 4) 0 \n", - "10097 [0.0, 8.312804375248272e-17, 8.355829536326019... (5, 5) 3 \n", - "10098 [0.0, 3.678794411714443e-17, 6.736904239848323... (5, 6) 3 \n", - "10099 [3.678794411714443e-17, 5.0321472440806026e-17... (5, 6) 4 \n", + " prior env_state action \\\n", + "0 [0.02672363098939523, 0.009831074434450556, 0.... (0, 0) \n", + "1 [0.07885262269015354, 0.00939946303377394, 0.0... (0, 0) 2 \n", + "2 [0.0, 0.18876736668835936, 0.00827788401338408... (0, 1) 3 \n", + "3 [1.5101684899968635e-16, 0.3999520191586417, 0... (0, 1) 0 \n", + "4 [1.4877450397086345e-16, 0.6443591692583799, 0... (0, 1) 4 \n", + "5 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... (1, 1) 1 \n", + "6 [4.1181428736319827e-17, 4.1181428736319827e-1... (1, 1) 4 \n", + "7 [0.0, 5.432448290832731e-17, 5.432448290832731... (1, 2) 3 \n", + "8 [7.484111208766037e-17, 1.2249870199770002e-16... (0, 2) 0 \n", + "9 [1.156059852219525e-16, 1.4290141914068662e-16... (0, 2) 0 \n", + "10 [7.949932295795853e-17, 8.956383259645139e-17,... (0, 2) 4 \n", + "11 [0.0, 6.60899954562223e-17, 6.979565509502117e... (0, 3) 3 \n", + "12 [9.791952483089828e-17, 6.248378424280491e-17,... (0, 2) 2 \n", + "13 [1.3260015997376943e-16, 0.9999333151346724, 2... (0, 1) 2 \n", + "14 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... (1, 1) 1 \n", + "15 [1.0505785297112441e-16, 0.9999909746644572, 2... (0, 1) 0 \n", + "16 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... (1, 1) 1 \n", + "17 [3.678802132788118e-17, 3.678802132788118e-17,... (1, 1) 4 \n", + "18 [1.1085239629760632e-16, 0.999999550651144, 1.... (0, 1) 0 \n", + "19 [1.3594378840824958e-16, 0.9999998382108979, 7... (0, 1) 0 \n", + "20 [1.4506197601108719e-16, 0.9999999417749954, 3... (0, 1) 0 \n", "\n", - " current_inference simulation subset \\\n", - "0 0 0 \n", - "1 [0.0694531596563796, 0.00939946303377394, 0.00... 0 0 \n", - "2 [0.16866478870681606, 0.008397325366597817, 0.... 0 0 \n", - "3 [7.753058021694268e-17, 7.753058021694268e-17,... 0 0 \n", - "4 [0.5998596018130191, 0.004041822203908954, 0.0... 0 0 \n", - "... ... ... ... \n", - "10095 [3.678794411714443e-17, 3.678794411714443e-17,... 0 99 \n", - "10096 [6.883370760125423e-17, 6.883370760125423e-17,... 0 99 \n", - "10097 [8.312804375248272e-17, 8.355829536326019e-17,... 0 99 \n", - "10098 [3.678794411714443e-17, 6.736904239848323e-17,... 0 99 \n", - "10099 [3.678794411714443e-17, 5.0321472440806026e-17... 0 99 \n", + " current_inference simulation subset \\\n", + "0 0 0 \n", + "1 [0.0694531596563796, 0.00939946303377394, 0.00... 0 0 \n", + "2 [0.18876736668835936, 0.008277884013384086, 0.... 0 0 \n", + "3 [7.550842449984317e-17, 0.3874510195577802, 0.... 0 0 \n", + "4 [1.4877450397086345e-16, 0.6443591692583799, 0... 0 0 \n", + "5 [1.180597986182839e-16, 0.8312251288310746, 0.... 0 0 \n", + "6 [4.1181428736319827e-17, 4.1181428736319827e-1... 0 0 \n", + "7 [5.432448290832731e-17, 5.432448290832731e-17,... 0 0 \n", + "8 [3.7420556043830184e-17, 5.7749079616062e-17, ... 0 0 \n", + "9 [6.472993992872797e-17, 8.23738046650617e-17, ... 0 0 \n", + "10 [7.949932295795853e-17, 8.956383259645139e-17,... 0 0 \n", + "11 [6.60899954562223e-17, 6.979565509502117e-17, ... 0 0 \n", + "12 [3.679940114359091e-17, 6.112012368730737e-17,... 0 0 \n", + "13 [7.281886706856661e-17, 5.978129290520284e-17,... 0 0 \n", + "14 [8.557242399117368e-17, 0.9999754669748677, 7.... 0 0 \n", + "15 [3.678851462715861e-17, 3.678851462715861e-17,... 0 0 \n", + "16 [7.543699873446601e-17, 0.9999966797456583, 1.... 0 0 \n", + "17 [3.678802132788118e-17, 3.678802132788118e-17,... 0 0 \n", + "18 [5.032153969868005e-17, 5.032153969868005e-17,... 0 0 \n", + "19 [7.756828375232375e-17, 0.9999998346937435, 5.... 0 0 \n", + "20 [8.67988779044398e-17, 0.9999999404811064, 2.5... 0 0 \n", "\n", - " run substep timestep \n", - "0 1 0 0 \n", - "1 1 1 1 \n", - "2 1 1 2 \n", - "3 1 1 3 \n", - "4 1 1 4 \n", - "... ... ... ... \n", - "10095 1 1 96 \n", - "10096 1 1 97 \n", - "10097 1 1 98 \n", - "10098 1 1 99 \n", - "10099 1 1 100 \n", - "\n", - "[10100 rows x 12 columns]" + " run substep timestep \n", + "0 1 0 0 \n", + "1 1 1 1 \n", + "2 1 1 2 \n", + "3 1 1 3 \n", + "4 1 1 4 \n", + "5 1 1 5 \n", + "6 1 1 6 \n", + "7 1 1 7 \n", + "8 1 1 8 \n", + "9 1 1 9 \n", + "10 1 1 10 \n", + "11 1 1 11 \n", + "12 1 1 12 \n", + "13 1 1 13 \n", + "14 1 1 14 \n", + "15 1 1 15 \n", + "16 1 1 16 \n", + "17 1 1 17 \n", + "18 1 1 18 \n", + "19 1 1 19 \n", + "20 1 1 20 " ] }, - "execution_count": 16, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -790,34 +991,14 @@ "df" ] }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAc40lEQVR4nO3cf5xdVXnv8c8zkx8kGSDa1CmSYPASW1O0wkQoUgtT8BqgQm9NFVp+tWLurQZrESqoRUt9edVWa2/BlhawFi9OkVYNkDaIDJaq0CQKmAAhQ8KvkAD5RTL5NZmZp388z+Yc42RmSE6GYeX7fr3Oa9bee+2911p77Wevvc8+Y+6OiIi88jW93AUQEZHGUEAXESmEArqISCEU0EVECqGALiJSiDEv146nTJni06dP36dtbN26lUmTJh0w6dFSDtVZ9VedG1PPvbFkyZJ17v7zAy5095fl09bW5vuqs7PzgEqPlnKoziOXHi3lUJ0bn95bwGLfQ1zVIxcRkUIooIuIFEIBXUSkEAroIiKFUEAXESnEkAHdzG4ws+fMbOkelpuZ/T8z6zKzB83s2MYXU0REhjKcEfo/ArMHWX4aMCM/c4G/3fdiiYjISzVkQHf3/wA2DJLlLOCf8hXJe4HJZnZYowooIiLDYz6M/4duZtOB29z96AGW3QZ81t3/M6e/C3zU3RcPkHcuMYqntbW1raOjY58K393dTUtLywGTHi3lUJ1Vf9W5MfXcG+3t7UvcfdaAC/f0i6P6DzAdWLqHZbcBv1Y3/V1g1lDb1C9FX2Iaap/RUqYRSI+Wcqj+I5ceLeUYiXruDfbzL0VXA9PqpqfmPBERGUGNCOjzgfPzbZdfBV5w9zUN2K6IiLwEQ/63RTP7OnAyMMXMngY+CYwFcPe/AxYApwNdwDbg9/dXYUVEZM+GDOjufs4Qyx34YMNKJCIie0W/FBURKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESmEArqISCEU0EVECqGALiJSCAV0EZFCKKCLiBRCAV1EpBAK6CIihVBAFxEphAK6iEghFNBFRAqhgC4iUggFdBGRQiigi4gUQgFdRKQQCugiIoVQQBcRKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFGFZAN7PZZrbczLrM7PIBlh9hZp1m9mMze9DMTm98UUVEZDBDBnQzawauAU4DZgLnmNnM3bJ9ArjZ3Y8Bzga+3OiCiojI4IYzQj8O6HL3le7eA3QAZ+2Wx4FDMn0o8EzjiigiIsNh7j54BrM5wGx3vyinzwOOd/d5dXkOA+4AXgVMAk519yUDbGsuMBegtbW1raOjY58K393dTUtLywGRPrm9/cV6393ZOSrKNBLp0VIO1V913h/13Bvt7e1L3H3WgAvdfdAPMAe4rm76PODq3fJcAnwk0ycADwFNg223ra3N91VnZ+eBk4baZ7SUaQTSo6Ucqv/IpUdLOUainnsDWOx7iKvDeeSyGphWNz0159V7H3BzXiB+CBwETBnGtkVEpEGGE9AXATPM7EgzG0d86Tl/tzxPAqcAmNkbiYD+fCMLKiIigxsyoLt7LzAPWAg8TLzNsszMrjKzMzPbR4D3m9kDwNeBC/PWQERERsiY4WRy9wXAgt3mXVmXfgg4sbFFExGRl0K/FBURKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESmEArqISCEU0EVECqGALiJSCAV0EZFCKKCLiBRCAV1EpBAK6CIihVBAFxEphAK6iEghFNBFRAqhgC4iUggFdBGRQiigi4gUQgFdRKQQCugiIoVQQBcRKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESnEsAK6mc02s+Vm1mVml+8hz3vM7CEzW2ZmNzW2mCIiMpQxQ2Uws2bgGuAdwNPAIjOb7+4P1eWZAVwBnOjuG83sNfurwCIiMrDhjNCPA7rcfaW79wAdwFm75Xk/cI27bwRw9+caW0wRERmKufvgGczmALPd/aKcPg843t3n1eX5FvAocCLQDHzK3f99gG3NBeYCtLa2tnV0dOxT4bu7u2lpaTkg0ie3t79Y77s7O0dFmUYiPVrKofqrzvujnnujvb19ibvPGnChuw/6AeYA19VNnwdcvVue24BvAmOBI4GngMmDbbetrc33VWdn54GThtpntJRpBNKjpRyq/8ilR0s5RqKeewNY7HuIq8N55LIamFY3PTXn1XsamO/uu9x9FTFanzGsy42IiDTEcAL6ImCGmR1pZuOAs4H5u+X5FnAygJlNAd4ArGxcMUVEZChDBnR37wXmAQuBh4Gb3X2ZmV1lZmdmtoXAejN7COgELnP39fur0CIi8rOGfG0RwN0XAAt2m3dlXdqBS/IjIiIvA/1SVESkEAroIiKFUEAXESmEArqISCEU0EVECqGALiJSCAV0EZFCKKCLiBRCAV1EpBAK6CIihVBAFxEphAK6iEghFNBFRAqhgC4iUggFdBGRQiigi4gUQgFdRKQQCugiIoVQQBcRKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESmEArqISCEU0EVECqGALiJSCAV0EZFCKKCLiBRiWAHdzGab2XIz6zKzywfJ924zczOb1bgiiojIcAwZ0M2sGbgGOA2YCZxjZjMHyHcw8EfAfY0upIiIDG04I/TjgC53X+nuPUAHcNYA+f4c+Bywo4HlExGRYRpOQD8ceKpu+umc9yIzOxaY5u63N7BsIiLyEpi7D57BbA4w290vyunzgOPdfV5ONwF3ARe6++NmdjdwqbsvHmBbc4G5AK2trW0dHR37VPju7m5aWloOiPTJ7e0v1vvuzs5RUaaRSI+Wcqj+qvP+qOfeaG9vX+LuA39P6e6DfoATgIV101cAV9RNHwqsAx7Pzw7gGWDWYNtta2vzfdXZ2XngpKH2GS1lGoH0aCmH6j9y6dFSjpGo594AFvse4upwHrksAmaY2ZFmNg44G5hfd0F4wd2nuPt0d58O3Auc6QOM0EVEZP8ZMqC7ey8wD1gIPAzc7O7LzOwqMztzfxdQRESGZ8xwMrn7AmDBbvOu3EPek/e9WCIi8lLpl6IiIoVQQBcRKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESmEArqISCEU0EVECqGALiJSCAV0EZFCKKCLiBRCAV1EpBAK6CIihVBAFxEphAK6iEghFNBFRAqhgC4iUggFdBGRQiigi4gUQgFdRKQQCugiIoVQQBcRKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFGFZAN7PZZrbczLrM7PIBll9iZg+Z2YNm9l0ze13jiyoiIoMZMqCbWTNwDXAaMBM4x8xm7pbtx8Asd38zcAvw+UYXVEREBjecEfpxQJe7r3T3HqADOKs+g7t3uvu2nLwXmNrYYoqIyFDM3QfPYDYHmO3uF+X0ecDx7j5vD/mvBta6+6cHWDYXmAvQ2tra1tHRsU+F7+7upqWl5YBIn9ze/mK97+7sHBVlGon0aCmH6q8674967o329vYl7j5rwIXuPugHmANcVzd9HnD1HvKeS4zQxw+13ba2Nt9XnZ2dB04aap/RUqYRSI+Wcqj+I5ceLeUYiXruDWCx7yGujhnGBWE1MK1uemrO+ylmdirwceAkd9853KuNiIg0xnCeoS8CZpjZkWY2DjgbmF+fwcyOAa4FznT35xpfTBERGcqQAd3de4F5wELgYeBmd19mZleZ2ZmZ7S+AFuAbZna/mc3fw+ZERGQ/Gc4jF9x9AbBgt3lX1qVPbXC5RETkJdIvRUVECqGALiJSCAV0EZFCKKCLiBRCAV1EpBAK6CIihVBAFxEphAK6iEghFNBFRAqhgC4iUggFdBGRQiigi4gUQgFdRKQQCugiIoVQQBcRKYQCuohIIRTQRUQKoYAuIlIIBXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESmEArqISCEU0EVECqGALiJSCAV0EZFCKKCLiBRCAV1EpBAK6CIihVBAFxEphAK6iEghFNBFRAoxrIBuZrPNbLmZdZnZ5QMsH29m/5zL7zOz6Q0vqYiIDGrIgG5mzcA1wGnATOAcM5u5W7b3ARvd/Sjgr4DPNbqgIiIyuOGM0I8Dutx9pbv3AB3AWbvlOQv4aqZvAU4xM2tcMUVEZChjhpHncOCpuumngeP3lMfde83sBeDngHX1mcxsLjA3J7vNbPneFLrOlLp9HAjpmDZ7ucsx8nV++cuh+qvO+6Oee+N1e1zi7oN+gDnAdXXT5wFX75ZnKTC1bvoxYMpQ297XD7D4QEqPlnKozqq/6tyYejb6M5xHLquBaXXTU3PegHnMbAxwKLB+GNsWEZEGGU5AXwTMMLMjzWwccDYwf7c884ELMj0HuMvzciQiIiNjyGfoHs/E5wELgWbgBndfZmZXEbcP84HrgRvNrAvYQAT9kfD3B1h6tJRDdR659Ggph+rc+HTDmQbSIiJl0C9FRUQKoYAuIlKI4byHPuqY2Wzgr4ln+huJ9zI3AGuA1sw2CdhE1PEW4CpgMfCLxGuVfYBn+ljgtcCTwDbgl3K764FfACYDvcD9wOtz/iG5r13ATuLNnk2AAS25/W2Z3gHcBvxmTm8HxuY2Hwdm5Ho91I7JrtzG2Kraufz53O+YLL8B44DNwASgP+eRy/uAZ4GJuV5f7pcs14S6bfVmmzZnnSzTvbneZuDVuY/m3O8y4s2nSZmvKdfry/JuzTpMyXm7ct/9ud+xue+eujZ5JI9TL/BEbt+z3V+b6YnZ3mOzDY/KbTxOvHF1L9AFzMv2fhR4A/EO8I7MvxNYDhyd9dmRZRuf29qe8ydkWSyXPUYc/yl17TQu07uAg6j1Heq2Oy6Xj8122pFt0ARsyf1Ux4HcrtW1aXWcD6rbjme6CVib7XpUbmNH5qn+HpT5t2d6TJa5qnNvLpuY86rtjiP6XTNx/HtzP1V9erLs3dku1bpVnr7cTlPOt9x31a+qfly18dicX7X35vxbtUm1vZ3AKuAI4FVZDs/PVuBgan2q6m9V2SzXH59/dxDnedUXDwZeyGW9xLGuBsA7gc/n8k8T5/TmrOtzROx4Y+5vV25zTJbh4SxTf27rceD33H0zDfCKG6EP8K8IXgX8YS7+iLvPJH741AOcA7wFmA18kWhMgHZ3fwvwIPDv7v56oiMeB7yLaPyTiC93DwI+SwSHMcBFxMG9NfNvBe4BlhAnw//K+VuABVmWLcSJ8BPiQH4956/N9U/Msr0353uuexJxkXks836W6Cw9wDHA1zK9KOu3hrg4/X+igx4LfAF4TZZ/c5blWOLfNYzPsv5i3brPECdmW+67B2jP8h2U7X470ZF7gI8Qv024P/PdmttoB75CBIHF1C5e7cDFxMXu5CzDgzl/O3FC/Eeu1w3ckMdsM/BlIiA/le2wDLg78zxHBP9/zLb6JeCEbMuHgH/K/TQB1xK/eL4n51dB8Yhst+3EReQh4A7itdwbgBuJk/l/A+dm+0wD/jbbdRrwiWyXh4kTflPOPz3nv97dm7MO03LbjwOXZvm3AVdk2TYDHyX6+5N5DD8GfCPzfLlu3hXUguIL2TYfy+O1HfhE7vebOX81cBNx0Xky1zsK+D4RMP9Hbv/GbJs/AH6PGBwclXXeSgxGPp51uz/33Z3zz8h1Z1K7KMzI4+PALxN9ZRzRJ1YQ/eGDwJ15rNYDl2f5Ls7j0Zzpf85tLM/lnvM/RgywLs423pXp6oL7x5meBMwCVhI/hGzLbU4g+uyfEheMrcSgcFMel78CPgD8EdHvNhOx4a3EwPFoou8dm2351mzvDcRbgGOBXnd/Ux6Py2iQV1xA52f/FcH1Oa/X3X8E4O5biJPxcKLxJgJvA66rNmJmhwK/nuvj7j3uvgn4NeIgrCM6ywaiwwB8mzhQhwB/lsvWESO/XmCbu3fm/K25zgainY8C/jznPZHzX0101mezDPNz/hjgnqzPVOAB4mRZRXSW+7Juf5rbG08E2ftz/q1EJzucOJmNODGdOJkOJ/5dw4+A17h7V6YPJ4L/E5meSnR8J07wCUTA+WK2Hbnsu8RJ5sB/ZXmcGLlMBv4m8zbl/P9DnAj91EaZE/I49WYbbyeC6xnZHpMy/XnigjoB+EviInNG1rkl0yuIu5Htdfs4gzjJnyCC63VZnjPzbz+1u4/qLmIaMWCACG5VnXH3O6ipLkBkOccT/aPeHxLBrienqxFaa7bR9UR7jwOuJi5IUzJ9LXEXupr4Ed+JOX8xMD3Tr87yVm+YvS7nfyinv5R9flbObwX+K/v8EcAmd38C+BWgP9N/Cbwz27Ea/T+WyzYB2zP95tzHh4hjtTHnXwmscPcVwCnEcXmS6GP9xLlzGHH8N2Zdmohz4Pps0z7gB5neQPTTKs+EnF+d19X883O99bm8idpdyebcVxNxgTk9231bpv8z13018EPiYruJuPC8AHRmG27J43ZHbqsr138fcYE4lYgLj+T81iz7b+f+qsHld4B30yj781dL++PDwL9c/SqwtG7edKLjPEiMFh4hrr4n54H7ERHwVxEjuh8TnWISMRK7JdfbQHSAtxAjnh9m/r66/SwlDvi9xIWmmr+ZuAA8Q5zEf5/z+7JsDxPB63O5/63Elfw9RCB7khjF7SSC/pPEKMwzfUhur58YNbyrbv6dWc+n81PdCm/OfS4lOvVmIijcS5yws6mNdqrb138j7iw8y9KcbdFHdPCf5PY/T1xQ+og7iuqWuzPnV49Mtmf6nixjL/APRKDto/aIaBdxwp+S29+Vx/B3Mv0McVFaQ9zVPJrttj7rvDnz9+ZnRZa/utjfR1xUrqwrX0+WeRPRR/pzW6tyumqXR4g7snU5f1Nuu4fa7XuVvy/r3Jv13Zpl2FRX5l05vz+XfS3T/Zl+JNNrsyzP5fwtub+tuY3tOb/a78a6/a7LfBuJILQ92/XpzHtT9t1eYEemq8cSK4ggdgMwL5dV59dTuZ3v5/ydxHnzYNbtO9nWa4g+8utZhqq+TgTN7pzuy/T6LHcV+BdT6webcn5v3f4991c9wllN7bHfT3JZdXyrc3In0feq+f9A3Pl65jmR6Bsbif7aneuszzLvIM7f6pHO8myrXUQfXpz5byXOmar/7QTen+11CbBlJH8p+opiZi3AvwAfdvc3E1fMg6k9K/yBux9L3C5NB37o7scQHexjRJBoBY7Mv13EAZlOHNxqZFVv93c/P5h/fwv4jUxvyL99xMlxek6/i3hMsxa4mRgx9gF/7O7TiFH4p4lO0pLrfDjL8S9Ep4c4WT5MPAJpIx4JHEOMhu7I/AcRj4xOyG29QFzkDiVO3ncTz6i/6O7jsj1OJUa41ajmjdQC/zhi5DmVuBh1EReuJ4F3EIHnsCzLNuIkfxtxAswk7qwuJIJ09WhpPRHEFxEjoL/I/fa7+xLiNrYKvrOI43pR7mstMcI8BOjJ/H1E0DqaGD0252O5RzPfCVmPRVnnG4gL+4Tcx88Rj/KOpRZQvkA8yniWCOzfz7qfSoz6tuSxWJvzZ9e132nEoGQCcWFZkeW6hLiVb8n8PdQC6rmZ7icuQFOIEfziPDbfzHY7KPOvy/z/Su27lW8TjwkmEyP5xcTjqq8SwX22mf2IuGBXfWpsfp7LMp4JfMPMriTOjXdku04AluQPD8cAbyf6wzhi9P/2zPNe4rFNSx7vVuIR1RFZ/k1Z7suI82Yccd6cSNxx3J/HZxdxN7Au63NZlrlq917g54nvT+YSjxS7iIFTdWF/Put4YZZ5a9bvr6k9M/8bYhC3Luu5Otuz+k7oB3msejM9lYgTz+a2m/OYnZLL7ibuPD4KfNbMlhCxqbpr23cv94h7L0boJwAL66avIK6SS7OhFwKX1C3/v0Tg2kCcYNuIUcwvEB3o0sz3dqKTPwBcX7f++cTjiqXAZ4gOuJMIVNOJ0dNyal/CXUiMzpbl+u+kNrJ7mtoI+61EJ1pFbaT/GNF5qi+EqvpUX2rdQ3TmaTn/k7nvDcRjkPdlnT6a+b9PBKHqC77q7qAz97Wqaq+crkaX1e8TLKcvy308TTzrPSHX2whcmXnvyjo0EUHnk8RJ2pPr9uf0pdSC96V1+6hGj9Wx2pptVY2+qnQ1wu+vy9PLT4/wvC5P/br9den67Vaj2K8RweX5LNtT2b6X5vFeTozMvpR1/lZOTwQ+RVx8q+O3gdrz7D8Dvpf1q/rbBiIAVaP1S4k+2U88PnuCONa3E480erP93pn5NxKDjjXEI6/1WZfqrqePuFP7QU7fntvvzeO8mVqfvyOnz8r1f5xlvDDb+G6ij9yR8x4B7sw8HyKOcQ8RyPqJ43tB1nFVbvcO4phvyrJdn+v/Tpap+h6i+q7k49TukP8g23klcQf7vdzuFmpf1vcTfePLefw2Zrra/o5c//ycXx3nDxDH+dFMH5Zt+SgxYLg115uY5f0M0ber/vw4teP8OWp3WB/I/Bvr6vWZ3IcBm3P5G4hHXwfsCH2gf0VwZy67nniUcaOZTc55VxEjw/OJzvg9dz+XaOQmooNBXEUnEf/G4FfNbGL+C+AziEA9lnj+NZ84OBfkepOJ0Q/Eif0nxIixejtlOTHCmU88n+8lRpnPZ7p61j6OGDF2ESfySVmfXcTt2vXEaPHO3N/D1L7k2UZ08C8QX+r8a+ZfRFzw1hDPYZ8hRhwPEBedybmd24hn5s8TJ+MZWab35/YfyfYaS3T+C4iTagLQZWaXEZ3/k8QIfDYxcr0lt3k+cbItz21V3x88QoxWIS6IfcQI6nziC9W1WeeNxLPYFuIWfmmW8yu57GDii8oFxGjrZmClu1dvktyeeR4FnsvtfCLLdGMeg8XZL84j+sVyIkBOJgLRBURQehXwmJm9B/ifxMXzTVnnRXkMNgG/SwS4Fdnmm4hn60vN7Pgszypi1Dc2l5+Y7b2G2oV4DXE3tzWP9xuIfnZ71rUnt/Mpoh/clOs8QwSb1UTwWUP0L8v2X00MACZlu6wgXiLYmPWF+PLwe5l+J/Ho4k+IPvOdnP+2bO/5xEBhI9G/T8vtbM3t3pX7ry7wv2FmE4nvZiD6/ZZsoy7iLrN6m+doYoR7LdFHfoUYPX87l38q22Znrvs8cRfQRYykLct+CHGX8DBxkZhA9FEj7hZuIi5QRpyzpxPH+L3A4WZ2DBEDxufxOZ4Y3KyhdoexMut+k5mdRBznB4i+cG7u47eBFWbWRPTDv6NBXpG/FDWz04EvUXsT4jXELVYzMYrcSlyBn830ze5+lZmdQ3SKVcTJcBe1W+4niBPqSGI08l5q3/w3U3vFawdxElQXw+qLvvr//95ft7yf2ii0h3i84XXrVMurL+Z2Znmc2mtP1WteLxAd5/X89Ktj9apXH5uovVI2iTi5jxhgfvWse3xueyu1NwOaqL0m1p37nEztFbDq1bMx1F5JrMpT7WMbcXK8idrIsf6VO6gdt2biJF5LjCZbcv59xKOpahQ6NdevXqlcRtxWNxGBaQ3xPcdsM+vLOj2Z5Tia6BdNxEXsbOLRwYezvNWXaNWoawe1kXN9nasvgXuyztXjkb6s85rcV/Wst4/aF6fV64jPZ1tMyrpW7WhEcGsmgnfVV6r9VP1kfN32dxFBayIxoq9eudyVdTg4970tPzuJYNdCBHfL9rg22/J1xJ3jc8RFtXr9cUO2R3XHcwgR3C/O4zSeuOjPyPk78tgtzfqeQDza+wpxUa8fVFavblbt0ExtYNRPrd9U3y9sJ/rHWOJCt4Pa66DVK499uf6GTL+W2muPXrfs4Lptj8l9V68K78pyVefOdmKwczHRrw+py/cUMcqfwE9/J1B9Z7Cd2iPOncTg6wpvUCB+RQZ0ERH5Wa/ERy4iIjIABXQRkUIooIuIFEIBXUSkEAroIiKFUEAXESmEArqISCH+G+DZbQFglT8CAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "u.plot_beliefs(df['prior'][9])" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "plot_beliefs(df['prior'][9])" + ] } ], "metadata": { diff --git a/notebooks/modular_single.ipynb b/notebooks/simple_gridworld/modular_single.ipynb similarity index 100% rename from notebooks/modular_single.ipynb rename to notebooks/simple_gridworld/modular_single.ipynb diff --git a/notebooks/multi_agent_exp.ipynb b/notebooks/simple_gridworld/multi_agent_exp.ipynb similarity index 100% rename from notebooks/multi_agent_exp.ipynb rename to notebooks/simple_gridworld/multi_agent_exp.ipynb