# Load packages

In [2]:
import ssms
import torch
import pickle
import numpy as np
import os
import lanfactory
from copy import deepcopy
import pandas as pd
import matplotlib.pyplot as plt
from pymc.sampling import jax as pmj
import arviz

wandb not available
wandb not available


# Check model parameters

In [3]:
ssms.config.model_config['angle']

{'name': 'angle',
 'params': ['v', 'a', 'z', 't', 'theta'],
 'param_bounds': [[-3.0, 0.3, 0.1, 0.001, -0.1], [3.0, 3.0, 0.9, 2.0, 1.3]],
 'boundary': <function ssms.basic_simulators.boundary_functions.angle(t=1, theta=1)>,
 'n_params': 5,
 'default_params': [0.0, 1.0, 0.5, 0.001, 0.0],
 'hddm_include': ['z', 'theta'],
 'nchoices': 2}

# Set up the functions for jax and pytensor wrapper

In [4]:
from os import PathLike
from typing import Callable, Tuple

import pytensor 
pytensor.config.floatX = "float32"
import pytensor.tensor as pt
import jax.numpy as jnp
from jax.scipy.special import expit
import numpy as np
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
from jax import jit, vjp, vmap
from jax import grad, jit
from numpy.typing import ArrayLike

LogLikeFunc = Callable[..., ArrayLike]
LogLikeGrad = Callable[..., ArrayLike]

import pymc as pm
from pytensor.tensor.random.op import RandomVariable

import warnings 
warnings.filterwarnings('ignore')

from jax.config import config

config.update("jax_enable_x64", False)

class NetworkLike:
    @classmethod
    def make_logp_jax_funcs(
        cls,
        params_is_reg: list[bool],
        list_params: list,
        model = None,
        n_params: int | None = None,
        bounds = None,
        kind: str = 'lan',
    ) -> Tuple[LogLikeFunc, LogLikeGrad, LogLikeFunc,]:
        """Makes a jax log likelihood function from flax network forward pass.
        Args:
            model: A path or url to the ONNX model, or an ONNX Model object
            already loaded.
            compile: Whether to use jit in jax to compile the model.
        Returns: A triple of jax or Python functions. The first calculates the
            forward pass, the second calculates the gradient, and the third is
            the forward-pass that's not jitted.
        """
        if kind == 'lan':
            def logp_lan(data: np.ndarray, *dist_params) -> ArrayLike:
                """
                Computes the sum of the log-likelihoods given data and arbitrary
                numbers of parameters assuming the trial by trial likelihoods
                are derived from a LAN.
                Args:
                    data: response time with sign indicating direction.
                    dist_params: a list of parameters used in the likelihood computation.
                Returns:
                    The sum of log-likelihoods.
                """
                
                
                transformed_params = []
                for i in range(len(dist_params)):
                    if list_params[i] in bounds.keys():
                        transformed_params.append(expit(dist_params[i]) * (bounds[list_params[i]][1] - bounds[list_params[i]][0]) + bounds[list_params[i]][0])
                    else:
                        transformed_params.append(dist_params[i])
                input_matrix = jnp.concatenate((jnp.array(transformed_params[:-1]), data))

                ll = jnp.multiply(jnp.exp(model(input_matrix)),1-transformed_params[-1]) + transformed_params[-1] * 1/2.5

                # Network forward and sum
                return jnp.sum(
                    jnp.squeeze(jnp.log(ll))
                )
            # The vectorization of the logp function
            vmap_logp_lan = vmap(
                logp_lan,
                in_axes=[0] + [0 if is_regression else None for is_regression in params_is_reg],
            )
            # logp_grad_lan = grad(logp_lan, argnums=range(1, 1 + n_params))
            # return jit(logp_lan), jit(logp_grad_lan), logp_lan
            
            def vjp_vmap_logp_lan(
                data: np.ndarray, *dist_params: list[float | ArrayLike], gz: ArrayLike
            ) -> list[ArrayLike]:
                """Compute the VJP of the log-likelihood function.

                Parameters
                ----------
                data
                    A two-column numpy array with response time and response.
                dist_params
                    A list of parameters used in the likelihood computation.
                gz
                    The value of vmap_logp at which the VJP is evaluated, typically is just
                    vmap_logp(data, *dist_params)

                Returns
                -------
                list[ArrayLike]
                    The VJP of the log-likelihood function computed at gz.
                """
                _, vjp_fn = vjp(vmap_logp_lan, data, *dist_params)
                return vjp_fn(gz)[1:]

            return jit(vmap_logp_lan), jit(vjp_vmap_logp_lan), vmap_logp_lan

        elif kind == 'cpn':
            def logp_cpn(data: np.ndarray, *dist_params) -> ArrayLike:
                """
                Computes the sum of the log-likelihoods given data and arbitrary
                numbers of parameters assuming the trial-by-trial likelihood derive for a CPN.
                Args:
                    data: response time with sign indicating direction.
                    dist_params: a list of parameters used in the likelihood computation.
                Returns:
                    The sum of log-likelihoods.
                """

                # Makes a matrix to feed to the LAN model
                # n_nogo_go_condition = jnp.sum(data > 0)
                # n_nogo_nogo_condition = jnp.sum(data < 0)
                # n_omission = jnp.sum(data>0)
                # n_total = jnp.sum(data>=0)
                transformed_params = []

                for i in range(len(dist_params)):
                    if list_params[i] in bounds.keys():
                        transformed_params.append(expit(dist_params[i]) * (bounds[list_params[i]][1] - bounds[list_params[i]][0]) + bounds[list_params[i]][0])
                    else:
                        transformed_params.append(dist_params[i])
                        
                params_matrix  = jnp.array(transformed_params)

                # AF-TODO Bugfix here !
                # dist_params_nogo = jnp.stack(dist_params).reshape(1, -1)
                # dist_params_nogo = dist_params_nogo.at[0].set((-1) * dist_params_nogo[0])

                net_out = jnp.squeeze(model(params_matrix))

                # Include lapse distribution (uniform) into omission likelihood
                # dist_params[-1]: outlier
                # dist_params[-2]: deadline (in second)

                out = jnp.sum(jnp.multiply(jnp.log(1 - jnp.exp(net_out) + 1e-64),data))

                return out
            vmap_logp_cpn = vmap(
                logp_cpn,
                in_axes=[0] + [0 if is_regression else None for is_regression in params_is_reg],
            )
            def vjp_vmap_logp_cpn(
                data: np.ndarray, *dist_params: list[float | ArrayLike], gz: ArrayLike
            ) -> list[ArrayLike]:
                """Compute the VJP of the log-likelihood function.

                Parameters
                ----------
                data
                    A two-column numpy array with response time and response.
                dist_params
                    A list of parameters used in the likelihood computation.
                gz
                    The value of vmap_logp at which the VJP is evaluated, typically is just
                    vmap_logp(data, *dist_params)

                Returns
                -------
                list[ArrayLike]
                    The VJP of the log-likelihood function computed at gz.
                """
                _, vjp_fn = vjp(vmap_logp_cpn, data, *dist_params)
                return vjp_fn(gz)[1:]

            return jit(vmap_logp_cpn), jit(vjp_vmap_logp_cpn), vmap_logp_cpn

    @staticmethod

    def make_jax_logp_ops(
        logp: LogLikeFunc,
        logp_vjp: LogLikeGrad,
        logp_nojit: LogLikeFunc,
    ) -> Op:
        """Wrap the JAX functions and its gradient in pytensor Ops.

        Parameters
        ----------
        logp
            A JAX function that represents the feed-forward operation of the LAN
            network.
        logp_vjp
            The Jax function that calculates the VJP of the logp function.
        logp_nojit
            The non-jit version of logp.

        Returns
        -------
        Op
            An pytensor op that wraps the feed-forward operation and can be used with
            pytensor.grad.
        """

        class LANLogpOp(Op):  # pylint: disable=W0223
            """Wraps a JAX function in an pytensor Op."""

            def make_node(self, data, *dist_params):
                """Take the inputs to the Op and puts them in a list.

                Also specifies the output types in a list, then feed them to the Apply node.

                Parameters
                ----------
                data
                    A two-column numpy array with response time and response.
                dist_params
                    A list of parameters used in the likelihood computation. The parameters
                    can be both scalars and arrays.
                """
                inputs = [
                    pt.as_tensor_variable(data),
                ] + [pt.as_tensor_variable(dist_param) for dist_param in dist_params]

                outputs = [pt.vector()]

                return Apply(self, inputs, outputs)

            def perform(self, node, inputs, output_storage):
                """Perform the Apply node.

                Parameters
                ----------
                inputs
                    This is a list of data from which the values stored in
                    output_storage are to be computed using non-symbolic language.
                output_storage
                    This is a list of storage cells where the output
                    is to be stored. A storage cell is a one-element list. It is
                    forbidden to change the length of the list(s) contained in
                    output_storage. There is one storage cell for each output of
                    the Op.
                """
                result = logp(*inputs)
                output_storage[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)

            def grad(self, inputs, output_gradients):
                """Perform the pytensor.grad() operation.

                Parameters
                ----------
                inputs
                    The same as the inputs produced in `make_node`.
                output_gradients
                    Holds the results of the perform `perform` method.

                Notes
                -----
                    It should output the VJP of the Op. In other words, if this `Op`
                    outputs `y`, and the gradient at `y` is grad(x), the required output
                    is y*grad(x).
                """
                results = lan_logp_vjp_op(inputs[0], *inputs[1:], gz=output_gradients[0])
                output = [
                    pytensor.gradient.grad_not_implemented(self, 0, inputs[0]),
                ] + results

                return output

        class LANLogpVJPOp(Op):  # pylint: disable=W0223
            """Wraps the VJP operation of a jax function in an pytensor op."""

            def make_node(self, data, *dist_params, gz):
                """Take the inputs to the Op and puts them in a list.

                Also specifies the output types in a list, then feed them to the Apply node.

                Parameters
                ----------
                data:
                    A two-column numpy array with response time and response.
                dist_params:
                    A list of parameters used in the likelihood computation.
                """
                inputs = (
                    [
                        pt.as_tensor_variable(data),
                    ]
                    + [pt.as_tensor_variable(dist_param) for dist_param in dist_params]
                    + [pt.as_tensor_variable(gz)]
                )
                outputs = [inp.type() for inp in inputs[1:-1]]

                return Apply(self, inputs, outputs)

            def perform(self, node, inputs, outputs):
                """Perform the Apply node.

                Parameters
                ----------
                inputs
                    This is a list of data from which the values stored in
                    `output_storage` are to be computed using non-symbolic language.
                output_storage
                    This is a list of storage cells where the output
                    is to be stored. A storage cell is a one-element list. It is
                    forbidden to change the length of the list(s) contained in
                    output_storage. There is one storage cell for each output of
                    the Op.
                """
                results = logp_vjp(inputs[0], *inputs[1:-1], gz=inputs[-1])

                for i, result in enumerate(results):
                    outputs[i][0] = np.asarray(result, dtype=node.outputs[i].dtype)

        lan_logp_op = LANLogpOp()
        lan_logp_vjp_op = LANLogpVJPOp()

        # Unwraps the JAX function for sampling with JAX backend.
        @jax_funcify.register(LANLogpOp)
        def logp_op_dispatch(op, **kwargs):  # pylint: disable=W0612,W0613
            return logp_nojit

        return lan_logp_op


# Load LAN

In [5]:
# Loaded Net
model_config = ssms.config.model_config['angle']

jax_infer_lan = lanfactory.trainers.MLPJaxFactory(
    network_config="../network/angle/lan/96f2b24a933211ee99b9a0423f3e9a40_lan_angle__network_config.pickle",
    train=False,
)

forward_pass_lan, forward_pass_jitted_lan = jax_infer_lan.make_forward_partial(
    seed=42,
    input_dim=model_config["n_params"] + 2,
    state="../network/angle/lan/96f2b24a933211ee99b9a0423f3e9a40_lan_angle__train_state.jax",
    add_jitted=True,
)

passing through identity


# Load OPN

In [6]:
# Loaded Net
jax_infer_cpn = lanfactory.trainers.MLPJaxFactory(
    network_config="../network/angle/cpn/338ff01ca91911ee91a3a0423f3e9b42_cpn_angle__network_config.pickle",
    train=False,
)

forward_pass_cpn, forward_pass_jitted_cpn = jax_infer_cpn.make_forward_partial(
    seed=42,
    input_dim=model_config["n_params"] + 1,
    state="../network/angle/cpn/338ff01ca91911ee91a3a0423f3e9b42_cpn_angle__train_state.jax",
    add_jitted=True,
)

passing through transform


# This is the part I loaded the dataset. You can ignore this part.

In [5]:
df_srm = pd.read_csv('/users/xleng/TSS_OCD/data/2022-04-11_4.0/fa_subset.csv')

df_srm['subj_idx'] = df_srm['PROLIFIC_PID']
df_srm['subject'] = df_srm['subj_idx'].factorize()[0]

df_idx = df_srm[['subj_idx','subject']]

In [6]:
df_srm

Unnamed: 0,MR1_NO,MR2_NO,MR1_O,MR2_O,PROLIFIC_PID,Bias,Pro,Pre,subj_idx,subject
0,2.436324,1.673657,2.231432,0.999315,60fd2ca6bf8d050ebf440221,-1.328178,-0.683701,0.937111,60fd2ca6bf8d050ebf440221,0
1,0.112875,-0.357938,0.287001,-0.477968,5c3d376ec2d0b700017c7c50,-0.674127,-1.109457,-0.207784,5c3d376ec2d0b700017c7c50,1
2,2.007355,1.324592,1.861469,0.757895,6103159a48c853995ad5039a,-0.968450,-0.470824,0.708132,6103159a48c853995ad5039a,2
3,2.518625,0.056017,3.012727,-0.982447,5b92f80b2f777d000175da5c,0.699379,-1.109457,-1.810637,5b92f80b2f777d000175da5c,3
4,-0.384611,0.733107,-0.772784,1.042902,60712d937752fb8780e89951,0.928297,-0.045068,-1.123700,60712d937752fb8780e89951,4
...,...,...,...,...,...,...,...,...,...,...
311,-1.019517,-0.543781,-0.999801,-0.229666,5d19281012e152001920227f,-0.968450,-0.470824,0.708132,5d19281012e152001920227f,311
312,0.699188,2.345177,-0.145930,2.531872,5e79b84d4633bc575ddb3812,0.699379,0.167810,-0.665742,5e79b84d4633bc575ddb3812,312
313,-0.952481,-0.817260,-0.803674,-0.586848,5ff5f7ad932d56101bf7c90d,-1.099260,0.380687,1.624048,5ff5f7ad932d56101bf7c90d,313
314,-0.570416,-0.832249,-0.336753,-0.764156,5d9d2e421af9ed0011c31894,-0.052779,1.232198,1.166090,5d9d2e421af9ed0011c31894,314


In [7]:

df = pd.read_csv('/users/xleng/TSS_OCD/data/2022-04-11_4.0/hddm_all.csv')

df = df.merge(df_srm[['subj_idx','subject']])

df['catRewLevel'] = df['catRewLevel'].factorize()[0]

df['catPunLevel'] = df['catPunLevel'].factorize()[0]

df['catCong'] = df['catCong'].factorize()[0]

df['response'] = df['response'] * 2 - 1

df['catRewLevel'] = df['catRewLevel'] * 2 - 1

df['catPunLevel'] = df['catPunLevel'] * 2 - 1

df['rewpunLevel'] = df['catRewLevel'] * df['catPunLevel']

df['catCong'] = df['catCong'] * 2 - 1

df_commission = df.loc[df['oe'] == False]

In [8]:
df_commission

Unnamed: 0,rt,catRewLevel,catPunLevel,oe,subj_idx,response,catCong,subject,rewpunLevel
0,0.591,-1,-1,False,60fd2ca6bf8d050ebf440221,1,-1,0,1
1,0.538,-1,-1,False,60fd2ca6bf8d050ebf440221,1,1,0,1
2,0.648,-1,-1,False,60fd2ca6bf8d050ebf440221,1,-1,0,1
3,0.330,-1,-1,False,60fd2ca6bf8d050ebf440221,1,1,0,1
4,0.307,-1,-1,False,60fd2ca6bf8d050ebf440221,1,1,0,1
...,...,...,...,...,...,...,...,...,...
172708,0.742,-1,1,False,5daa3d1726488800157a6ffc,1,1,315,-1
172709,0.638,-1,1,False,5daa3d1726488800157a6ffc,1,-1,315,-1
172710,0.918,-1,1,False,5daa3d1726488800157a6ffc,1,-1,315,-1
172711,0.734,-1,1,False,5daa3d1726488800157a6ffc,1,1,315,-1


# Create the pytensor OP

In [None]:
# Instantiate LAN logp functions
lan_logp_jitted, lan_logp_vjp_jitted, lan_logp = NetworkLike.make_logp_jax_funcs(model = forward_pass_lan,
                                                                                  n_params = 6,
                                                                                  kind = "lan",
                                                                                  list_params = ['v','a','z','t','theta','p_outlier'],
                                                                                  bounds = {'v':(-3,3),
                                                                                            'a':(0.2,2.5),
                                                                                            'z':(0.1,0.9),
                                                                                            't':(0.01,0.5),
                                                                                            'theta':(0,1.2),
                                                                                            'p_outlier':(0,0.05)},
                                                                                  params_is_reg=[True,True,True,True,True,False])

# Turn into logp op
lan_logp_op = NetworkLike.make_jax_logp_ops(
                                logp = lan_logp_jitted,
                                logp_vjp = lan_logp_vjp_jitted,
                                logp_nojit = lan_logp)

# Instantiate CPN logp functions
cpn_logp_jitted, cpn_logp_vjp_jitted, cpn_logp = NetworkLike.make_logp_jax_funcs(model = forward_pass_cpn,
                                                                                 n_params = 6,
                                                                                 list_params = ['v','a','z','t','theta','deadline'],
                                                                                 bounds = {'v':(-3,3),
                                                                                           'a':(0.2,2.5),
                                                                                           'z':(0.1,0.9),
                                                                                           't':(0.01,0.5),
                                                                                           'theta':(0,1.2)},
                                                                                 kind = "cpn",
                                                                                 params_is_reg=[True,True,True,True,True,False])

# Turn into logp op
cpn_logp_op = NetworkLike.make_jax_logp_ops(
                                logp = cpn_logp_jitted,
                                logp_vjp = cpn_logp_vjp_jitted,
                                logp_nojit = cpn_logp)

# Run the fitting

In [None]:
coords = {
    "id": df.subject.unique(),  # actual group names
    "observation1": np.arange(df_commission.shape[0]),  # or use this, `data.index.values
    "observation2": np.arange(df.shape[0])
}

with pm.Model(coords=coords) as hierarchical:
    # Hyperpriors
    mu_v = pm.Normal('mu_v',mu=0,sigma=1)
    sigma_v = pm.HalfCauchy('sigma_v',beta=0.2)
    v_subj = pm.Normal('v_subj',mu=0,sigma=1,dims='id')

    mu_v_catRewLevel = pm.Normal('v_catRewLevel',mu=0,sigma=1)
    sigma_v_catRewLevel = pm.HalfCauchy('sigma_v_catRewLevel',beta=0.2)
    v_catRewLevel_subj = pm.Normal('v_catRewLevel_subj',mu=0,sigma=1,dims='id')
    
    mu_a_catRewLevel = pm.Normal('a_catRewLevel',mu=0,sigma=1)
    sigma_a_catRewLevel = pm.HalfCauchy('sigma_a_catRewLevel',beta=0.2)
    a_catRewLevel_subj = pm.Normal('a_catRewLevel_subj',mu=0,sigma=1,dims='id')
    
    mu_theta_catRewLevel = pm.Normal('theta_catRewLevel',mu=0,sigma=1)
    sigma_theta_catRewLevel = pm.HalfCauchy('sigma_theta_catRewLevel',beta=0.2)
    theta_catRewLevel_subj = pm.Normal('theta_catRewLevel_subj',mu=0,sigma=1,dims='id')    

    mu_v_catRewLevelcatPunLevel = pm.Normal('v_catRewLevelcatPunLevel',mu=0,sigma=0.5)
    sigma_v_catRewLevelcatPunLevel = pm.HalfCauchy('sigma_v_catRewLevelPunLevel',beta=0.2)
    v_catRewLevelcatPunLevel_subj = pm.Normal('v_catRewLevelcatPunLevel_subj',mu=0,
                                              sigma=1,dims='id')
    mu_a_catRewLevelcatPunLevel = pm.Normal('a_catRewLevelcatPunLevel',mu=0,sigma=0.5)
    sigma_a_catRewLevelcatPunLevel = pm.HalfCauchy('sigma_a_catRewLevelcatPunLevel',beta=0.2)
    a_catRewLevelcatPunLevel_subj = pm.Normal('a_catRewLevelcatPunLevel_subj',mu=0,sigma=1,dims='id')
    
    mu_theta_catRewLevelcatPunLevel = pm.Normal('theta_catRewLevelcatPunLevel',mu=0,sigma=0.5)
    sigma_theta_catRewLevelcatPunLevel = pm.HalfCauchy('sigma_theta_catRewLevelcatPunLevel',beta=0.2)
    theta_catRewLevelcatPunLevel_subj = pm.Normal('theta_catRewLevelcatPunLevel_subj',mu=0,sigma=1,dims='id')    

    mu_v_cong = pm.Normal('v_cong',mu=0,sigma=1)
    sigma_v_cong = pm.HalfCauchy('sigma_v_cong',beta=0.2)
    v_cong_subj = pm.Normal('v_cong_subj',mu=0,sigma=1,dims='id')
    
    mu_a_cong = pm.Normal('a_cong',mu=0,sigma=1)
    sigma_a_cong = pm.HalfCauchy('sigma_a_cong',beta=0.2)
    a_cong_subj = pm.Normal('a_cong_subj',mu=0,sigma=1,dims='id')
    
    mu_theta_cong = pm.Normal('theta_cong',mu=0,sigma=1)
    sigma_theta_cong = pm.HalfCauchy('sigma_theta_cong',beta=0.2)
    theta_cong_subj = pm.Normal('theta_cong_subj',mu=0,sigma=1,dims='id')   
    
    mu_z_cong = pm.Normal('z_cong',mu=0,sigma=1)
    sigma_z_cong = pm.HalfCauchy('sigma_z_cong',beta=0.2)
    z_cong_subj = pm.Normal('z_cong_subj',mu=0,sigma=1,dims='id')
    
    mu_v_catPunLevel = pm.Normal('v_catPunLevel',mu=0,sigma=1)
    sigma_v_catPunLevel = pm.HalfCauchy('sigma_v_catPunLevel',beta=0.2)
    v_catPunLevel_subj = pm.Normal('v_catPunLevel_subj',mu=0,sigma=1,dims='id')
    
    mu_a_catPunLevel = pm.Normal('a_catPunLevel',mu=0,sigma=1)
    sigma_a_catPunLevel = pm.HalfCauchy('sigma_a_catPunLevel',beta=0.2)
    a_catPunLevel_subj = pm.Normal('a_catPunLevel_subj',mu=0,sigma=1,dims='id')
    
    mu_theta_catPunLevel = pm.Normal('theta_catPunLevel',mu=0,sigma=1)
    sigma_theta_catPunLevel = pm.HalfCauchy('sigma_theta_catPunLevel',beta=0.2)
    theta_catPunLevel_subj = pm.Normal('theta_catPunLevel_subj',mu=0,sigma=1,dims='id')    
 
    mu_theta = pm.Normal('mu_theta',mu=0,sigma=1)
    sigma_theta = pm.HalfCauchy('sigma_theta',beta=0.2)
    theta_subj = pm.Normal('theta_subj',mu=0,sigma=1,dims='id')
    
    mu_a = pm.Normal('mu_a',mu=0,sigma=1)
    sigma_a = pm.HalfCauchy('sigma_a',beta=0.2)
    a_subj = pm.Normal('a_subj',mu=0,sigma=1,dims='id')
        
    mu_z = pm.Normal("mu_z", mu=0,sigma=1)
    sigma_z = pm.HalfCauchy("sigma_z",beta=0.2)
    z_subj = pm.Normal("z_subj", mu=0,sigma=1,dims='id')
    
    mu_t = pm.Normal("mu_t", mu=0,sigma=1)
    sigma_t = pm.HalfCauchy("sigma_t",beta=0.2)
    t_subj = pm.Normal("t_subj", mu=0,sigma=1,dims='id')
    
    mu_p = pm.Normal("mu_p", mu=0,sigma=1)
    sigma_p = pm.HalfCauchy("sigma_p",beta=0.2)
    p_subj = pm.Normal("p_subj", mu=0,sigma=1,dims='id')
    
    # p_outlier = pm.Normal("p_outlier", mu=0,sigma=1)
    # deadline = pm.ConstantData('deadline',1.25)

    idx1 = pm.ConstantData('idx1',df_commission.subject,dims='observation1')
    idx2 = pm.ConstantData('idx2',df.subject,dims='observation2')
    
    rewardLevel1 = pm.ConstantData('rewardLevel1',df_commission.catRewLevel,dims='observation1')
    rewardLevel2 = pm.ConstantData('rewardLevel2',df.catRewLevel,dims='observation2')
    punishLevel1 = pm.ConstantData('punishLevel1',df_commission.catPunLevel,dims='observation1')
    punishLevel2 = pm.ConstantData('punishLevel2',df.catPunLevel,dims='observation2')
    rewpun1 = pm.ConstantData('rewpun1',df_commission.rewpunLevel,dims='observation1')
    rewpun2 = pm.ConstantData('rewpun2',df.rewpunLevel,dims='observation2')
    cong1 = pm.ConstantData('cong1',df_commission.catCong,dims='observation1')
    cong2 = pm.ConstantData('cong2',df.catCong,dims='observation2')
    
    v_trial1 = pm.Deterministic('v_trial1',v_subj[idx1] * sigma_v + mu_v + 
                                rewardLevel1 * (v_catRewLevel_subj[idx1] * sigma_v_catRewLevel + mu_v_catRewLevel) + 
                                punishLevel1 * (v_catPunLevel_subj[idx1] * sigma_v_catPunLevel + mu_v_catPunLevel) +
                                cong1 * (v_cong_subj[idx1] * sigma_v_cong + mu_v_cong) + 
                                rewpun1 * (v_catRewLevelcatPunLevel_subj[idx1] * sigma_v_catRewLevelcatPunLevel + mu_v_catRewLevelcatPunLevel))
    v_trial2 = pm.Deterministic('v_trial2',v_subj[idx2] * sigma_v + mu_v + 
                                rewardLevel2 * (v_catRewLevel_subj[idx2] * sigma_v_catRewLevel + mu_v_catRewLevel) + 
                                punishLevel2 * (v_catPunLevel_subj[idx2] * sigma_v_catPunLevel + mu_v_catPunLevel) +
                                cong2 * (v_cong_subj[idx2] * sigma_v_cong + mu_v_cong) + 
                                rewpun2 * (v_catRewLevelcatPunLevel_subj[idx2] * sigma_v_catRewLevelcatPunLevel + mu_v_catRewLevelcatPunLevel))
    a_trial1 = pm.Deterministic('a_trial1',a_subj[idx1] * sigma_a + mu_a + 
                                rewardLevel1 * (a_catRewLevel_subj[idx1] * sigma_a_catRewLevel + mu_a_catRewLevel) + 
                                punishLevel1 * (a_catPunLevel_subj[idx1] * sigma_a_catPunLevel + mu_a_catPunLevel) +
                                cong1 * (a_cong_subj[idx1] * sigma_a_cong + mu_a_cong) + 
                                rewpun1 * (a_catRewLevelcatPunLevel_subj[idx1] * sigma_a_catRewLevelcatPunLevel + mu_a_catRewLevelcatPunLevel))
    a_trial2 = pm.Deterministic('a_trial2',a_subj[idx2] * sigma_a + mu_a + 
                                rewardLevel2 * (a_catRewLevel_subj[idx2] * sigma_a_catRewLevel + mu_a_catRewLevel) + 
                                punishLevel2 * (a_catPunLevel_subj[idx2] * sigma_a_catPunLevel + mu_a_catPunLevel) +
                                cong2 * (a_cong_subj[idx2] * sigma_a_cong + mu_a_cong) + 
                                rewpun2 * (a_catRewLevelcatPunLevel_subj[idx2] * sigma_a_catRewLevelcatPunLevel + mu_a_catRewLevelcatPunLevel))
    theta_trial1 = pm.Deterministic('theta_trial1',theta_subj[idx1] * sigma_theta + mu_theta + 
                                rewardLevel1 * (theta_catRewLevel_subj[idx1] * sigma_theta_catRewLevel + mu_theta_catRewLevel) + 
                                punishLevel1 * (theta_catPunLevel_subj[idx1] * sigma_theta_catPunLevel + mu_theta_catPunLevel) +
                                cong1 * (theta_cong_subj[idx1] * sigma_theta_cong + mu_theta_cong) + 
                                rewpun1 * (theta_catRewLevelcatPunLevel_subj[idx1] * sigma_theta_catRewLevelcatPunLevel + mu_theta_catRewLevelcatPunLevel))
    theta_trial2 = pm.Deterministic('theta_trial2',theta_subj[idx2] * sigma_theta + mu_theta + 
                                rewardLevel2 * (theta_catRewLevel_subj[idx2] * sigma_theta_catRewLevel + mu_theta_catRewLevel) + 
                                punishLevel2 * (theta_catPunLevel_subj[idx2] * sigma_theta_catPunLevel + mu_theta_catPunLevel) +
                                cong2 * (theta_cong_subj[idx2] * sigma_theta_cong + mu_theta_cong) + 
                                rewpun2 * (theta_catRewLevelcatPunLevel_subj[idx2] * sigma_theta_catRewLevelcatPunLevel + mu_theta_catRewLevelcatPunLevel))
    z_trial1 = pm.Deterministic('z_trial1',z_subj[idx1] * sigma_z + mu_z + cong1 * (z_cong_subj[idx1] * sigma_z_cong + mu_z_cong))  
    z_trial2 = pm.Deterministic('z_trial2',z_subj[idx2] * sigma_z + mu_z + cong2 * (z_cong_subj[idx2] * sigma_z_cong + mu_z_cong))
    
    t_trial1 = pm.Deterministic('t_trial1',t_subj[idx1] * sigma_t + mu_t)  
    t_trial2 = pm.Deterministic('t_trial2',t_subj[idx2] * sigma_t + mu_t)  

        t_trial1 = pm.Deterministic('t_trial1',t_subj[idx1] * sigma_t + mu_t)  
    t_trial2 = pm.Deterministic('t_trial2',t_subj[idx2] * sigma_t + mu_t)  

    pm.CustomDist("choice_rt", v_trial1,
                                          a_trial1,
                                          z_trial1, 
                                          t_trial1, 
                                          theta_trial1,
                                          p_outlier,
                 logp=lan_logp_op,observed=df_commission[['rt','response']])
    pm.CustomDist("omission", v_trial2,
                                         a_trial2, 
                                         z_trial2, 
                                         t_trial2, 
                                         theta_trial2,
                                         deadline,
                 logp=cpn_logp_op,observed=df['oe'])   
    ddm_blog_traces_numpyro_d = pmj.sample_numpyro_nuts(
            chains=2, draws=1000, tune=500,initvals={'mu_v':0,'mu_a':0,'mu_theta':0,"mu_z":0,"mu_t":0,"v_catRewLevel":0,"a_catRewLevel":0,
                                                    'theta_catRewLevel':0,"v_catPunLevel":0,"a_catPunLevel":0,
                                                    'theta_catPunLevel':0,"v_cong":0,"a_cong":0,
                                                    'theta_cong':0,'z_cong':0,'v_catRewLevelcatPunLevel':0,
                                                    'a_catRewLevelcatPunLevel':0,'theta_catRewLevelcatPunLevel':0}
    )

Compiling...


passing through identity
passing through transform


Compilation time = 0:00:08.532185
Sampling...


passing through transform
passing through identity


  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

passing through transform
passing through identity


# Function to extract and plot traces that I'm interested in

In [7]:
def mytraceplot(trace):
    keep = [x for x in trace.posterior.data_vars.keys() if x[-6:-1] != 'trial']
    pm.plot_trace(trace, keep)

    
def mysummary(trace):
    keep = [x for x in trace.posterior.data_vars.keys() if x[-6:-1] != 'trial']
    return pm.summary(trace, keep)

In [8]:
import pymc as pm

a = mysummary(idata)

# Tried to estimate LL from model

In [59]:
with hierarchical:
    a = pm.compute_log_likelihood(ddm_blog_traces_numpyro_d)

In [64]:
pm.waic(a,var_name='omission')

Computed from 400 posterior samples and 164376 observations log-likelihood matrix.

          Estimate       SE
elpd_waic -17555.59   256.33
p_waic      715.00        -
