In [97]:
import numpy as np
from pymdp.envs import TMazeEnv

In [2]:
%load_ext autoreload
%autoreload 2

*actions*: 0: CENTER, 1: RIGHT ARM, 2: LEFT ARM, 3: CUE LOCATION (bottom)

*observations*: 0: LOCATION, 1: REWARD, 2: CUE

*states*: 0: LOCATION, 1: CONTEXT

In [4]:
env = TMazeEnv(reward_probs=[0.98, 0.02])
A = env.get_likelihood_dist()
B = env.get_transition_dist()

In [6]:
# Beliefs over the states: p(location), p(context)
D = [np.array([1,0,0,0]),     # Knows it is in the center
     np.array([0.5, 0.5])]    # but doesn't know the context ('reward condition')
# Preferences over the observations: \tilde p(Location), \tilde p(reward), \tilde(context)
C = [np.array([0., 0., 0., 0.]), # Location: Doesn't matter where it is
     np.array([ 0., 3., -3.]),   # Reward: Prefers to see reward than no reward, or punishment
     np.array([0., 0.])]         # Cue: Doesn't matter the cue (right, left)

In [7]:
obs = env.reset() # reset the environment and get an initial observation
obs # Location, reward, cue

[0, 0, 1]

In [92]:
def softmax(dist):
    """ 
    Computes the softmax function on a set of values
    """

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    return output

In [96]:
EPS = 1e-16

s_location_idx = 0
s_context_idx = 1
qs_context = D[s_context_idx] # Equals D for the first iteration
n_state = len(B)

obs = [np.array([1., 0., 0., 0.]), np.array([1., 0., 0.]), np.array([1., 0.])]

old_qs = D
old_vfe = np.inf

# n_factor => n_state
# n_modality => n_obs

# Compute joint likelihood -----
dim_state = A[0].shape[1:]
n_uniq_state = np.prod(dim_state)
# Note that np.prod(A[0].shape[1:]) == np.prod(A[i].shape[1:]) for all possible i
likelihood = np.ones(n_uniq_state)
for i in range(len(A)): 
    likelihood *= A[i].reshape(A[i].shape[0], n_uniq_state).T.dot(obs[i])

likelihood = likelihood.reshape(*dim_state)
log_likelihood = np.log(likelihood+EPS)

# ------------------------------

curr_iter = 0
num_iter = 10
dF = 1
dF_tol = 0.001

prior = []
qs = []  # Init posterior
for i in range(n_state):
    prior_i = np.log(old_qs[i] + EPS)
    prior.append(prior_i)
    qs.append(np.ones(len(prior_i)) / len(prior_i))
    
while curr_iter < num_iter and dF >= dF_tol:
    
    free_energy = 0
    
    if curr_iter > 0:
        qs_all = np.outer(*qs)
        ll_tensor = qs_all*log_likelihood
        for i in range(n_state):
            qs[i] = softmax(np.einsum(ll_tensor, np.arange(n_state), [i]) / qs[i] + prior[i])

    for i in range(n_state):
        
        # Neg-entropy of posterior marginal H(q[f])
        negH_qs = qs[i].dot(np.log(qs[i] + 1e-16))
        
        # Cross entropy of posterior marginal with prior marginal H(q[f],p[f])
        xH_qp = -qs[i].dot(prior[i])
        
        free_energy += negH_qs + xH_qp
        
    if curr_iter > 0:
        accuracy = np.sum(np.outer(*qs) * log_likelihood)
        free_energy -= accuracy
    
    dF = np.abs(free_energy - old_vfe)
    old_vfe = free_energy
    curr_iter += 1

factor #0
qs[factor] [0.25 0.25 0.25 0.25]
negH_qs -1.3862943611198901
xH_qp 27.63102111592855
factor #1
qs[factor] [0.5 0.5]
negH_qs -0.6931471805599451
xH_qp 0.6931471805599451
free energy 26.24472675480866
--------------------
Curr iter (- 1) 0
qs after [array([1.e+00, 2.e-32, 2.e-32, 2.e-32]), array([0.5, 0.5])]
factor #0
qs[factor] [1.e+00 2.e-32 2.e-32 2.e-32]
negH_qs -2.2104816892742725e-30
xH_qp 2.2104816892742725e-30
factor #1
qs[factor] [0.5 0.5]
negH_qs -0.6931471805599451
xH_qp 0.6931471805599451
accuracy -0.6931471805599451
free energy 0.6931471805599451
--------------------
Curr iter (- 1) 1
qs after [array([1.e+00, 2.e-32, 2.e-32, 2.e-32]), array([0.5, 0.5])]
factor #0
qs[factor] [1.e+00 2.e-32 2.e-32 2.e-32]
negH_qs -2.2104816892742725e-30
xH_qp 2.2104816892742725e-30
factor #1
qs[factor] [0.5 0.5]
negH_qs -0.6931471805599451
xH_qp 0.6931471805599451
accuracy -0.6931471805599451
free energy 0.6931471805599451
