# A step-by-step tutorial on active inference and its application to empirical data

- [Paper](https://www.sciencedirect.com/science/article/pii/S0022249621000973)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from jax import grad
from jax import numpy as jnp
import copy

In [None]:
eps_ = 1e-16
eps_ = np.exp(-16) #NOTE: in all their script and experiments, they use this epsilon value. It seems too large, but will use for now.

def log(x):
    return np.log(x + eps_)

def softmax(x):
    # stable softmax
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

### Variational Free Energy Approximation

$$ F = \sum_{s \in S} q(s) \ln \frac{q(s)}{p(o, s)} $$

#### Example from Figure 3

In [None]:
prior = np.array([.5, .5]) # prior p(s)

p_o_given_s = np.array([.8, .2]) # likelihood p(o|s)

joint_p_o_s =  p_o_given_s * prior # joint p(o, s)
joint_p_o_s

In [None]:
# Compute the exact posterior

p_o = joint_p_o_s.sum() # marginal likelihood p(o)

p_s_given_o = joint_p_o_s / p_o # posterior p(s|o)
p_s_given_o

Minimizing the variational free energy will approach the true posterior distribution as it the upper bound on the log model evidence.

In [None]:
# Compute the approximate posterior

# set the initial approximate posterior q(s) to be the prior
q_s = prior.copy() # approximate posterior q(s)

F = q_s.dot(np.log(q_s / joint_p_o_s)) # compute the KL divergence between the joint and the approximate posterior
F

Free energy can also be computed with the exact posterior distribution to demonstrate that the variational free energy is an upper bound on the log model evidence and minimization of the variational free energy will lead q(s) to approach the true posterior distribution.

$$ F = E_{q(s)} \ln \frac{q(s)}{p(s|o)} - \ln p(o) $$

In [None]:
F = np.asarray(q_s).dot(np.log(q_s / p_s_given_o)) - np.log(p_o)
F

In [None]:
# create a jax version of the free energy
grad_F = grad(lambda q_s: q_s.dot(jnp.log(q_s / joint_p_o_s))) # compute the gradient of the KL divergence with respect to the approximate posterior q(s)

In [None]:
# Run this cell ~ 180 times to converge to the exact posterior
learning_rate = .1
# convert the update log probabilities to probabilities via softmax
q_s = softmax(np.log(q_s) - grad_F(q_s) * learning_rate) # compute the gradient of the KL divergence at the initial approximate posterior
q_s

$$ F = E_{q(s)} [\ln q(s) - \ln p(o,s)] $$

In [None]:
F = q_s.dot(np.log(q_s) - np.log(joint_p_o_s))
F

$$ F = E_{q(s)} [\ln q(s) - \ln p(s)] - E_{q(s)}[\ln p (o|s)] $$

In [None]:
F = q_s.dot(np.log(q_s) - np.log(prior)) - q_s.dot(np.log(p_o_given_s))
F

$$ F = D_{KL} [q(s) || p(s)] - E_{q(s)}[\ln p (o|s)] $$

In [None]:
F = q_s.dot(np.log(q_s / prior)) - q_s.dot(np.log(p_o_given_s))
F

### Figure 5 - Static Perception

In [None]:
D = np.array([.5, .5]) # Prior

A = np.array([
    [.9, .2],
    [.1, .8]
]) # likelihood p(o|s) 

o = np.array([1, 0]) # output

q_s = softmax(np.log(D) + np.log(A.T.dot(o)))
q_s

## Marginal Message Passing

1. Initialize the values of the approximate posteriors q(s_(?,?) ) 
   for all hidden variables (i.e., all edges) in the graph. 
2. Fix the value of observed variables (here, o_?).
3. Choose an edge (V) corresponding to the hidden variable you want to 
   infer (here, s_(?,?)).
4. Calculate the messages, ?(s_(?,?)), which take on values sent by 
   each factor node connected to V.
5. Pass a message from each connected factor node N to V (often written 
   as ?_(N?V)). 
6. Update the approximate posterior represented by V according to the 
   following rule: q(s_(?,?) )? ? ?(s_(?,?))? ?(s_(?,?)). The arrow 
   notation here indicates messages from two different factors arriving 
   at the same edge. 
    6A. Normalize the product of these messages so that q(s_(?,?) ) 
        corresponds to a proper probability distribution. 
    6B. Use this new q(s_(?,?) ) to update the messages sent by 
        connected factors (i.e., for the next round of message passing).
7. Repeat steps 4-6 sequentially for each edge.
8. Steps 3-7 are then repeated until the difference between updates 
   converges to some acceptably low value (i.e., resulting in stable 
   posterior beliefs for all edges).

from [Message_passing_example.m](https://github.com/rssmith33/Active-Inference-Tutorial-Scripts/blob/main/Message_passing_example.m)




In [None]:
# Example 1

# Fixed observations and message passing steps. Both observations are fixed from the start / already observed. 
# In full active inference there is another time step variable which acts to sequentially present the observations.

# prior p(s) regarding the initial two states
D = np.array([.5, .5])

# likelihood p(o|s) 
A = np.array([
    [.9, .1],
    [.1, .9]
])

# state to state transition probability matrix p(s_ τ+1 |s_ τ)
B = np.array([
    [1, 0],
    [0, 1]
])

# transpose the transition matrices and normalize the columns for future message passing
# Note: this technically not necessary for this example as B.T is already normalized, but it is included for completeness
B_T = [
    b.T / b.T.sum(axis =0) for b in B 
]

# fixed observations o_ τ
# Fix the observations at each time step (Step 2)
o_arr = np.array([
    [1, 0],
    [1, 0]
])

time_steps = len(o_arr) # number of time steps

num_iter = 16 # number of iterations of message passing

# Initialize approximate posteriors q(s) at each time step (Step 1)
qs_arr = np.ones((time_steps, len(D))) / len(D) # array of approximate posteriors q(s) at each time step

# Initialize history of approximate posteriors q(s) for each iteration and time step (This variable is only used for visualization)
qs_history = np.zeros((num_iter, time_steps, len(D)))

# Iterate a set number of times or until convergence (Step 8) 
for i in range(num_iter):
    # For each edge (hidden state) (Step 7)
    for tt in range(time_steps):
        # get the log of the approximate posterior q(s) at this time step (Step 3)
        q_s = np.log(qs_arr[tt])

        # get the message sent from the past (Step 4: Message 1) 
        if tt == 0: # if this is the first time step we use the prior
            log_B_past = np.log(D)
        else: # otherwise we compute the belief of the current state based on the previous state and the transition matrix
            log_B_past = np.log(B @ qs_arr[tt - 1])

        # get the message sent from the future (Step 4: Message 2)
        if tt == time_steps - 1: # if this is the last time step we use a message of zero (no future states)
            log_B_future = 0
        else: # otherwise we compute the belief of the current state based on the future state and the transition matrix
            log_B_future = np.log(B_T @ qs_arr[tt + 1])

        # get the likelihood of the state given the observation (Step 4: Message 3)
        log_Ao = np.log(A.T @ o_arr[tt])

        # Pass messages and update the posterior (Step 5-6)
        # Since all terms are in log space, this is addition instead of
        # multiplication. This corresponds to  equation 16 in the main
        # text (within the softmax)
        q_s = .5 * (log_B_past + log_B_future) + log_Ao

        # normalize the posterior (Step 6A)
        qs_arr[tt] = softmax(q_s)

        qs_history[i, tt] = qs_arr[tt]

qs_arr

In [None]:
#Example 1: Posterior over states 

plt.matshow(qs_arr.T, vmin = 0.0, vmax = 1.0, cmap='binary')
plt.xlabel('Time')
plt.ylabel('Approximate Posterior $q(s)$')

qs_arr

In [None]:
np.vstack([np.array([[D] * 2]), qs_history]).shape

In [None]:
# 
# Note: the initial prior is not always added to the history when plotting
qs_history_with_priors = np.vstack([np.array([[D] * 2]), qs_history])

plt.plot(qs_history_with_priors.reshape(-1, 4))
plt.ylabel('Approximate Posterior Probability, $q(s_{tau})$')
plt.xlabel('Message Passing Iterations')
plt.show()

In [None]:
# Example 2

# prior p(s) regarding the initial two states
D = np.array([.5, .5])

# likelihood p(o|s) 
A = np.array([
    [.9, .1],
    [.1, .9]
])

# state to state transition probability matrix p(s_ τ+1 |s_ τ)
B = np.array([
    [1, 0],
    [0, 1]
])

# transpose the transition matrices and normalize the columns for future message passing
# Note: this technically not necessary for this example as B.T is already normalized, but it is included for completeness
B_T = [
    b.T / b.T.sum(axis =0) for b in B 
]

# In the original Message_passing_example.m code the sequential observations are defined as a matrix; (τ, t)
# this symbolizes that each τ can see all observations up to τ.
# Hence the second observation of the first τ is [0, 0] (not observed)
# o_arr = np.array([
#     [
#         [1, 0],
#         [0, 0]
#     ],
#     [
#         [1, 0],
#         [1, 0]
#     ]
# ])
# For simplicity, we will instead just check in each iteration if τ is less than or equal to t and then use the observation at that time step.
# Otherwise, we will set log_Ao to zero.

# fixed observations o_ τ
# Fix the observations at each time step (Step 2)
o_arr = np.array([
    [1, 0],
    [1, 0]
])

time_steps = len(o_arr) # number of time steps

num_iter = 16 # number of iterations of message passing

# Initialize approximate posteriors q(s) at each time step (Step 1)

qs_arr = np.ones((time_steps, len(D))) / len(D) # array of approximate posteriors q(s) at each time step

# Initialize history of approximate posteriors q(s) for each iteration and time step (This variable is only used for visualization)
qs_history = np.zeros((time_steps, num_iter, time_steps, len(D))) 
# Initialize history of errors for each iteration and time step (This variable is only used for visualization)
err_history = np.zeros((time_steps, num_iter, time_steps, len(D))) 

# for each time step (over all observations)
for t in range(time_steps):
    # for each factor (light blue shapes, light green shapes)
    for i in range(num_iter):
        # for each time step (over all observations)
        for tt in range(time_steps):

            # get the log of the approximate posterior q(s) at this time step (Step 3)
            v = log(qs_arr[tt])

            # get the message sent from the past (Step 4: Message 1) 
            if tt == 0: # if this is the first time step we use the prior
                log_B_past = log(D)
            else: # otherwise we compute the belief of the current state based on the previous state and the transition matrix
                log_B_past = log(B @ qs_arr[tt - 1])

            # get the message sent from the future (Step 4: Message 2)
            if tt == time_steps - 1: # if this is the last time step we use a message of zero (no future states)
                log_B_future = 0
            else: # otherwise we compute the belief of the current state based on the future state and the transition matrix
                log_B_future = log(B_T @ qs_arr[tt + 1])

            # get the likelihood of the state given the observation (Step 4: Message 3)
            if tt <= t: # if the observation has been observed
                log_Ao = log(A.T  @ o_arr[tt] )
            else: # if the observation has not been observed
                log_Ao = 0
    
            err = 0.5 * (log_B_past + log_B_future) + log_Ao - v

            v += err

            qs_arr[tt] = softmax(v)

            err_history[t, i, tt] = err
            qs_history[t, i, tt] = qs_arr[tt]


In [None]:
#Example 2: Posterior over states 

plt.matshow(qs_arr.T, vmin = 0.0, vmax = 1.0, cmap='binary')
plt.xlabel('Time')
plt.ylabel('Approximate Posterior $q(s)$')

In [None]:
# Note: the initial prior is not always added to the history when plotting
full_beliefs = []
full_beliefs.append(np.array([D] * 2).flatten())
for t in range(time_steps):
    for i in range(num_iter):
        full_beliefs.append(qs_history[t][i].flatten())
full_beliefs = np.asarray(full_beliefs)
plt.plot(full_beliefs)

plt.ylabel('Approximate Posterior Probability, $q(s_{tau})$')
plt.xlabel('Message Passing Iterations')
plt.title("Firing Rates (traces)")
plt.show()

In [None]:
states = 2
epochs = 2

event_related_potentials = []

for tau in range(4):
    epoch_gradients = []
    for i in range(epochs):
            # since we attach the original prior to the beginning of the array, we need to skip the first element
            start_offset = i * num_iter
            end_offset = i * num_iter + num_iter
            epoch_gradients.append(np.gradient(full_beliefs[1:][start_offset:end_offset, tau]))

    event_related_potentials.append(np.concatenate(epoch_gradients))

event_related_potentials = np.asarray(event_related_potentials)

event_related_potentials = np.hstack((np.zeros((4, 1)), event_related_potentials))

In [None]:
plt.plot(event_related_potentials.T)
plt.xlabel('Message Passing Iterations')
plt.ylabel('Response')
plt.title("Event-Related Potentials (ERPs) for each state")
plt.show()

# State Prediction Errors

επ,τ ←
1(
2
ln Bπ,τ −1 sπ,τ −1 + ln B†π,τ sπ,τ +1
(
− ln sπ,τ
)
(
))
+ ln AT oτ

In [None]:
A = np.array([
    [.8, .4],
    [.2, .6]
])

B_past = np.array([
    [.9, .2],
    [.1, .8]
])

B_current = np.array([
    [.2, .3],
    [.8, .7]
])

# NOTE: in shape_patterns.ipynb the normalization is done via division by the sum of the columns instead of softmax
B_T = softmax(B_current.T)

o = np.array([1, 0])

q_s = np.array([.5, .5])

q_s_past, q_s_future = q_s, q_s

v = log(q_s)

In [None]:
err = 0.5 * (log(B_past @ q_s) + log(B_T @ q_s)) + log(A.T @ o) - v
err

In [None]:
v = v + err
v

In [None]:
q_s = softmax(v)
q_s

In [None]:
B_past @ q_s_past

In [None]:
softmax(B_current.T)

## Outcome Prediction Errors

ςπ,τ = Asπ,τ · ln Asπ,τ − ln Cτ − diag AT ln A · sπ,τ

In [None]:
A = np.array([
    [.9, .1],
    [.1, .9]
])

C = np.array([1, 0])

qs_p1 = np.array([.9, .1])

qs_p2 = np.array([.5, .5])

In [None]:
A_qs_p1 = A @ qs_p1
A_qs_p1

In [None]:
A_qs_p1 @ (log(A_qs_p1) - log(C)) # expected difference between preferred outcomes

In [None]:
A_q_s_2 = A @ qs_p2
A_q_s_2

In [None]:
A_q_s_2 @ (log(A_q_s_2) - log(C)) # expected difference between preferred outcomes

In [None]:
A = np.array([
    [.4, .2],
    [.6, .8]
])

qs_p1 = np.array([.9, .1])
qs_p2 = np.array([.1, .9])

In [None]:
-1 * np.diag(A.T @ log(A)) @ qs_p1

In [None]:
-1 * np.diag(A.T @ log(A)) @ qs_p2

# 3. Building specific task models

#FIXME: copied
In the beginning of the explore–exploit task, the participant
is told that on each trial one machine will tend to pay out more
often, but they will not know which one. They are also told that
the better machine will not always be the same on each trial.
They can choose to select one right away and possibly win $4.
Or they can choose to press a button that gives them a hint about
which slot machine is better on that trial. However, if they choose
to take the hint, they can only win $2 if they pick the correct
machine. Over many trials, the participant can learn which slot
machine tends to pay out more often and either make safe or
risky choices (i.e., take the hint or not).

## Capital letters stand for the generative process and lower case letters stand for the generative model.

In [None]:
# Generative process priors

# context: both machines are equally likely to have the better outcome
D_1 = np.array([.5, .5]) # numbers in paper 
# D_1 = np.array([1, 0]) # numbers in code

# choices:
# 1. start
# 2. hint
# 3. choose left/0
# 4. choose right/1
D_2 = np.array([1, 0, 0, 0])

D = [D_1, D_2]

In [None]:
# Generative model priors

d_1 = np.array([.25, .25])

d_2 = np.array([1, 0, 0, 0])

d = [d_1, d_2]

In [None]:
# A_1 represents the observation likelihoods corresponding to the hint given a state 
# there is a outcome likelihood for each possible choice (start, hint, left, right)
# for each likelihood matrix, the rows correspond to the observations, (no hint, left is better, right is better)
# and the columns correspond to the number of context (left is better, right is better)

phA = 1 # probability of accuracy of the hint

A_1 = np.array([
    # start
    [
        [1, 1], # no hint
        [0, 0], # left is better
        [0, 0], # right is better
    ],
    # hint
    [
        [0, 0], # no hint
        [phA, 1 - phA], # left is better
        [1 - phA, phA], # right is better
    ],
    # choose left
    [
        [1, 1], # no hint
        [0, 0], # left is better
        [0, 0], # right is better
    ],
    # choose right
    [
        [1, 1], # no hint
        [0, 0], # left is better
        [0, 0], # right is better
    ],
])

# A_2 represents the observation likelihoods corresponding to the win/lose outcomes given a state
# there is a outcome likelihood for each possible choice (start, hint, left, right)
# for each likelihood matrix, the rows correspond to the observations, (undetermined, win, loss)
# and the columns correspond to the number of context (left is better, right is better)

pWin = .8 # probability of winning

A_2 = np.array([
    # start
    [
        [1, 1], # undetermined
        [0, 0], # loss
        [0, 0], # win
    ],
    # hint
    [
        [1, 1], # undetermined
        [0, 0], # loss
        [0, 0], # win
    ],
    # choose left
    [
        [0, 0], # undetermined
        [1 - pWin, pWin], # loss
        [pWin, 1 - pWin], # win
    ],
    # choose right
    [
        [0, 0], # undetermined
        [pWin, 1 - pWin], # loss
        [1 - pWin, pWin], # win
    ],
])

# A_3 represents the mapping between the behavior states and observed behaviors

A_3 = np.array([
    [
        [1, 1],
        [0, 0],
        [0, 0],
        [0, 0],
    ],
    [
        [0, 0],
        [1, 1],
        [0, 0],
        [0, 0],
    ],
    [
        [0, 0],
        [0, 0],
        [1, 1],
        [0, 0],
    ],
    [
        [0, 0],
        [0, 0],
        [0, 0],
        [1, 1],
    ],
])

A = [A_1, A_2, A_3]

In [None]:
# B_1 represents the transition probabilities between the context states. Since the agent can not change the context, the transition matrix is the identity matrix

B_1 = np.array([
    [1, 0], # left is better
    [0, 1], # right is better
])

# Note: technically only one matrix is required but to make python operations easier, we will use two matrices
B_1 = np.array([
    [
        [1, 0], # left is better
        [0, 1], # right is better
    ],
    [
        [1, 0], # left is better
        [0, 1], # right is better
    ],
])

B_2 = np.array([
    # move to the start state from any other state
    [
        [1, 1, 1, 1], # start
        [0, 0, 0, 0], # hint
        [0, 0, 0, 0], # choose left
        [0, 0, 0, 0], # choose right
    ],
    # move to the hint state from any other state
    [
        [0, 0, 0, 0], # start
        [1, 1, 1, 1], # hint
        [0, 0, 0, 0], # choose left
        [0, 0, 0, 0], # choose right
    ],
    # move to the choose left state from any other state
    [
        [0, 0, 0, 0], # start
        [0, 0, 0, 0], # hint
        [1, 1, 1, 1], # choose left
        [0, 0, 0, 0], # choose right
    ],
    # move to the choose right state from any other state
    [
        [0, 0, 0, 0], # start
        [0, 0, 0, 0], # hint
        [0, 0, 0, 0], # choose left
        [1, 1, 1, 1], # choose right
    ],
])

# B_2 = np.ones((4, 4)) / 4

B = [B_1, B_2]

# B_1.T / B_1.T.sum(axis =0)
B_future = [b.sum(axis = 0).T / b.sum(axis = 0).T.sum(axis = 0) for b in B]

# B_2_f = 

In [None]:
B_1 = np.array([
    [
        [1, 0], # left is better
        [0, 1], # right is better
    ],
    [
        [1, 0], # left is better
        [0, 1], # right is better
    ],
])

B_1.sum(axis = 0).T / B_1.sum(axis = 0).T.sum(axis = 0)

In [None]:
B_2.sum(axis = 0).T / B_2.sum(axis = 0).T.sum(axis = 0)

In [None]:
np.array([[1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]) @ np.array([1.00000000e+00, 8.27037108e-18, 8.27037108e-18, 8.27037108e-18])

In [None]:
B_future

In [None]:
B_2.sum(axis = 0).T / B_2.sum(axis = 0).T.sum(axis = 0)

In [None]:
B_1.sum(axis = 1).T / B_1.sum(axis = 1).T.sum(axis = 0)

In [None]:

# rows are observations, columns indicate time points

# no preference for observing a hint or not

C_1 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
])

# Preference for observing a win and not a loss

C_2 = np.array([
    [0, 0, 0],
    [0, -1, -1],
    [0, 4, 2]
])

# no preference for observing any action

C_3 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0],
])

C = [C_1, C_2, C_3]

In [None]:
log(softmax(C_2))

In [None]:
# allowable policies

U = np.array([
    [
        [0, 0],
        [0, 1],
        [0, 2],
        [0, 3]
    ]
])

In [None]:
# rows correspond to time points, columns correspond to context states

V = np.array([
    [
        [0, 0],
        [0, 1],
        [0, 1],
        [0, 2],
        [0, 3]
    ],
    [
        [0, 0],
        [0, 2],
        [0, 3],
        [0, 0],
        [0, 0]
    ]
])

V

In [None]:
eta = 0.5 # learning rate
omega = 1 # forgetting rate
beta = 1 # expected precision of expected free energy (G) over policies. (a positive value, with higher values indicating lower expected precision).
alpha = 32 # An 'inverse temperature' or 'action precision' parameter that controls how much randomness there is when selecting actions
erp = 1 # degree of belief resetting at each time point in a trial when simulating neural responses
tau = 12 # time constant for evidence accumulation.
zeta = 3 # Occam window policies

In [None]:
# in this function they use a smaller epsilon value for the log function
def spm_log(x):
    return np.log(x + 1e-16)

def multi_dot(a, b, dim):
    if dim == 0:
        return np.sum(a @ b[1].reshape(-1, 1), axis = 1)
    elif dim == 1:
        return np.sum(a * b[0].reshape(-1, 1), axis = 0)# np.sum(a @ b[0].reshape(-1, 1), axis = dim)
    raise ValueError("dim must be 0 or 1")

def multi_dot_no_dim(a, b):
    _X = copy.deepcopy(a)
    DIM = np.arange(len(b)) + _X.ndim - len(b)

    for i in range(len(b)):
        # print(i)
        s = np.ones(_X.ndim)
        s[DIM[i]] = len(b[i])
        _X = (_X * b[i].reshape(s.astype(int)))
        
        _X = _X.sum(axis = i + 1, keepdims = True)

    return _X.squeeze()

def sum_probs(prob_array):
    # FIXME: further understand this fn. normalisation of a probability transition matrix (columns)
    s = prob_array + 1e-16
    s = (1 / np.sum(s, axis=0) - 1/ s) / 2
    return s

def normalize(x):
    x = x / x.sum(axis = 0)
    x[np.isnan(x)] = 1/x.shape[0]
    return x

# MDP_VB_X_TUTORIAL defined & initialized
T = 3

s = np.zeros((2, 3))
o = np.zeros((3, 3))
n = np.zeros((3, 3))

# outside variables
Ns = [2, 4]
Nu = 4#[1, 4]
Np = 5
No = [3, 3, 4]
Ni = 16
Ng = 3
Nf = 2
M = [1, 1, 1]
S = 3
p = np.arange(Np) #5


#FIXME: can be replaced with capital U?
# u = np.zeros((T, 1))
u = []

F_hist = np.zeros((T, Np))
Q_hist = np.zeros((T, Np))
H_hist = np.zeros(T)

qE = spm_log(np.ones(5) / 5)

qb = beta # rate parameters
w = np.array([1/qb] * T) # posterior precision (policy)

pD = copy.deepcopy(d)
wD = [sum_probs(_d) for _d in d]

# X = copy.deepcopy(D) # TODO: maybe have to be expanded by time points

X = [np.zeros((T, *_d.shape)) for _d in D]

P = np.zeros((T, Nu))

for i, _d in enumerate(D):
    X[i][0] = _d


x = [np.ones((ns, T, Np)) / ns for ns in Ns]

for k in range(Np):
    for f in range(Nf):
        x[f][:, 0, k] = D[f]

xn = [np.zeros((Ni, ns, S, T, Np)) + 1/ns for ns in Ns]
vn = [np.zeros((Ni, ns, S, T, Np)) for ns in Ns]

wn = np.zeros(T * Ni)
un = np.zeros((T * Ni, Np))
ext_u = np.zeros((T, Np))

ext_A = []
for g in range(Ng):
    ext_A.append(normalize(np.transpose(A[g], axes=(1, 2, 0))))

ext_C = [spm_log(softmax(c)) for c in C] 
# transition probabilities



# variables initialized and used in the loop
xqq = [np.zeros_like(_d) for _d in D]
xq = [np.zeros_like(_d) for _d in D]

# may have to be expanded to fit factors
L = np.ones((T, 2, 4))


np.random.seed(42)

# belief updating over successive time points
for t in range(T):

    # sample state, if not specified
    for f in range(Nf):

        if s[f, t] == 0:
    
            if t > 0:
                #TODO: check if out put is correct
                ps = B[f][u[t -1][f]][:,int(s[f, t - 1])] #B[f][int(s[f, t - 1])][u[t -1][f]]
            else:
                
                ps = D[f] / D[f].sum() # ensure ps is normalized
            # sample state
            s[f, t] = np.argmax(np.random.rand() < np.cumsum(ps)) # FIXME: check if argmax is correct

    # posterior predictive density over hidden (external) states
    for f in range(Nf):
        # under selected action
        if t > 0:
            xqq[f] = B[f][u[t -1][f]] @ X[f][t - 1]
        else:
            xqq[f] = X[f][t]
        # Bayesian model average (xq)
        xq[f] = X[f][t]

    # sample outcome if not specified
    for g in range(Ng):
        # if outcome not specified
        if not o[g, t]:
            if n[g, t]: # outcome is generated by model n if not 0
                pass
            else: # or sampled from likelihood given hidden states
                ind = s[:, t]
                # po = A[g][int(ind[0]), :, int(ind[1])]
                po = A[g][int(ind[1]), :, int(ind[0])]
                o[g, t] = np.argmax(np.random.rand() < np.cumsum(po))
                pass

    O = [] #TODO: May initialize at beginning for all time points
    # get outcome likelihoods
    for g in range(Ng):
        # specified as the sampled outcome
        # TODO: this most likely should be the likelihood of an outcome (matlab using index 1 can be misleading of how their function accomplishes this)
        os = np.zeros(No[g])
        os[int(o[g, t])] = 1
        O.append(os)

    # likelihood of hidden states
    for g in range(Ng):
        L[t] *= np.tensordot(np.transpose(A[g], axes=(2, 0, 1)), O[g], axes=([2], [0]))

    # TODO: remove and find a better way to handle this
    if t == 2:
        # swap the first and second row
        L[t] = L[t][[1, 0]]

    # eliminate unlikely policies
    if len(u) - 1 < t  and t > 0: # TODO: len u has to be changed (should check if U is defined)
        F = np.log(ext_u[t - 1]) 
        p = p[(F - np.max(F)) > -zeta]

    for f in range(Nf):
        x[f] = softmax(spm_log(x[f]) / erp)

    S = V.shape[0] + 1 # horizon

    R = S

    F = np.zeros(Np)
    for k in p: # loop over plausible policies
        dF = 1 # reset criterion for this policy
        for i in range(Ni): # iterate belief updates
            F[k] = 0  # reset free energy for this policy
            for j in range(S): # loop over future time points
                # current posterior over outcome factors
                if j <= t:
                    for f in range(Nf):
                        xq[f] = x[f][:, j, k]

                for f in range(Nf):
                    # hidden states for this time and policy
                    sx = x[f][:, j, k]
                    qL = np.zeros(Ns[f])
                    v = np.zeros(Ns[f])

                    #if f == 0:print(sx, t, k, j, f) # print(sx, t, k, f"{i:<2}", j, f)
                    # t == 2 and k ==1 and i ==2 and j == 0 and f == 0
                    # t == 3 && k == 2 && i == 3 && j == 1 && f == 1
                        
                    # evaluate free energy and gradients
                    if dF > np.exp(-8) or i > 3: 
                        # marginal likelihood over outcome factors
                        if j <= t:
                            # FIXME: why does current f dimension get removed?
                            # qL = xq[abs(1 - f)] @ L[j].T 
                            qL = multi_dot(L[j], xq, f)
                            qL = spm_log(qL)

                        # entropy
                        qx = spm_log(sx) # FIXME: qx values are not the same as in the matlab code

                        # empirical priors (forward messages)
                        if j < 1:
                            px = spm_log(D[f])
                            v = v + px + qL - qx
                        else:
                            #TODO: investigate why B and B_future are defined as they are
                            px = spm_log(B[f][V[j - 1, k, f]] @ x[f][:, j-1, k])
                            # px = np.dot(spm_log(sB[f][:, :,V[j-1, k, f]]), x[f][:, j-1, k])
                            v = v + px + qL - qx

                        # empirical priors (backward messages)
                        if j < R - 1:
                            px = spm_log(B_future[f].T @ x[f][:, j+1, k])
                            v = v + px + qL - qx

                        # (negative) free energy
                        if j == 0 or j == S - 1:
                            F[k] = F[k] + np.dot(sx.T, v * 0.5)
                        else:
                            F[k] = F[k] + np.dot(sx.T, v * 0.5 - (Nf - 1) * qL / Nf)

                        # update posterior
                        v = v - np.mean(v)
                        sx = softmax(qx + v / tau)
                    
                    else:
                        F[k] = G[k]

                    x[f][:, j, k] = sx
                    xq[f] = sx
                    xn[f][i, :, j, t, k] = sx
                    vn[f][i, :, j, t, k] = v

            if i > 0:
                dF = F[k] - G[k]
            
            G = F.copy()

    # accumulate expected free energy over policies (Q)
    pu = 1 # empirical prior
    qu = 1 # posterior
    Q = np.zeros(Np) # expected free energy over policies

    if Np > 0: #FIXME: check if this should be 1
        for k in p:
            # bayesian surprise about initial conditions
            for f in range(Nf):
                Q[k] = Q[k] - wD[f] @ x[f][:, 0, k] # FIXME: should be spm_dot

            pass

            for j in range(t, S):
                # get expected states for this policy and time point
                for f in range(Nf): # TODO: check if required
                    xq[f] = x[f][:, j, k]

                # (negative) expected free energy

                # bayesian surprise about states
                _qx = xq[0].reshape(-1, 1) * xq[1] 

                #>> compute G
                G = 0 
                qo = 0 # Fixme qo should be calculated with spm dot
                # for i in np.where(_qx > np.exp(-16))[0]:
                for r_i, c_i in zip(*np.where(_qx > np.exp(-16))):
                    po = 1
                    for g in range(len(ext_A)):
                        # kx = np.transpose(ext_A[g], axes=(2, 0, 1))[0, ..., i]
                        kx = ext_A[g][..., r_i, c_i]
                        reshape_last_dim = [1] * (po.ndim) if isinstance(po, np.ndarray) else []
                        reshape_arr = [-1] + reshape_last_dim
                        po = po * kx.reshape(reshape_arr)

                    po = po.flatten()
                    qo = qo + _qx[r_i][c_i] * po
                    G = G + _qx[r_i][c_i] * po @ np.log(po + np.exp(-16))  

                G = G - qo @ np.log(qo + np.exp(-16))
                # << Compute G
                Q[k] = Q[k] + G



                for g in range(Ng):
                    qo = multi_dot_no_dim(ext_A[g], xq)

                    Q[k] =  Q[k] + qo.flatten() @ spm_log(softmax(C[g]))[:,j].reshape(-1, 1)

                pass

    # previous expected precision
    if t > 0:
        w[t] = w[t - 1] # FIXME: w[t] not the same to matlab code

    for i in range(Ni):
        qu = softmax(qE[p] + w[t] * Q[p] + F[p]) #FIXME: tiny difference in outcome of Q
        pu = softmax(qE[p] + w[t] * Q[p])

        # precision (w) with free energy gradients (v = -dF/dw)
        eg = (qu - pu) @ Q[p]
        dFdg = qb - beta + eg
        qb = qb - dFdg / 2
        w[t] = 1/qb

        # simulated dopamine responses (expected precision)
        _n = t * Ni + i #TODO: originally was t - 1
        wn[_n] = w[t]
        un[_n][p] = qu
        ext_u[t][p] = qu


    # bayesian model averaging (over policies)
    for f in range(Nf):
        for i in range(S):
            X[f][i] = x[f][:, i, :] @ ext_u[t] #FIXME: might be cause for troubles

    
    # processing (reaction) time
    F_hist[t] = F
    Q_hist[t] = Q
    H_hist[t] = qu @ F[p] - qu @ (spm_log(qu) - spm_log(pu))


    if t < T -1:

        Pu = np.zeros(Nu)

        for i in range(Np):
            Pu[V[t, i, 1]] += ext_u[t][i] # still tiny differences in outcome

        Pu = softmax(np.log(Pu) * alpha)
        P[t] = Pu

        if len(u) > t: # TODO: maybe should be filled with nan instead?
            u[t] = u[t] # unecessary assignment
        else:
            idx = np.argmax(np.random.rand() < np.cumsum(Pu))
            u.append(np.unravel_index(idx, [1, Nu]))

        pass

    #TODO: There's code to check if U is used
            
# if t == T:
#     if T == 1:
#         u = np.zeros((T, 1))
#         o = None
#         s = None
#         u = None





In [None]:
from scipy import special

def betaln(x):
    # only compute the gammaln for elements greater than 0
    x = x[x > 0]
    return special.gammaln(x).sum() - special.gammaln(x.sum())

def psi(x):
    return special.psi(x) - special.psi(x.sum())


# learning - accumulate concentration parameters
# X is different 
# initial hidden states
for f in range(Nf):
    i = d[f] > 0
    # d[f][i] = d[f][i] * omega + X[f][0, i] * eta # TODO: (seems like correct output [(1.0000e+00, 1.2525e-08), (1)])
    d[f] = (d[f] * omega + X[f][0] * eta) * i # TODO: (seems like correct output [(1.0000e+00, 1.2525e-08), (1)])
    print(X[f][0, i])

Fd = np.zeros(Nf)
# (negative) free energy of parameters: state specific
# compute KL divergence between two dirichlet distributions
for f in range(Nf):
    Fd[f] = - ( betaln(pD[f]) - betaln(d[f]) - np.sum((pD[f] - d[f]) * psi(d[f] + 1/32))) 
    # 0.2473
    # -0.04099317344047182

# simulated dopamine (or cholinergic) responses
# wn is different
if Np > 1:
    dn = 8 * np.gradient(wn) + wn / 8 

# Bayesian model averaging of expected hidden states over policies
Xn_arr = []
Vn_arr = []
for f in range(Nf):
    Xn = np.zeros((Ni, Ns[f], T, T))
    Vn = np.zeros((Ni, Ns[f], T, T))

    for t in range(T):
        for k in range(Np):
            Xn[:, :, :, t] = Xn[:, :, :, t] + xn[f][:, :, :, t, k] * ext_u[t, k]
            Vn[:, :, :, t] = Vn[:, :, :, t] + vn[f][:, :, :, t, k] * ext_u[t, k]

    Xn_arr.append(Xn)
    Vn_arr.append(Vn)

# number of belief updates (T)
# number of outcomes (O)
# probability of action at time point (P)
# conditional expectations over policies ((ext_)u) -> R
# conditional expectations over N states (x -> Q)
# Bayesian model averages over T outcomes (X)
# utility (C)
# posterior expectations of precision (policy) (w)
# simulated neuronal prediction error (Vn)
# simulated neuronal encoding of hidden states (Xn)
# simulated neuronal encoding of policies (un)
# simulated neuronal encoding of precision (wn)
# simulated dopamine responses (deconvolved) (dn)
# simulated reaction time (seconds) (rt)
pass

In [None]:
i