In [1]:
import numpy as np
from pymdp.envs import TMazeEnv

In [2]:
%load_ext autoreload
%autoreload 2

*actions*: 0: CENTER, 1: RIGHT ARM, 2: LEFT ARM, 3: CUE LOCATION (bottom)

*observations*: 0: LOCATION, 1: REWARD, 2: CUE

*states*: 0: LOCATION, 1: CONTEXT

In [3]:
env = TMazeEnv(reward_probs=[0.98, 0.02])
A = env.get_likelihood_dist()
B = env.get_transition_dist()

In [4]:
# Beliefs over the states: p(location), p(context)
D = [np.array([1,0,0,0]),     # Knows it is in the center
     np.array([0.5, 0.5])]    # but doesn't know the context ('reward condition')
# Preferences over the observations: \tilde p(Location), \tilde p(reward), \tilde(context)
C = [np.array([0., 0., 0., 0.]), # Location: Doesn't matter where it is
     np.array([ 0., 3., -3.]),   # Reward: Prefers to see reward than no reward, or punishment
     np.array([0., 0.])]         # Cue: Doesn't matter the cue (right, left)

In [5]:
obs = env.reset() # reset the environment and get an initial observation
obs # Location, reward, cue

[0, 0, 0]

In [6]:
def softmax(dist):
    """ 
    Computes the softmax function on a set of values
    """

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    return output

In [7]:
# PYMDP nomenclature:
# n_factor: number of dimension of a state
# n_modality: number of dimension of an observation

In [8]:
EPS = 1e-16

# s_location_idx = 0
# s_context_idx = 1
# qs_context = D[s_context_idx] # Equals D for the first iteration
n_dim_state = len(B)  # Location, context

# Observe that they are in the center, that there is no reward, and that the cue is on the right
obs = [np.array([1., 0., 0., 0.]), np.array([1., 0., 0.]), np.array([1., 0.])]

dim_state = A[0].shape[1:]
# Note that A[1].shape[1:] == A[2].shape[1:] == A[3].shape[1:] == dim_state
# In this example, dim_state = (4, 2) 
# because state is defined by two factors, location and context, and there is 4 locations and 2 contexts
n_uniq_state = np.prod(dim_state)
# Note that np.prod(A[0].shape[1:]) == np.prod(A[i].shape[1:]) for all possible i
likelihood = np.ones(n_uniq_state)
for i in range(len(A)): 
    likelihood *= A[i].reshape(A[i].shape[0], n_uniq_state).T.dot(obs[i])

likelihood = likelihood.reshape(*dim_state)
log_likelihood = np.log(likelihood+EPS)

# ------------------------------

curr_iter = 0
num_iter = 10
dF = 1
dF_tol = 0.001


old_qs = D  # Init posterior, after first iteration, this will be the posterior from the previous timestep

prior = []
qs = []  # Init posterior

for i in range(n_dim_state):
    prior_i = np.log(old_qs[i] + EPS)
    prior.append(prior_i)
    qs.append(np.ones(len(prior_i)) / len(prior_i))
    
old_vfe = np.inf  # Init free energy, after first iteration, this will be the free energy from the previous iteration

free_energy = None 
    
while curr_iter < num_iter+1 and dF >= dF_tol:
    
    free_energy = 0
    
    # Following PYMDP, what follows is not done at the first computation of the VFE (called pre-vfe)
    # Note that in PYMDP, the code is structured diferrently and this first computation is done BEFORE the main loop 
    # Here, to avoid copy-pasting the code, we refactored such that this is done at the iteration 0, and we continue 
    # one iteration more (hence, the '+1' in 'curr_iter < num_iter+1') 
    if curr_iter > 0:
        # ================================================================
        # Estimate q(s) for the current iteration 
        # --------------------------------------------------------------
        # As according to equation B5 (p 245 in the book) q(s) = sigma(v)
        # and v ~ - F
        # We follow equation B4 (p 245 in the book) that gives a definition of F
        # Note #1: as doing only one step ahead, we need to use a forward-only message passing technique only (as in 
        # https://doi.org/10.1162/NECO_a_00912)
        ll_tensor = np.outer(*qs)*log_likelihood
        for i in range(n_dim_state):
            # Note #2: Sum over all the dimension of the state ("factor") except the one under consideration (dimension 
            # `i`)
            ll_i = np.einsum(ll_tensor, np.arange(n_dim_state), [i])
            # Rigorosly following the book, it should be
            # v = ll_i - np.log(ps[i]) + prior[i]
            # as per, per B4 and B5,  v = - ln(s) + ln(A) . o + ln(D)
            v = ll_i / qs[i] + prior[i]
            # Equation B.5 (first line)
            qs[i] = softmax(v)
            
        # ================================================================
        # Computing part of VFE (Following again Equation B4)
        # => s . (ln(A) . o) 
        # equiv. to E_Q(s) P(o | s) in Eq B2
        # ---------------------------------------------------------------
        # meaning that 
        accuracy = np.sum(np.outer(*qs) * log_likelihood)
        free_energy -= accuracy

    # -------------------------------------------------------------------------------------

    for i in range(n_dim_state):
        
        # Equation B.4 (simplified as we only look one step ahead, and distributing s_\pi) 
        # (implementation of B2)
        
        # s . ln(s)
        # equiv. to E_Q(s) (ln Q(s)) in equation B.2
        # that is the negative entropy of the posterior marginal H(Q(s))
        negH_qs = qs[i].dot(np.log(qs[i] + 1e-16))
        
        # s . ln(D) or if t > 0: s . ln(s_{t-1})
        # that is the cross entropy of posterior marginal with prior marginal H(Q(s),P(s))
        xH_qp = -qs[i].dot(prior[i])
        
        free_energy += negH_qs + xH_qp
    
    dF = np.abs(free_energy - old_vfe)
    old_vfe = free_energy
    curr_iter += 1
print("free energy: ", free_energy)

free energy:  0.6931471805599451


In [9]:
n_dim_state

2

In [24]:
dim_state

(4, 2)

In [25]:
A[0].shape

(4, 4, 2)