In [2]:
import numpy as np
from pymdp.envs import TMazeEnv

In [3]:
%load_ext autoreload
%autoreload 2

*actions*: 0: CENTER, 1: RIGHT ARM, 2: LEFT ARM, 3: CUE LOCATION (bottom)

*observations*: 0: LOCATION, 1: REWARD, 2: CUE

*states*: 0: LOCATION, 1: CONTEXT

In [4]:
env = TMazeEnv(reward_probs=[0.98, 0.02])
A = env.get_likelihood_dist()
B = env.get_transition_dist()

In [5]:
# Beliefs over the states: p(location), p(context)
D = [np.array([1,0,0,0]),     # Knows it is in the center
     np.array([0.5, 0.5])]    # but doesn't know the context ('reward condition')
# Preferences over the observations: \tilde p(Location), \tilde p(reward), \tilde(context)
C = [np.array([0., 0., 0., 0.]), # Location: Doesn't matter where it is
     np.array([ 0., 3., -3.]),   # Reward: Prefers to see reward than no reward, or punishment
     np.array([0., 0.])]         # Cue: Doesn't matter the cue (right, left)

In [6]:
obs = env.reset() # reset the environment and get an initial observation
obs # Location, reward, cue

[0, 0, 1]

In [7]:
def softmax(dist):
    """ 
    Computes the softmax function on a set of values
    """

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    return output

In [8]:
# PYMDP nomenclature:
# n_factor: number of dimension of a state
# n_modality: number of dimension of an observation

In [9]:
EPS = 1e-16

# s_location_idx = 0
# s_context_idx = 1
# qs_context = D[s_context_idx] # Equals D for the first iteration
n_dim_state = len(B)  # Location, context

# Observe that they are in the center, that there is no reward, and that the cue is on the right
obs = [np.array([1., 0., 0., 0.]), np.array([1., 0., 0.]), np.array([1., 0.])]

dim_state = A[0].shape[1:]
# Note that A[1].shape[1:] == A[2].shape[1:] == A[3].shape[1:] == dim_state
# In this example, dim_state = (4, 2) 
# because state is defined by two factors, location and context, and there is 4 locations and 2 contexts
n_uniq_state = np.prod(dim_state)
# Note that np.prod(A[0].shape[1:]) == np.prod(A[i].shape[1:]) for all possible i
likelihood = np.ones(n_uniq_state)
for i in range(len(A)): 
    likelihood *= A[i].reshape(A[i].shape[0], n_uniq_state).T.dot(obs[i])

likelihood = likelihood.reshape(*dim_state)
log_likelihood = np.log(likelihood+EPS)

# ------------------------------

curr_iter = 0
num_iter = 10
dF = 1
dF_tol = 0.001


old_qs = D  # Init posterior, after first iteration, this will be the posterior from the previous timestep

prior = []
qs = []  # Init posterior

for i in range(n_dim_state):
    prior_i = np.log(old_qs[i] + EPS)
    prior.append(prior_i)
    qs.append(np.ones(len(prior_i)) / len(prior_i))
    
old_vfe = np.inf  # Init free energy, after first iteration, this will be the free energy from the previous iteration

free_energy = None 
    
while curr_iter < num_iter+1 and dF >= dF_tol:
    
    free_energy = 0
    
    # Following PYMDP, what follows is not done at the first computation of the VFE (called pre-vfe)
    # Note that in PYMDP, the code is structured diferrently and this first computation is done BEFORE the main loop 
    # Here, to avoid copy-pasting the code, we refactored such that this is done at the iteration 0, and we continue 
    # one iteration more (hence, the '+1' in 'curr_iter < num_iter+1') 
    if curr_iter > 0:
        # ================================================================
        # Estimate the posterior Q(s) for the current iteration by fixed-point iteration
        # --------------------------------------------------------------
        # Here, there is a departure to what is expressed in the book in Eq B5 (p 245 in the book)
        # We still are looking for an estimation of `v` such that `Q(s) = sigma(v)`
        # but to get this estimate, following PYMDP, we actually follow  closely the equation A.42 (p 237) (instead of Eq 5)
        # v_i = E_Q(s_j != i) ln p(o, s) = E_Q(s) ln p(o | s) + ln p(s) 
        # Concretely, we will compute:
        # v_i = \sum_j (q(s_i) q(s_j) ln p(o | s_i, s_j) ) / q(s_i) + ln p(s_i)
        # Step #1: We first create a joint likelihood with our current best guess (the estimation of the posterior)
        # => q(s_i) q(s_j)  * ln p(o | s_i, s_j)
        joint_ll = np.outer(*qs)*log_likelihood
        for i in range(n_dim_state):
            # Step #2: Compute the expectation of the likelihood over all the dimensions of the state ("factors") 
            # except the one under 
            # consideration (dimension `i`)
            # => \sum_j (q(s_i) q(s_j) ln p(o | s_i, s_j) ) / q(s_i)
            ll_i = np.einsum(joint_ll, np.arange(n_dim_state), [i]) / qs[i]
            # Then add the prior
            v = ll_i + prior[i]
            # Similar to B.5, softmax it to make sure that it sums to 1
            qs[i] = softmax(v)
            
        # ================================================================
        # Following Equation B4
        # F = s . (ln(A) . o ln(s)) + s . ln(s) - s . ln(D)
        # Note: the second part of Eq B4 (the summation over the remaining time steps) is ignored as this will be applied recursively 
        # Computing part #1 of VFE
        # => s . (ln(A) . o) 
        # equiv. to E_Q(s) P(o | s) in Eq B2
        # ---------------------------------------------------------------
        # meaning that 
        accuracy = np.sum(np.outer(*qs) * log_likelihood)
        free_energy -= accuracy

    # ------------------------------------------------------------------

    for i in range(n_dim_state):
        
        # Equation B.4 (simplified as we only look one step ahead, and distributing s_\pi) 
        # (implementation of B2)
        # Computing part #2 of VFE
        # => s . ln(s)
        # equiv. to E_Q(s) (ln Q(s)) in equation B.2
        # that is the negative entropy of the posterior marginal H(Q(s))
        negH_qs = qs[i].dot(np.log(qs[i] + 1e-16))
        
        # Computing part #3 of VFE
        # => - s . ln(D) or if t > 0: s . ln(s_{t-1})
        # that is the cross entropy of posterior marginal with prior marginal H(Q(s), P(s)) = - E_Q P(s)
        xH_qp = -qs[i].dot(prior[i])
        
        free_energy += negH_qs + xH_qp
    
    dF = np.abs(free_energy - old_vfe)
    old_vfe = free_energy
    curr_iter += 1

print("free energy: ", free_energy)

free energy:  0.6931471805599451


In [10]:
n_dim_state

2

In [11]:
dim_state

(4, 2)

In [12]:
A[0].shape

(4, 4, 2)

In [18]:
fake_ll_tensor = np.arange(8).reshape(dim_state)  # dim_state = (4,2) because 4 possible locations, 2 possible contexts
np.einsum(fake_ll_tensor, np.arange(n_dim_state), [0])

2
(4, 2)


array([ 1,  5,  9, 13])

In [24]:
np.einsum('ij-> i', fake_ll_tensor)

array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]])

In [25]:
fake_ll_tensor = np.arange(8).reshape((4, 2))
results = []
for char in ('i', 'j'):
     results.append(np.einsum(f'ij-> {char}', fake_ll_tensor))
print(results)

[array([ 1,  5,  9, 13]), array([12, 16])]


In [26]:
fake_ll_tensor = np.arange(8).reshape((4, 2))
print(np.einsum(f'ij-> ij', fake_ll_tensor))

[[0 1]
 [2 3]
 [4 5]
 [6 7]]
