In [1]:
import jax
import jax.numpy as jnp
#import tqdm
key = jax.random.PRNGKey(42)

# JAX Tutorial on Active Inference and the Free Energy Principle

## Discrete Case

In [2]:
n_reward_states = 2
n_actions = 4
n_states = 4 + n_reward_states

In [3]:

def set_reward_and_cue(key, left_prob = 0.5):
    """
    The reward and cue for the T-maze. Left prob is the chance that the reward is on the left side.
    The reward is negative (i.e. punishment) for the other side.
    Returned reward array is 0-padded so that center and bottom of the T-maze have 0 reward, 
    and so that cue states don't *directly* affect reward.
    The cue is a 2D one-hot encoding of the side that the reward is on.
    """
    # reward is 1/-1 depending on prob
    reward_left = jax.random.bernoulli(key, left_prob,) * 2 - 1
    reward_right = 1 - 2 * reward_left
    reward = jnp.array([reward_left, reward_right, 0, 0, 0, 0])
    cue = jnp.array([1, 0]) if reward_left == 1 else jnp.array([0, 1])
    return reward, cue

print(set_reward_and_cue(key))


(Array([ 1, -1,  0,  0,  0,  0], dtype=int32), Array([1, 0], dtype=int32))


In [4]:
def get_state_obs_transiton(n_states, n_actions, cue):
    """
    Function that returns the transition matrix from world states to agent observations.
    """
    transition = jnp.eye(n_actions)
    # zero pad to 6x6
    transition = jnp.pad(transition, ((0, n_states-n_actions),
                                      (0, n_states-n_actions)))
    # if agent moves to bottom, it sees the cue
    transition = transition.at[4:, 3].set(cue)
    return transition
print(get_state_obs_transiton(n_states, n_actions, jnp.array([1, 0])))
print(get_state_obs_transiton(n_states, n_actions, jnp.array([0, 1])))

[[1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]
[[1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]]


In [5]:
def get_action_state_transition(n_states, n_actions):
    """
    Function that returns the transition matrix from actions to world states.
    """
    base_matrix = jnp.zeros((n_states, n_states))
    # reward unchanged by action
    base_matrix = base_matrix.at[4:, 4:].set(jnp.eye(2))
    # duplicate for number of actions
    base_matrix = jnp.stack([base_matrix for _ in range(n_actions)])
    for i in range(n_actions):
        # ones for the first n_actions columns
        mask = jnp.arange(base_matrix.shape[-1]) < n_actions
        base_matrix = base_matrix.at[i, i, :].set(mask)
    return base_matrix

# matrix is fixed so define here
action_state_transition = get_action_state_transition(n_states, n_actions)
print(action_state_transition[0])

[[1. 1. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1.]]


In [6]:
# prior about reward states, uniform
d = jnp.ones(n_states) / n_states
# 4D location 1-hot with 2D cue one-hot
observation_state = jnp.array([0, 0, 1, 0, 0, 0])


Our agents will be forming models of the environment they are in. Essentially, they will try to build models of the state transition and observation transition matrices from above. 
#TODO : write on conjugate priors, Categorical distributions and dirchlet

In [7]:
class Agent:
    def __init__(self,
                 n_states,
                 n_actions,
                 action_state_transition):
        self.n_states = n_states
        self.n_actions = n_actions
        self.n_reward_states = n_reward_states
        self.action_state_transition = action_state_transition

        # prior on the state of the world
        self.state_prior = None
        self.prev_state_prior = None
        self.reset_state_prior()

        # concentration for action-state transition
        self.c_action_state = jnp.ones_like(action_state_transition)
        # concentration for state-observation transition
        self.c_state_obs = jnp.ones_like(get_state_obs_transiton(n_states, n_actions, jnp.array([1, 0])))
        # use concentration to generate priors
        self.action_transition_prior = jax.random.dirichlet(key, self.c_action_state)
        self.obs_transition_prior = jax.random.dirichlet(key, self.c_state_obs)

    def reset_state_prior(self):
        """
        Reset the state prior (assumed to be known as center of T-maze)
        """
        self.state_prior = jnp.array([0, 0, 1, 0, 0, 0])

    def update_state_belief(self, action_idx, next_observation):
        state_belief = self.state_prior.copy()
        self.prev_state_prior = state_belief.copy()
        action_transition = self.action_transition_prior[action_idx]

        # (prior) state belief given an action
        state_belief = action_transition @ state_belief
        # expected observation given model of state->observation transition
        likelihood = jnp.dot(next_observation, self.obs_transition_prior) #TODO : this might use self.obs_transition_prior @ state_belief
        # update state belief
        state_belief = jnp.log(likelihood) + jnp.log(state_belief)
        state_belief = jax.nn.softmax(state_belief)

        self.state_prior = state_belief
        return state_belief
    
    def update_concentration(self, action_idx, observation):
        """
        Update the concentration parameters for the transition models.
        """
        state_obs_update = observation[:, None] @ self.state_prior[None, :]
        self.c_state_obs += state_obs_update

        # state prior outer product
        state_prior_outer = self.state_prior[:, None] @ self.prev_state_prior[None, :]


In [12]:
tmp = jnp.ones(n_states)
tmp2 = jnp.ones((n_states, n_states))

jnp.dot(tmp, tmp2)

Array([6., 6., 6., 6., 6., 6.], dtype=float32)

In [11]:
action_state_prior.sum(axis=-1)

Array([[1.0000001 , 1.        , 1.        , 0.99999994, 0.99999994,
        1.        ],
       [1.        , 1.0000001 , 1.        , 0.99999994, 1.        ,
        1.        ],
       [0.99999994, 1.0000001 , 1.        , 1.        , 1.        ,
        0.9999999 ],
       [1.0000001 , 1.        , 1.        , 1.        , 1.        ,
        1.0000001 ]], dtype=float32)

## Continuous Case

# Deep Active Inference