In [1]:
import numpy as np
from pymdp.envs import TMazeEnv

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
env = TMazeEnv(reward_probs=[0.98, 0.02])
A = env.get_likelihood_dist()
B = env.get_transition_dist()

A[i][j, k, l] for the modality i, in location (understood as a state) k and context l, what is the probability of 
observing location (understood as an observation) j?

In [4]:
# A is the set of the likelihood distributions for each type of observation
# There are 3 types of observations: Location, reward, cue
n_obs = len(A) 
print(n_obs)

3


In [5]:
# Likelihood for location:  p(location | location, context)
print(A[0][:,:,0])  # Location, location, context (reward condition)

[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


In [6]:
# Likelihood for reward: p(reward | location, context)
print(A[1][:,:,0])  # Reward, location, context (reward condition)


[[1.   0.   0.   1.  ]
 [0.   0.98 0.02 0.  ]
 [0.   0.02 0.98 0.  ]]


In [7]:
# Likelihood for cue: p(cue | location, context)
print(A[2][:,:,0]) 

[[0.5 0.5 0.5 1. ]
 [0.5 0.5 0.5 0. ]]


In [8]:
n_state = len(B)
print(n_state)  # For the 2 types of transitions for each state: location, context

2


In [9]:
B[0].shape  # p(location | location, action) Location, location location/action

(4, 4, 4)

In [10]:
B[0][0, :, 0]  # Probability of going center, depending on where you are and that you choose to go to center => 
# You're certain you'll end up in center wherever you are

array([1., 1., 1., 1.])

In [11]:
B[1].shape  # p(context | context, dummy action)

(2, 2, 1)

In [12]:
# Beliefs over the states: p(location), p(context)
D = [np.array([1,0,0,0]),     # Knows it is in the center
     np.array([0.5, 0.5])]    # but doesn't know the context ('reward condition')
# Preferences over the observations: \tilde p(Location), \tilde p(reward), \tilde(context)
C = [np.array([0., 0., 0., 0.]), # Location: Doesn't matter where it is
     np.array([ 0., 3., -3.]),   # Reward: Prefers to see reward than no reward, or punishment
     np.array([0., 0.])]         # Cue: Doesn't matter the cue

In [13]:
obs = env.reset() # reset the environment and get an initial observation
obs # Location, reward, cue

[0, 0, 0]

*actions*: 0: CENTER, 1: RIGHT ARM, 2: LEFT ARM, 3: CUE LOCATION (bottom)

*observations*: 0: LOCATION, 1: REWARD, 2: CUE

*states*: 0: LOCATION, 1: CONTEXT

`A_gp[i][j, k, l]`
In context `l`, for the modality `i`, if agent takes action `j`, what is the probability of observing `k`?, 

In [14]:
def softmax(dist):
    """ 
    Computes the softmax function on a set of values
    """

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    return output

In [15]:
s_location_idx = 0
s_context_idx = 1
qs_context = D[s_context_idx] # Equals D for the first iteration

have_preference = [np.all(C[i] == 0.) == False for i in range(len(C))]

n_action = len(B[s_location_idx][0,0,:])

epistemic_val = np.zeros(n_action)
pragmatic_val = np.zeros(n_action)
for action in range(n_action):
    
    qs_location_action = B[s_location_idx][:,:, action].dot(D[s_location_idx])
    
    for i in range(n_obs):
        
        A_mod = A[i]
        
        # Equation B.28 (third line) (p 252)
        qo_mod_i_pi = A_mod[:,:,:].dot(qs_context).dot(qs_location_action)
        
        # Pragmatic value - Equation B.28
        if have_preference[i]:
            soft_c = softmax(C[i])
            # Equation B.28 (first line) (p 252)
            pragmatic_val[action] += qo_mod_i_pi.T.dot(np.log(soft_c))
        
        # Epistemic value - Equation B.29
        # Equation B.29 (second line, first term) (p 252)
        first_term = - qo_mod_i_pi.T.dot(np.log(qo_mod_i_pi + 1e-10))
        
        # Equation B.29, setting up the second term (p 252)
        qs = np.outer(qs_location_action, qs_context)
        A_mod = A_mod.reshape(A_mod.shape[0], -1)
        # Equation B.29 (third line) (p 252)
        h = -np.diag(A_mod.T.dot(np.log(A_mod + 1e-10)))
        # Equation B.29 (second line, second term) (p 252)
        second_term = h.dot(qs.ravel())
        
        epistemic_val[action] += first_term - second_term

In [16]:
epistemic_val

array([0.        , 0.59510807, 0.59510807, 0.69314718])

In [17]:
pragmatic_val

array([-3.05094576, -3.05094576, -3.05094576, -3.05094576])

In [18]:
epistemic_val + pragmatic_val

array([-3.05094576, -2.4558377 , -2.4558377 , -2.35779858])