In [34]:
import numpy as np
import pandas as pd
import scipy.io as io
import math
import os
import scipy

In [87]:
def get_files(task:str, subdir = '\\results_zuco\\'):
    """
        Args: Task number ("task1", "task2", "task3") plus subdirectory
        Return: 12 matlab files (one per subject) for given task.
    """
    path = os.getcwd() + subdir + task
    files = [os.path.join(path,file) for file in os.listdir(path)[1:]]
    assert len(files) == 12, 'each task must contain 12 .mat files'
    return files

In [228]:
def mk_dataframe(task:str, subject:int):
    """
        Args: Task number ("task1", "task2", "task3" , test subject(0-11).
        Return: DataFrame on word level.
    """
    files = get_files(task)
    data = io.loadmat(files[subject], squeeze_me=True, struct_as_record=False)['sentenceData']
    
    n_words = sum([len(sent.word) for sent in data])    
    fields = list(set(field for sent in data for word in sent.word for field in word._fieldnames\
                 if not field.startswith('raw')))
    fields = sorted(fields, reverse=True)
    
    fields.insert(0, 'word_id')
    fields.insert(0, 'sent_id')
    
    df = pd.DataFrame(index=range(n_words), columns=[fields])
    k = 0
    for i, sent in enumerate(data):
        for j, word in enumerate(sent.word):
            df.iloc[k, 0] = str(i) + '_NR' if task=='task1' or task=='task2' else str(i) + '_TSR'
            df.iloc[k, 1] = j
            df.iloc[k, 2:] = [getattr(word, field) if hasattr(word, field) else np.nan\
                              for field in fields[2:]]
            k += 1
            
    return df

In [76]:
files_task1 = get_files('task1')
files_task2 = get_files('task2')
files_task3 = get_files('task3')

In [229]:
mk_dataframe('task1', 0)

Unnamed: 0,sent_id,word_id,nFixations,meanPupilSize,fixPositions,content,TRT_t2_diff,TRT_t2,TRT_t1_diff,TRT_t1,...,FFD_g1,FFD_b2_diff,FFD_b2,FFD_b1_diff,FFD_b1,FFD_a2_diff,FFD_a2,FFD_a1_diff,FFD_a1,FFD
0,0_NR,0,4,934.25,"[1, 3, 5, 6]",Presents,"[0.17093132436275482, 0.23016424290835857, 0.8...","[0.12807040102779865, 0.15583992563188076, 0.3...","[-0.06478019058704376, 0.34110720455646515, 0....","[0.18953146785497665, 0.16591010801494122, 0.4...",...,"[0.13122937083244324, 0.15754754841327667, 0.3...","[0.31596484780311584, 0.7004818767309189, 0.94...","[0.16326536238193512, 0.2756389379501343, 0.61...","[0.12626001238822937, 0.7299983352422714, 0.25...","[0.07769270241260529, 0.076673723757267, 0.235...","[0.15999606251716614, 0.4642724245786667, 0.54...","[0.0893532782793045, 0.13098077476024628, 0.15...","[0.0516873300075531, 0.5803970471024513, 0.284...","[0.06277229636907578, 0.05697965249419212, 0.0...",119
1,0_NR,1,3,945.667,"[2, 4, 7]",a,"[0.015314777692159018, 0.6190686710178852, 0.5...","[0.0876611980299155, 0.20953752845525742, 0.47...","[-0.1142110824584961, 0.5119785244266192, 0.39...","[0.11784198135137558, 0.1543090914686521, 0.37...",...,"[0.18791787326335907, 0.24760755896568298, 0.5...","[-0.09268638491630554, 0.41456082463264465, 0....","[0.18840184807777405, 0.25742051005363464, 0.4...","[-0.053418368101119995, 0.4655509740114212, 0....","[0.1406080573797226, 0.10803885757923126, 0.18...","[0.1806027889251709, 0.3100277744233608, 0.154...","[0.056175682693719864, 0.15377525985240936, 0....","[0.23943424224853516, 0.5354129057377577, 0.65...","[0.02239406295120716, 0.1753470003604889, 0.37...",73
2,0_NR,2,[],[],[],good,[],[],[],[],...,[],[],[],[],[],[],[],[],[],[]
3,0_NR,3,1,1013,8,case,"[0.07515546679496765, 0.24319761991500854, 0.7...","[0.12766218185424805, 0.11027426272630692, 0.3...","[0.3435903787612915, 0.8332478851079941, 0.801...","[0.1525195986032486, 0.19608567655086517, 0.27...",...,"[0.10230665653944016, 0.1632472574710846, 0.21...","[0.5146050155162811, 0.6119954288005829, 1.463...","[0.180507630109787, 0.21670447289943695, 0.526...","[0.22460457682609558, 0.4381115064024925, 1.17...","[0.049681223928928375, 0.04665345698595047, 0....","[0.06357455253601074, 0.3188253715634346, 0.21...","[0.09262838214635849, 0.0725376158952713, 0.18...","[-0.3331557959318161, 0.4462147355079651, 0.38...","[0.0836104154586792, 0.16386760771274567, 0.29...",106
4,0_NR,4,3,961,"[9, 10, 11]",while,"[0.11732248465220134, 0.6342294986049334, 1.15...","[0.10918454577525456, 0.20640802880128226, 0.5...","[0.20968930423259735, 0.6499944627285004, 0.89...","[0.12572863698005676, 0.26124537984530133, 0.5...",...,"[0.11853289604187012, 0.07562389224767685, 0.3...","[0.2706948518753052, 0.6316809356212616, 0.643...","[0.10269293189048767, 0.18218013644218445, 0.3...","[0.2617609202861786, 0.3947877660393715, 0.720...","[0.11272498220205307, 0.07714755088090897, 0.2...","[0.1020689457654953, 0.26623471081256866, 0.69...","[0.1636156290769577, 0.05232367664575577, 0.18...","[0.13684087991714478, 0.3667342811822891, 0.60...","[0.1464187353849411, 0.05142131447792053, 0.18...",145
5,0_NR,5,[],[],[],failing,[],[],[],[],...,[],[],[],[],[],[],[],[],[],[]
6,0_NR,6,1,921,12,to,"[0.2643273174762726, 0.5289530642330647, 0.649...","[0.03452111408114433, 0.03291948139667511, 0.1...","[-0.02882358431816101, 0.14098352193832397, 1....","[0.14814212918281555, 0.08748437464237213, 0.0...",...,"[0.22659769654273987, 0.167430579662323, 0.440...","[0.4716464579105377, 0.31433725357055664, 0.63...","[0.19104188680648804, 0.25872570276260376, 0.4...","[0.26010188460350037, 0.26982705295085907, 0.5...","[0.11639441549777985, 0.18189920485019684, 0.3...","[0.13140583038330078, 0.32175467908382416, 0.7...","[0.07934348285198212, 0.11735381931066513, 0.3...","[-0.29032693803310394, 0.29215262830257416, 1....","[0.0673527866601944, 0.0704529881477356, 0.230...",79
7,0_NR,7,1,906,13,provide,"[0.3768288493156433, 0.35734689608216286, 0.48...","[0.06083488091826439, 0.08733436465263367, 0.1...","[0.5352238118648529, 0.48722460120916367, 0.68...","[0.08540228754281998, 0.13832443952560425, 0.2...",...,"[0.17450399696826935, 0.17115753889083862, 0.4...","[0.3272908627986908, 0.38392966985702515, 1.35...","[0.22429275512695312, 0.25241222977638245, 0.3...","[0.24216577410697937, 0.23270849883556366, 0.7...","[0.15724726021289825, 0.12310400605201721, 0.2...","[-0.2247772514820099, 0.22716928273439407, 0.5...","[0.07421215623617172, 0.12266665697097778, 0.4...","[0.14065027236938477, 0.26751745492219925, 0.1...","[0.03635331243276596, 0.12545287609100342, 0.2...",95
8,0_NR,8,1,885,14,a,"[-0.09411075711250305, 0.07324101030826569, 0....","[0.1328655481338501, 0.21843639016151428, 0.41...","[-0.48542138934135437, 0.49903615564107895, 0....","[0.10085997730493546, 0.17488892376422882, 0.4...",...,"[0.17040430009365082, 0.11887594312429428, 0.4...","[-0.07762640714645386, 0.4102623760700226, 1.0...","[0.12462028861045837, 0.17401264607906342, 0.3...","[-0.22616493701934814, 0.3008614331483841, 0.6...","[0.2286626547574997, 0.3025731146335602, 0.767...","[0.10765665769577026, 0.6926355510950089, 0.68...","[0.16740505397319794, 0.16033174097537994, 0.3...","[0.04434528946876526, 0.29404187947511673, 0.5...","[0.04694598168134689, 0.060286615043878555, 0....",82
9,0_NR,9,2,839.5,"[15, 16]",reason,"[0.09132316708564758, 0.3811422437429428, 0.36...","[0.1498737782239914, 0.18488851189613342, 0.37...","[-0.04135511815547943, 0.8654781132936478, 0.7...","[0.13018794357776642, 0.12814026325941086, 0.2...",...,"[0.21344870328903198, 0.2817718982696533, 0.65...","[0.3949022889137268, 0.7368169501423836, 0.667...","[0.09777360409498215, 0.14585213363170624, 0.4...","[-0.1948184370994568, 0.6234569251537323, 0.77...","[0.21934983134269714, 0.24145059287548065, 0.4...","[-0.014522910118103027, 0.7211864590644836, 0....","[0.15009433031082153, 0.18924950063228607, 0.3...","[-0.08891907334327698, 0.1918889656662941, 0.6...","[0.06606128066778183, 0.14493362605571747, 0.2...",76


In [86]:
# index of the array `data` is the number of sentence
data = io.loadmat(files_task1[0], squeeze_me=True, struct_as_record=False)['sentenceData']

In [133]:
data = io.loadmat(files_task1[0], squeeze_me=True, struct_as_record=False)['sentenceData']

In [126]:
# example: print sentence
getattr(data[0],data[0]._fieldnames[0])

'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.'

In [172]:
fixations = data[0].allFixations

In [234]:
data[0].word[2].nFixations

array([], dtype=float64)

In [205]:
# get word level data
word_data = data[0].word

In [222]:
n_fieldnames = max(set(len(word._fieldnames) for sent in data for word in sent.word))
fieldnames = set(field for sent in data for word in sent.word for field in word._fieldnames if not field.startswith('raw'))
fieldnames = sorted(list(fieldnames), reverse=True)
fieldnames.insert(0, 'word_id')
fieldnames.insert(0, 'sent_id')


for i, sent in enumerate(data):
    for j, word in enumerate(sent.word):
        print("Index:", j)
        print()
        print("Word:", word.content)
        print()
        print("Number of attributes: ", len([getattr(word, field) if hasattr(word, field) else np.nan\
                              for field in fieldnames]))
        print()
        print([getattr(word, field) if hasattr(word, field) else np.nan\
                              for field in fieldnames[2:]])
        break
    break

Index: 0

Word: Presents

Number of attributes:  96

[4, 934.25, array([1, 3, 5, 6], dtype=uint8), 'Presents', array([ 0.17093132,  0.23016424,  0.81667487,  0.13524071,  0.10467273,
        0.83457349,  0.14093339,  0.27756206,  0.15902166,  0.03051959,
        0.27036073,  0.15323352,  0.52587967,  0.35987031,  0.34231071,
        0.14545473,  0.11751574,  0.0904882 ,  0.05618303,  0.14583212,
       -0.07369199,  0.22227728, -0.1217005 , -0.04732322,  0.10044497,
        0.08773544, -0.17775445,  0.04341073, -0.05242344, -0.12650121,
       -0.00308708,  0.43181629,  0.32022797,  1.15886215,  0.26400988,
        0.15506082,  0.19397768,  0.76751815, -0.69729031,  0.27507687,
        0.09322431, -0.19464346,  0.13874976,  0.62543051,  0.34844649,
        0.24407224,  0.05701613, -0.0246141 ]), array([0.1280704 , 0.15583993, 0.37515532, 0.4304162 , 0.43466741,
       0.4133538 , 0.32536288, 0.3613942 , 0.54403603, 0.48743234,
       0.85197657, 0.32456754, 0.44474829, 0.3367801 , 1.20

In [207]:
word_data[0].content

'Presents'

In [189]:
len(word_data[0]._fieldnames)

96

In [176]:
data[0].content.split()

22

In [181]:
# index of the array `data` is the number of sentence
data = io.loadmat(files_task1[0], squeeze_me=True, struct_as_record=False)['sentenceData']

In [95]:
# example: get omission rate of first sentence
omission_rate = data[0].omissionRate
print(omission_rate)

0.22727272727272727


In [183]:
# get word level data
word_data = data[0].word

In [166]:
# get names of all word features
# index of the array `word_data` is the number of the word
print(word_data[0]._fieldnames)

['content', 'fixPositions', 'nFixations', 'meanPupilSize', 'rawEEG', 'rawET', 'FFD', 'FFD_pupilsize', 'FFD_t1', 'FFD_t2', 'FFD_a1', 'FFD_a2', 'FFD_b1', 'FFD_b2', 'FFD_g1', 'FFD_g2', 'FFD_t1_diff', 'FFD_t2_diff', 'FFD_a1_diff', 'FFD_a2_diff', 'FFD_b1_diff', 'FFD_b2_diff', 'FFD_g1_diff', 'FFD_g2_diff', 'TRT', 'TRT_pupilsize', 'TRT_t1', 'TRT_t2', 'TRT_a1', 'TRT_a2', 'TRT_b1', 'TRT_b2', 'TRT_g1', 'TRT_g2', 'TRT_t1_diff', 'TRT_t2_diff', 'TRT_a1_diff', 'TRT_a2_diff', 'TRT_b1_diff', 'TRT_b2_diff', 'TRT_g1_diff', 'TRT_g2_diff', 'GD', 'GD_pupilsize', 'GD_t1', 'GD_t2', 'GD_a1', 'GD_a2', 'GD_b1', 'GD_b2', 'GD_g1', 'GD_g2', 'GD_t1_diff', 'GD_t2_diff', 'GD_a1_diff', 'GD_a2_diff', 'GD_b1_diff', 'GD_b2_diff', 'GD_g1_diff', 'GD_g2_diff', 'GPT', 'GPT_pupilsize', 'GPT_t1', 'GPT_t2', 'GPT_a1', 'GPT_a2', 'GPT_b1', 'GPT_b2', 'GPT_g1', 'GPT_g2', 'GPT_t1_diff', 'GPT_t2_diff', 'GPT_a1_diff', 'GPT_a2_diff', 'GPT_b1_diff', 'GPT_b2_diff', 'GPT_g1_diff', 'GPT_g2_diff', 'SFD', 'SFD_pupilsize', 'SFD_t1', 'SFD_t2', 

In [170]:
word_fieldnames = word_data[0]._fieldnames #sorted(list(set([field for word in word_data for field in word._fieldnames])))
for word in word_data:
    for field in word_fieldnames:
        print('Attribute:', field)
        print()
        print('Values:', getattr(word, field))
        print()
    break

Attribute: content

Values: Presents

Attribute: fixPositions

Values: [1 3 5 6]

Attribute: nFixations

Values: 4

Attribute: meanPupilSize

Values: 934.25

Attribute: rawEEG

Values: [array([[ 0.24008693,  0.21717785, -0.40477642, ..., -0.1208232 ,
         0.6734169 , -0.044492  ],
       [-0.44889867, -0.32277313, -0.43373403, ..., -0.42020953,
         0.2686897 , -0.36028022],
       [-0.61362416, -0.63692635, -1.191164  , ...,  0.18665355,
         1.3690135 ,  0.10849172],
       ...,
       [ 0.41300774,  0.15055656, -0.33241248, ...,  0.48656818,
         1.6300433 ,  0.7753639 ],
       [-0.22323966, -0.27009162, -0.7895932 , ...,  0.07122315,
         1.5149357 ,  0.87533915],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)
 array([[-0.05841855, -0.03879149, -0.01537458, ...,  1.1995841 ,
         1.0480303 , -0.6473091 ],
       [-0.29775155, -0.57203853, -0.3339718 , ...,  1.775105  ,
         1.6121757 ,

In [156]:
# example: get first word
print(word_data[0].nFixations)

4


In [154]:
# example: get number of fixations of first word
print(word_data[0].nFixations)

4


In [226]:
class DataFrameLoader():
    """
        DataFrame loader object for ZuCo 1.0.
    """
    
    def __init__(self, task:str, subject:int):
        tasks = ['task1', 'task2', 'task3']
        if task not in tasks:
            raise Exception('task can only be one of "task1", "task2", or "task3"')
        else:
            self.task = tasks
        subjects = list(range(12))
        if subject not in subjects:
            raise Exception('subject must be an integer value between 0 - 11')
        else:
            self.subject = subject
        
    def __call__(self):
        return mk_dataframe(self)
    
    def mk_dataframe(self):
        """
            Args: Task number ("task1", "task2", "task3" , test subject(0-11).
            Return: DataFrame with features (i.e., attributes) on word level.
        """
        files = get_files(self.task)
        data = io.loadmat(files[self.subject], squeeze_me=True, struct_as_record=False)['sentenceData']

        n_words = sum([len(sent.word) for sent in data])    
        fields = list(set(field for sent in data for word in sent.word for field in word._fieldnames\
                     if not field.startswith('raw')))
        fields = sorted(fields, reverse=True)

        fields.insert(0, 'word_id')
        fields.insert(0, 'sent_id')

        df = pd.DataFrame(index=range(n_words), columns=[fields])
        k = 0
        for i, sent in enumerate(data):
            for j, word in enumerate(sent.word):
                df.iloc[k, 0] = str(i) + '_NR' if task=='task1' or task=='task2' else str(i) + '_TSR'
                df.iloc[k, 1] = j
                df.iloc[k, 2:] = [getattr(word, field) if hasattr(word, field) else np.nan\
                                  for field in fields[2:]]
                k += 1

        return df