In [1]:
import numpy as np
import matplotlib.pyplot as plt
import json
import pickle

from modules import *

In [2]:
def parse_raw_to_trials(raw, num_trial):
    """
    Parse raw data to trials.
    """

    trials = [[] for _ in range(num_trial)]
    main_flag = False
    trial_index = 0

    # loop through events
    for event in raw:

        # if in the main exp
        if main_flag:

            # update trial index
            if 'trial_index' in event.keys():
                trial_index = event['trial_index']

            # log event to trial
            trials[trial_index].append(event)
        
        # enter the main exp
        if event['event'] == 'timeline.start.main':
            main_flag = True

    return trials

In [3]:
def parse_trials_to_data(trials):
    """
    Parse trials to data.
    """

    data = {
        'graphs': [],
        'rewards': [],
        'starts': [],
        'hover_limits': [],
        'hover_seqs': [],
        'visit_seqs': [],
        'rollout_lengths': [],
    }

    # loop through trials
    for trial in trials:
        # loop through events in the trial
        for event in trial:

            # new trial
            if 'graph' in event.keys():
                # initialize trial recording
                hover_seq = []
                visit_seq = []

                # log trial info
                data['graphs'].append(event['graph'])
                data['rewards'].append(event['rewards'])
                data['starts'].append(event['start'])
                data['hover_limits'].append(event['hover_limit'])
            
            # imagination
            if event['event'] == 'graph.imagine':
                hover_seq.append(event['state'])
            
            # navigation
            if event['event'] == 'graph.visit':
                visit_seq.append(event['state'])
        
        # log trial info
        data['hover_seqs'].append(hover_seq)
        data['visit_seqs'].append(visit_seq)
    
    return data

In [4]:
num_trial = 100

# directory
dir_load = 'data/data_raw'
dir_save = 'data/data_processed'

# subj id list
# id_list = ['wa0b5c62', 'w0b569b1', 'w3efaab9', 'w5a2c6c0', 'wb67e511', 'w83a6669']#, 'wd72ec4f']
id_list = [
    'wff0e37f',
    'w64a20fb',
    'w930df1c',
    'w1136569',
    'w5f0a4e2',
    'we52e417',
    'w4ee306c',
    'waa3a4a4',
    'w3ba63d6',
    'w1206510',
    'w2ce5641',
    'wfc8657f',
    'w6cba169',
    'w3c4d8b1',
    'waec10f6',
    'wf3bd4f2',
    'w4431e3b',
    'w4fd295b',
    'w651bd37',
    'w7425c1c'
]

# loop through subjects
for id in id_list:
    print(id)
    # load data
    with open(f'{dir_load}/{id}.json', 'r', encoding = 'utf-8') as file:
        raw_subj = json.load(file)

    # parse trials
    trials_subj = parse_raw_to_trials(raw_subj, num_trial)

    # process data
    data_subj = parse_trials_to_data(trials_subj)

    # further processing
    data_subj['max_depths'] = []
    data_subj['child_dicts'] = []
    data_subj['hover_counts'] = []
    data_subj['rollout_counts'] = []
    data_subj['cum_rewards'] = []
    data_subj['action_values'] = []

    for i in range(num_trial):

        # add max depth and child dict
        data_subj['max_depths'].append(max_depth(data_subj['graphs'][i], data_subj['starts'][i]))
        data_subj['child_dicts'].append(list_to_dict(data_subj['graphs'][i]))

        # remove repeating nodes
        data_subj['hover_seqs'][i] = merge_adjacent(data_subj['hover_seqs'][i])
        data_subj['visit_seqs'][i] = merge_adjacent(data_subj['visit_seqs'][i])

        # remove nodes not in the tree
        data_subj['hover_seqs'][i] = [node for node in data_subj['hover_seqs'][i] if is_node_in_tree(node, data_subj['child_dicts'][i])]

        # # insert root node
        # if data_subj['imagine_seqs'][i][0] != data_subj['starts'][i]:
        #     data_subj['imagine_seqs'][i].insert(0, data_subj['starts'][i])

        # remove last node
        if len(data_subj['hover_seqs'][i]) > 1 and data_subj['hover_seqs'][i][-1] == data_subj['starts'][i]:
            data_subj['hover_seqs'][i].pop()
        
        # add depths
        # data_subj['depths'].append([get_depth(node, data_subj['child_dicts'][i], data_subj['starts'][i]) for node in range(len(data_subj['graphs'][i]))])
        # data_subj['depth_seqs'].append([get_depth(node, data_subj['child_dicts'][i], data_subj['starts'][i]) for node in data_subj['imagine_seqs'][i]])

        # add hover numbers
        data_subj['hover_counts'].append(len(data_subj['hover_seqs'][i]))
        data_subj['rollout_counts'].append(data_subj['hover_seqs'][i].count(data_subj['starts'][i]))
        data_subj['rollout_lengths'].append(segment_lengths(data_subj['hover_seqs'][i], data_subj['starts'][i]))

        # add cumulative rewards
        data_subj['cum_rewards'].append(sum([data_subj['rewards'][i][node] for node in data_subj['visit_seqs'][i]]))

        # add action values
        depth_1_nodes = data_subj['child_dicts'][i][data_subj['starts'][i]]
        data_subj['action_values'].append([get_action_value(node, data_subj['child_dicts'][i], data_subj['rewards'][i]) for node in depth_1_nodes])

    # save data
    pickle.dump(data_subj, open(f'{dir_save}/{id}.p', 'wb'))

wff0e37f
w64a20fb
w930df1c
w1136569
w5f0a4e2
we52e417
w4ee306c
waa3a4a4
w3ba63d6
w1206510
w2ce5641
wfc8657f
w6cba169
w3c4d8b1
waec10f6
wf3bd4f2
w4431e3b
w4fd295b
w651bd37
w7425c1c
