#### Housekeeping (run once per kernel restart)

In [None]:
# change directory to parent
import os
os.chdir('..')
print(os.getcwd())

# Imports

In [None]:
import importlib
import numpy as np
np.random.seed(1)

import matplotlib.pyplot as plt

import minimal_environment as me
importlib.reload(me)

# Environments

Environments are defined as discrete time Partially Observable Markov Decision Processes ([POMDP](https://en.wikipedia.org/wiki/Partially_observable_Markov_decision_process)s) witout reward function:

- a set of states $\mathcal{S}$
- a set of observations $\mathcal{\Omega}$
- a set of conditional transition probabilities $\mathcal{\tau}: p(s'|s, a)$
- a set of emission/ observation probabilities $\mathcal{O}: p(o|s)$

## MinimalEnv environment

The `MinimalEnv` environment has a discrete set of states and generates discrete outputs: either food `True` or no food `False`.

### Emission probability 
The emission probability distributions $p(o|s)$ are defined as a conditional probability table `p[s,o]`, where each row $i$ defines the probability of observing `False` (column 0) or `True` (column 1) in state $s_i$.

In [None]:
env = me.MinimalEnv(N=8, # number of states
                    s_food=0) # location of the food source

p_o_given_s = env.p_o_given_s #precomputed result of env.emission_probability()
print(p_o_given_s)
print('shape', p_o_given_s.shape)

fig, ax = plt.subplots(figsize=(8, 3))
o_food = 1
ax.bar(range(env.p_o_given_s[:,1].shape[0]), env.p_o_given_s[:,o_food])
ax.bar([env.s_food], env.p_o_given_s[env.s_food,o_food], color='red', label='food source')
ax.set_xlabel('state')
ax.set_ylabel('$p(o=True|s)$')
ax.legend()

### Transition dynamics
The transition dynamics $p(s'|s, a)$ are defined as a conditional probability table `p[s,a,s']`. The subarray `p[0,0,:]` defines the probability of transitioning from state $0$ to any successor state, given that action `0` (move left) was taken.

In [None]:
p_s1_given_s_a = env.p_s1_given_s_a # precomputed result of env.transition_dynamics()
print('shape', p_s1_given_s_a.shape)
p_s1_given_s_a[0,0,:] # left: 0, right: 1

### Random Agent Behavior

Unlike a POMDP, the environment itself does not define a goal, motivation, or purpose for an agent that interacts with it. The envirment is indifferent about how agents interact with it. There is therefore no value (good, bad, high or low performance) associated with any individual sequence of behavior. 

The code below simulates the interaction between the minimal environment and an agent that behaves randomly.

In [None]:
n_steps = 100
ss, os = [], []

o = env.reset()
ss.append(env.s_t)
os.append(o)

for i in range(n_steps):
  a = np.random.choice([0,1]) # random agent
  o = env.step(a)
  ss.append(env.s_t)
  os.append(o)

We inspect the sequence of states and emissions during this interaction.

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(16, 12))
ax[0].plot(ss, label='agent state $s_t$')
ax[0].plot(np.ones_like(ss) * env.s_food, 
           'r--', label='food source', linewidth=1)
ax[0].set_xlabel('timestep t')
ax[0].set_ylabel('$s$')
ax[0].legend()
ax[1].plot(np.array(os))
ax[1].set_xlabel('timestep t')
ax[1].set_ylabel('observation (1=Food)')

# Active Inference

## Belief Update

An active inference agent holds a belief about the state of the environment at time $t$, modelled as a probability distribution $Q(s; \theta_t)$ over states $s \in \mathcal{S}$. There is uncertainty associated with this belief as the state cannot be observed directly and must instead be inferred from observations.

### Update through Time

At any time $t$ an agent holds a prior belief about $s_t$ before taking in an observation from its environment. This belief could be uniform, for example at the start of an interaction. It could also be informed by propagating the belief $Q(s; \theta_{t-1})$ about state $s_{t-1}$ through its model of the environment transition dynamics, taking into account the action $a_{t-1}$ it took at the previous time step.

$$Q(s; \theta_{t}) = \mathbb{E}_{s\sim Q_{t-1}}[p(s_t|s, a_{t-1})]$$

Note that updating the belief in light of the chosen action involves fitting the parameters $\theta_t$. For our minimal agent, we represent $Q(s;\theta)$ as a softmax with logits $\theta$.



In [None]:
def update_belief_a(env, theta_prev, a, lr=1.):
    # prior assumed to be parameters of the softmax (logits)
    theta = torch.tensor(theta_prev)
    q = torch.nn.Softmax(dim=0)(theta)
    
    # this is the prior for the distribution at time t
    # if we worked on this level, we would be done. 
    # but we need to determine the parameters of Q that produce 
    # this distribution
    q1 = torch.matmul(q, torch.tensor(env.p_s1_given_s_a[:,a,:]))

    # initialize updated belief to uniform
    theta1 = torch.zeros_like(theta, requires_grad=True)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([theta1], lr=lr)
    
    for i in range(50):
        l = loss(theta1, q1)
        # backpropagation
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
    #print(l.item())
        
    return theta1.detach().numpy()

Let's see the effect of this belief update in action. Say we believed strongly that the environment was in states 1 or 4 with equal probability at time $t$ and we took action 1 (move right). Recall that according to the environment dynamics there is some probability of the state remaining unchanged.

In [None]:
theta = np.eye(env.s_N)[1] * 2 + np.eye(env.s_N)[4] * 2

env = me.MinimalEnv(N=8,s_food=0)
theta_ = update_belief_a(env, theta, a=1, lr=1.0)

plt.bar(np.arange(env.s_N)-0.2, width=0.4, height=softmax(theta), alpha=0.5, label='before') # belief before update
plt.bar(np.arange(env.s_N)+0.2, width=0.4, height=softmax(theta_), alpha=0.5, label='after') # belief before update
plt.title('Propagating prior beliefs through environment dynamics.')
plt.legend()

### Update based on new observation

Now that we have updated our prior taking into account the action taken in the previous time step and the (agent's model of the) environment dynamics, we turn our attention to updating beliefs in light of a new observation.

In active inference, this belief update is cast as minimizing the variational free energy, i.e. minimizing the KL-divergence between $Q(s;\theta')$ and $p(o, s) = Q(s; \theta) p(o|s)$ with respect to $\theta'$.

$$D_{KL}(\quad Q_{\theta'}(s), p(o|s)Q_{\theta}(s) \quad) \quad = \quad \mathbb{E}_{s \sim Q_{\theta'}}[\quad \log Q_{\theta'}(s) - \log p(o|s)Q_{\theta}(s) \quad]$$

In [None]:
def update_belief(env, theta_prev, o, lr=1.):
    theta = torch.tensor(theta_prev)
    
    # make p(s) from b
    q = torch.nn.Softmax(dim=0)
    p = torch.tensor(env.p_o_given_s[:,o]) * q(theta) # p(o|s)p(s)
    log_p = torch.log(p)
    
    # initialize updated belief with current belief
    theta1 = torch.tensor(theta_prev, requires_grad=True)
    
    # estimate loss
    def forward():
        q1 = q(theta1)
        # free energy: KL[ q(s) || p(s, o) ]
        fe = torch.sum(q1 * (torch.log(q1) - log_p))
        return fe
    
    optimizer = torch.optim.SGD([theta1], lr=lr)
    for i in range(100):
        loss = forward()

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    #print(loss.item())
        
    return theta1.detach().numpy()



Let's see the effect of this in action. We start with a uniform prior belief. Recall that the food source is in state 0 and that the probability of observing food decreases exponentially with the distance of a state from the food source, with the state space wrapping around.

If we observed no food ($o=0$), then it is most likely that we are in the state furthest away from the food source. If we observed food ($o=1$), then it is most likely that we are at the food source.

In [None]:
o = 0

env = me.MinimalEnv(N=8,s_food=0)
theta = np.zeros(env.s_N)
theta_ = update_belief(env, theta, o=o, lr=1.)

plt.bar(np.arange(env.s_N)-0.2, width=0.4, height=softmax(theta), alpha=0.5, label='before') # belief before update
plt.bar(np.arange(env.s_N)+0.2, width=0.4, height=softmax(theta_), alpha=0.5, label='after') # belief before update
plt.title('Updating beliefs in light of a new observation.')
plt.legend()

### Accumulating observations over time

Note that, based on a single observation, we cannot refine our belief to become most confident in a state that is not the food source itself or furthest away from the food source. This can be improved if we accumulate observations from a single state over time. Because the probability of observing food decreases symmetrically to the left and right of the food source, the agent is likely to hold strong beliefs about all states that share the same rate of observing food, but cannot disambiguate these states without action.

In [None]:
s = 2

env = me.MinimalEnv(N=8, # number of states
                    s_food=0, # location of the food source
                    s_0=s) # starting location

b = np.zeros(env.s_N) # initialize state prior as uniform
o = env.reset() # set state to starting state

# refine belief by sampling 10 observations
for i in range(50):
    b = update_belief(env, theta_prev=b, o=int(o))
    o = env.sample_o()

q = softmax(b)
plt.bar(range(env.s_N), q)

## Action Selection

In active inference, sequences of actions (plans, policies) are scored by the negative expected free energy, and selected by exponentiating and normalizing, i.e. sampling from the softmax over plans. Plans $\pi: a_0, a_1, ..., a_K$ define sequences of actions up to a finite horizon $K$.

The expected free energy can be decomposed in various ways and here we chose one that we find most intuitive. The pragmatic term assesses the probability of observing or arriving in states that the agent desires following $\pi$. In this context, $Q_\theta$ is estimated by propagating beliefs through the environment model.

$\mathbb{E}_{s \sim Q_{\theta}}[\quad \log p_c(y|s) \quad]$



In [None]:
env = me.MinimalEnv(N=16, # number of states
                    s_food=0, # location of the food source
                    s_0=s) # starting location

# initialize belief
theta_start = np.eye(env.s_N)[1] * 5

# initialize preference
target_state = 14
theta_star = np.eye(env.s_N)[target_state] * 5
log_p_c = np.log(softmax(theta_star))

# create plans
n_plans = 32 # number of plans to evaluate
k = 4 # planning horizon (number of sequential actions per plan)
n_actions = 2 # possible actions (assumed to be discrete and indexed 0)
plans = np.random.choice(n_actions, size=(n_plans, k), replace=True).tolist()

def kl(a, b):
    return (a * (np.log(a) - np.log(b))).sum()

def rollout(theta, pi, use_info_gain=True, use_pragmatic_value=True):
    if pi == []:
        return pi
    
    a, rest = pi[0], pi[1:]

    # Do I want to get to where I will be?
    theta = update_belief_a(env, theta, a=a, lr=1.) 
    q = softmax(theta)
    pragmatic = np.dot(q, log_p_c)
    
    # Do I learn from observing what happens after doing what I plan?
    # What might I observe if I took action a? marginalize p(o, s) over s
    p_o = np.dot(q, env.p_o_given_s)
    # enumerate/ sample observations, update belief and estimate info gain
    q_o = [softmax(update_belief(env, theta, o=i)) for i in range(p_o.shape[0])]
    d_o = [kl(a, q) for a in q_o]
    info_gain = np.dot(p_o, d_o)
    
    # negative expected free energy for this timestep
    nefe = use_pragmatic_value * pragmatic + use_info_gain * info_gain
    
    # nefe for remainder of policy rollout
    rest = rollout(theta, rest)
    
    # concatenate expected free energy across future time steps
    return [nefe] + rest

nefes = []
for pi in plans:
    step_nefes = rollout(theta_start, pi)
    nefe = np.array(step_nefes).mean() # expected value over steps
    nefes.append(nefe)
    


In [None]:
pi = np.array(plans) # (sample, step)
p_pi = softmax(np.array(nefes))

# aggregate p(action| policy) across policies with same first action
[p_pi[pi[:,0]==a].sum() for a in range(2)]

In [None]:
def select_action(env, theta_star, b):
    log_p_c = np.log(softmax(b_star)) # log probability of preferred states
    
    # with a discrete set of actions we can enumerate all possible plans.
    # here, we consider only plans of length K=1 for simplicity.
    #num_plans = 10
    #k = 4
    #plans = [np.choice(2, size=(n_plans, k), replace=True)
    pi = []
    g = []
    #for i in range(num_plans):
    for a in range(2):
        
        #pi = plans[i]
        q = softmax(b)
        q_ = np.dot(q, env.p_s1_given_s_a[:,a,:]) # next state given action
        pragmatic_value = np.sum(q_ * log_p_c)
        pi.append([a])
        g.append(pragmatic_value)

    # sample from softmax(g(policy))
    i = np.random.choice(len(g), p=softmax(np.array(g)))
    a = pi[i][0]
    #print(a)
    
    return a
    



In [None]:
target_state = 5

env = me.MinimalEnv(N=16, # number of states
                    s_food=0, # location of the food source
                    s_0=0) # starting location
b_star = np.eye(env.s_N)[target_state] * 5 # state preference 
b = np.zeros(env.s_N) # initialize state prior as uniform
o = env.reset() # set state to starting state

ss = [env.s_t]
bs = [b]
for i in range(100):
    b = update_belief(env, theta_prev=b, o=int(o))
    a = select_action(env, theta_star=b_star, b=b)
    b = update_belief_a(env, theta_prev=b, a=a)
    o = env.step(a)
    ss.append(env.s_t)
    bs.append(b)

q = softmax(b)
plt.bar(range(env.s_N), q)

In [None]:
#plt.plot(s)
fig, ax = plt.subplots(figsize=(16, 6))
plt.imshow(np.array(bs).T)
plt.plot(ss, label='state')
plt.plot([0, len(ss)-1], [target_state]*2, label='target')
plt.legend()