# Continuous Action Spaces

In this notebook we develop the active inference agent for environments with continuous action spaces. 

We start by modifying the minimal environment to accept continous-valued actions that represent an agents intention to move by `[-2,2]` cells. Then, we modify the components of the minimal agent that currently exploit the discreteness of the action space, namely the belief-update after an action was taken (belief propagation through time) and policy sampling during action selection.

#### Housekeeping (run once per kernel restart)

In [None]:
# change directory to parent
import os
os.chdir('..')
print(os.getcwd())

# Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.markers import CARETUP, CARETDOWN
import torch

# Continuous Action Environment

## From continuous actions to discrete environment steps

The environment transforms a continuous-valued action intent into a probability distribution over discrete actions (steps of size 0, 1 or 2 in either direction) and samples from this distribution to advance the environemnt state in each step. Below, we illustrate the relevant code and illustrate the distribution over discrete actions given a continuous-valued intent. The stochasticity of the environment transition dynamics is controlled by the variance of the clipped Gaussian distribution.

In [None]:
# generate conditional probability table for a fixed action a, where the
# discrete transition probability is proportional to some Gaussian around a.
s_N = 5

def p_a_discrete_given_a(a, a_lims=[-2,2], var=0.75**2):
  # probability distribution of a discrete action (step) given a continuous
  # action intent.
  a = np.clip(a, a_lims[0], a_lims[1])
  a_discrete = np.arange(2*s_N-1) - s_N + 1
  p_a = np.exp(-0.5 * (a_discrete-a)**2 / var)
  p_a[a_discrete > a_lims[1]] = 0
  p_a[a_discrete < a_lims[0]] = 0
  p_a = p_a/p_a.sum()
  return a_discrete, p_a
  
a, p_a = p_a_discrete_given_a(a=2)
plt.bar(x=a, height=p_a)

## Action-specific state transition table

It might be convenient for the agent (we'll see) to have access to an action-dependent conditional state-transition probability table to avoid relying exclusively on sampling transitions. In order to calculate this, we compute the signed distance between all pairs of states and assign the relevant probability of the corresponding discrete action to each state transition. 

Note that the all-pairs signed wrap-around distances only need to be computed once and can then be reused for each new `a`.

In [None]:
# estimate signed distance of destination from source (do this once)
s_N = 8
s = np.arange(s_N)
other, this = np.meshgrid(s, s)
d = other - this
d1 = other - this + s_N
d2 = other - this - s_N
d[np.abs(d) > np.abs(d1)] = d1[np.abs(d) > np.abs(d1)]
d[np.abs(d) > np.abs(d2)] = d2[np.abs(d) > np.abs(d2)]
d_s = d

def p_s1_given_s_a(a, d_s, var):
  """ computes transition probability p(s'| s, a) for specific a

  Returns:
  p[s, s1] of size (s_N, s_N)
  """
  a_d, p_a = p_a_discrete_given_a(a=a, var=var)
  return p_a[d_s - a_d[0]]
  
plt.imshow(p_s1_given_s_a(a=-0.5, d_s=d_s, var=1.5**2), cmap='viridis')

## Full environment specification

In [None]:
# environment
class ContinuousActionEnv(object):
  """ Wrap-around 1D state space with single food source.
  
  The probability of sensing food at locations near the food source decays 
  exponentially with increasing distance.
  
  state (int): 1 of N discrete locations in 1D space.
  observation (bool): food detected yes/ no.
  actions(float): [-2, 2] intention to move left or right.
  """
  def __init__(self, 
               N = 16, # how many discrete locations can the agent reside in
               s_0 = 0, # where does the agent start each episode?
               s_food = 0, # where is the food?
               sigma_move = 0.75, # Gaussian stdev around continuous move
               a_lims = [-2, 2], # maximum step in either direction.
               p_o_max = 0.9, # maximum probability of sensing food
               o_decay = 0.2 # decay rate of observing distant food source
               ):
    
    self.o_decay = o_decay
    self.var_move = sigma_move**2
    self.p_o_max = p_o_max
    self.s_0 = s_0
    self.s_food = s_food
    self.s_N = N
    self.o_N = 2 # {False, True} indicating whether food has been found
    self.a_lims = a_lims
    """
    environment dynamics are governed by two probability distributions
    1. state transition probability p(s'|s, a)
    2. emission/ observation probability p(o|s)
    
    We pre-compute the full conditional emission probability table as before.
    """
    self.p_o_given_s = self.emission_probability() # Matrix A
    """
    With continuous-valued actions, we can nolonger represent (1.) with a 
    single conditional probability table. However, we can generate one table of
    size |S| x |S| for each continuous action value.
    """
    self.d_s = self._signed_state_distances()
    # self.p_s1_given_s_a(a=a) returns matrix p[s, s1] for given a
    
    self.s_t = None # state at current timestep

  def _signed_state_distances(self):
    s = np.arange(self.s_N)
    other, this = np.meshgrid(s, s)
    d = other - this
    d1 = other - this + self.s_N
    d2 = other - this - self.s_N
    d[np.abs(d) > np.abs(d1)] = d1[np.abs(d) > np.abs(d1)]
    d[np.abs(d) > np.abs(d2)] = d2[np.abs(d) > np.abs(d2)]
    return d
  
  def _p_a_discrete_given_a(self, a):
    # probability distribution of a discrete action (step) given a continuous
    # action intent.
    a = np.clip(a, self.a_lims[0], self.a_lims[1])
    a_discrete = np.arange(2*self.s_N-1) - self.s_N + 1
    p_a = np.exp(-0.5 * (a_discrete-a)**2 / self.var_move)
    p_a[a_discrete > self.a_lims[1]] = 0
    p_a[a_discrete < self.a_lims[0]] = 0
    p_a = p_a/p_a.sum()
    return a_discrete, p_a
  
  def p_s1_given_s_a(self, a):
    """ computes transition probability p(s'| s, a) for specific a
    
    Note: this is provided for convenience in the agent; it is not used within
    the environment simulation.

    Returns:
    p[s, s1] of size (s_N, s_N)
    """
    a_d, p_a = self._p_a_discrete_given_a(a=a)
    return p_a[self.d_s - a_d[0]]

  def emission_probability(self):
    """ computes conditional probability table p(o|s). 
    
    Returns:
    p[s, o] of size (s_N, o_N)
    """
    s = np.arange(self.s_N)
    # distance from food source
    d = np.minimum(np.abs(s - self.s_food), 
                   np.abs(s - self.s_N - self.s_food))
    p = np.zeros((self.s_N, self.o_N))
    # exponentially decaying concentration ~ probability of detection
    p[:,1] = self.p_o_max * np.exp(-self.o_decay * d)
    p[:,0] = 1 - p[:,1]
    return p

  def reset(self):
    self.s_t = self.s_0
    return self.sample_o()

  def step(self, a):
    if (self.s_t is None):
      print("Warning: reset environment before first action.")
      self.reset()
      
    a_discrete = self.sample_a(a)
    self.s_t = (self.s_t + a_discrete) % self.s_N
    return self.sample_o()

  def sample_o(self):
    return np.random.random() < self.p_o_given_s[self.s_t,1]
  
  def sample_a(self, a):
    a_d, p_a = self._p_a_discrete_given_a(a=a)
    return np.random.choice(a_d, p=p_a)

## Random Agent Behavior

To test the environment we simulate a random agent's interactions with it. Here, the random agent samples actions uniformly in the interval `[-2, 2]`.

In [None]:
env = ContinuousActionEnv(N=16, # number of states
                    s_food=0, # location of the food source
                    sigma_move=0.75, # Gaussian noise around continuous move
                    a_lims=[-3,3], # maximum number of steps in either direction
                    o_decay=0.4) # decay of observing food away from source 

n_steps = 100
ss, oo, aa = [], [], []

o = env.reset()
ss.append(env.s_t)
oo.append(o)

for i in range(n_steps):
  a = np.random.uniform(low=env.a_lims[0], high=env.a_lims[1]) # random agent
  o = env.step(a)
  ss.append(env.s_t)
  oo.append(o)
  aa.append(a)

We inspect the sequence of states, actions and emissions during this interaction

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(16, 12))
ax[0].plot(ss, label='agent state $s_t$')
ax[0].plot(np.ones_like(ss) * env.s_food, 
           'r--', label='food source', linewidth=1)
for i in range(len(aa)):
  ax[0].plot([i, i], [ss[i], ss[i]+aa[i]], 
             color='orange', 
             linewidth=0.5,
             marker= CARETUP if aa[i] > 0 else CARETDOWN,
             label=None if i > 0 else 'action')
  
ax[0].set_xlabel('timestep t')
ax[0].set_ylabel('$s$')
ax[0].legend()

ax[1].plot(np.array(os))
ax[1].set_xlabel('timestep t')
ax[1].set_ylabel('observation (1=Food)')



# Continuous Action Agent

We start implementing the continuous action agent by updating its belief through time (whenever a new action is taken).

## Updating belief through time

Here, we can make use of the action-specific state transition probability table, requiring only a minimal change to the code of the minimal agent.

```
q1 = torch.matmul(q, torch.tensor(env.p_s1_given_s_a(a=a)))
```

Instead of indexing the full transition dynamics `p[s, a, s1]` with action `a`, we generate the action-specific transition dynamics `p[s, s1]`.

In [None]:
def update_belief_a(env, theta_prev, a, lr=4., n_steps=10, debug=False):
    # prior assumed to be expressed as parameters of the softmax (logits)
    theta = torch.tensor(theta_prev)
    q = torch.nn.Softmax(dim=0)(theta)
    
    # this is the prior for the distribution at time t
    # if we worked on this level, we would be done. 
    # but we need to determine the parameters of Q that produce 
    # this distribution
    q1 = torch.matmul(q, torch.tensor(env.p_s1_given_s_a(a=a)))

    # initialize updated belief to uniform
    theta1 = torch.zeros_like(theta, requires_grad=True)
    loss = torch.nn.CrossEntropyLoss() # expects logits and target distribution.
    optimizer = torch.optim.SGD([theta1], lr=lr)
    if debug:
        ll = np.zeros(n_steps)
        
    for i in range(n_steps):
        l = loss(theta1, q1)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        if debug:
            ll[i] = l.detach().numpy()
            
    theta1 = theta1.detach().numpy()
    if debug:
        return theta1, ll
        
    return theta1

Let's see the effect of this belief update in action. Say we believed strongly that the environment was in states 1 or 4 with equal probability at time $t$ and we took action 2 (move right by two). Recall that according to the environment dynamics there is some probability of the state remaining unchanged, some probability of moving by 1 and some probability of moving by 2.

In [None]:
env = ContinuousActionEnv(N=8, # number of states
                          s_food=0) # location of the food source

theta = np.eye(env.s_N)[1] * 2 + np.eye(env.s_N)[4] * 2
theta1, ll = update_belief_a(env, theta, a=2, 
                             lr=4.0, n_steps=20, debug=True)

def softmax(x):
  e = np.exp(x - x.max())
  return e / e.sum()

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
plt.sca(ax[0])
plt.plot(ll)
plt.plot([0, ll.shape[0]-1], [ll.min()]*2, 'k--')
plt.xlabel('optimization step')
plt.ylabel('loss')

plt.sca(ax[1])
plt.bar(np.arange(env.s_N)-0.2, width=0.4, height=softmax(theta), alpha=0.5, label='before') # belief before update
plt.bar(np.arange(env.s_N)+0.2, width=0.4, height=softmax(theta1), alpha=0.5, label='after') # belief before update
plt.xlabel('env state')
plt.ylabel('belief')
plt.title('Propagating prior beliefs through environment dynamics.')
plt.legend()

## Policy sampling

The minimal agent explored policies by enumerating and evaluating all possible combinations of discrete actions up to a finite time horizon via rollout simulations. With continuous-valued actions, we can nolonger evaluate all possible finite-horizon strategies and therefore need to resort to sampling.

Before we start we copy some code unchanged from the previous notebook, because action selection requires belief updating in light of new observations and rollouts.

In [None]:
def update_belief(env, theta_prev, o, lr=4., n_steps=10, debug=False):
    theta = torch.tensor(theta_prev)
    
    # make p(s) from b
    q = torch.nn.Softmax(dim=0)
    p = torch.tensor(env.p_o_given_s[:,o]) * q(theta) # p(o|s)p(s)
    log_p = torch.log(p)
    
    # initialize updated belief with current belief
    theta1 = torch.tensor(theta_prev, requires_grad=True)
    
    # estimate loss
    def forward():
        q1 = q(theta1)
        # free energy: KL[ q(s) || p(s, o) ]
        fe = torch.sum(q1 * (torch.log(q1) - log_p))
        return fe
    
    optimizer = torch.optim.SGD([theta1], lr=lr)
    ll = np.zeros(n_steps)
    for i in range(n_steps):
        l = forward()
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        if debug:
            ll[i] = l.detach().numpy()
            
    theta1 = theta1.detach().numpy()
    if debug:
        return theta1, ll
        
    return theta1
  
def kl(a, b):
    """ Discrete KL-divergence """
    return (a * (np.log(a) - np.log(b))).sum()

def rollout_step(env, log_p_c, theta, pi, 
                 use_info_gain, use_pragmatic_value):
    
    if pi == []:
        return []

    a, pi_rest = pi[0], pi[1:]

    # Where will I be after taking action a?
    theta1 = update_belief_a(env, theta, a=a, lr=1.) 
    q = softmax(theta1)

    # Do I like being there?
    pragmatic = np.dot(q, log_p_c)

    # What might I observe after taking action a? (marginalize p(o, s) over s)
    p_o = np.dot(q, env.p_o_given_s)

    # Do I learn about s from by observing o?
    # enumerate/ sample observations, update belief and estimate info gain
    q_o = [softmax(update_belief(env, theta1, o=i)) for i in range(p_o.shape[0])]
    d_o = [kl(q_o_i, q) for q_o_i in q_o] # info gain for each observation
    info_gain = np.dot(p_o, d_o) # expected value of info gain

    # negative expected free energy for this timestep
    nefe = use_pragmatic_value * pragmatic + use_info_gain * info_gain

    # nefe for remainder of policy rollout
    nefe_rest = rollout_step(env, log_p_c, theta1, pi_rest, 
                        use_info_gain=use_info_gain, 
                        use_pragmatic_value=use_pragmatic_value)

    # concatenate expected free energy across future time steps
    return [nefe] + nefe_rest

We make the following modifications to the code. First, we replace enumeration of all possible plans (policies) with sampling, where we fetch the limits of continuous-valued actions from the environment and sample uniformly within this range; we take independent samples for each plan and each timestep up to the planning horizon k.

```
n_plans = 32
plans = np.random.uniform(low=env.a_lims[0], high=env.a_lims[1], size=(n_plans, k))
```

For debugging purposes, we can nolonger compute the marginal probability of the first action. Instead, we would need to estimate its density from samples, e.g., by sampling with replacement from the evaluated policies with probability proportional to selecting each policy.

In [None]:
def select_action(env, theta_star, theta_start, 
                  n_plans=32, # number of plans to evaluate
                  k=4, # planning horizon (number of sequential actions per plan)
                  use_info_gain=True, 
                  use_pragmatic_value=True,
                  select_max_pi=False, # replace sampling with best action selection
                  debug=False, # return plans, p of selecting each
                 ):
    log_p_c = np.log(softmax(theta_star))

    # sampling
    plans = np.random.uniform(low=env.a_lims[0], high=env.a_lims[1], size=(n_plans, k)).tolist()

    # evaluate negative expected free energy of all plans
    nefes = []
    for pi in plans:
        step_nefes = rollout_step(env, log_p_c, theta_start, pi, 
                                  use_info_gain=use_info_gain, 
                                  use_pragmatic_value=use_pragmatic_value)
        nefe = np.array(step_nefes).mean() # expected value over steps
        nefes.append(nefe)
        
    # compute probability of following each plan
    p_pi = softmax(np.array(nefes)).tolist()  

    if select_max_pi:
        a = plans[np.argmax(nefes)]
    else:
        a = plans[np.random.choice(len(plans), p=p_pi)]
    
    if debug:
        return a, plans, p_pi
    
    return a

Let's explore action selection from plans with horizon $k$ by specifying sharp priors on the starting state and target state $k$ steps apart.

If the starting state is to the right of the target (recall the state space wraps around), then policies that take a sequence of left actions ($a<0$)) are scored higher. Note that this holds true irrespective of the food source location. 

If the starting state is to the left of the target (e.g., $s_0=11$), then policies that take a sequence of right actions ($a>0$) are scored higher.

In [None]:
starting_state = 11
target_state = 14

import continuous_action_environment as cae
env = cae.ContinuousActionEnv(N=16, # number of states
                          s_food=8) # location of the food source

# initialize belief
theta_start = np.eye(env.s_N)[starting_state] * 10 # believe we are in state 1

# initialize preference
theta_star_shaped = 10 * np.ones(env.s_N) - np.abs(env.d_s[target_state])
theta_star = np.eye(env.s_N)[target_state] * 10
theta_star = theta_star_shaped

a, plans, p_pi = select_action(env, theta_star, theta_start, k=2, n_plans=128, debug=True)

# and explore what the agent prefers
fig, ax = plt.subplots(2, 1, figsize=(12, 12))
plt.sca(ax[0])
plt.bar(x = range(len(plans)), height=p_pi)
plt.xlabel('plan id')
plt.ylabel('$p(\pi)$')

#print('sum of actions, plans and associated probability of selecting them.')
#for p, pi in zip(p_pi, plans):
#    print(np.sum(pi), pi, p)

# estimate marginal probability of selecting a plan with first action 0 or 1
print('marginal probability of next action')
a_sample = np.random.choice([pi[0] for pi in plans], p=p_pi, size=10000, replace=True)
plt.sca(ax[1])
plt.title('marginal probability of next action across all evaluated policies')
# visualise empirical cumulative density function
plt.hist(a_sample, density=True, bins=20);
#plt.plot(np.linspace(0, 1, a_sample.shape[0]), np.sort(a_sample));
#plt.plot([0,1],[-2,2])

print('best plan')
pi_max = np.argmax(p_pi)
print(plans[pi_max])

## Preference distribution shaping

Note that this sampling approach appears to only work well for very short time horizons ($k=2$) and returns a flat action posterior when it is difficult to find plans with which the target state can be reached within the planning horizon's number of steps. This could be mittigated by shaping the preference distribution or employing more clever rollouts of promising action sequences. The former can be done, for example, by also specifying a preference for states that are close to the target state, for which we can make use of the state distance computation we developed near the start of this notebook.

Running the cell above with `14: theta_star = theta_star_shaped` helps find useful policies even when the target state cannot be reached from the start state within the finite time horizon.

The cell below illustrates one way to shape the state preference distribution.

In [None]:
target_state = 14

theta_star = np.eye(env.s_N)[target_state] * 10
theta_star_shaped = np.ones_like(theta_star) - np.abs(env.d_s[target_state])
plt.plot(softmax(theta_star))
plt.plot(softmax(theta_star_shaped))

## Putting it all together

Now we have all components required to implement an Active Infererence agent for environments with discrete state spaces, discrete observation spaces and _continuous_ action spaces. The changes to the minimal agent that interacts discrete action space environments turned out to be few and small.

1. Updating belief through time required an interface change from accessing the environments full conditional probability table $p(s,a,s_1)$ to an action-specific table $p_a(s,s_1)$.

2. Action selection can nolonger be performed by enumerating all possible finite horizon action sequences. Instead, we sample finite length sequences from the action space.

Note that both of these changes could be ported back into the minimal agent to derive an interface that works for both discrete and continuous action spaces. But sampling discrete action sequences is far less efficient than enumerating all combinations.

Let's encapsulate it into a class that manages the target state and current belief state over time and provides a minimal interface with reset and step methods.


In [None]:
class ContinuousActionAgent:
    
    def __init__(self, 
                 env,
                 target_state, 
                 shape_target=False, # smooth preference distribution using poirwise state distances
                 n_plans=128, # number of plans rolled out during action selection
                 k=2, # planning horizon
                 use_info_gain=True, # score actions by info gain
                 use_pragmatic_value=True, # score actions by pragmatic value
                 select_max_pi=False, # sample plan (False), select max negEFE (True).
                 n_steps_o=20, # optimization steps after new observation
                 n_steps_a=20, # optimization steps after new action
                 lr_o=4., # learning rate of optimization after new observation
                 lr_a=4.): # learning rate of optimization after new action)
        
        self.env = env
        self.target_state = target_state
        self.shape_target = shape_target
        self.n_plans = n_plans
        self.k = k
        self.use_info_gain = use_info_gain
        self.use_pragmatic_value = use_pragmatic_value
        self.select_max_pi = select_max_pi
        self.n_steps_o = n_steps_o
        self.n_steps_a = n_steps_a
        self.lr_a = lr_a
        self.lr_o = lr_o
        
    def reset(self):
        # initialize state preference
        if self.shape_target:
            self.b_star = np.ones(shape=self.env.s_N) - \
                          np.abs(self.env.d_s[self.target_state])
        else:
            self.b_star = np.eye(self.env.s_N)[self.target_state] * 10
        self.log_p_c = np.log(softmax(self.b_star))
        # initialize state prior as uniform
        self.b = np.zeros(self.env.s_N)
        
    def step(self, o, debug=False):
        if debug:
            return self._step_debug(o)
        
        self.b = self._update_belief(theta_prev=self.b, o=int(o))
        a = select_action(theta_start=self.b)[0] # pop first action of selected plan
        self.b = self._update_belief_a(theta_prev=self.b, a=a)
        return a
    
    def _step_debug(self, o):
        self.b, ll_o = self._update_belief(theta_prev=self.b, 
                                           o=int(o), debug=True)
        a, plans, p_pi = self._select_action(theta_start=self.b, debug=True)
        max_a = plans[np.argmax(p_pi)][0]
        a = a[0]
        self.b, ll_a = self._update_belief_a(theta_prev=self.b, a=a, debug=True)
        return a, ll_o, ll_a, max_a
    
    def _update_belief_a(self, theta_prev, a, debug=False):
        # prior assumed to be expressed as parameters of the softmax (logits)
        theta = torch.tensor(theta_prev)
        q = torch.nn.Softmax(dim=0)(theta)

        # this is the prior for the distribution at time t
        q1 = torch.matmul(q, torch.tensor(self.env.p_s1_given_s_a(a=a)))

        # initialize parameters of updated belief to uniform
        theta1 = torch.zeros_like(theta, requires_grad=True)
        loss = torch.nn.CrossEntropyLoss() # expects logits and target distribution.
        optimizer = torch.optim.SGD([theta1], lr=self.lr_a)
        if debug:
            ll = np.zeros(self.n_steps_a)

        for i in range(self.n_steps_a):
            l = loss(theta1, q1)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            if debug:
                ll[i] = l.detach().numpy()

        theta1 = theta1.detach().numpy()
        if debug:
            return theta1, ll

        return theta1
    
    def _update_belief(self, theta_prev, o, debug=False):
        theta = torch.tensor(theta_prev)

        # make p(s) from b
        q = torch.nn.Softmax(dim=0)
        p = torch.tensor(self.env.p_o_given_s[:,o]) * q(theta) # p(o|s)p(s)
        log_p = torch.log(p)

        # initialize updated belief with current belief
        theta1 = torch.tensor(theta_prev, requires_grad=True)

        # estimate loss
        def forward():
            q1 = q(theta1)
            # free energy: KL[ q(s) || p(s, o) ]
            fe = torch.sum(q1 * (torch.log(q1) - log_p))
            return fe

        optimizer = torch.optim.SGD([theta1], lr=self.lr_o)
        ll = np.zeros(self.n_steps_o)
        for i in range(self.n_steps_o):
            l = forward()
            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            if debug:
                ll[i] = l.detach().numpy()

        theta1 = theta1.detach().numpy()
        if debug:
            return theta1, ll

        return theta1

    def _select_action(self, theta_start, debug=False): # return plans, p of selecting each, and marginal p of actions
        # sampling
        a_lims = self.env.a_lims
        plans = np.random.uniform(low=a_lims[0], high=a_lims[1], size=(self.n_plans, self.k)).tolist()
        
        # evaluate negative expected free energy of all plans
        nefes = []
        for pi in plans:
            step_nefes = self._rollout_step(theta_start, pi)
            nefe = np.array(step_nefes).mean() # expected value over steps
            nefes.append(nefe)

        # compute probability of following each plan
        p_pi = softmax(np.array(nefes)).tolist()
        if self.select_max_pi:
            a = plans[np.argmax(nefes)]
        else:
            a = plans[np.random.choice(len(plans), p=p_pi)]

        if debug:

            return a, plans, p_pi

        return a

    def _rollout_step(self, theta, pi):
        if pi == []:
            return []

        a, pi_rest = pi[0], pi[1:]
        # Where will I be after taking action a?
        theta1 = self._update_belief_a(theta, a=a) 
        q = softmax(theta1)
        # Do I like being there?
        pragmatic = np.dot(q, self.log_p_c)
        # What might I observe after taking action a? (marginalize p(o, s) over s)
        p_o = np.dot(q, self.env.p_o_given_s)
        # Do I learn about s from by observing o?
        # enumerate/ sample observations, update belief and estimate info gain
        q_o = [softmax(self._update_belief(theta1, o=i)) for i in range(p_o.shape[0])]
        d_o = [kl(q_o_i, q) for q_o_i in q_o] # info gain for each observation
        info_gain = np.dot(p_o, d_o) # expected value of info gain
        # negative expected free energy for this timestep
        nefe = self.use_pragmatic_value * pragmatic + \
               self.use_info_gain * info_gain
        # nefe for remainder of policy rollout
        nefe_rest = self._rollout_step(theta1, pi_rest)
        # concatenate expected free energy across future time steps
        return [nefe] + nefe_rest

The code below iterates over all steps involved in the interaction between the environment and the active inference agent. In each interaction step, the agent updates its belief about the current state given a new observation and selects an action to minimise expected free energy. It then updates its belief assuming the selected action was taken and starts anew by updating its belief based on the next observation.

In [None]:
import importlib
import continuous_action_environment as cae
import continuous_action_agent as caa
importlib.reload(cae)
importlib.reload(caa)

target_state = 4
k = 4 # planning horizon; run time increases exponentially with planning horizon

# runtime increases linearly with optimization steps during belief update
n_steps_o = 10 # optimization steps updating belief after observation
n_steps_a = 10 # optimization steps updating belief after action
lr_o = 4. # learning rate updating belief after observation
lr_a = 4. # learning rate updating belief after action

render_losses = True

env = cae.ContinuousActionEnv(N=16, # number of states
                              s_food=0, # location of the food source
                              s_0=0, 
                              o_decay=0.6) # starting location 

agent = caa.ContinuousActionAgent(env=env, 
                             target_state=target_state, 
                             shape_target=True,
                             n_plans=128,
                             k=k, 
                             use_info_gain=True,
                             use_pragmatic_value=True,
                             select_max_pi=True,
                             n_steps_o=n_steps_o, 
                             n_steps_a=n_steps_a, 
                             lr_a=lr_a, 
                             lr_o=lr_o)

o = env.reset() # set state to starting state
agent.reset() # initialize belief state and target state distribution

ss = [env.s_t]
bb = [agent.b]
aa = []
if render_losses:
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].set_title('updates from actions')
    ax[0].set_ylabel('loss')
    ax[0].set_xlabel('optimization step')
    ax[1].set_title('updates from observations')
    ax[1].set_ylabel('loss')
    ax[1].set_xlabel('optimization step')
    
for i in range(64):
    a, ll_o, ll_a, max_a = agent.step(o, debug=True)
    print(f"step {i}, s: {env.s_t}, o: {['FOOD', 'NONE'][int(o)]}, top a: {max_a}, a: {a}")
    if render_losses:
        ax[0].plot(ll_a)
        ax[1].plot(ll_o)
    
    o = env.step(a)
    
    ss.append(env.s_t)
    bb.append(agent.b)
    aa.append(a)


from matplotlib.markers import CARETUP, CARETDOWN
aa = np.array(aa)
ss = np.array(ss)

fig, ax = plt.subplots(figsize=(16, 6))
plt.imshow(np.array(bb).T, label='belief')

for i in range(len(aa)):
  plt.plot([i, i], [ss[i], ss[i]+aa[i]], 
             color='orange', 
             linewidth=0.5,
             marker= CARETUP if aa[i] < 0 else CARETDOWN,
             label=None if i > 0 else 'action')


plt.plot(ss, label='state')
plt.plot([0, len(ss)-1], [target_state]*2, label='target')
plt.plot([0, len(ss)-1], [env.s_food]*2, 'w--', label='food')
plt.legend()