In [1]:

# Import the needed packages 
# 
# 1/ the usual suspects
import sys,os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
from jax.tree_util import tree_map

from functools import partial

# 2/ The Active Inference package 
import actynf
from actynf.jaxtynf.jax_toolbox import _normalize,_jaxlog
from actynf.jaxtynf.layer_trial import compute_step_posteriors
from actynf.jaxtynf.layer_learn import learn_after_trial
from actynf.jaxtynf.layer_options import get_learning_options,get_planning_options
from actynf.jaxtynf.shape_tools import to_log_space,get_vectorized_novelty
from actynf.jaxtynf.shape_tools import vectorize_weights

# Weights for the active inference model : 
from simulate.hmm_weights import basic_latent_model

import tensorflow_probability.substrates.jax.distributions as tfd

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
# Old version :
def compute_log_probs(params_list, prior_dist_list):
    # Check that the lists have the same length
    assert len(params_list) == len(prior_dist_list), "Lists must have the same length"
    
    log_probs = []
    
    # Compute log_prob for each parameter and distribution pair
    for param, dist in zip(params_list, prior_dist_list):
        # Use the log_prob method from TensorFlow Probability distributions
        log_prob = dist.log_prob(param)
        log_probs.append(log_prob)
    
    # Convert the list of log probabilities into a JAX array (tensor)
    return jnp.array(log_probs)


prior_distributions = [tfd.Normal(5.0,1.0),tfd.Beta(10.0,10.0)]
log_prior_param_func = lambda __params : jnp.sum(compute_log_probs(__params,prior_distributions))
grad_func = jax.value_and_grad(log_prior_param_func)



# Better version :
def compute_log_prob(_it_param,_it_prior_dist):
    _mapped = tree_map(lambda x,y : y.log_prob(x),_it_param,_it_prior_dist)
    
    if isinstance(_mapped,dict):
        _mapped = list(_mapped.values())
    
    _params_lp = jnp.stack(_mapped)
    return jnp.sum(_params_lp),_params_lp



dict_ = {
    "a" : 0.5,
    "b" : 0.3
}

dict_lp = {
    "a" : tfd.Normal(5.0,1.0),
    "b" : tfd.Beta(10.0,10.0)
}
print(compute_log_prob(dict_,dict_lp))


list_ =  [0.5,0.3]
list_lp = [tfd.Normal(5.0,1.0),tfd.Beta(10.0,10.0)]
print(compute_log_prob(list_,list_lp))

(Array(-11.353539, dtype=float32), Array([-11.043939  ,  -0.30960083], dtype=float32))
(Array(-11.353539, dtype=float32), Array([-11.043939  ,  -0.30960083], dtype=float32))
[1.4 0.5] -7.659359
[1.76 0.5 ] -6.1393585
[2.084 0.5  ] -4.908159
[2.3756 0.5   ] -3.9108868
[2.63804 0.5    ] -3.1030965
[2.874236 0.5     ] -2.4487863
[3.0868125 0.5      ] -1.9187951
[3.2781312 0.5      ] -1.489502
[3.450318 0.5     ] -1.1417749
[3.6052864 0.5      ] -0.860116
[3.7447577 0.5      ] -0.63197196
[3.870282 0.5     ] -0.4471755
[3.9832537 0.5      ] -0.29749036
[4.0849285 0.5      ] -0.17624533
[4.1764355 0.5      ] -0.078036785
[4.258792 0.5     ] 0.0015118122
[4.332913 0.5     ] 0.06594646
[4.3996215 0.5      ] 0.11813855
[4.4596596 0.5      ] 0.16041398
[4.513694 0.5     ] 0.1946572
[4.5623245 0.5      ] 0.22239423
[4.606092 0.5     ] 0.24486125
[4.645483 0.5     ] 0.26305938
[4.680935 0.5     ] 0.27779996
[4.7128415 0.5      ] 0.28973985
[4.741557 0.5     ] 0.29941112
[4.767401 0.5     ] 0.3072

In [None]:
# Test grad for lists : 
list_lp = [tfd.Normal(5.0,1.0),tfd.Beta(10.0,10.0)]
encoder = lambda x : list(x)
graded_function = jax.value_and_grad(lambda x  : compute_log_prob(encoder(x),list_lp)[0])

params = jnp.array([1.0,0.5])
lr = 0.1
for it in range(100):
    log_prior_value,gradient = graded_function(params)
    
    params = jnp.clip(params + lr*jnp.array(gradient),1e-10)
    
    print(params,log_prior_value)
print(params)


# Test grad for dicts : 
dict_lp = {
    "a" : tfd.Normal(5.0,1.0),
    "b" : tfd.Beta(10.0,10.0)
}
encoder = lambda x : {"a":x[0],"b":x[1]}
graded_function = jax.value_and_grad(lambda x  : compute_log_prob(encoder(x),dict_lp)[0])

params = jnp.array([1.0,0.5])
lr = 0.1
for it in range(100):
    log_prior_value,gradient = graded_function(params)
    
    params = jnp.clip(params + lr*jnp.array(gradient),1e-10)
    
    print(params,log_prior_value)
print(params)