# PRAM linear models informed with EEG and pre-trial accuracy

Imports

In [None]:
# import cmdstanpy
# cmdstanpy.install_cmdstan()

In [None]:
from cmdstanpy import CmdStanModel
import os
import numpy as np
import pandas as pd
from datetime import datetime
import pickle
import json
import time
from contextlib import redirect_stdout
import seaborn as sns

## Define the model

In [None]:
name = 'rt_regression' 
model_name = f'{name}_model.stan'

print(f'Processing model: {model_name}')

## Compile the model

In [None]:
stan_file = os.path.join('../models/ncognitive_models/TBB_models/', model_name)
rt_model = CmdStanModel(
    stan_file=stan_file, 
    cpp_options={'STAN_THREADS': True}, 
    force_compile=True
)

In [None]:
rt_model.exe_info()

## Define data file

In [None]:
data_file = os.path.join('../data/current_dataset', 'sonata_data_standardized_rt.json')

### Read the data

In [None]:
with open(data_file, 'r') as file:
    data = json.load(file)

In [None]:
data_df = pd.DataFrame(
    {
        'participant_index': data['participant'],
         'rt': abs(np.array(data['rt']))
    }
)

In [None]:
data_df

In [None]:
with pd.option_context('display.max_rows', None,):
    display(data_df.groupby('participant_index').describe())

## Fit the model

Fit parameters

In [None]:
num_chains = 3
warmup = 5000
num_samples = 10000
thin=5
adapt_delta=0.99
random_state = 42

Define initial values

Perform fit

In [None]:
with open('jupyter_logs.txt', 'a') as f:
    with redirect_stdout(f):
        start = time.time()
        fit = rt_model.sample(
            data=data_file,
            chains=num_chains, 
            seed=random_state,
            thin=thin,
            adapt_delta=adapt_delta,
            # inits=initials, 
            iter_warmup=warmup, 
            iter_sampling=num_samples,
            parallel_chains=num_chains,
            threads_per_chain= 12,
            max_treedepth=10,
            show_progress=True,
            show_console=True,
            output_dir=f'../../plgrid_results/pram_results/sonata/sonata_cond_{name}/'
        )
        end = time.time()

print(f'Fitting took: {end - start}')

In [None]:
print(fit.diagnose())

In [None]:
fit_df = fit.draws_pd()

sns.lineplot(
    data=fit_df,
    x = 'iter__',
    y = 'pre_acc_prop',
    hue='chain__'
)