In [2]:
from math_helper_functions import log_stable
from math_helper_functions import softmax
from math_helper_functions import kl_div
import numpy as np
import math
from pymdp.maths import spm_log_single as log_stable

EPS_VAL = 1e-16 #negligibleconstant

def entropy(A):
    """ Compute the entropy of a set of condition distributions, i.e. one entropy value per column """
    
    H_A = - (A * log_stable(A)).sum(axis=0)
    return H_A

#Dynamic programming in G (expected free energy)

def action_dist(A, B, C, T, sm_par):
    
    num_modalities = A.shape[0]
    num_factors = B.shape[0]

    num_states = []
    for i in range(num_factors):
        num_states.append(B[i].shape[0])

    num_obs = []
    for i in range(num_modalities):
        num_obs.append(A[i].shape[0])

    num_controls = []
    for i in range(num_factors):
        num_controls.append(B[i].shape[2])

    numS = 1
    for i in num_states:
        numS *= i
    numA = 1
    for i in num_controls:
        numA *= i

    new_num_states = [numS]
    new_num_controls = [numA]

    new_A = utils.random_A_matrix(num_obs, new_num_states) #* 0 + EPS_VAL
    new_B = utils.random_B_matrix(1, 1) #* 0 + EPS_VAL

    for i in range(num_modalities):
        new_A[i] = np.reshape(A[i], [A[i].shape[0], numS])

    for i in range(num_factors):
        new_B[0] = np.kron(new_B[0],B[i])

    #Expected free energy (Only RISK)
    
    G = np.zeros((T-1, numA, numS))
    Q_actions = np.zeros((T-1, numA, numS))

    for mod in range(num_modalities):

        Q_po = np.zeros((A[mod].shape[0], numS, numA))

        for i in range(numS):
            for j in range(numA):
                Q_po[:,i,j] = new_A[mod].dot(new_B[0][:,i,j])

        for k in range(T-2,-1,-1):
            for i in range(numA):
                for j in range(numS):

                    if(k==T-2):
                        G[k,i,j] += kl_div(Q_po[:,j,i],C[mod])

                    else:
                        G[k,i,j] += kl_div(Q_po[:,j,i],C[mod])
                        for jj in range(numS):
                            for kk in range(numA):
                                G[k,i,j] += Q_actions[k+1,kk,jj]*new_B[0][jj,j,i]*G[k+1,kk,jj]

            #Distribution for action-selection
            for ppp in range(numS):
                Q_actions[k,:,ppp] = softmax(sm_par*(-1*G[k,:,ppp]))
                
    return Q_actions

In [5]:
# (Hidden)Factors
s1_size = 42

num_states = [s1_size]
num_factors = len(num_states)

# Rewards
reward_modes = 3 #Max score-5 (assumption)

# Controls
s1_actions = ['Stay', 'Play-Up', 'Play-Down']
num_controls = [len(s1_actions)]

# Observations
#Ball-x
o1_obs_size = s1_size
#Ball-y
o2_obs_size = s1_size
#Ball-vx
o3_obs_size = 2
#Ball-vy
o4_obs_size = 2
#Paddle-pos
o5_obs_size = s1_size
#Paddle-velocity
o6_obs_size = 2
#Reward (Shock, Chocolate, and Nothing)
reward_obs_size = reward_modes

num_obs = [o1_obs_size, o2_obs_size, o3_obs_size, o4_obs_size, o5_obs_size, o6_obs_size, reward_obs_size]
num_modalities = len(num_obs)

EPS_VAL = 1e-16 # Negligibleconstant

# Likelhiood Dynamics
A = utils.random_A_matrix(num_obs, num_states)*0 + EPS_VAL

# Transisition dynamics
# Initialised as random becuase the agent need to learn the dynamics

B = utils.random_B_matrix(num_states, num_controls)*0 + EPS_VAL

numS = 1
for i in num_states:
    numS *= i
numA = 1
for i in num_controls:
    numA *= i

A = normalise_A(A, num_states, num_modalities)
B = normalise_B(B, num_states, num_controls)

# Prior preferences for biasing the generative model to control behaviour

# The preferences are set uniform for all the hidden-states except the reward function
C = utils.obj_array_uniform(num_obs)

# Highest for the high-score and lowest for the lowscore
C_score = np.array([-5.8, 0 , 1])
# Normalising the prior preference
C[6] = pymdp.maths.softmax(1*C_score)

#Prior over hidden-states
D = utils.obj_array_uniform(num_states)

In [4]:
from pymdp import utils
import pymdp
from math_helper_functions import normalise_A, normalise_B

In [6]:
%time Q_pi = action_dist(A, B, C, T=5, sm_par=1)

CPU times: total: 359 ms
Wall time: 356 ms


In [8]:
T=5
sm_par=1

In [9]:
num_modalities = A.shape[0]
num_factors = B.shape[0]

num_states = []
for i in range(num_factors):
    num_states.append(B[i].shape[0])

num_obs = []
for i in range(num_modalities):
    num_obs.append(A[i].shape[0])

num_controls = []
for i in range(num_factors):
    num_controls.append(B[i].shape[2])

numS = 1
for i in num_states:
    numS *= i
numA = 1
for i in num_controls:
    numA *= i

new_num_states = [numS]
new_num_controls = [numA]

new_A = utils.random_A_matrix(num_obs, new_num_states) #* 0 + EPS_VAL
new_B = utils.random_B_matrix(1, 1) #* 0 + EPS_VAL

for i in range(num_modalities):
    new_A[i] = np.reshape(A[i], [A[i].shape[0], numS])

for i in range(num_factors):
    new_B[0] = np.kron(new_B[0],B[i])

#Expected free energy (Only RISK)

G = np.zeros((T-1, numA, numS))
Q_actions = np.zeros((T-1, numA, numS))

for mod in range(num_modalities):

    Q_po = np.zeros((A[mod].shape[0], numS, numA))

    for i in range(numS):
        for j in range(numA):
            Q_po[:,i,j] = new_A[mod].dot(new_B[0][:,i,j])

    for k in range(T-2,-1,-1):
        for i in range(numA):
            for j in range(numS):

                if(k==T-2):
                    G[k,i,j] += kl_div(Q_po[:,j,i],C[mod]) + np.dot(new_B[0][:,j,i],entropy(new_A[0]))

                else:
                    G[k,i,j] += kl_div(Q_po[:,j,i],C[mod]) + np.dot(new_B[0][:,j,i],entropy(new_A[0]))
                    for jj in range(numS):
                        for kk in range(numA):
                            G[k,i,j] += Q_actions[k+1,kk,jj]*new_B[0][jj,j,i]*G[k+1,kk,jj]

        #Distribution for action-selection
        for ppp in range(numS):
            Q_actions[k,:,ppp] = softmax(sm_par*(-1*G[k,:,ppp]))

In [38]:
np.dot(new_B[0][:,0,0],entropy(new_A[0]))

3.737669618283361

In [10]:
Q_actions.shape

(4, 3, 42)

In [27]:
entropy(new_A[0]).shape

(42,)

In [39]:
new_B[0].shape, new_A[0].shape

((42, 42, 3), (42, 42))

In [33]:
(Q_actions[0,:,:]*(new_B[0][:,:,0].dot(entropy(new_A[0])))).shape

(3, 42)