In [1]:
import numpy as np
import matplotlib.pyplot as plt
import json
import pickle

from modules import *

In [2]:
num_trial = 100

# directory
dir_load = 'data/data_raw'
dir_save = 'data/data_processed'

# subj id list
id_list = [
    'w965a4a3',
    'w2cf6dda',
    'wf010b3e',
    'wd4f37ac',
    'w6b9d6f5',
    'w13af226',
    'wa193ebb',
    'w41fd54e',
    'wf8ea6d2',
    'w932e053',
    'w94abfc5',
    'wb66bb8e',
    'w2745be0',
    'we284b29',
    'wfca9e9d',
    'w65882b1',
    'wd7a2628',
    'w9c2d004',
    'w5718ddc',
    'w9564981',
    'wf4e5037',
    'w8114143',
    'w366d84d',
    'wb680bbe',
    'wb666b62',
    'w372a061',
    'wec7fe05',
    'we6fc6a7',
    'wc6007d5',
    'w66a0c83',
    'w3e17086',
    'w9f25b4d',
    'w96d6a12',
    'w57507c5',
    'w423e302',
    'w3dea2c5',
    'wc35aa68',
    'wc807245',
    'w21ecefc',
    'wb94879d',
    'w84845cc',
    'wfd9b752',
    'wd361f5e',
    'w034037a',
    'w6155833',
    'w802176b',
    'w581d495',
    'w0cafb94',
    'w9349321',
    'w943b3ab',
]

# loop through subjects
for id in id_list:
    # load data
    with open(f'{dir_load}/{id}.json', 'r', encoding = 'utf-8') as file:
        raw_subj = json.load(file)

    # parse trials
    trials_subj = parse_raw_to_trials(raw_subj, num_trial)

    # process data
    data_subj = parse_trials_to_data(trials_subj)

    # further processing
    data_subj['max_depths'] = []
    data_subj['child_dicts'] = []
    data_subj['hover_counts'] = []
    data_subj['rollout_counts'] = []
    data_subj['cum_rewards'] = []
    data_subj['action_values'] = []
    data_subj['hover_times'] = []
    data_subj['visit_times'] = []

    for i in range(num_trial):

        # add max depth and child dict
        data_subj['max_depths'].append(max_depth(data_subj['graphs'][i], data_subj['starts'][i]))
        data_subj['child_dicts'].append(list_to_dict(data_subj['graphs'][i]))

        # remove repeating nodes
        data_subj['hover_seqs'][i] = merge_adjacent(data_subj['hover_seqs'][i])
        data_subj['visit_seqs'][i] = merge_adjacent(data_subj['visit_seqs'][i])

        # remove nodes not in the tree
        data_subj['hover_seqs'][i] = [node for node in data_subj['hover_seqs'][i] if is_node_in_tree(node, data_subj['child_dicts'][i])]

        # # insert root node
        # if data_subj['imagine_seqs'][i][0] != data_subj['starts'][i]:
        #     data_subj['imagine_seqs'][i].insert(0, data_subj['starts'][i])

        # remove last node
        if len(data_subj['hover_seqs'][i]) > 1 and data_subj['hover_seqs'][i][-1] == data_subj['starts'][i]:
            data_subj['hover_seqs'][i].pop()
        
        # add depths
        # data_subj['depths'].append([get_depth(node, data_subj['child_dicts'][i], data_subj['starts'][i]) for node in range(len(data_subj['graphs'][i]))])
        # data_subj['depth_seqs'].append([get_depth(node, data_subj['child_dicts'][i], data_subj['starts'][i]) for node in data_subj['imagine_seqs'][i]])

        # add hover numbers
        data_subj['hover_counts'].append(len(data_subj['hover_seqs'][i]))
        data_subj['rollout_counts'].append(data_subj['hover_seqs'][i].count(data_subj['starts'][i]))

        # add cumulative rewards
        data_subj['cum_rewards'].append(sum([data_subj['rewards'][i][node] for node in data_subj['visit_seqs'][i]]))

        # add action values
        depth_1_nodes = data_subj['child_dicts'][i][data_subj['starts'][i]]
        data_subj['action_values'].append([get_action_value(node, data_subj['child_dicts'][i], data_subj['rewards'][i]) for node in depth_1_nodes])

        # add hover times
        if len(data_subj['hover_time_seqs'][i]) <= 2:
            data_subj['hover_times'].append(None)
        else:
            data_subj['hover_times'].append(data_subj['hover_time_seqs'][i][-1] - data_subj['hover_time_seqs'][i][0])

        # add visit times
        if len(data_subj['visit_time_seqs'][i]) <= 2:
            data_subj['visit_times'].append(None)
        else:
            data_subj['visit_times'].append(data_subj['visit_time_seqs'][i][-1] - data_subj['visit_time_seqs'][i][0])



    # save data
    pickle.dump(data_subj, open(f'{dir_save}/{id}.p', 'wb'))

In [3]:
for i in range(5):
    print(data_subj['child_dicts'][i])
    print(data_subj['starts'][i])
    print(data_subj['hover_seqs'][i])
    print()

{10: [0, 6], 22: [29, 10], 29: [5, 14]}
22
[22, 29, 5, 22, 29, 22, 10]

{1: [8, 26], 2: [9, 17], 3: [19, 12], 4: [15, 23], 8: [30, 13], 9: [22, 21], 14: [27, 2], 16: [25, 4], 17: [3, 24], 21: [28, 5], 22: [10, 11], 24: [7, 18], 25: [29, 6], 26: [20, 0], 27: [16, 1]}
14
[14, 2, 17, 24, 7, 14, 2]

{9: [17, 13], 14: [7, 27], 21: [14, 9]}
21
[21]

{10: [21, 27], 17: [22, 10], 22: [30, 15]}
17
[17]

{12: [28, 17], 17: [10, 4], 28: [8, 15]}
12
[12, 28]

