In [62]:
import numpy as np
from pymdp.envs import TMazeEnvNullOutcome
from pymdp import utils, maths
import copy

In [63]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


*actions*: 0: CENTER, 1: RIGHT ARM, 2: LEFT ARM, 3: CUE LOCATION (bottom)

*observations*: 0: LOCATION, 1: REWARD, 2: CUE

*states*: 0: LOCATION, 1: CONTEXT

In [64]:
reward_probabilities = [0.85, 0.15]  # the 'true' reward probabilities
env = TMazeEnvNullOutcome(reward_probs=reward_probabilities)
A_gp = env.get_likelihood_dist()
B_gp = env.get_transition_dist()

In [65]:
# Wrong model for A
pA = utils.dirichlet_like(A_gp, scale=1e16)
pA[1][1:, 1:3, :] = 1.0
A_gm = utils.norm_dist_obj_arr(pA)
B_gm = copy.deepcopy(B_gp)
# Beliefs over the states: p(location), p(context)
D = [np.array([1,0,0,0]),     # Knows it is in the center
     np.array([0.5, 0.5])]    # but doesn't know the context ('reward condition')
# Preferences over the observations: \tilde p(Location), \tilde p(reward), \tilde(context)
C = [np.array([0., 0., 0., 0.]), # Location: Doesn't matter where it is
     np.array([ 0., 2., -2.]),   # Reward: Prefers to see reward than no reward, or punishment
     np.array([0., 0., 0.])]         # Cue: Doesn't matter the cue (none, right, left)

lr = 0.25

In [66]:
obs = env.reset() # reset the environment and get an initial observation
obs # Location, reward, cue

[0, 0, 0]

In [67]:
def softmax(dist):
    """ 
    Computes the softmax function on a set of values
    """

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    return output

In [82]:
s_location_idx = 0
s_context_idx = 1
qs_context = D[s_context_idx] # Equals D for the first iteration

have_preference = [np.all(C[i] == 0.) == False for i in range(len(C))]

n_action = len(B_gp[s_location_idx][0,0,:])
n_obs = len(A_gp)

epistemic_val = np.zeros(n_action)
pragmatic_val = np.zeros(n_action)
epistemic_model_val = np.zeros(n_action)
for action in range(n_action):
    
    qs_location_action = B_gm[s_location_idx][:,:, action].dot(D[s_location_idx])
    
    for i in range(n_obs):
        
        A_mod = A_gm[i]
        
        # Equation B.28 (third line) (p 252)
        qo_mod_i_pi = A_mod[:,:,:].dot(qs_context).dot(qs_location_action)
        
        # Pragmatic value - Equation B.28
        # if have_preference[i]:
        soft_c = softmax(C[i])
        # Equation B.28 (first line) (p 252)
        val = qo_mod_i_pi.T.dot(np.log(soft_c))
        pragmatic_val[action] += val
        
        # Epistemic value - Equation B.29
        # Equation B.29 (second line, first term) (p 252)
        first_term = - qo_mod_i_pi.T.dot(np.log(qo_mod_i_pi + 1e-10))
        
        # Equation B.29, setting up the second term (p 252)
        qs = np.outer(qs_location_action, qs_context)
        # Reshape so we have a matrix of size(n_obs, n_states)
        A_mod_rh = A_mod.reshape(A_mod.shape[0], -1)
        # Equation B.29 (third line) (p 252)
        h = -np.diag(A_mod_rh.T.dot(np.log(A_mod_rh + 1e-10)))
        # Equation B.29 (second line, second term) (p 252)
        second_term = h.dot(qs.ravel())
        
        epistemic_val[action] += first_term - second_term
        
        # Epistemic value for model parameters - B.34 p(253)
        # Equation B.33 (first line) (p 253) simplified as in `pymdp.maths.spm_wnorm`
        pA_mod = pA[i]
        A = pA_mod + 1e-16
        norm = 1 / np.sum(A, axis=0)
        avg = 1 / A
        w = norm - avg

        # Equation B.34
        make_sense = pA_mod > 0
        w = w * make_sense.astype(float)
        ws = w.reshape(A_mod.shape[0], -1).dot(qs.ravel())
        val = - qo_mod_i_pi.dot(ws)
        
        epistemic_model_val[action] += val

In [83]:
pragmatic_val

array([-4.62783828, -4.62783828, -4.62783828, -4.62783828])

In [84]:
epistemic_val

array([0.        , 0.        , 0.        , 0.69314718])

In [85]:
epistemic_model_val

array([0. , 0.5, 0.5, 0. ])

In [86]:
epistemic_val + pragmatic_val + epistemic_model_val

array([-4.62783828, -4.12783828, -4.12783828, -3.9346911 ])