# Evidence estimation in a Bayesian cognitive model


In [1]:
%load_ext autoreload
%autoreload 2

import os

# For JAX, we can use CPU (set SELECTED_DEVICE to empty (''), or one of the available GPUs)
SELECTED_DEVICE = '0'
print(f'Setting CUDA visible devices to [{SELECTED_DEVICE}]')
os.environ['CUDA_VISIBLE_DEVICES'] = f'{SELECTED_DEVICE}'

Setting CUDA visible devices to [0]


In [5]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import jax
jax.config.update("jax_enable_x64", True)  # Do we need this here? -> it seems we do for the LML computations (otherwise NaNs get introduced), but not for performance

import jax.random as jrnd
import jax.numpy as jnp
import distrax as dx
import blackjax
import pandas as pd
import jax.scipy.special as jsp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import pyreadr as pr

from distrax._src.distributions.distribution import Distribution

import os
import sys
import requests

from blackjax import normal_random_walk

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../')))

import bamojax
from bamojax.base import Node, Model
from bamojax.sampling import gibbs_sampler, inference_loop, run_chain, smc_inference_loop

print('Python version:       ', sys.version)
print('Jax version:          ', jax.__version__)
print('BlackJax version:     ', blackjax.__version__)
print('Distrax version:      ', dx.__version__)
print('BaMoJax version:      ', bamojax.__version__)
print('Jax default backend:  ', jax.default_backend())
print('Jax devices:          ', jax.devices())

SMALL_SIZE = 14
MEDIUM_SIZE = 16
LARGE_SIZE = 22

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=LARGE_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=LARGE_SIZE)  # fontsize of the figure title

Python version:        3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
Jax version:           0.4.35
BlackJax version:      1.2.4
Distrax version:       0.1.5
BaMoJax version:       0.0.1
Jax default backend:   gpu
Jax devices:           [CudaDevice(id=0)]


In [3]:
def download_to_disk(url, filepath):
    response = requests.get(url)
    if response.status_code == 200:
        with open(filepath, 'wb') as file:
            file.write(response.content)
        print('File downloaded successfully!')
    else:
        print(f'Failed to download the file. Status code: {response.status_code}')

#
data_busemeyer_url = 'https://osf.io/download/5vws6/'  # DataBusemeyerNoNA.rdata on https://osf.io/f9cq4/; contains IGT data
data_busemeyer_file = 'DataBusemeyerNoNA.rdata'

data_steingroever_url = 'https://osf.io/download/bmnsv/'  # contains Steingroever's importance sampling marginal likelihoods
data_steingroever_file = 'DataSteingroever.rdata'

lml_url = 'https://osf.io/download/txnbs/' # ind_LogMargLik.txt on https://osf.io/f9cq4/; contains Gronau's bridge sampling estmates
lml_file = 'ind_LogMargLik.txt'

download_to_disk(data_busemeyer_url, data_busemeyer_file)
download_to_disk(data_steingroever_url, data_steingroever_file)
download_to_disk(lml_url, lml_file)

File downloaded successfully!
File downloaded successfully!
File downloaded successfully!


In [None]:
data_file = pr.read_r('DataBusemeyerNoNA.rdata')
choices = jnp.asarray(data_file['choice'].to_numpy().astype(int)) - 1  # Python zero-indexing
losses = jnp.asarray(data_file['lo'].to_numpy())
wins = jnp.asarray(data_file['wi'].to_numpy())

N, T = choices.shape
K = 4

In [None]:
def ev_link_fn(w, a, c_raw):
    c = 4*c_raw - 2.0
    ev = jnp.zeros((K, ))

    # construct logits/probs for Categorical

EVModel = Model('Expectance valence model')
w_node = EVModel.add_node('w', distribution=dx.Beta(alpha=1.0, beta=1.0))
a_node = EVModel.add_node('a', distribution=dx.Beta(alpha=1.0, beta=1.0))
c_raw_node = EVModel.add_node('c_raw', distribution=dx.Beta(alpha=1.0, beta=1.0))

choice_node = EVModel.add_node('choices', distribution=dx.Categorical, link_fn=ev_link_fn)


    #     def loglikelihood_fn_(state: GibbsState) -> Float:
    #         position = getattr(state, 'position', state)
    #         w = position['w']
    #         a = position['a']
    #         c = 4*position['c_raw'] - 2.0
    #         ev = jnp.zeros((self.K, ))  # expected valence per deck of cards
    #         loglik = dx.Categorical(probs=jnp.ones((self.K, ))).log_prob(value=self.choices[0])

    #         def for_body(t, carry):
    #             ev_, loglik_ = carry
    #             theta = (0.1*(t+1))**c  
    #             current_utility = (1-w) * self.wins[t] + w*self.loss[t]
    #             ev_ = ev_.at[self.choices[t]].add(a * (current_utility - ev_[self.choices[t]]))
    #             loglik_ += dx.Categorical(logits=theta * ev_).log_prob(value=self.choices[t+1])
    #             return (ev_, loglik_)
            
    #         _, loglik = jax.lax.fori_loop(0, self.T - 1, for_body, (ev, loglik))
            
    #         return loglik
        
    #     #
    #     return loglikelihood_fn_

    # #

