# Evidence estimation in a Bayesian cognitive model


In [1]:
%load_ext autoreload
%autoreload 2

import os

# For JAX, we can use CPU (set SELECTED_DEVICE to empty (''), or one of the available GPUs)
SELECTED_DEVICE = '9'
print(f'Setting CUDA visible devices to [{SELECTED_DEVICE}]')
os.environ['CUDA_VISIBLE_DEVICES'] = f'{SELECTED_DEVICE}'

Setting CUDA visible devices to [9]


In [2]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import jax
jax.config.update("jax_enable_x64", True)  # Do we need this here? -> it seems we do for the LML computations (otherwise NaNs get introduced), but not for performance

import jax.random as jrnd
import jax.numpy as jnp
import distrax as dx
import blackjax
import pandas as pd
import jax.scipy.special as jsp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import pyreadr as pr

from distrax._src.distributions.distribution import Distribution

import os
import sys
import requests

from blackjax import normal_random_walk

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../')))

import bamojax
from bamojax.base import Node, Model
from bamojax.sampling import gibbs_sampler, inference_loop, run_chain, smc_inference_loop

print('Python version:       ', sys.version)
print('Jax version:          ', jax.__version__)
print('BlackJax version:     ', blackjax.__version__)
print('Distrax version:      ', dx.__version__)
print('BaMoJax version:      ', bamojax.__version__)
print('Jax default backend:  ', jax.default_backend())
print('Jax devices:          ', jax.devices())

SMALL_SIZE = 14
MEDIUM_SIZE = 16
LARGE_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=LARGE_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=LARGE_SIZE)  # fontsize of the figure title

Python version:        3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
Jax version:           0.4.35
BlackJax version:      1.2.4
Distrax version:       0.1.5
BaMoJax version:       0.0.1
Jax default backend:   gpu
Jax devices:           [CudaDevice(id=0)]


In [3]:
def download_to_disk(url, filepath):
    response = requests.get(url)
    if response.status_code == 200:
        with open(filepath, 'wb') as file:
            file.write(response.content)
        print('File downloaded successfully!')
    else:
        print(f'Failed to download the file. Status code: {response.status_code}')

#
data_busemeyer_url = 'https://osf.io/download/5vws6/'  # DataBusemeyerNoNA.rdata on https://osf.io/f9cq4/; contains IGT data
data_busemeyer_file = 'DataBusemeyerNoNA.rdata'

data_steingroever_url = 'https://osf.io/download/bmnsv/'  # contains Steingroever's importance sampling marginal likelihoods
data_steingroever_file = 'DataSteingroever.rdata'

lml_url = 'https://osf.io/download/txnbs/' # ind_LogMargLik.txt on https://osf.io/f9cq4/; contains Gronau's bridge sampling estmates
lml_file = 'ind_LogMargLik.txt'

download_to_disk(data_busemeyer_url, data_busemeyer_file)
download_to_disk(data_steingroever_url, data_steingroever_file)
download_to_disk(lml_url, lml_file)

File downloaded successfully!
File downloaded successfully!
File downloaded successfully!


In [4]:
data_file = pr.read_r('DataBusemeyerNoNA.rdata')
choices = jnp.asarray(data_file['choice'].to_numpy().astype(int)) - 1  # Python zero-indexing
losses = jnp.asarray(data_file['lo'].to_numpy())
wins = jnp.asarray(data_file['wi'].to_numpy())

N, T = choices.shape
K = 4

In [5]:
print(wins.shape)
print(losses.shape)
print(choices.shape)

(30, 100)
(30, 100)
(30, 100)


Unrolling the likelihood over time seems very inefficient, but how can we add the contribution of an individual choice to `ev_` during the loop?

In [36]:
key = jrnd.PRNGKey(42)

for subject in range(N):

    def ev_link_fn(w, a, c_raw, obs):
        c = 4*c_raw - 2.0
        ev = jnp.zeros((K, ))
        logits = jnp.zeros((K, T))

        def for_body_fn(t, carry):
            ev_, logits_, logits_all = carry
            theta = (0.1*(t+1))**c
            current_utility = (1-w) * wins[subject, t] + w * losses[subject, t]
            ev_ = ev_.at[obs[t]].add(a * (current_utility - ev_[obs[t]]))
            logits_ = theta * ev_
            logits_all = logits_all.at[:, t].set(logits_)
            return (ev_, logits_, logits_all)

        #
        initial_logits = jax.nn.sigmoid(1/K*jnp.ones((K, )))
        _, _, logits = jax.lax.fori_loop(0, T-1, for_body_fn, (ev, initial_logits, logits) )
        return dict(logits=logits.T)

    #

    EVModel = Model('Expectance valence model')
    w_node = EVModel.add_node('w', distribution=dx.Beta(alpha=1.0, beta=1.0))
    a_node = EVModel.add_node('a', distribution=dx.Beta(alpha=1.0, beta=1.0))
    c_raw_node = EVModel.add_node('c_raw', distribution=dx.Beta(alpha=1.0, beta=1.0))

    choice_node = EVModel.add_node('choices', observations=choices[subject, :], distribution=dx.Categorical, link_fn=ev_link_fn, parents=dict(w=w_node, a=a_node, c_raw=c_raw_node, obs=choices[subject,:]))

    

    num_mcmc_steps = 100
    num_particles = 1_000
    num_chains = 1

    step_fns = dict(a=normal_random_walk, w=normal_random_walk, c_raw=normal_random_walk)
    step_fn_params = dict(a=dict(sigma=0.5), w=dict(sigma=0.5), c_raw=dict(sigma=0.5))

    gibbs = gibbs_sampler(EVModel, step_fns=step_fns, step_fn_params=step_fn_params)

    key, subkey = jrnd.split(key)
    final_state, lml, n_iter, final_info = smc_inference_loop(subkey, model=EVModel, kernel=gibbs, num_particles=num_particles, num_mcmc_steps=num_mcmc_steps, num_chains=num_chains)

    print(f'LML subject {subject}:', lml)

    for c in range(num_chains):
        print('Acceptance rates a:', jnp.mean(final_info.update_info['a'].is_accepted[c, ...]))
        print('Acceptance rates w:', jnp.mean(final_info.update_info['w'].is_accepted[c, ...]))
        print('Acceptance rates c_raw:', jnp.mean(final_info.update_info['c_raw'].is_accepted[c, ...]))


LML subject 0: -126.678533607858
Acceptance rates a: 0.34
Acceptance rates w: 0.04
Acceptance rates c_raw: 0.13
LML subject 1: -105.00576176906173
Acceptance rates a: 0.099999994
Acceptance rates w: 0.04
Acceptance rates c_raw: 0.06
LML subject 2: -115.59157787685521
Acceptance rates a: 0.049999997
Acceptance rates w: 0.06
Acceptance rates c_raw: 0.14
LML subject 3: -114.46374477291202
Acceptance rates a: 0.14999999
Acceptance rates w: 0.12
Acceptance rates c_raw: 0.14
LML subject 4: -105.6579978582115
Acceptance rates a: 0.089999996
Acceptance rates w: 0.08
Acceptance rates c_raw: 0.11
LML subject 5: -135.73854035893558
Acceptance rates a: 0.39
Acceptance rates w: 0.39999998
Acceptance rates c_raw: 0.29999998
LML subject 6: -134.27809000774008
Acceptance rates a: 0.37
Acceptance rates w: 0.39
Acceptance rates c_raw: 0.34
LML subject 7: -123.6594524991902
Acceptance rates a: 0.14999999
Acceptance rates w: 0.06
Acceptance rates c_raw: 0.099999994
LML subject 8: -137.62227411022198
Accep

The acceptance rates are quite diverging. 

In [31]:
K = 4

probs = jnp.ones((K, 5)).T

x = jnp.array([1, 0, 3, 2, 0])
print(probs.shape)

print(jnp.sum(dx.Categorical(probs=probs).log_prob(value=x)))

(5, 4)
-6.931471805599453
