In [4]:
# -*- coding: utf-8 -*- 

import config

import os
import numpy as np
import pandas as pd

from pathlib import Path

from hbayesdm.models import ra_prospect

import nibabel as nib
# import nipype as nip
# import nilearn as nil

from bids import BIDSLayout, BIDSValidator
from tqdm import tqdm

import pickle
from sklearn import linear_model
from scipy.stats import zscore

import matplotlib.pyplot as plt

In [5]:
layout = BIDSLayout(config.RAW_DATA_DIR, derivatives=True)
layout.add_derivatives(config.PREP_DIR)

In [7]:
frmi_subject_run = {}
# simple structure for fmri_subject_run
# fmri_subject_run: {
#     subj: {
#         niis: [run1, run2, run3],
#         events: [run1, run2, run3]
#     }
# }

for subject_id in tqdm(layout.get(target='subject', return_type='id')):
    if subject_id not in frmi_subject_run.keys():
        frmi_subject_run[subject_id] = {
            'niis': layout.derivatives['fMRIPrep'].get(subject=subject_id, return_type='file', suffix='bold', extension='nii.gz'),
            'events': layout.get(subject=subject_id, return_type='file', suffix='events', extension='tsv')
        }
    else: # subjID is already exists..
        print(f'error! {subject_id}')
        break

100%|██████████| 16/16 [00:04<00:00,  3.91it/s]


In [8]:
df_all = []

for subject_id in layout.get(target='subject', return_type='id'):
    for i, df_path in enumerate(frmi_subject_run[subject_id]['events']):
        df = pd.read_table(df_path)
        df['subjID'] = int(subject_id)
        df['run'] = i + 1
        df_all.append(df)
    
df_all = pd.concat(df_all)

In [9]:
columns = df_all.columns.copy()

In [10]:
df_all = df_all[df_all['respcat'] != -1]
df_all['onset'] = df_all['onset'].astype(np.int)
df_all['cert'] = 0 # certain..?
df_all['gamble'] = df_all['respcat']
df_all.head()

INFO:numexpr.utils:Note: detected 88 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 88 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


Unnamed: 0,onset,duration,parametric loss,distance from indifference,parametric gain,gain,loss,PTval,respnum,respcat,response_time,subjID,run,cert,gamble
1,4,3,-0.0227,-0.4147,-0.189,18,12,6.12,2,1,1.793,1,1,0,1
2,8,3,0.1273,0.2519,-0.389,10,15,-4.85,3,0,1.637,1,1,0,0
3,18,3,0.1773,-0.0814,0.211,34,16,18.16,1,1,1.316,1,1,0,1
4,24,3,-0.3727,-0.0814,-0.189,18,5,13.05,1,1,1.67,1,1,0,1
5,28,3,0.0273,-0.4147,0.011,26,13,13.13,2,1,1.232,1,1,0,1


In [None]:
model = ra_prospect(data=df_all, ncore=4)
print(model.all_ind_pars)
params = model.all_ind_pars
params.to_csv('models/ra_prospect_params.tsv', sep='\t', index=False)

In [11]:
params = pd.read_csv('models/ra_prospect_params.tsv', sep='\t')
params['subjID'] = params.index + 1

In [62]:
task_name = config.TASK
if not config.BEHAV_RESULT_DIR.exists():
    config.BEHAV_RESULT_DIR.mkdir()

for subject_id in tqdm(layout.get(target='subject', return_type='id')):
    iparams = params[params['subjID']==int(subject_id)]
    irho = float(iparams['rho'])
    ilambda = float(iparams['lambda'])
    result_subj_dir = config.BEHAV_RESULT_DIR / f'sub-{subject_id}' 
    if not result_subj_dir.exists():
        result_subj_dir.mkdir()
    result_subj_dir = result_subj_dir / 'func'
    if not result_subj_dir.exists():
        result_subj_dir.mkdir()
    for run_id in layout.get(target='run', return_type='id'):
        events = layout.get(subject=subject_id,run=run_id, return_type='file', suffix='events', extension='tsv')
        assert(len(events) == 1)
        event_df = pd.read_csv(events[0],sep='\t')
        event_df['utility'] = (event_df['gain'] ** irho) - (ilambda * (event_df['loss'] ** irho))
        result_path = result_subj_dir / '_'.join([f'sub-{subject_id}',
                                           task_name,
                                           f'run-{run_id}',
                                           'events.tsv'])
        event_df.to_csv(result_path,sep='\t',index=False)

100%|██████████| 16/16 [00:22<00:00,  1.41s/it]
