# <b> active-pynference </b> : Sophisticated Inference with Jax !

<b>actynf</b> implements Sophisticated Inference with jax functions ! This allows us to beneficiate from the Just-In-Time compilation, auto-vectorization and auto-differntiation abilities of the package. This notebook is used to compare the results of the [SPM12's implementation of sophisticated inference](https://github.com/spm/spm/blob/main/toolbox/DEM/spm_MDP_VB_XX.m), the *numpy* implementation of this package, and the *jax* implementation of this package.

**Note :** Writing in Jax comes with a number of constraints that don't exist in classical Python (see [this page](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) for the common pitfalls). This means that jax_actynf does not do exactly the same operations as classical sophisticated inference implementations (tree pruning, dynamic variable-dependent conditionning, etc.). Depending on your goal, it may be more interesting to switch to a Jax-based model , or remain in a classical (numpy) based environment. We give a few details regarding this point by the end of this notebook.


## 1 . Environment : simple T-maze

We'll be using a close analog to the T-maze environment here. The [basics about the T-maze environment](Tmaze_demo.ipynb) remain the same, but clue and reward modalities are fused together. The MDP weights for this situation are available in [this](../actynf/demo_tools/tmaze/weights.py) file.

In [1]:
import numpy as np

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
from functools import partial


import actynf
print("Actynf version : " + str(actynf.__version__))
from local_demo_tools.tmaze.weights import get_T_maze_gen_process,get_T_maze_model,get_jax_T_maze_model

Actynf version : 0.1.39


ModuleNotFoundError: No module named 'demos'

Just like in the other demo, the weights depend on scalar parameters describing the properties of the true environment as well as the initial model of the artificial mouse : 

For the environment (process) :
- $p_{init}$ is the probability of the reward being on the right at the beginning.
- $p_{HA}$ is the probability of the clue showing the right (resp. left) when the reward is on the right (resp.left).
- $p_{win}$ is the probability of getting a positive (resp. adversive) stimulus when picking the reward (resp. the shock).

For the mouse model : 
- $p_{HA}$ is the mouse belief about the mapping of the clue
- *initial_hint_confidence* is the strenght of this belief
- $la$,$rs$ are the agent priors about receiving adversive vs positive stimuli.
- $p_{win}$ is the mouse belief about probability of getting a positive (resp. negative) stimulus when picking the reward (resp. the shock).
- *context_belief* is where the mouse thinks the reward spawns at the beginning of each trial.

In [None]:
T = 3
Th = 2

# Those weights will remain the same for the whole notebook :
true_process_pHA = 1.0
true_process_pinit = 1.0
true_process_pwin = 0.98  # For a bit of noise !

true_A,true_B,true_D,U = get_T_maze_gen_process(true_process_pinit,true_process_pHA,true_process_pwin)


true_model_pHA = 1.0
true_model_pwin = 0.98
true_model_context_belief = 0.5
true_model_hint_conf = 2.0
true_model_la = -4.0
true_model_rs = 2.0
true_alpha = 16.0

true_a,true_b,true_c,true_d,true_e,_ = get_T_maze_model(true_model_pHA,true_model_pwin,true_model_hint_conf,
                                        true_model_la,true_model_rs,true_model_context_belief)


### 1. Simulating Active Inference agents trials with SPM, actynf_classic and actynf_jax

We compare the outcomes of 3 implementation of the same sophisticated inference algorithm, with a few noteable differences : 
- The original Matlab script, using a recursive (+ tree pruning) approach
- Our classical Python implementation
- Our Jax-based implementation



1. For the first results, we used modified versions of the original SPM files spm_MDP_VB_XX.m and DEM_demo_MDP_XX.m. These modified files are available [here](./SPM_XX/). You can run them yourself with the same options as in this notebook and compare the displayed results. 
For the values above, we get : 

```
>> DEM_demo_MDP_XX
Computed EFE for t = 1
  -11.4325  -12.8459  -12.8459   -9.5125

Computed EFE for t = 2
   -6.5111  -10.3886   -4.6286   -6.5111

Computed EFE for t = 3
   -1.3863   -1.3863   -1.3863   -1.3863
```

Let's now make the same computation with the classical *actynf* SI paradigm, and then see the actynf_jax result !

In [None]:
# 2. Classical actynf simulations : 
from actynf import layer,link,layer_network

def get_tmaze_net_classic():
    # The T-maze environment : 
    process_layer = layer("T-maze_environment","process",true_A,true_B,None,true_D,None,U,T)
    
    
    # The mouse model :
    model_layer = layer("mouse_model","model",true_a,true_b,true_c,true_d,true_e,U,T,T_horiz=Th)
    # This time, we define our layer as a "model" 

    # Here, we give a few hyperparameters guiding the beahviour of our agent :
    model_layer.hyperparams.alpha = true_alpha # action precision : 
        # for high values the mouse will always perform the action it perceives as optimal, with very little exploration 
        # towards actions with similar but slightly lower interest

    model_layer.learn_options.eta = 1.0 # learning rate (shared by all channels : a,b,c,d,e)
    model_layer.learn_options.learn_a = True  # The agent learns the reliability of the clue
    model_layer.learn_options.learn_b = False # The agent does not learn transitions
    model_layer.learn_options.learn_d = True  # The agent has to learn the initial position of the cheese
    model_layer.learn_options.backwards_pass = True  # When learning, the agent will perform a backward pass, using its perception of 
                                               # states in later trials (e.g. I saw that the cheese was on the right at t=3)
                                               # as well as what actions it performed (e.g. and I know that the cheese position has
                                               # not changed between timesteps) to learn more reliable weights (therefore if my clue was
                                               # a right arrow at time = 2, I should memorize that cheese on the right may correlate with
                                               # right arrow in general)
    model_layer.learn_options.memory_loss = 0.0
                                            # How many trials will be needed to "erase" 50% of the information gathered during one trial
                                            # Used during the learning phase. 0.0 means that the mouse doens't forget
                                            # anything.

    
    
    # Create a link from observations generated by the environment to the mouse sensory states :
    model_layer.inputs.o = link(process_layer, lambda x : x.o)
    #     the layer from which we get the data | the function extracting the data

    # Create a link from the actions selected by the mouse to the t-maze environment :
    process_layer.inputs.u = link(model_layer,lambda x : x.u)
    
    return layer_network([process_layer,model_layer],"t-maze_network")


# 3. The jax results. 
# in the jax implementation, planning, action selection and learning options
# are stored in dictionnaries. 
from actynf.jax_methods.layer_options import get_planning_options,get_action_selection_options,get_learning_options
from actynf.jax_methods.layer_training import synthetic_training

def get_tmaze_net_jaxified(Ntrials):
    # Training options :
    Sh = 2                      # State horizon (or the number of inidividual states that will create their own branch)
    remainder_state_bool = True # Do we create an additional branch with the remaining potential state density ?
    Ph = 4                      # Policy horizon (or the number of individual actions that will be explored at each node)
    option_a_nov = False
    option_b_nov = False
    additional_options_planning = False    

    planning_options = get_planning_options(Th,"sophisticated",
                            Sh,Ph,remainder_state_bool,
                            option_a_nov,option_b_nov,additional_options_planning)

    as_options  =get_action_selection_options("stochastic",alpha=16)

    learn_options = get_learning_options(True,False,True,run_smoother=True)

    training_parrallel_func = partial(synthetic_training,
        Ntrials=Ntrials,T=T,
        A=true_A,B=true_B,D=true_D,U=U,
        a0=true_a,b0=true_b,c=true_c,d0=true_d,e=true_e,
        planning_options=planning_options,
        action_selection_options = as_options,
        learning_options = learn_options)
    
    return training_parrallel_func

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import time 
# Classic version computation time ! 
t0_classic = time.time()
tmaze_net_classic = get_tmaze_net_classic()
stm,weights = tmaze_net_classic.run()
computed_efe_classic = np.sum(stm[1].Gd,axis=1)
delta_t_classic = time.time() - t0_classic


# Jax version computation time ! 
rngkey_training = jr.PRNGKey(300)  # Jax requires a PRNG key before generating pseudo random numbers
                                # This will be used in agent action selection and process outcome generation

t0_jax = time.time()
tmaze_net_jax = get_tmaze_net_jaxified(1) # For a single trial for now !
# Let's ensure the function is compiled !
tmaze_net_jax_jitted = jax.jit(tmaze_net_jax)
[all_obs_arr,all_true_s_arr,all_u_arr,
    all_qs_arr,all_qs_post,all_qpi_arr,efes_arr,
    a_hist,b_hist,d_hist] = tmaze_net_jax_jitted(rngkey_training)
computed_efe_jax = efes_arr.block_until_ready()
delta_t_jax = (time.time() - t0_jax)

print("-------------------------------------------------")
print("EFE computed by the classical(numpy) actynf implementation : ")
print(computed_efe_classic.T)
print(f"It took {delta_t_classic:.3f} seconds.")
print("\n###\n")
print("EFE computed by the jax actynf implementation : ")
print(computed_efe_jax)
print(f"It took {delta_t_jax:.3f} seconds.")
print("-------------------------------------------------")

 Network [t-maze_network] : Timestep 3 / 3
 Done !   -------- (seeds : [3470-0;9143-0])
-------------------------------------------------
EFE computed by the classical(numpy) actynf implementation : 
[[  -11.44075064   -12.86416348   -12.86416348    -9.50241543]
 [   -6.51113554    -4.63113554 -1009.00484118    -6.51113554]]
It took 0.011 seconds.

###

EFE computed by the jax actynf implementation : 
[[[-11.44075   -12.918507  -12.918507   -9.51523  ]
  [ -6.5111356  -4.631136  -10.391135   -6.5111356]]]
It took 2.476 seconds.
-------------------------------------------------


The results match pretty closely ! Our planning algorithms also match the results of SPM's Sophisticated Inference method ! Huzzah :D !

However ... it took a while to process this very basic example ! Is it normal ? Isn't jax supposed to allow faster processing than classical numpy ?

[The answer isn't that clear cut](https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy), but the most significant discrepancy can be explained by the compilation of the code, as illustrated below :

In [None]:
Ntrials = 10
tmaze_net_jax_jitted_new = jax.jit(get_tmaze_net_jaxified(Ntrials))  # Get a new function to avoid using a previously compiled function

print(">>>>>>>>>>>>  With just in time compilatrion (JIT) <<<<<<<<<<<<<<")
print("--- First computation ---")
t0_jax = time.time()
key = rngkey_training
efe_value  = tmaze_net_jax_jitted_new(rngkey_training)[6].block_until_ready()
delta_t_jax_run1 = (time.time() - t0_jax)
print(f"It took {delta_t_jax_run1:.3f} seconds.")

print("--- Subsequent computations ---")
infer_times = []
Nsamples = 1000
for t in range(Nsamples):
    t0_jax = time.time()
    rngkey_training = jr.PRNGKey(np.random.randint(0,100))  # Jax requires a PRNG key before generating pseudo random numbers
                                # This will be used in agent action selection and process outcome generation
    [all_obs_arr,all_true_s_arr,all_u_arr,
    all_qs_arr,all_qs_post,all_qpi_arr,efes_arr,
    a_hist,b_hist,d_hist] = tmaze_net_jax_jitted_new(rngkey_training)
    computed_efe_jax = efes_arr.block_until_ready()
    
    infer_times.append(time.time() - t0_jax)

avg_deltat = 1e3*np.mean(infer_times)
std_deltat = 1e3*np.std(infer_times)
print(f"It took on average {avg_deltat:.4f} ms (\u00B1 {std_deltat:.5f} ms). ({Nsamples} samples)")


print(">>>>>>>>>>>>  Without just in time compilatrion (JIT) <<<<<<<<<<<<<<")
print("--- First computation ---")
tmaze_nojit = get_tmaze_net_jaxified(1)
t0_jax = time.time()
key = rngkey_training
efe_value  = tmaze_nojit(rngkey_training)[6].block_until_ready()
delta_t_jax_run1 = (time.time() - t0_jax)
print(f"It took {delta_t_jax_run1:.3f} seconds.")
print("--- Subsequent computations ---")
infer_times = []
Nsamples = 100
for t in range(Nsamples):
    t0_jax = time.time()
    rngkey_training = jr.PRNGKey(np.random.randint(0,100))  # Jax requires a PRNG key before generating pseudo random numbers
                                # This will be used in agent action selection and process outcome generation
    [all_obs_arr,all_true_s_arr,all_u_arr,
    all_qs_arr,all_qs_post,all_qpi_arr,efes_arr,
    a_hist,b_hist,d_hist] = tmaze_nojit(rngkey_training)
    computed_efe_jax = efes_arr.block_until_ready()
    
    infer_times.append(time.time() - t0_jax)

avg_deltat = 1e3*np.mean(infer_times)
std_deltat = 1e3*np.std(infer_times)
print(f"It took on average {avg_deltat:.4f} ms (\u00B1 {std_deltat:.5f} ms). ({Nsamples} samples)")

>>>>>>>>>>>>  With just in time compilatrion (JIT) <<<<<<<<<<<<<<
--- First computation ---
It took 2.668 seconds.
--- Subsequent computations ---
It took on average 1.7435 ms (± 0.44742 ms). (1000 samples)
>>>>>>>>>>>>  Without just in time compilatrion (JIT) <<<<<<<<<<<<<<
--- First computation ---
It took 2.214 seconds.
--- Subsequent computations ---
It took on average 2093.0567 ms (± 98.25488 ms). (100 samples)


The runtime for the already compiled function seems much more interesting ! The lesson here is to avoid recompilating too often when using the jaxified version ;) !

And as you can see, jitting the functions play a huge part in allowing faster computations. Running multiple agents in parrallel is now much more feasible and this allows us to simulate multiple inference and learning schemes using SI !

### What's missing

At some point, we should compare training-scale results (i.e. how the agents learn across multiple trials). There are some differences in implementation here so the results may not be that close.





### Going further 

We can use this new implementation as well as the auto-differentiable capabilities of the Jax framework to work towards more extensive parameter estimation scheme ! In [the following notebook](jax_Tmaze_inversion.ipynb), we perform model inversion on the T-maze model !

