# Active Inference cadCAD model

This notebook explores multi-agent active inference simulations by representing agents and their locations in a dictionary format..

## cadCAD Standard Notebook Layout

### 0. Dependencies

In [1]:
import pandas as pd
import numpy as np
from random import normalvariate, random
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

from radcad import Model, Simulation, Experiment

from model import ActiveGridference

# Additional dependencies

# For analytics
import itertools

# local utils
import utils as u
from control import construct_policies
import random as rand
from pymdp.maths import softmax

### Initializing Agent Network

In [2]:
grid = list(itertools.product(range(10), repeat=2))
print(grid[grid.index((0,0))])

(0, 0)


In [3]:
# create a dict of agents
agents = {}
priors = {}
env_states = {}
inferences = {}
actions = {}

# Number of agents
NUMBER_AGENTS = 10

for a in range(NUMBER_AGENTS):
    # create new agent
    agent = ActiveGridference(grid)
    # generate target state
    target = (rand.randint(0,9), rand.randint(0,9))
    # add target state
    agent.get_C(target)
    # all agents start in the same position
    agent.get_D((0, 0))

    agents[a] = agent
    priors[a] = agent.D
    env_states[a] = agent.env_state
    inferences[a] = agent.current_inference
    actions[a] = agent.current_action

print(f"Agents: {agents}")

Agents: {0: <model.ActiveGridference object at 0x124231370>, 1: <model.ActiveGridference object at 0x1242313d0>, 2: <model.ActiveGridference object at 0x124231550>, 3: <model.ActiveGridference object at 0x124231520>, 4: <model.ActiveGridference object at 0x1194b7a60>, 5: <model.ActiveGridference object at 0x1240a2b80>, 6: <model.ActiveGridference object at 0x1240a2d00>, 7: <model.ActiveGridference object at 0x1242315b0>, 8: <model.ActiveGridference object at 0x124231610>, 9: <model.ActiveGridference object at 0x12409f250>}


### 1. State Variables

In [4]:
initial_state = {
    'agents': agents,
    'priors': priors,
    'env_states': env_states,
    'actions': actions,
    'inferences': inferences
}

### 2. System Parameters

In [5]:
params = {
    'preferred_state': grid,
    'initial_state': grid,
    'noise': [0.00001, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
}

### 3. Policy Functions

- `get_observation`
- `infer_states`
- `calc_efe`
- `calc_action_posterior`
- `sample_action`
- `calc_next_prior`
- `update_env_state`

In [6]:
def p_actinf(params, substep, state_history, previous_state):
    # State Variables
    agents = previous_state['agents']

    # list of all updates to the agents in the network
    agent_updates = []

    for source, agent in agents.items():

        policies = construct_policies([agent.n_states], [len(agent.E)], policy_len = agent.policy_len)
        # get obs_idx
        obs_idx = grid.index(agent.env_state)

        # infer_states
        qs_current = u.infer_states(obs_idx, agent.A, agent.prior, params['noise'])

        # calc efe
        _G = u.calculate_G_policies(agent.A, agent.B, agent.C, qs_current, policies=policies)

        # calc action posterior
        Q_pi = u.softmax(-_G, params['noise'])
        # compute the probability of each action
        P_u = u.compute_prob_actions(agent.E, policies, Q_pi)
        
        # sample action
        chosen_action = u.sample(P_u)

        # calc next prior
        prior = agent.B[:,:,chosen_action].dot(qs_current) 

        # update env state
        # action_label = params['actions'][chosen_action]

        (Y, X) = agent.env_state
        Y_new = Y
        X_new = X
        # here

        if chosen_action == 0: # UP
            
            Y_new = Y - 1 if Y > 0 else Y
            X_new = X

        elif chosen_action == 1: # DOWN

            Y_new = Y + 1 if Y < agent.border else Y
            X_new = X

        elif chosen_action == 2: # LEFT
            Y_new = Y
            X_new = X - 1 if X > 0 else X

        elif chosen_action == 3: # RIGHT
            Y_new = Y
            X_new = X +1 if X < agent.border else X

        elif chosen_action == 4: # STAY
            Y_new, X_new = Y, X 
            
        current_state = (Y_new, X_new) # store the new grid location
        agent_update = {'source': source,
                        'update_prior': prior,
                        'update_env': current_state,
                        'update_action': chosen_action,
                        'update_inference': qs_current}
        agent_updates.append(agent_update)

    return {'agent_updates': agent_updates}

### 4. State Update Functions

In [7]:
def s_agents(params, substep, state_history, previous_state, policy_input):

    agents_new = previous_state['agents'].copy()

    agent_updates = policy_input['agent_updates']

    if agent_updates != []:
        for update in agent_updates:
            s = update['source']
            agent = agents_new[s]
            update_prior = update['update_prior']
            update_env = update['update_env']
            update_action = update['update_action']
            update_inference = update['update_inference']

            agent.prior = update_prior
            agent.env_state = update_env
            agent.current_action = update_action
            agent.current_inference = update_inference

    return 'agents', agents_new

def s_priors(params, substep, state_history, previous_state, policy_input):

    priors_new = previous_state['priors'].copy()

    agent_updates = policy_input['agent_updates']

    if agent_updates != []:
        for update in agent_updates:
            s = update['source']
            update_prior = update['update_prior']
            priors_new[s] = update_prior

    return 'priors', priors_new

def s_env_states(params, substep, state_history, previous_state, policy_input):

    env_states_new = previous_state['env_states'].copy()

    agent_updates = policy_input['agent_updates']

    if agent_updates != []:
        for update in agent_updates:
            s = update['source']
            update_env = update['update_env']
            env_states_new[s] = update_env

    return 'env_states', env_states_new

def s_actions(params, substep, state_history, previous_state, policy_input):

    actions_new = previous_state['actions'].copy()

    agent_updates = policy_input['agent_updates']

    if agent_updates != []:
        for update in agent_updates:
            s = update['source']
            update_action = update['update_action']
            actions_new[s] = update_action

    return 'actions', actions_new

def s_inferences(params, substep, state_history, previous_state, policy_input):

    inferences_new = previous_state['inferences'].copy()

    agent_updates = policy_input['agent_updates']

    if agent_updates != []:
        for update in agent_updates:
            s = update['source']
            update_inference = update['update_inference']
            inferences_new[s] = update_inference

    return 'inferences', inferences_new
            

### 5. Partial State Update Blocks

In [8]:
state_update_blocks = [
    {
        'policies': {
            'p_actinf': p_actinf
        },
        'variables': {
            'agents': s_agents,
            'priors': s_priors,
            'env_states': s_env_states,
            'actions': s_actions,
            'inferences': s_inferences
        }
    }
]

### 6. Configuration

In [9]:
model = Model(
    # Model initial state
    initial_state=initial_state,
    # Model Partial State Update Blocks
    state_update_blocks=state_update_blocks,
    # System Parameters
    params=params
)

### 7. Execution

In [18]:
simulation = Simulation(
    model=model,
    timesteps=100,  # Number of timesteps
    runs=2  # Number of Monte Carlo Runs
)

In [19]:
result = simulation.run()

### 8. Analysis

In [14]:
pd.options.plotting.backend = "plotly"

In [15]:
df = pd.DataFrame(result)
df

Unnamed: 0,agents,priors,env_states,actions,inferences,simulation,subset,run,substep,timestep
0,{0: <model.ActiveGridference object at 0x20da3...,"{0: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","{0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 4...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...",0,0,1,0,0
1,{0: <model.ActiveGridference object at 0x20da3...,"{0: [0.0, 1.0, 9.999999999999931e-33, 9.999999...","{0: (0, 1), 1: (0, 0), 2: (1, 0), 3: (0, 0), 4...","{0: 3, 1: 4, 2: 1, 3: 4, 4: 3, 5: 0, 6: 3, 7: ...","{0: [1.0, 9.999999999999931e-33, 9.99999999999...",0,0,1,1,1
2,{0: <model.ActiveGridference object at 0x20da3...,"{0: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","{0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 4...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...",0,1,1,0,0
3,{0: <model.ActiveGridference object at 0x20da3...,"{0: [1.0, 1.9999999999999862e-32, 1.9999999999...","{0: (0, 0), 1: (1, 0), 2: (0, 0), 3: (1, 0), 4...","{0: 0, 1: 1, 2: 0, 3: 1, 4: 4, 5: 3, 6: 4, 7: ...","{0: [1.0, 9.999999999999931e-33, 9.99999999999...",0,1,1,1,1
4,{0: <model.ActiveGridference object at 0x20da4...,"{0: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","{0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 4...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...",0,2,1,0,0
...,...,...,...,...,...,...,...,...,...,...
195,{0: <model.ActiveGridference object at 0x216c9...,"{0: [1.0, 1.9999999999999862e-32, 1.9999999999...","{0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 1), 4...","{0: 0, 1: 4, 2: 0, 3: 3, 4: 2, 5: 3, 6: 3, 7: ...","{0: [1.0, 9.999999999999931e-33, 9.99999999999...",0,97,1,1,1
196,{0: <model.ActiveGridference object at 0x216ca...,"{0: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","{0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 4...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...",0,98,1,0,0
197,{0: <model.ActiveGridference object at 0x216ca...,"{0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","{0: (1, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 4...","{0: 1, 1: 0, 2: 2, 3: 0, 4: 1, 5: 0, 6: 2, 7: ...","{0: [1.0, 9.999999999999931e-33, 9.99999999999...",0,98,1,1,1
198,{0: <model.ActiveGridference object at 0x21018...,"{0: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","{0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 4...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...","{0: '', 1: '', 2: '', 3: '', 4: '', 5: '', 6: ...",0,99,1,0,0
